Make it async. I'm not sure it's a good idea

main
Florent Daigniere 2 years ago
parent cf34be967c
commit 5ec4277e1e

@ -40,6 +40,7 @@ class DictProtocol(asyncio.Protocol):
def connection_made(self, transport): def connection_made(self, transport):
logging.info('Connect {}'.format(transport.get_extra_info('peername'))) logging.info('Connect {}'.format(transport.get_extra_info('peername')))
self.transport = transport self.transport = transport
self.transport_lock = asyncio.Lock()
def data_received(self, data): def data_received(self, data):
logging.debug("Received {}".format(data)) logging.debug("Received {}".format(data))
@ -94,9 +95,9 @@ class DictProtocol(asyncio.Protocol):
else: else:
response = json.dumps(result).encode("ascii") response = json.dumps(result).encode("ascii")
logging.debug("Replying {}".format(key)) logging.debug("Replying {}".format(key))
return self.reply(b"O", (key_type+'/'+key).encode("utf8"), response, end=True) if is_iter else self.reply(b"O", response) return await (self.reply(b"O", (key_type+'/'+key).encode("utf8"), response, end=True) if is_iter else self.reply(b"O", response))
except KeyError: except KeyError:
return self.reply(b"N") return await self.reply(b"N")
async def process_iterate(self, flags, max_rows, path, user=None): async def process_iterate(self, flags, max_rows, path, user=None):
""" Process an iterate command """ Process an iterate command
@ -107,18 +108,24 @@ class DictProtocol(asyncio.Protocol):
max_rows = int(max_rows.decode("utf-8")) max_rows = int(max_rows.decode("utf-8"))
flags = int(flags.decode("utf-8")) flags = int(flags.decode("utf-8"))
if flags != 0: # not implemented if flags != 0: # not implemented
return self.reply(b"F") return await self.reply(b"F")
rows = []
try: try:
result = await self.dict.iter(key) result = await self.dict.iter(key)
logging.debug("Found {} entries: {}".format(len(result), result)) logging.debug("Found {} entries: {}".format(len(result), result))
returned_results = 0 returned_results = 0
for k in result: for k in result:
if max_rows == 0 or returned_results < max_rows: if max_rows == 0 or returned_results < max_rows:
await self.process_lookup((path.decode("utf8")+k).encode("utf8"), user, is_iter=True) rows.append(self.process_lookup((path.decode("utf8")+k).encode("utf8"), user, is_iter=True))
returned_results += 1 returned_results += 1
return self.reply(b"\n") # ITER_FINISHED await asyncio.gather(*rows)
return await self.reply(b"\n") # ITER_FINISHED
except KeyError: except KeyError:
return self.reply(b"F") return await self.reply(b"F")
except Exception as e:
logging.error(f"Got {e}, cancelling remaining tasks")
for task in rows:
task.cancel()
def process_begin(self, transaction_id, user=None): def process_begin(self, transaction_id, user=None):
""" Process a dict begin message """ Process a dict begin message
@ -147,14 +154,15 @@ class DictProtocol(asyncio.Protocol):
# Remove stored transaction # Remove stored transaction
del self.transactions[transaction_id] del self.transactions[transaction_id]
del self.transactions_user[transaction_id] del self.transactions_user[transaction_id]
return self.reply(b"O", transaction_id) return await self.reply(b"O", transaction_id)
def reply(self, command, *args, end=True): async def reply(self, command, *args, end=True):
logging.debug("Replying {} with {}".format(command, args)) logging.debug("Replying {} with {}".format(command, args))
self.transport.write(command) async with self.transport_lock:
self.transport.write(b"\t".join(map(tabescape, args))) self.transport.write(command)
if end: self.transport.write(b"\t".join(map(tabescape, args)))
self.transport.write(b"\n") if end:
self.transport.write(b"\n")
@classmethod @classmethod
def factory(cls, table_map): def factory(cls, table_map):

Loading…
Cancel
Save