Make it async. I'm not sure it's a good idea

main
Florent Daigniere 2 years ago
parent cf34be967c
commit 5ec4277e1e

@ -40,6 +40,7 @@ class DictProtocol(asyncio.Protocol):
def connection_made(self, transport):
logging.info('Connect {}'.format(transport.get_extra_info('peername')))
self.transport = transport
self.transport_lock = asyncio.Lock()
def data_received(self, data):
logging.debug("Received {}".format(data))
@ -94,9 +95,9 @@ class DictProtocol(asyncio.Protocol):
else:
response = json.dumps(result).encode("ascii")
logging.debug("Replying {}".format(key))
return self.reply(b"O", (key_type+'/'+key).encode("utf8"), response, end=True) if is_iter else self.reply(b"O", response)
return await (self.reply(b"O", (key_type+'/'+key).encode("utf8"), response, end=True) if is_iter else self.reply(b"O", response))
except KeyError:
return self.reply(b"N")
return await self.reply(b"N")
async def process_iterate(self, flags, max_rows, path, user=None):
""" Process an iterate command
@ -107,18 +108,24 @@ class DictProtocol(asyncio.Protocol):
max_rows = int(max_rows.decode("utf-8"))
flags = int(flags.decode("utf-8"))
if flags != 0: # not implemented
return self.reply(b"F")
return await self.reply(b"F")
rows = []
try:
result = await self.dict.iter(key)
logging.debug("Found {} entries: {}".format(len(result), result))
returned_results = 0
for k in result:
if max_rows == 0 or returned_results < max_rows:
await self.process_lookup((path.decode("utf8")+k).encode("utf8"), user, is_iter=True)
rows.append(self.process_lookup((path.decode("utf8")+k).encode("utf8"), user, is_iter=True))
returned_results += 1
return self.reply(b"\n") # ITER_FINISHED
await asyncio.gather(*rows)
return await self.reply(b"\n") # ITER_FINISHED
except KeyError:
return self.reply(b"F")
return await self.reply(b"F")
except Exception as e:
logging.error(f"Got {e}, cancelling remaining tasks")
for task in rows:
task.cancel()
def process_begin(self, transaction_id, user=None):
""" Process a dict begin message
@ -147,10 +154,11 @@ class DictProtocol(asyncio.Protocol):
# Remove stored transaction
del self.transactions[transaction_id]
del self.transactions_user[transaction_id]
return self.reply(b"O", transaction_id)
return await self.reply(b"O", transaction_id)
def reply(self, command, *args, end=True):
async def reply(self, command, *args, end=True):
logging.debug("Replying {} with {}".format(command, args))
async with self.transport_lock:
self.transport.write(command)
self.transport.write(b"\t".join(map(tabescape, args)))
if end:

Loading…
Cancel
Save