|
|
|
@ -40,6 +40,7 @@ class DictProtocol(asyncio.Protocol):
|
|
|
|
|
def connection_made(self, transport):
|
|
|
|
|
logging.info('Connect {}'.format(transport.get_extra_info('peername')))
|
|
|
|
|
self.transport = transport
|
|
|
|
|
self.transport_lock = asyncio.Lock()
|
|
|
|
|
|
|
|
|
|
def data_received(self, data):
|
|
|
|
|
logging.debug("Received {}".format(data))
|
|
|
|
@ -94,9 +95,9 @@ class DictProtocol(asyncio.Protocol):
|
|
|
|
|
else:
|
|
|
|
|
response = json.dumps(result).encode("ascii")
|
|
|
|
|
logging.debug("Replying {}".format(key))
|
|
|
|
|
return self.reply(b"O", (key_type+'/'+key).encode("utf8"), response, end=True) if is_iter else self.reply(b"O", response)
|
|
|
|
|
return await (self.reply(b"O", (key_type+'/'+key).encode("utf8"), response, end=True) if is_iter else self.reply(b"O", response))
|
|
|
|
|
except KeyError:
|
|
|
|
|
return self.reply(b"N")
|
|
|
|
|
return await self.reply(b"N")
|
|
|
|
|
|
|
|
|
|
async def process_iterate(self, flags, max_rows, path, user=None):
|
|
|
|
|
""" Process an iterate command
|
|
|
|
@ -107,18 +108,24 @@ class DictProtocol(asyncio.Protocol):
|
|
|
|
|
max_rows = int(max_rows.decode("utf-8"))
|
|
|
|
|
flags = int(flags.decode("utf-8"))
|
|
|
|
|
if flags != 0: # not implemented
|
|
|
|
|
return self.reply(b"F")
|
|
|
|
|
return await self.reply(b"F")
|
|
|
|
|
rows = []
|
|
|
|
|
try:
|
|
|
|
|
result = await self.dict.iter(key)
|
|
|
|
|
logging.debug("Found {} entries: {}".format(len(result), result))
|
|
|
|
|
returned_results = 0
|
|
|
|
|
for k in result:
|
|
|
|
|
if max_rows == 0 or returned_results < max_rows:
|
|
|
|
|
await self.process_lookup((path.decode("utf8")+k).encode("utf8"), user, is_iter=True)
|
|
|
|
|
rows.append(self.process_lookup((path.decode("utf8")+k).encode("utf8"), user, is_iter=True))
|
|
|
|
|
returned_results += 1
|
|
|
|
|
return self.reply(b"\n") # ITER_FINISHED
|
|
|
|
|
await asyncio.gather(*rows)
|
|
|
|
|
return await self.reply(b"\n") # ITER_FINISHED
|
|
|
|
|
except KeyError:
|
|
|
|
|
return self.reply(b"F")
|
|
|
|
|
return await self.reply(b"F")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logging.error(f"Got {e}, cancelling remaining tasks")
|
|
|
|
|
for task in rows:
|
|
|
|
|
task.cancel()
|
|
|
|
|
|
|
|
|
|
def process_begin(self, transaction_id, user=None):
|
|
|
|
|
""" Process a dict begin message
|
|
|
|
@ -147,14 +154,15 @@ class DictProtocol(asyncio.Protocol):
|
|
|
|
|
# Remove stored transaction
|
|
|
|
|
del self.transactions[transaction_id]
|
|
|
|
|
del self.transactions_user[transaction_id]
|
|
|
|
|
return self.reply(b"O", transaction_id)
|
|
|
|
|
return await self.reply(b"O", transaction_id)
|
|
|
|
|
|
|
|
|
|
def reply(self, command, *args, end=True):
|
|
|
|
|
async def reply(self, command, *args, end=True):
|
|
|
|
|
logging.debug("Replying {} with {}".format(command, args))
|
|
|
|
|
self.transport.write(command)
|
|
|
|
|
self.transport.write(b"\t".join(map(tabescape, args)))
|
|
|
|
|
if end:
|
|
|
|
|
self.transport.write(b"\n")
|
|
|
|
|
async with self.transport_lock:
|
|
|
|
|
self.transport.write(command)
|
|
|
|
|
self.transport.write(b"\t".join(map(tabescape, args)))
|
|
|
|
|
if end:
|
|
|
|
|
self.transport.write(b"\n")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def factory(cls, table_map):
|
|
|
|
|