diff --git a/core/base/libs/podop/podop/dovecot.py b/core/base/libs/podop/podop/dovecot.py index 18956dd4..6dc25af5 100644 --- a/core/base/libs/podop/podop/dovecot.py +++ b/core/base/libs/podop/podop/dovecot.py @@ -40,6 +40,7 @@ class DictProtocol(asyncio.Protocol): def connection_made(self, transport): logging.info('Connect {}'.format(transport.get_extra_info('peername'))) self.transport = transport + self.transport_lock = asyncio.Lock() def data_received(self, data): logging.debug("Received {}".format(data)) @@ -94,9 +95,9 @@ class DictProtocol(asyncio.Protocol): else: response = json.dumps(result).encode("ascii") logging.debug("Replying {}".format(key)) - return self.reply(b"O", (key_type+'/'+key).encode("utf8"), response, end=True) if is_iter else self.reply(b"O", response) + return await (self.reply(b"O", (key_type+'/'+key).encode("utf8"), response, end=True) if is_iter else self.reply(b"O", response)) except KeyError: - return self.reply(b"N") + return await self.reply(b"N") async def process_iterate(self, flags, max_rows, path, user=None): """ Process an iterate command @@ -107,18 +108,24 @@ class DictProtocol(asyncio.Protocol): max_rows = int(max_rows.decode("utf-8")) flags = int(flags.decode("utf-8")) if flags != 0: # not implemented - return self.reply(b"F") + return await self.reply(b"F") + rows = [] try: result = await self.dict.iter(key) logging.debug("Found {} entries: {}".format(len(result), result)) returned_results = 0 for k in result: if max_rows == 0 or returned_results < max_rows: - await self.process_lookup((path.decode("utf8")+k).encode("utf8"), user, is_iter=True) + rows.append(self.process_lookup((path.decode("utf8")+k).encode("utf8"), user, is_iter=True)) returned_results += 1 - return self.reply(b"\n") # ITER_FINISHED + await asyncio.gather(*rows) + return await self.reply(b"\n") # ITER_FINISHED except KeyError: - return self.reply(b"F") + return await self.reply(b"F") + except Exception as e: + logging.error(f"Got {e}, cancelling remaining tasks") + for task in rows: + task.cancel() def process_begin(self, transaction_id, user=None): """ Process a dict begin message @@ -147,14 +154,15 @@ class DictProtocol(asyncio.Protocol): # Remove stored transaction del self.transactions[transaction_id] del self.transactions_user[transaction_id] - return self.reply(b"O", transaction_id) + return await self.reply(b"O", transaction_id) - def reply(self, command, *args, end=True): + async def reply(self, command, *args, end=True): logging.debug("Replying {} with {}".format(command, args)) - self.transport.write(command) - self.transport.write(b"\t".join(map(tabescape, args))) - if end: - self.transport.write(b"\n") + async with self.transport_lock: + self.transport.write(command) + self.transport.write(b"\t".join(map(tabescape, args))) + if end: + self.transport.write(b"\n") @classmethod def factory(cls, table_map):