diff --git a/core/admin/mailu/manage.py b/core/admin/mailu/manage.py index 80a73230..e02d9ad4 100644 --- a/core/admin/mailu/manage.py +++ b/core/admin/mailu/manage.py @@ -4,9 +4,15 @@ import sys import os import socket +import json +import logging import uuid +from collections import Counter +from itertools import chain + import click +import sqlalchemy import yaml from flask import current_app as app @@ -14,7 +20,7 @@ from flask.cli import FlaskGroup, with_appcontext from marshmallow.exceptions import ValidationError from . import models -from .schemas import MailuSchema +from .schemas import MailuSchema, get_schema db = models.db @@ -322,60 +328,211 @@ SECTIONS = {'domains', 'relays', 'users', 'aliases'} @mailu.command() -@click.option('-v', '--verbose', is_flag=True, help='Increase verbosity') +@click.option('-v', '--verbose', count=True, help='Increase verbosity') +@click.option('-q', '--quiet', is_flag=True, help='Quiet mode - only show errors') +@click.option('-u', '--update', is_flag=True, help='Update mode - merge input with existing config') @click.option('-n', '--dry-run', is_flag=True, help='Perform a trial run with no changes made') @click.argument('source', metavar='[FILENAME|-]', type=click.File(mode='r'), default=sys.stdin) @with_appcontext -def config_import(verbose=False, dry_run=False, source=None): - """ Import configuration from YAML +def config_import(verbose=0, quiet=False, update=False, dry_run=False, source=None): + """ Import configuration as YAML or JSON from stdin or file """ - def log(**data): - caller = sys._getframe(1).f_code.co_name # pylint: disable=protected-access - if caller == '_track_import': - print(f'Handling {data["self"].opts.model.__table__} data: {data["data"]!r}') + # verbose + # 0 : show number of changes + # 1 : also show changes + # 2 : also show secrets + # 3 : also show input data + # 4 : also show sql queries + + if quiet: + verbose = -1 + + counter = Counter() + dumper = {} def format_errors(store, path=None): + + res = [] if path is None: path = [] for key in sorted(store): location = path + [str(key)] value = store[key] if isinstance(value, dict): - format_errors(value, location) + res.extend(format_errors(value, location)) else: for message in value: - print(f'[ERROR] {".".join(location)}: {message}') + res.append((".".join(location), message)) - context = { - 'callback': log if verbose else None, + if path: + return res + + fmt = f' - {{:<{max([len(loc) for loc, msg in res])}}} : {{}}' + res = [fmt.format(loc, msg) for loc, msg in res] + num = f'error{["s",""][len(res)==1]}' + res.insert(0, f'[ValidationError] {len(res)} {num} occured during input validation') + + return '\n'.join(res) + + def format_changes(*message): + if counter: + changes = [] + last = None + for (action, what), count in sorted(counter.items()): + if action != last: + if last: + changes.append('/') + changes.append(f'{action}:') + last = action + changes.append(f'{what}({count})') + else: + changes = 'no changes.' + return chain(message, changes) + + def log(action, target, message=None): + if message is None: + message = json.dumps(dumper[target.__class__].dump(target), ensure_ascii=False) + print(f'{action} {target.__table__}: {message}') + + def listen_insert(mapper, connection, target): # pylint: disable=unused-argument + """ callback function to track import """ + counter.update([('Added', target.__table__.name)]) + if verbose >= 1: + log('Added', target) + + def listen_update(mapper, connection, target): # pylint: disable=unused-argument + """ callback function to track import """ + + changed = {} + inspection = sqlalchemy.inspect(target) + for attr in sqlalchemy.orm.class_mapper(target.__class__).column_attrs: + if getattr(inspection.attrs, attr.key).history.has_changes(): + if sqlalchemy.orm.attributes.get_history(target, attr.key)[2]: + before = sqlalchemy.orm.attributes.get_history(target, attr.key)[2].pop() + after = getattr(target, attr.key) + # only remember changed keys + if before != after and (before or after): + if verbose >= 1: + changed[str(attr.key)] = (before, after) + else: + break + + if verbose >= 1: + # use schema with dump_context to hide secrets and sort keys + primary = json.dumps(str(target), ensure_ascii=False) + dumped = get_schema(target)(only=changed.keys(), context=dump_context).dump(target) + for key, value in dumped.items(): + before, after = changed[key] + if value == '': + before = '' if before else before + after = '' if after else after + else: + # TODO: use schema to "convert" before value? + after = value + before = json.dumps(before, ensure_ascii=False) + after = json.dumps(after, ensure_ascii=False) + log('Modified', target, f'{primary} {key}: {before} -> {after}') + + if changed: + counter.update([('Modified', target.__table__.name)]) + + def listen_delete(mapper, connection, target): # pylint: disable=unused-argument + """ callback function to track import """ + counter.update([('Deleted', target.__table__.name)]) + if verbose >= 1: + log('Deleted', target) + + # this listener should not be necessary, when: + # dkim keys should be stored in database and it should be possible to store multiple + # keys per domain. the active key would be also stored on disk on commit. + def listen_dkim(session, flush_context): # pylint: disable=unused-argument + """ callback function to track import """ + for target in session.identity_map.values(): + if not isinstance(target, models.Domain): + continue + primary = json.dumps(str(target), ensure_ascii=False) + before = target._dkim_key_on_disk + after = target._dkim_key + if before != after and (before or after): + if verbose >= 2: + before = before.decode('ascii', 'ignore') + after = after.decode('ascii', 'ignore') + else: + before = '' if before else '' + after = '' if after else '' + before = json.dumps(before, ensure_ascii=False) + after = json.dumps(after, ensure_ascii=False) + log('Modified', target, f'{primary} dkim_key: {before} -> {after}') + counter.update([('Modified', target.__table__.name)]) + + def track_serialize(self, item): + """ callback function to track import """ + log('Handling', self.opts.model, item) + + # configure contexts + dump_context = { + 'secrets': verbose >= 2, + } + load_context = { + 'callback': track_serialize if verbose >= 3 else None, + 'clear': not update, 'import': True, } - error = False + # register listeners + for schema in get_schema(): + model = schema.Meta.model + dumper[model] = schema(context=dump_context) + sqlalchemy.event.listen(model, 'after_insert', listen_insert) + sqlalchemy.event.listen(model, 'after_update', listen_update) + sqlalchemy.event.listen(model, 'after_delete', listen_delete) + + # special listener for dkim_key changes + sqlalchemy.event.listen(db.session, 'after_flush', listen_dkim) + + if verbose >= 4: + logging.basicConfig() + logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) + try: - config = MailuSchema(context=context).loads(source) + with models.db.session.no_autoflush: + config = MailuSchema(only=SECTIONS, context=load_context).loads(source) except ValidationError as exc: - error = True - format_errors(exc.messages) - else: - print(config) + raise click.ClickException(format_errors(exc.messages)) from exc + except Exception as exc: + # (yaml.scanner.ScannerError, UnicodeDecodeError, ...) + raise click.ClickException(f'[{exc.__class__.__name__}] {" ".join(str(exc).split())}') from exc + + # flush session to show/count all changes + if dry_run or verbose >= 1: + db.session.flush() + + # check for duplicate domain names + dup = set() + for fqdn in chain(db.session.query(models.Domain.name), + db.session.query(models.Alternative.name), + db.session.query(models.Relay.name)): + if fqdn in dup: + raise click.ClickException(f'[ValidationError] Duplicate domain name: {fqdn}') + dup.add(fqdn) + + # TODO: implement special update "items" + # -pkey: which - remove item "which" + # -key: null or [] or {} - set key to default + # -pkey: null or [] or {} - remove all existing items in this list + + # don't commit when running dry + if dry_run: + db.session.rollback() + if not quiet: + print(*format_changes('Dry run. Not commiting changes.')) + # TODO: remove debug print(MailuSchema().dumps(config)) - # TODO: need to delete other entries - - # TODO: enable commit - error = True - - # don't commit when running dry or validation errors occured - if error: - print('An error occured. Not committing changes.') - db.session.rollback() - sys.exit(2) - elif dry_run: - print('Dry run. Not commiting changes.') - db.session.rollback() else: db.session.commit() + if not quiet: + print(*format_changes('Commited changes.')) @mailu.command() @@ -385,28 +542,35 @@ def config_import(verbose=False, dry_run=False, source=None): @click.option('-d', '--dns', is_flag=True, help='Include dns records') @click.option('-o', '--output-file', 'output', default=sys.stdout, type=click.File(mode='w'), help='save yaml to file') +@click.option('-j', '--json', 'as_json', is_flag=True, help='Dump in josn format') @click.argument('sections', nargs=-1) @with_appcontext -def config_export(full=False, secrets=False, dns=False, output=None, sections=None): - """ Export configuration as YAML to stdout or file +def config_export(full=False, secrets=False, dns=False, output=None, as_json=False, sections=None): + """ Export configuration as YAML or JSON to stdout or file """ if sections: for section in sections: if section not in SECTIONS: - print(f'[ERROR] Unknown section: {section!r}') - sys.exit(1) + print(f'[ERROR] Unknown section: {section}') + raise click.exceptions.Exit(1) sections = set(sections) else: sections = SECTIONS - context={ + context = { 'full': full, 'secrets': secrets, 'dns': dns, } - MailuSchema(only=sections, context=context).dumps(models.MailuConfig(), output) + if as_json: + schema = MailuSchema(only=sections, context=context) + schema.opts.render_module = json + print(schema.dumps(models.MailuConfig(), separators=(',',':')), file=output) + + else: + MailuSchema(only=sections, context=context).dumps(models.MailuConfig(), output) @mailu.command() diff --git a/core/admin/mailu/models.py b/core/admin/mailu/models.py index 3187f597..dac1dc70 100644 --- a/core/admin/mailu/models.py +++ b/core/admin/mailu/models.py @@ -12,7 +12,8 @@ from itertools import chain import flask_sqlalchemy import sqlalchemy -import passlib +import passlib.context +import passlib.hash import idna import dns @@ -79,11 +80,11 @@ class CommaSeparatedList(db.TypeDecorator): for item in value: if ',' in item: raise ValueError('list item must not contain ","') - return ','.join(sorted(value)) + return ','.join(sorted(set(value))) def process_result_value(self, value, dialect): """ split comma separated string to list """ - return list(filter(bool, value.split(','))) if value else [] + return list(filter(bool, [item.strip() for item in value.split(',')])) if value else [] python_type = list @@ -136,19 +137,11 @@ class Config(Base): value = db.Column(JSONEncoded) -# TODO: use sqlalchemy.event.listen() on a store method of object? -@sqlalchemy.event.listens_for(db.session, 'after_commit') -def store_dkim_key(session): - """ Store DKIM key on commit """ +def _save_dkim_keys(session): + """ store DKIM keys after commit """ for obj in session.identity_map.values(): if isinstance(obj, Domain): - if obj._dkim_key_changed: - file_path = obj._dkim_file() - if obj._dkim_key: - with open(file_path, 'wb') as handle: - handle.write(obj._dkim_key) - elif os.path.exists(file_path): - os.unlink(file_path) + obj.save_dkim_key() class Domain(Base): """ A DNS domain that has mail addresses associated to it. @@ -165,7 +158,7 @@ class Domain(Base): signup_enabled = db.Column(db.Boolean, nullable=False, default=False) _dkim_key = None - _dkim_key_changed = False + _dkim_key_on_disk = None def _dkim_file(self): """ return filename for active DKIM key """ @@ -174,6 +167,17 @@ class Domain(Base): selector=app.config['DKIM_SELECTOR'] ) + def save_dkim_key(self): + """ save changed DKIM key to disk """ + if self._dkim_key != self._dkim_key_on_disk: + file_path = self._dkim_file() + if self._dkim_key: + with open(file_path, 'wb') as handle: + handle.write(self._dkim_key) + elif os.path.exists(file_path): + os.unlink(file_path) + self._dkim_key_on_disk = self._dkim_key + @property def dns_mx(self): """ return MX record for domain """ @@ -189,7 +193,7 @@ class Domain(Base): @property def dns_dkim(self): """ return DKIM record for domain """ - if os.path.exists(self._dkim_file()): + if self.dkim_key: selector = app.config['DKIM_SELECTOR'] return ( f'{selector}._domainkey.{self.name}. 600 IN TXT' @@ -199,7 +203,7 @@ class Domain(Base): @property def dns_dmarc(self): """ return DMARC record for domain """ - if os.path.exists(self._dkim_file()): + if self.dkim_key: domain = app.config['DOMAIN'] rua = app.config['DMARC_RUA'] rua = f' rua=mailto:{rua}@{domain};' if rua else '' @@ -214,19 +218,19 @@ class Domain(Base): file_path = self._dkim_file() if os.path.exists(file_path): with open(file_path, 'rb') as handle: - self._dkim_key = handle.read() + self._dkim_key = self._dkim_key_on_disk = handle.read() else: - self._dkim_key = b'' + self._dkim_key = self._dkim_key_on_disk = b'' return self._dkim_key if self._dkim_key else None @dkim_key.setter def dkim_key(self, value): """ set private DKIM key """ old_key = self.dkim_key - if value is None: - value = b'' - self._dkim_key_changed = value != old_key - self._dkim_key = value + self._dkim_key = value if value is not None else b'' + if self._dkim_key != old_key: + if not sqlalchemy.event.contains(db.session, 'after_commit', _save_dkim_keys): + sqlalchemy.event.listen(db.session, 'after_commit', _save_dkim_keys) @property def dkim_publickey(self): @@ -331,14 +335,14 @@ class Email(object): def sendmail(self, subject, body): """ send an email to the address """ - from_address = f'{app.config["POSTMASTER"]}@{idna.encode(app.config["DOMAIN"]).decode("ascii")}' + f_addr = f'{app.config["POSTMASTER"]}@{idna.encode(app.config["DOMAIN"]).decode("ascii")}' with smtplib.SMTP(app.config['HOST_AUTHSMTP'], port=10025) as smtp: to_address = f'{self.localpart}@{idna.encode(self.domain_name).decode("ascii")}' msg = text.MIMEText(body) msg['Subject'] = subject - msg['From'] = from_address + msg['From'] = f_addr msg['To'] = to_address - smtp.sendmail(from_address, [to_address], msg.as_string()) + smtp.sendmail(f_addr, [to_address], msg.as_string()) @classmethod def resolve_domain(cls, email): @@ -589,7 +593,6 @@ class Alias(Base, Email): return None -# TODO: where are Tokens used / validated? # TODO: what about API tokens? class Token(Base): """ A token is an application password for a given user. @@ -650,20 +653,22 @@ class MailuConfig: and loading """ - # TODO: add sqlalchemy session updating (.add & .del) class MailuCollection: - """ Provides dict- and list-like access to all instances + """ Provides dict- and list-like access to instances of a sqlalchemy model """ def __init__(self, model : db.Model): - self._model = model + self.model = model + + def __str__(self): + return f'<{self.model.__name__}-Collection>' @cached_property def _items(self): return { inspect(item).identity: item - for item in self._model.query.all() + for item in self.model.query.all() } def __len__(self): @@ -676,8 +681,8 @@ class MailuConfig: return self._items[key] def __setitem__(self, key, item): - if not isinstance(item, self._model): - raise TypeError(f'expected {self._model.name}') + if not isinstance(item, self.model): + raise TypeError(f'expected {self.model.name}') if key != inspect(item).identity: raise ValueError(f'item identity != key {key!r}') self._items[key] = item @@ -685,23 +690,24 @@ class MailuConfig: def __delitem__(self, key): del self._items[key] - def append(self, item): + def append(self, item, update=False): """ list-like append """ - if not isinstance(item, self._model): - raise TypeError(f'expected {self._model.name}') + if not isinstance(item, self.model): + raise TypeError(f'expected {self.model.name}') key = inspect(item).identity if key in self._items: - raise ValueError(f'item {key!r} already present in collection') + if not update: + raise ValueError(f'item {key!r} already present in collection') self._items[key] = item - def extend(self, items): + def extend(self, items, update=False): """ list-like extend """ add = {} for item in items: - if not isinstance(item, self._model): - raise TypeError(f'expected {self._model.name}') + if not isinstance(item, self.model): + raise TypeError(f'expected {self.model.name}') key = inspect(item).identity - if key in self._items: + if not update and key in self._items: raise ValueError(f'item {key!r} already present in collection') add[key] = item self._items.update(add) @@ -721,8 +727,8 @@ class MailuConfig: def remove(self, item): """ list-like remove """ - if not isinstance(item, self._model): - raise TypeError(f'expected {self._model.name}') + if not isinstance(item, self.model): + raise TypeError(f'expected {self.model.name}') key = inspect(item).identity if not key in self._items: raise ValueError(f'item {key!r} not found in collection') @@ -739,12 +745,11 @@ class MailuConfig: def update(self, items): """ dict-like update """ for key, item in items: - if not isinstance(item, self._model): - raise TypeError(f'expected {self._model.name}') + if not isinstance(item, self.model): + raise TypeError(f'expected {self.model.name}') if key != inspect(item).identity: raise ValueError(f'item identity != key {key!r}') - if key in self._items: - raise ValueError(f'item {key!r} already present in collection') + self._items.update(items) def setdefault(self, key, item=None): """ dict-like setdefault """ @@ -752,13 +757,86 @@ class MailuConfig: return self._items[key] if item is None: return None - if not isinstance(item, self._model): - raise TypeError(f'expected {self._model.name}') + if not isinstance(item, self.model): + raise TypeError(f'expected {self.model.name}') if key != inspect(item).identity: raise ValueError(f'item identity != key {key!r}') self._items[key] = item return item + def __init__(self): + + # section-name -> attr + self._sections = { + name: getattr(self, name) + for name in dir(self) + if isinstance(getattr(self, name), self.MailuCollection) + } + + # known models + self._models = tuple(section.model for section in self._sections.values()) + + # model -> attr + self._sections.update({ + section.model: section for section in self._sections.values() + }) + + def _get_model(self, section): + if section is None: + return None + model = self._sections.get(section) + if model is None: + raise ValueError(f'Invalid section: {section!r}') + if isinstance(model, self.MailuCollection): + return model.model + return model + + def _add(self, items, section, update): + + model = self._get_model(section) + if isinstance(items, self._models): + items = [items] + elif not hasattr(items, '__iter__'): + raise ValueError(f'{items!r} is not iterable') + + for item in items: + if model is not None and not isinstance(item, model): + what = item.__class__.__name__.capitalize() + raise ValueError(f'{what} can not be added to section {section!r}') + self._sections[type(item)].append(item, update=update) + + def add(self, items, section=None): + """ add item to config """ + self._add(items, section, update=False) + + def update(self, items, section=None): + """ add or replace item in config """ + self._add(items, section, update=True) + + def remove(self, items, section=None): + """ remove item from config """ + model = self._get_model(section) + if isinstance(items, self._models): + items = [items] + elif not hasattr(items, '__iter__'): + raise ValueError(f'{items!r} is not iterable') + + for item in items: + if isinstance(item, str): + if section is None: + raise ValueError(f'Cannot remove key {item!r} without section') + del self._sections[model][item] + elif model is not None and not isinstance(item, model): + what = item.__class__.__name__.capitalize() + raise ValueError(f'{what} can not be removed from section {section!r}') + self._sections[type(item)].remove(item,) + + def clear(self, models=None): + """ remove complete configuration """ + for model in self._models: + if models is None or model in models: + db.session.query(model).delete() + domains = MailuCollection(Domain) relays = MailuCollection(Relay) users = MailuCollection(User) diff --git a/core/admin/mailu/schemas.py b/core/admin/mailu/schemas.py index 5dc10e17..04512f6d 100644 --- a/core/admin/mailu/schemas.py +++ b/core/admin/mailu/schemas.py @@ -24,6 +24,23 @@ ma = Marshmallow() # - fields which are the primary key => unchangeable when updating +### map model to schema ### + +_model2schema = {} + +def get_schema(model=None): + """ return schema class for model or instance of model """ + if model is None: + return _model2schema.values() + else: + return _model2schema.get(model) or _model2schema.get(model.__class__) + +def mapped(cls): + """ register schema in model2schema map """ + _model2schema[cls.Meta.model] = cls + return cls + + ### yaml render module ### # allow yaml module to dump OrderedDict @@ -79,26 +96,6 @@ class RenderYAML: return yaml.dump(*args, **kwargs) -### functions ### - -def handle_email(data): - """ merge separate localpart and domain to email - """ - - localpart = 'localpart' in data - domain = 'domain' in data - - if 'email' in data: - if localpart or domain: - raise ValidationError('duplicate email and localpart/domain') - elif localpart and domain: - data['email'] = f'{data["localpart"]}@{data["domain"]}' - elif localpart or domain: - raise ValidationError('incomplete localpart/domain') - - return data - - ### field definitions ### class LazyStringField(fields.String): @@ -177,9 +174,7 @@ class DkimKeyField(fields.String): return dkim.gen_key() # remember some keydata for error message - keydata = value - if len(keydata) > 40: - keydata = keydata[:25] + '...' + keydata[-10:] + keydata = f'{value[:25]}...{value[-10:]}' if len(value) > 40 else value # wrap value into valid pem layout and check validity value = ( @@ -197,6 +192,26 @@ class DkimKeyField(fields.String): ### base definitions ### +def handle_email(data): + """ merge separate localpart and domain to email + """ + + localpart = 'localpart' in data + domain = 'domain' in data + + if 'email' in data: + if localpart or domain: + raise ValidationError('duplicate email and localpart/domain') + data['localpart'], data['domain_name'] = data['email'].rsplit('@', 1) + elif localpart and domain: + data['domain_name'] = data['domain'] + del data['domain'] + data['email'] = f'{data["localpart"]}@{data["domain_name"]}' + elif localpart or domain: + raise ValidationError('incomplete localpart/domain') + + return data + class BaseOpts(SQLAlchemyAutoSchemaOpts): """ Option class with sqla session """ @@ -238,12 +253,15 @@ class BaseSchema(ma.SQLAlchemyAutoSchema): # update excludes kwargs['exclude'] = exclude + # init SQLAlchemyAutoSchema + super().__init__(*args, **kwargs) + # exclude_by_value self._exclude_by_value = getattr(self.Meta, 'exclude_by_value', {}) # exclude default values if not context.get('full'): - for column in getattr(self.Meta, 'model').__table__.columns: + for column in getattr(self.opts, 'model').__table__.columns: if column.name not in exclude: self._exclude_by_value.setdefault(column.name, []).append( None if column.default is None else column.default.arg @@ -256,10 +274,7 @@ class BaseSchema(ma.SQLAlchemyAutoSchema): if not flags & set(need): self._hide_by_context |= set(what) - # init SQLAlchemyAutoSchema - super().__init__(*args, **kwargs) - - # init order + # initialize attribute order if hasattr(self.Meta, 'order'): # use user-defined order self._order = list(reversed(getattr(self.Meta, 'order'))) @@ -267,17 +282,35 @@ class BaseSchema(ma.SQLAlchemyAutoSchema): # default order is: primary_key + other keys alphabetically self._order = list(sorted(self.fields.keys())) primary = self.opts.model.__table__.primary_key.columns.values()[0].name - self._order.remove(primary) - self._order.reverse() - self._order.append(primary) + if primary in self._order: + self._order.remove(primary) + self._order.reverse() + self._order.append(primary) + + # move pre_load hook "_track_import" to the front + hooks = self._hooks[('pre_load', False)] + if '_track_import' in hooks: + hooks.remove('_track_import') + hooks.insert(0, '_track_import') + # and post_load hook "_fooo" to the end + hooks = self._hooks[('post_load', False)] + if '_add_instance' in hooks: + hooks.remove('_add_instance') + hooks.append('_add_instance') @pre_load def _track_import(self, data, many, **kwargs): # pylint: disable=unused-argument - call = self.context.get('callback') - if call is not None: - call(self=self, data=data, many=many, **kwargs) +# TODO: also handle reset, prune and delete in pre_load / post_load hooks! +# print('!!!', repr(data)) + if callback := self.context.get('callback'): + callback(self, data) return data + @post_load + def _add_instance(self, item, many, **kwargs): # pylint: disable=unused-argument + self.opts.sqla_session.add(item) + return item + @post_dump def _hide_and_order(self, data, many, **kwargs): # pylint: disable=unused-argument @@ -306,6 +339,7 @@ class BaseSchema(ma.SQLAlchemyAutoSchema): ### schema definitions ### +@mapped class DomainSchema(BaseSchema): """ Marshmallow schema for Domain model """ class Meta: @@ -339,6 +373,7 @@ class DomainSchema(BaseSchema): dns_dmarc = fields.String(dump_only=True) +@mapped class TokenSchema(BaseSchema): """ Marshmallow schema for Token model """ class Meta: @@ -347,6 +382,7 @@ class TokenSchema(BaseSchema): load_instance = True +@mapped class FetchSchema(BaseSchema): """ Marshmallow schema for Fetch model """ class Meta: @@ -361,6 +397,7 @@ class FetchSchema(BaseSchema): } +@mapped class UserSchema(BaseSchema): """ Marshmallow schema for User model """ class Meta: @@ -368,7 +405,7 @@ class UserSchema(BaseSchema): model = models.User load_instance = True include_relationships = True - exclude = ['localpart', 'domain', 'quota_bytes_used'] + exclude = ['domain', 'quota_bytes_used'] exclude_by_value = { 'forward_destination': [[]], @@ -395,7 +432,7 @@ class UserSchema(BaseSchema): raise ValidationError(f'invalid hashed password {password!r}') elif 'password_hash' in data and 'hash_scheme' in data: if data['hash_scheme'] not in self.Meta.model.scheme_dict: - raise ValidationError(f'invalid password scheme {scheme!r}') + raise ValidationError(f'invalid password scheme {data["hash_scheme"]!r}') data['password'] = f'{{{data["hash_scheme"]}}}{data["password_hash"]}' del data['hash_scheme'] del data['password_hash'] @@ -409,17 +446,20 @@ class UserSchema(BaseSchema): # ctx.verify('', hashed) # =>? ValueError: hash could not be identified + localpart = fields.Str(load_only=True) + domain_name = fields.Str(load_only=True) tokens = fields.Nested(TokenSchema, many=True) fetches = fields.Nested(FetchSchema, many=True) +@mapped class AliasSchema(BaseSchema): """ Marshmallow schema for Alias model """ class Meta: """ Schema config """ model = models.Alias load_instance = True - exclude = ['localpart'] + exclude = ['domain'] exclude_by_value = { 'destination': [[]], @@ -429,9 +469,12 @@ class AliasSchema(BaseSchema): def _handle_email(self, data, many, **kwargs): # pylint: disable=unused-argument return handle_email(data) + localpart = fields.Str(load_only=True) + domain_name = fields.Str(load_only=True) destination = CommaSeparatedListField() +@mapped class ConfigSchema(BaseSchema): """ Marshmallow schema for Config model """ class Meta: @@ -440,6 +483,7 @@ class ConfigSchema(BaseSchema): load_instance = True +@mapped class RelaySchema(BaseSchema): """ Marshmallow schema for Relay model """ class Meta: @@ -453,18 +497,43 @@ class MailuSchema(Schema): class Meta: """ Schema config """ render_module = RenderYAML + ordered = True order = ['config', 'domains', 'users', 'aliases', 'relays'] - @post_dump(pass_many=True) - def _order(self, data : OrderedDict, many : bool, **kwargs): # pylint: disable=unused-argument - for key in reversed(self.Meta.order): - try: - data.move_to_end(key, False) - except KeyError: - pass + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # order fields + for field_list in self.load_fields, self.dump_fields, self.fields: + for section in reversed(self.Meta.order): + try: + field_list.move_to_end(section, False) + except KeyError: + pass + + @pre_load + def _clear_config(self, data, many, **kwargs): # pylint: disable=unused-argument + """ create config object in context if missing + and clear it if requested + """ + if 'config' not in self.context: + self.context['config'] = models.MailuConfig() + if self.context.get('clear'): + self.context['config'].clear( + models = {field.nested.opts.model for field in self.fields.values()} + ) return data + @post_load + def _make_config(self, data, many, **kwargs): # pylint: disable=unused-argument + """ update and return config object """ + config = self.context['config'] + for section in self.Meta.order: + if section in data: + config.update(data[section], section) + + return config + config = fields.Nested(ConfigSchema, many=True) domains = fields.Nested(DomainSchema, many=True) users = fields.Nested(UserSchema, many=True)