@ -33,6 +33,8 @@ class DictProtocol(asyncio.Protocol):
self . dict = None
self . dict = None
# Dictionary of active transaction lists per transaction id
# Dictionary of active transaction lists per transaction id
self . transactions = { }
self . transactions = { }
# Dictionary of user per transaction id
self . transactions_user = { }
super ( DictProtocol , self ) . __init__ ( )
super ( DictProtocol , self ) . __init__ ( )
def connection_made ( self , transport ) :
def connection_made ( self , transport ) :
@ -75,15 +77,15 @@ class DictProtocol(asyncio.Protocol):
logging . debug ( " Client {} . {} type {} , user {} , dict {} " . format (
logging . debug ( " Client {} . {} type {} , user {} , dict {} " . format (
self . major , self . minor , self . value_type , self . user , dict_name ) )
self . major , self . minor , self . value_type , self . user , dict_name ) )
async def process_lookup ( self , key ):
async def process_lookup ( self , key , user = None ):
""" Process a dict lookup message
""" Process a dict lookup message
"""
"""
logging . debug ( " Looking up {} ". format ( key ) )
logging . debug ( " Looking up {} for {} ". format ( key , user ) )
# Priv and shared keys are handled slighlty differently
# Priv and shared keys are handled slighlty differently
key_type , key = key . decode ( " utf8 " ) . split ( " / " , 1 )
key_type , key = key . decode ( " utf8 " ) . split ( " / " , 1 )
try :
try :
result = await self . dict . get (
result = await self . dict . get (
key , ns = ( self . user if key_type == " priv " else None )
key , ns = ( ( user . decode ( " utf8 " ) if user else self . user ) if key_type == " priv " else None )
)
)
if type ( result ) is str :
if type ( result ) is str :
response = result . encode ( " utf8 " )
response = result . encode ( " utf8 " )
@ -95,10 +97,11 @@ class DictProtocol(asyncio.Protocol):
except KeyError :
except KeyError :
return self . reply ( b " N " )
return self . reply ( b " N " )
def process_begin ( self , transaction_id ):
def process_begin ( self , transaction_id , user = None ):
""" Process a dict begin message
""" Process a dict begin message
"""
"""
self . transactions [ transaction_id ] = { }
self . transactions [ transaction_id ] = { }
self . transactions_user [ transaction_id ] = user . decode ( " utf8 " ) if user else self . user
def process_set ( self , transaction_id , key , value ) :
def process_set ( self , transaction_id , key , value ) :
""" Process a dict set message
""" Process a dict set message
@ -116,10 +119,11 @@ class DictProtocol(asyncio.Protocol):
key_type , key = key . decode ( " utf8 " ) . split ( " / " , 1 )
key_type , key = key . decode ( " utf8 " ) . split ( " / " , 1 )
result = await self . dict . set (
result = await self . dict . set (
key , json . loads ( value ) ,
key , json . loads ( value ) ,
ns = ( self . user if key_type == " priv " else None )
ns = ( self . transactions_ user[ transaction_id ] if key_type == " priv " else None )
)
)
# Remove stored transaction
# Remove stored transaction
del self . transactions [ transaction_id ]
del self . transactions [ transaction_id ]
del self . transactions_user [ transaction_id ]
return self . reply ( b " O " , transaction_id )
return self . reply ( b " O " , transaction_id )
def reply ( self , command , * args ) :
def reply ( self , command , * args ) :