@ -40,6 +40,7 @@ class DictProtocol(asyncio.Protocol):
def connection_made ( self , transport ) :
def connection_made ( self , transport ) :
logging . info ( ' Connect {} ' . format ( transport . get_extra_info ( ' peername ' ) ) )
logging . info ( ' Connect {} ' . format ( transport . get_extra_info ( ' peername ' ) ) )
self . transport = transport
self . transport = transport
self . transport_lock = asyncio . Lock ( )
def data_received ( self , data ) :
def data_received ( self , data ) :
logging . debug ( " Received {} " . format ( data ) )
logging . debug ( " Received {} " . format ( data ) )
@ -77,10 +78,11 @@ class DictProtocol(asyncio.Protocol):
logging . debug ( " Client {} . {} type {} , user {} , dict {} " . format (
logging . debug ( " Client {} . {} type {} , user {} , dict {} " . format (
self . major , self . minor , self . value_type , self . user , dict_name ) )
self . major , self . minor , self . value_type , self . user , dict_name ) )
async def process_lookup ( self , key , user = None ):
async def process_lookup ( self , key , user = None , is_iter = False ):
""" Process a dict lookup message
""" Process a dict lookup message
"""
"""
logging . debug ( " Looking up {} for {} " . format ( key , user ) )
logging . debug ( " Looking up {} for {} " . format ( key , user ) )
orig_key = key
# Priv and shared keys are handled slighlty differently
# Priv and shared keys are handled slighlty differently
key_type , key = key . decode ( " utf8 " ) . split ( " / " , 1 )
key_type , key = key . decode ( " utf8 " ) . split ( " / " , 1 )
try :
try :
@ -93,9 +95,38 @@ class DictProtocol(asyncio.Protocol):
response = result
response = result
else :
else :
response = json . dumps ( result ) . encode ( " ascii " )
response = json . dumps ( result ) . encode ( " ascii " )
return self . reply ( b " O " , response )
return await ( self . reply ( b " O " , orig_key , response ) if is_iter else self . reply ( b " O " , response ) )
except KeyError :
except KeyError :
return self . reply ( b " N " )
return await self . reply ( b " N " )
async def process_iterate ( self , flags , max_rows , path , user = None ) :
""" Process an iterate command
"""
logging . debug ( " Iterate flags {} max_rows {} on {} for {} " . format ( flags , max_rows , path , user ) )
# Priv and shared keys are handled slighlty differently
key_type , key = path . decode ( " utf8 " ) . split ( " / " , 1 )
max_rows = int ( max_rows . decode ( " utf-8 " ) )
flags = int ( flags . decode ( " utf-8 " ) )
if flags != 0 : # not implemented
return await self . reply ( b " F " )
rows = [ ]
try :
result = await self . dict . iter ( key )
logging . debug ( " Found {} entries: {} " . format ( len ( result ) , result ) )
for i , k in enumerate ( result ) :
if max_rows > 0 and i > = max_rows :
break
rows . append ( self . process_lookup ( ( path . decode ( " utf8 " ) + k ) . encode ( " utf8 " ) , user , is_iter = True ) )
await asyncio . gather ( * rows )
async with self . transport_lock :
self . transport . write ( b " \n " ) # ITER_FINISHED
return
except KeyError :
return await self . reply ( b " F " )
except Exception as e :
for task in rows :
task . cancel ( )
raise e
def process_begin ( self , transaction_id , user = None ) :
def process_begin ( self , transaction_id , user = None ) :
""" Process a dict begin message
""" Process a dict begin message
@ -124,13 +155,14 @@ class DictProtocol(asyncio.Protocol):
# Remove stored transaction
# Remove stored transaction
del self . transactions [ transaction_id ]
del self . transactions [ transaction_id ]
del self . transactions_user [ transaction_id ]
del self . transactions_user [ transaction_id ]
return self . reply ( b " O " , transaction_id )
return await self . reply ( b " O " , transaction_id )
def reply ( self , command , * args ) :
async def reply ( self , command , * args ) :
logging . debug ( " Replying {} with {} " . format ( command , args ) )
async with self . transport_lock :
self . transport . write ( command )
logging . debug ( " Replying {} with {} " . format ( command , args ) )
self . transport . write ( b " \t " . join ( map ( tabescape , args ) ) )
self . transport . write ( command )
self . transport . write ( b " \n " )
self . transport . write ( b " \t " . join ( map ( tabescape , args ) ) )
self . transport . write ( b " \n " )
@classmethod
@classmethod
def factory ( cls , table_map ) :
def factory ( cls , table_map ) :
@ -141,6 +173,7 @@ class DictProtocol(asyncio.Protocol):
COMMANDS = {
COMMANDS = {
ord ( " H " ) : process_hello ,
ord ( " H " ) : process_hello ,
ord ( " L " ) : process_lookup ,
ord ( " L " ) : process_lookup ,
ord ( " I " ) : process_iterate ,
ord ( " B " ) : process_begin ,
ord ( " B " ) : process_begin ,
ord ( " C " ) : process_commit ,
ord ( " C " ) : process_commit ,
ord ( " S " ) : process_set
ord ( " S " ) : process_set