next step for import/export yaml & json

master
Alexander Graf 4 years ago
parent 65b1ad46d9
commit 902b398127

@ -4,9 +4,15 @@
import sys import sys
import os import os
import socket import socket
import json
import logging
import uuid import uuid
from collections import Counter
from itertools import chain
import click import click
import sqlalchemy
import yaml import yaml
from flask import current_app as app from flask import current_app as app
@ -14,7 +20,7 @@ from flask.cli import FlaskGroup, with_appcontext
from marshmallow.exceptions import ValidationError from marshmallow.exceptions import ValidationError
from . import models from . import models
from .schemas import MailuSchema from .schemas import MailuSchema, get_schema
db = models.db db = models.db
@ -322,60 +328,211 @@ SECTIONS = {'domains', 'relays', 'users', 'aliases'}
@mailu.command() @mailu.command()
@click.option('-v', '--verbose', is_flag=True, help='Increase verbosity') @click.option('-v', '--verbose', count=True, help='Increase verbosity')
@click.option('-q', '--quiet', is_flag=True, help='Quiet mode - only show errors')
@click.option('-u', '--update', is_flag=True, help='Update mode - merge input with existing config')
@click.option('-n', '--dry-run', is_flag=True, help='Perform a trial run with no changes made') @click.option('-n', '--dry-run', is_flag=True, help='Perform a trial run with no changes made')
@click.argument('source', metavar='[FILENAME|-]', type=click.File(mode='r'), default=sys.stdin) @click.argument('source', metavar='[FILENAME|-]', type=click.File(mode='r'), default=sys.stdin)
@with_appcontext @with_appcontext
def config_import(verbose=False, dry_run=False, source=None): def config_import(verbose=0, quiet=False, update=False, dry_run=False, source=None):
""" Import configuration from YAML """ Import configuration as YAML or JSON from stdin or file
""" """
def log(**data): # verbose
caller = sys._getframe(1).f_code.co_name # pylint: disable=protected-access # 0 : show number of changes
if caller == '_track_import': # 1 : also show changes
print(f'Handling {data["self"].opts.model.__table__} data: {data["data"]!r}') # 2 : also show secrets
# 3 : also show input data
# 4 : also show sql queries
if quiet:
verbose = -1
counter = Counter()
dumper = {}
def format_errors(store, path=None): def format_errors(store, path=None):
res = []
if path is None: if path is None:
path = [] path = []
for key in sorted(store): for key in sorted(store):
location = path + [str(key)] location = path + [str(key)]
value = store[key] value = store[key]
if isinstance(value, dict): if isinstance(value, dict):
format_errors(value, location) res.extend(format_errors(value, location))
else: else:
for message in value: for message in value:
print(f'[ERROR] {".".join(location)}: {message}') res.append((".".join(location), message))
context = { if path:
'callback': log if verbose else None, return res
fmt = f' - {{:<{max([len(loc) for loc, msg in res])}}} : {{}}'
res = [fmt.format(loc, msg) for loc, msg in res]
num = f'error{["s",""][len(res)==1]}'
res.insert(0, f'[ValidationError] {len(res)} {num} occured during input validation')
return '\n'.join(res)
def format_changes(*message):
if counter:
changes = []
last = None
for (action, what), count in sorted(counter.items()):
if action != last:
if last:
changes.append('/')
changes.append(f'{action}:')
last = action
changes.append(f'{what}({count})')
else:
changes = 'no changes.'
return chain(message, changes)
def log(action, target, message=None):
if message is None:
message = json.dumps(dumper[target.__class__].dump(target), ensure_ascii=False)
print(f'{action} {target.__table__}: {message}')
def listen_insert(mapper, connection, target): # pylint: disable=unused-argument
""" callback function to track import """
counter.update([('Added', target.__table__.name)])
if verbose >= 1:
log('Added', target)
def listen_update(mapper, connection, target): # pylint: disable=unused-argument
""" callback function to track import """
changed = {}
inspection = sqlalchemy.inspect(target)
for attr in sqlalchemy.orm.class_mapper(target.__class__).column_attrs:
if getattr(inspection.attrs, attr.key).history.has_changes():
if sqlalchemy.orm.attributes.get_history(target, attr.key)[2]:
before = sqlalchemy.orm.attributes.get_history(target, attr.key)[2].pop()
after = getattr(target, attr.key)
# only remember changed keys
if before != after and (before or after):
if verbose >= 1:
changed[str(attr.key)] = (before, after)
else:
break
if verbose >= 1:
# use schema with dump_context to hide secrets and sort keys
primary = json.dumps(str(target), ensure_ascii=False)
dumped = get_schema(target)(only=changed.keys(), context=dump_context).dump(target)
for key, value in dumped.items():
before, after = changed[key]
if value == '<hidden>':
before = '<hidden>' if before else before
after = '<hidden>' if after else after
else:
# TODO: use schema to "convert" before value?
after = value
before = json.dumps(before, ensure_ascii=False)
after = json.dumps(after, ensure_ascii=False)
log('Modified', target, f'{primary} {key}: {before} -> {after}')
if changed:
counter.update([('Modified', target.__table__.name)])
def listen_delete(mapper, connection, target): # pylint: disable=unused-argument
""" callback function to track import """
counter.update([('Deleted', target.__table__.name)])
if verbose >= 1:
log('Deleted', target)
# this listener should not be necessary, when:
# dkim keys should be stored in database and it should be possible to store multiple
# keys per domain. the active key would be also stored on disk on commit.
def listen_dkim(session, flush_context): # pylint: disable=unused-argument
""" callback function to track import """
for target in session.identity_map.values():
if not isinstance(target, models.Domain):
continue
primary = json.dumps(str(target), ensure_ascii=False)
before = target._dkim_key_on_disk
after = target._dkim_key
if before != after and (before or after):
if verbose >= 2:
before = before.decode('ascii', 'ignore')
after = after.decode('ascii', 'ignore')
else:
before = '<hidden>' if before else ''
after = '<hidden>' if after else ''
before = json.dumps(before, ensure_ascii=False)
after = json.dumps(after, ensure_ascii=False)
log('Modified', target, f'{primary} dkim_key: {before} -> {after}')
counter.update([('Modified', target.__table__.name)])
def track_serialize(self, item):
""" callback function to track import """
log('Handling', self.opts.model, item)
# configure contexts
dump_context = {
'secrets': verbose >= 2,
}
load_context = {
'callback': track_serialize if verbose >= 3 else None,
'clear': not update,
'import': True, 'import': True,
} }
error = False # register listeners
for schema in get_schema():
model = schema.Meta.model
dumper[model] = schema(context=dump_context)
sqlalchemy.event.listen(model, 'after_insert', listen_insert)
sqlalchemy.event.listen(model, 'after_update', listen_update)
sqlalchemy.event.listen(model, 'after_delete', listen_delete)
# special listener for dkim_key changes
sqlalchemy.event.listen(db.session, 'after_flush', listen_dkim)
if verbose >= 4:
logging.basicConfig()
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
try: try:
config = MailuSchema(context=context).loads(source) with models.db.session.no_autoflush:
config = MailuSchema(only=SECTIONS, context=load_context).loads(source)
except ValidationError as exc: except ValidationError as exc:
error = True raise click.ClickException(format_errors(exc.messages)) from exc
format_errors(exc.messages) except Exception as exc:
else: # (yaml.scanner.ScannerError, UnicodeDecodeError, ...)
print(config) raise click.ClickException(f'[{exc.__class__.__name__}] {" ".join(str(exc).split())}') from exc
# flush session to show/count all changes
if dry_run or verbose >= 1:
db.session.flush()
# check for duplicate domain names
dup = set()
for fqdn in chain(db.session.query(models.Domain.name),
db.session.query(models.Alternative.name),
db.session.query(models.Relay.name)):
if fqdn in dup:
raise click.ClickException(f'[ValidationError] Duplicate domain name: {fqdn}')
dup.add(fqdn)
# TODO: implement special update "items"
# -pkey: which - remove item "which"
# -key: null or [] or {} - set key to default
# -pkey: null or [] or {} - remove all existing items in this list
# don't commit when running dry
if dry_run:
db.session.rollback()
if not quiet:
print(*format_changes('Dry run. Not commiting changes.'))
# TODO: remove debug
print(MailuSchema().dumps(config)) print(MailuSchema().dumps(config))
# TODO: need to delete other entries
# TODO: enable commit
error = True
# don't commit when running dry or validation errors occured
if error:
print('An error occured. Not committing changes.')
db.session.rollback()
sys.exit(2)
elif dry_run:
print('Dry run. Not commiting changes.')
db.session.rollback()
else: else:
db.session.commit() db.session.commit()
if not quiet:
print(*format_changes('Commited changes.'))
@mailu.command() @mailu.command()
@ -385,27 +542,34 @@ def config_import(verbose=False, dry_run=False, source=None):
@click.option('-d', '--dns', is_flag=True, help='Include dns records') @click.option('-d', '--dns', is_flag=True, help='Include dns records')
@click.option('-o', '--output-file', 'output', default=sys.stdout, type=click.File(mode='w'), @click.option('-o', '--output-file', 'output', default=sys.stdout, type=click.File(mode='w'),
help='save yaml to file') help='save yaml to file')
@click.option('-j', '--json', 'as_json', is_flag=True, help='Dump in josn format')
@click.argument('sections', nargs=-1) @click.argument('sections', nargs=-1)
@with_appcontext @with_appcontext
def config_export(full=False, secrets=False, dns=False, output=None, sections=None): def config_export(full=False, secrets=False, dns=False, output=None, as_json=False, sections=None):
""" Export configuration as YAML to stdout or file """ Export configuration as YAML or JSON to stdout or file
""" """
if sections: if sections:
for section in sections: for section in sections:
if section not in SECTIONS: if section not in SECTIONS:
print(f'[ERROR] Unknown section: {section!r}') print(f'[ERROR] Unknown section: {section}')
sys.exit(1) raise click.exceptions.Exit(1)
sections = set(sections) sections = set(sections)
else: else:
sections = SECTIONS sections = SECTIONS
context={ context = {
'full': full, 'full': full,
'secrets': secrets, 'secrets': secrets,
'dns': dns, 'dns': dns,
} }
if as_json:
schema = MailuSchema(only=sections, context=context)
schema.opts.render_module = json
print(schema.dumps(models.MailuConfig(), separators=(',',':')), file=output)
else:
MailuSchema(only=sections, context=context).dumps(models.MailuConfig(), output) MailuSchema(only=sections, context=context).dumps(models.MailuConfig(), output)

@ -12,7 +12,8 @@ from itertools import chain
import flask_sqlalchemy import flask_sqlalchemy
import sqlalchemy import sqlalchemy
import passlib import passlib.context
import passlib.hash
import idna import idna
import dns import dns
@ -79,11 +80,11 @@ class CommaSeparatedList(db.TypeDecorator):
for item in value: for item in value:
if ',' in item: if ',' in item:
raise ValueError('list item must not contain ","') raise ValueError('list item must not contain ","')
return ','.join(sorted(value)) return ','.join(sorted(set(value)))
def process_result_value(self, value, dialect): def process_result_value(self, value, dialect):
""" split comma separated string to list """ """ split comma separated string to list """
return list(filter(bool, value.split(','))) if value else [] return list(filter(bool, [item.strip() for item in value.split(',')])) if value else []
python_type = list python_type = list
@ -136,19 +137,11 @@ class Config(Base):
value = db.Column(JSONEncoded) value = db.Column(JSONEncoded)
# TODO: use sqlalchemy.event.listen() on a store method of object? def _save_dkim_keys(session):
@sqlalchemy.event.listens_for(db.session, 'after_commit') """ store DKIM keys after commit """
def store_dkim_key(session):
""" Store DKIM key on commit """
for obj in session.identity_map.values(): for obj in session.identity_map.values():
if isinstance(obj, Domain): if isinstance(obj, Domain):
if obj._dkim_key_changed: obj.save_dkim_key()
file_path = obj._dkim_file()
if obj._dkim_key:
with open(file_path, 'wb') as handle:
handle.write(obj._dkim_key)
elif os.path.exists(file_path):
os.unlink(file_path)
class Domain(Base): class Domain(Base):
""" A DNS domain that has mail addresses associated to it. """ A DNS domain that has mail addresses associated to it.
@ -165,7 +158,7 @@ class Domain(Base):
signup_enabled = db.Column(db.Boolean, nullable=False, default=False) signup_enabled = db.Column(db.Boolean, nullable=False, default=False)
_dkim_key = None _dkim_key = None
_dkim_key_changed = False _dkim_key_on_disk = None
def _dkim_file(self): def _dkim_file(self):
""" return filename for active DKIM key """ """ return filename for active DKIM key """
@ -174,6 +167,17 @@ class Domain(Base):
selector=app.config['DKIM_SELECTOR'] selector=app.config['DKIM_SELECTOR']
) )
def save_dkim_key(self):
""" save changed DKIM key to disk """
if self._dkim_key != self._dkim_key_on_disk:
file_path = self._dkim_file()
if self._dkim_key:
with open(file_path, 'wb') as handle:
handle.write(self._dkim_key)
elif os.path.exists(file_path):
os.unlink(file_path)
self._dkim_key_on_disk = self._dkim_key
@property @property
def dns_mx(self): def dns_mx(self):
""" return MX record for domain """ """ return MX record for domain """
@ -189,7 +193,7 @@ class Domain(Base):
@property @property
def dns_dkim(self): def dns_dkim(self):
""" return DKIM record for domain """ """ return DKIM record for domain """
if os.path.exists(self._dkim_file()): if self.dkim_key:
selector = app.config['DKIM_SELECTOR'] selector = app.config['DKIM_SELECTOR']
return ( return (
f'{selector}._domainkey.{self.name}. 600 IN TXT' f'{selector}._domainkey.{self.name}. 600 IN TXT'
@ -199,7 +203,7 @@ class Domain(Base):
@property @property
def dns_dmarc(self): def dns_dmarc(self):
""" return DMARC record for domain """ """ return DMARC record for domain """
if os.path.exists(self._dkim_file()): if self.dkim_key:
domain = app.config['DOMAIN'] domain = app.config['DOMAIN']
rua = app.config['DMARC_RUA'] rua = app.config['DMARC_RUA']
rua = f' rua=mailto:{rua}@{domain};' if rua else '' rua = f' rua=mailto:{rua}@{domain};' if rua else ''
@ -214,19 +218,19 @@ class Domain(Base):
file_path = self._dkim_file() file_path = self._dkim_file()
if os.path.exists(file_path): if os.path.exists(file_path):
with open(file_path, 'rb') as handle: with open(file_path, 'rb') as handle:
self._dkim_key = handle.read() self._dkim_key = self._dkim_key_on_disk = handle.read()
else: else:
self._dkim_key = b'' self._dkim_key = self._dkim_key_on_disk = b''
return self._dkim_key if self._dkim_key else None return self._dkim_key if self._dkim_key else None
@dkim_key.setter @dkim_key.setter
def dkim_key(self, value): def dkim_key(self, value):
""" set private DKIM key """ """ set private DKIM key """
old_key = self.dkim_key old_key = self.dkim_key
if value is None: self._dkim_key = value if value is not None else b''
value = b'' if self._dkim_key != old_key:
self._dkim_key_changed = value != old_key if not sqlalchemy.event.contains(db.session, 'after_commit', _save_dkim_keys):
self._dkim_key = value sqlalchemy.event.listen(db.session, 'after_commit', _save_dkim_keys)
@property @property
def dkim_publickey(self): def dkim_publickey(self):
@ -331,14 +335,14 @@ class Email(object):
def sendmail(self, subject, body): def sendmail(self, subject, body):
""" send an email to the address """ """ send an email to the address """
from_address = f'{app.config["POSTMASTER"]}@{idna.encode(app.config["DOMAIN"]).decode("ascii")}' f_addr = f'{app.config["POSTMASTER"]}@{idna.encode(app.config["DOMAIN"]).decode("ascii")}'
with smtplib.SMTP(app.config['HOST_AUTHSMTP'], port=10025) as smtp: with smtplib.SMTP(app.config['HOST_AUTHSMTP'], port=10025) as smtp:
to_address = f'{self.localpart}@{idna.encode(self.domain_name).decode("ascii")}' to_address = f'{self.localpart}@{idna.encode(self.domain_name).decode("ascii")}'
msg = text.MIMEText(body) msg = text.MIMEText(body)
msg['Subject'] = subject msg['Subject'] = subject
msg['From'] = from_address msg['From'] = f_addr
msg['To'] = to_address msg['To'] = to_address
smtp.sendmail(from_address, [to_address], msg.as_string()) smtp.sendmail(f_addr, [to_address], msg.as_string())
@classmethod @classmethod
def resolve_domain(cls, email): def resolve_domain(cls, email):
@ -589,7 +593,6 @@ class Alias(Base, Email):
return None return None
# TODO: where are Tokens used / validated?
# TODO: what about API tokens? # TODO: what about API tokens?
class Token(Base): class Token(Base):
""" A token is an application password for a given user. """ A token is an application password for a given user.
@ -650,20 +653,22 @@ class MailuConfig:
and loading and loading
""" """
# TODO: add sqlalchemy session updating (.add & .del)
class MailuCollection: class MailuCollection:
""" Provides dict- and list-like access to all instances """ Provides dict- and list-like access to instances
of a sqlalchemy model of a sqlalchemy model
""" """
def __init__(self, model : db.Model): def __init__(self, model : db.Model):
self._model = model self.model = model
def __str__(self):
return f'<{self.model.__name__}-Collection>'
@cached_property @cached_property
def _items(self): def _items(self):
return { return {
inspect(item).identity: item inspect(item).identity: item
for item in self._model.query.all() for item in self.model.query.all()
} }
def __len__(self): def __len__(self):
@ -676,8 +681,8 @@ class MailuConfig:
return self._items[key] return self._items[key]
def __setitem__(self, key, item): def __setitem__(self, key, item):
if not isinstance(item, self._model): if not isinstance(item, self.model):
raise TypeError(f'expected {self._model.name}') raise TypeError(f'expected {self.model.name}')
if key != inspect(item).identity: if key != inspect(item).identity:
raise ValueError(f'item identity != key {key!r}') raise ValueError(f'item identity != key {key!r}')
self._items[key] = item self._items[key] = item
@ -685,23 +690,24 @@ class MailuConfig:
def __delitem__(self, key): def __delitem__(self, key):
del self._items[key] del self._items[key]
def append(self, item): def append(self, item, update=False):
""" list-like append """ """ list-like append """
if not isinstance(item, self._model): if not isinstance(item, self.model):
raise TypeError(f'expected {self._model.name}') raise TypeError(f'expected {self.model.name}')
key = inspect(item).identity key = inspect(item).identity
if key in self._items: if key in self._items:
if not update:
raise ValueError(f'item {key!r} already present in collection') raise ValueError(f'item {key!r} already present in collection')
self._items[key] = item self._items[key] = item
def extend(self, items): def extend(self, items, update=False):
""" list-like extend """ """ list-like extend """
add = {} add = {}
for item in items: for item in items:
if not isinstance(item, self._model): if not isinstance(item, self.model):
raise TypeError(f'expected {self._model.name}') raise TypeError(f'expected {self.model.name}')
key = inspect(item).identity key = inspect(item).identity
if key in self._items: if not update and key in self._items:
raise ValueError(f'item {key!r} already present in collection') raise ValueError(f'item {key!r} already present in collection')
add[key] = item add[key] = item
self._items.update(add) self._items.update(add)
@ -721,8 +727,8 @@ class MailuConfig:
def remove(self, item): def remove(self, item):
""" list-like remove """ """ list-like remove """
if not isinstance(item, self._model): if not isinstance(item, self.model):
raise TypeError(f'expected {self._model.name}') raise TypeError(f'expected {self.model.name}')
key = inspect(item).identity key = inspect(item).identity
if not key in self._items: if not key in self._items:
raise ValueError(f'item {key!r} not found in collection') raise ValueError(f'item {key!r} not found in collection')
@ -739,12 +745,11 @@ class MailuConfig:
def update(self, items): def update(self, items):
""" dict-like update """ """ dict-like update """
for key, item in items: for key, item in items:
if not isinstance(item, self._model): if not isinstance(item, self.model):
raise TypeError(f'expected {self._model.name}') raise TypeError(f'expected {self.model.name}')
if key != inspect(item).identity: if key != inspect(item).identity:
raise ValueError(f'item identity != key {key!r}') raise ValueError(f'item identity != key {key!r}')
if key in self._items: self._items.update(items)
raise ValueError(f'item {key!r} already present in collection')
def setdefault(self, key, item=None): def setdefault(self, key, item=None):
""" dict-like setdefault """ """ dict-like setdefault """
@ -752,13 +757,86 @@ class MailuConfig:
return self._items[key] return self._items[key]
if item is None: if item is None:
return None return None
if not isinstance(item, self._model): if not isinstance(item, self.model):
raise TypeError(f'expected {self._model.name}') raise TypeError(f'expected {self.model.name}')
if key != inspect(item).identity: if key != inspect(item).identity:
raise ValueError(f'item identity != key {key!r}') raise ValueError(f'item identity != key {key!r}')
self._items[key] = item self._items[key] = item
return item return item
def __init__(self):
# section-name -> attr
self._sections = {
name: getattr(self, name)
for name in dir(self)
if isinstance(getattr(self, name), self.MailuCollection)
}
# known models
self._models = tuple(section.model for section in self._sections.values())
# model -> attr
self._sections.update({
section.model: section for section in self._sections.values()
})
def _get_model(self, section):
if section is None:
return None
model = self._sections.get(section)
if model is None:
raise ValueError(f'Invalid section: {section!r}')
if isinstance(model, self.MailuCollection):
return model.model
return model
def _add(self, items, section, update):
model = self._get_model(section)
if isinstance(items, self._models):
items = [items]
elif not hasattr(items, '__iter__'):
raise ValueError(f'{items!r} is not iterable')
for item in items:
if model is not None and not isinstance(item, model):
what = item.__class__.__name__.capitalize()
raise ValueError(f'{what} can not be added to section {section!r}')
self._sections[type(item)].append(item, update=update)
def add(self, items, section=None):
""" add item to config """
self._add(items, section, update=False)
def update(self, items, section=None):
""" add or replace item in config """
self._add(items, section, update=True)
def remove(self, items, section=None):
""" remove item from config """
model = self._get_model(section)
if isinstance(items, self._models):
items = [items]
elif not hasattr(items, '__iter__'):
raise ValueError(f'{items!r} is not iterable')
for item in items:
if isinstance(item, str):
if section is None:
raise ValueError(f'Cannot remove key {item!r} without section')
del self._sections[model][item]
elif model is not None and not isinstance(item, model):
what = item.__class__.__name__.capitalize()
raise ValueError(f'{what} can not be removed from section {section!r}')
self._sections[type(item)].remove(item,)
def clear(self, models=None):
""" remove complete configuration """
for model in self._models:
if models is None or model in models:
db.session.query(model).delete()
domains = MailuCollection(Domain) domains = MailuCollection(Domain)
relays = MailuCollection(Relay) relays = MailuCollection(Relay)
users = MailuCollection(User) users = MailuCollection(User)

@ -24,6 +24,23 @@ ma = Marshmallow()
# - fields which are the primary key => unchangeable when updating # - fields which are the primary key => unchangeable when updating
### map model to schema ###
_model2schema = {}
def get_schema(model=None):
""" return schema class for model or instance of model """
if model is None:
return _model2schema.values()
else:
return _model2schema.get(model) or _model2schema.get(model.__class__)
def mapped(cls):
""" register schema in model2schema map """
_model2schema[cls.Meta.model] = cls
return cls
### yaml render module ### ### yaml render module ###
# allow yaml module to dump OrderedDict # allow yaml module to dump OrderedDict
@ -79,26 +96,6 @@ class RenderYAML:
return yaml.dump(*args, **kwargs) return yaml.dump(*args, **kwargs)
### functions ###
def handle_email(data):
""" merge separate localpart and domain to email
"""
localpart = 'localpart' in data
domain = 'domain' in data
if 'email' in data:
if localpart or domain:
raise ValidationError('duplicate email and localpart/domain')
elif localpart and domain:
data['email'] = f'{data["localpart"]}@{data["domain"]}'
elif localpart or domain:
raise ValidationError('incomplete localpart/domain')
return data
### field definitions ### ### field definitions ###
class LazyStringField(fields.String): class LazyStringField(fields.String):
@ -177,9 +174,7 @@ class DkimKeyField(fields.String):
return dkim.gen_key() return dkim.gen_key()
# remember some keydata for error message # remember some keydata for error message
keydata = value keydata = f'{value[:25]}...{value[-10:]}' if len(value) > 40 else value
if len(keydata) > 40:
keydata = keydata[:25] + '...' + keydata[-10:]
# wrap value into valid pem layout and check validity # wrap value into valid pem layout and check validity
value = ( value = (
@ -197,6 +192,26 @@ class DkimKeyField(fields.String):
### base definitions ### ### base definitions ###
def handle_email(data):
""" merge separate localpart and domain to email
"""
localpart = 'localpart' in data
domain = 'domain' in data
if 'email' in data:
if localpart or domain:
raise ValidationError('duplicate email and localpart/domain')
data['localpart'], data['domain_name'] = data['email'].rsplit('@', 1)
elif localpart and domain:
data['domain_name'] = data['domain']
del data['domain']
data['email'] = f'{data["localpart"]}@{data["domain_name"]}'
elif localpart or domain:
raise ValidationError('incomplete localpart/domain')
return data
class BaseOpts(SQLAlchemyAutoSchemaOpts): class BaseOpts(SQLAlchemyAutoSchemaOpts):
""" Option class with sqla session """ Option class with sqla session
""" """
@ -238,12 +253,15 @@ class BaseSchema(ma.SQLAlchemyAutoSchema):
# update excludes # update excludes
kwargs['exclude'] = exclude kwargs['exclude'] = exclude
# init SQLAlchemyAutoSchema
super().__init__(*args, **kwargs)
# exclude_by_value # exclude_by_value
self._exclude_by_value = getattr(self.Meta, 'exclude_by_value', {}) self._exclude_by_value = getattr(self.Meta, 'exclude_by_value', {})
# exclude default values # exclude default values
if not context.get('full'): if not context.get('full'):
for column in getattr(self.Meta, 'model').__table__.columns: for column in getattr(self.opts, 'model').__table__.columns:
if column.name not in exclude: if column.name not in exclude:
self._exclude_by_value.setdefault(column.name, []).append( self._exclude_by_value.setdefault(column.name, []).append(
None if column.default is None else column.default.arg None if column.default is None else column.default.arg
@ -256,10 +274,7 @@ class BaseSchema(ma.SQLAlchemyAutoSchema):
if not flags & set(need): if not flags & set(need):
self._hide_by_context |= set(what) self._hide_by_context |= set(what)
# init SQLAlchemyAutoSchema # initialize attribute order
super().__init__(*args, **kwargs)
# init order
if hasattr(self.Meta, 'order'): if hasattr(self.Meta, 'order'):
# use user-defined order # use user-defined order
self._order = list(reversed(getattr(self.Meta, 'order'))) self._order = list(reversed(getattr(self.Meta, 'order')))
@ -267,17 +282,35 @@ class BaseSchema(ma.SQLAlchemyAutoSchema):
# default order is: primary_key + other keys alphabetically # default order is: primary_key + other keys alphabetically
self._order = list(sorted(self.fields.keys())) self._order = list(sorted(self.fields.keys()))
primary = self.opts.model.__table__.primary_key.columns.values()[0].name primary = self.opts.model.__table__.primary_key.columns.values()[0].name
if primary in self._order:
self._order.remove(primary) self._order.remove(primary)
self._order.reverse() self._order.reverse()
self._order.append(primary) self._order.append(primary)
# move pre_load hook "_track_import" to the front
hooks = self._hooks[('pre_load', False)]
if '_track_import' in hooks:
hooks.remove('_track_import')
hooks.insert(0, '_track_import')
# and post_load hook "_fooo" to the end
hooks = self._hooks[('post_load', False)]
if '_add_instance' in hooks:
hooks.remove('_add_instance')
hooks.append('_add_instance')
@pre_load @pre_load
def _track_import(self, data, many, **kwargs): # pylint: disable=unused-argument def _track_import(self, data, many, **kwargs): # pylint: disable=unused-argument
call = self.context.get('callback') # TODO: also handle reset, prune and delete in pre_load / post_load hooks!
if call is not None: # print('!!!', repr(data))
call(self=self, data=data, many=many, **kwargs) if callback := self.context.get('callback'):
callback(self, data)
return data return data
@post_load
def _add_instance(self, item, many, **kwargs): # pylint: disable=unused-argument
self.opts.sqla_session.add(item)
return item
@post_dump @post_dump
def _hide_and_order(self, data, many, **kwargs): # pylint: disable=unused-argument def _hide_and_order(self, data, many, **kwargs): # pylint: disable=unused-argument
@ -306,6 +339,7 @@ class BaseSchema(ma.SQLAlchemyAutoSchema):
### schema definitions ### ### schema definitions ###
@mapped
class DomainSchema(BaseSchema): class DomainSchema(BaseSchema):
""" Marshmallow schema for Domain model """ """ Marshmallow schema for Domain model """
class Meta: class Meta:
@ -339,6 +373,7 @@ class DomainSchema(BaseSchema):
dns_dmarc = fields.String(dump_only=True) dns_dmarc = fields.String(dump_only=True)
@mapped
class TokenSchema(BaseSchema): class TokenSchema(BaseSchema):
""" Marshmallow schema for Token model """ """ Marshmallow schema for Token model """
class Meta: class Meta:
@ -347,6 +382,7 @@ class TokenSchema(BaseSchema):
load_instance = True load_instance = True
@mapped
class FetchSchema(BaseSchema): class FetchSchema(BaseSchema):
""" Marshmallow schema for Fetch model """ """ Marshmallow schema for Fetch model """
class Meta: class Meta:
@ -361,6 +397,7 @@ class FetchSchema(BaseSchema):
} }
@mapped
class UserSchema(BaseSchema): class UserSchema(BaseSchema):
""" Marshmallow schema for User model """ """ Marshmallow schema for User model """
class Meta: class Meta:
@ -368,7 +405,7 @@ class UserSchema(BaseSchema):
model = models.User model = models.User
load_instance = True load_instance = True
include_relationships = True include_relationships = True
exclude = ['localpart', 'domain', 'quota_bytes_used'] exclude = ['domain', 'quota_bytes_used']
exclude_by_value = { exclude_by_value = {
'forward_destination': [[]], 'forward_destination': [[]],
@ -395,7 +432,7 @@ class UserSchema(BaseSchema):
raise ValidationError(f'invalid hashed password {password!r}') raise ValidationError(f'invalid hashed password {password!r}')
elif 'password_hash' in data and 'hash_scheme' in data: elif 'password_hash' in data and 'hash_scheme' in data:
if data['hash_scheme'] not in self.Meta.model.scheme_dict: if data['hash_scheme'] not in self.Meta.model.scheme_dict:
raise ValidationError(f'invalid password scheme {scheme!r}') raise ValidationError(f'invalid password scheme {data["hash_scheme"]!r}')
data['password'] = f'{{{data["hash_scheme"]}}}{data["password_hash"]}' data['password'] = f'{{{data["hash_scheme"]}}}{data["password_hash"]}'
del data['hash_scheme'] del data['hash_scheme']
del data['password_hash'] del data['password_hash']
@ -409,17 +446,20 @@ class UserSchema(BaseSchema):
# ctx.verify('', hashed) # ctx.verify('', hashed)
# =>? ValueError: hash could not be identified # =>? ValueError: hash could not be identified
localpart = fields.Str(load_only=True)
domain_name = fields.Str(load_only=True)
tokens = fields.Nested(TokenSchema, many=True) tokens = fields.Nested(TokenSchema, many=True)
fetches = fields.Nested(FetchSchema, many=True) fetches = fields.Nested(FetchSchema, many=True)
@mapped
class AliasSchema(BaseSchema): class AliasSchema(BaseSchema):
""" Marshmallow schema for Alias model """ """ Marshmallow schema for Alias model """
class Meta: class Meta:
""" Schema config """ """ Schema config """
model = models.Alias model = models.Alias
load_instance = True load_instance = True
exclude = ['localpart'] exclude = ['domain']
exclude_by_value = { exclude_by_value = {
'destination': [[]], 'destination': [[]],
@ -429,9 +469,12 @@ class AliasSchema(BaseSchema):
def _handle_email(self, data, many, **kwargs): # pylint: disable=unused-argument def _handle_email(self, data, many, **kwargs): # pylint: disable=unused-argument
return handle_email(data) return handle_email(data)
localpart = fields.Str(load_only=True)
domain_name = fields.Str(load_only=True)
destination = CommaSeparatedListField() destination = CommaSeparatedListField()
@mapped
class ConfigSchema(BaseSchema): class ConfigSchema(BaseSchema):
""" Marshmallow schema for Config model """ """ Marshmallow schema for Config model """
class Meta: class Meta:
@ -440,6 +483,7 @@ class ConfigSchema(BaseSchema):
load_instance = True load_instance = True
@mapped
class RelaySchema(BaseSchema): class RelaySchema(BaseSchema):
""" Marshmallow schema for Relay model """ """ Marshmallow schema for Relay model """
class Meta: class Meta:
@ -453,18 +497,43 @@ class MailuSchema(Schema):
class Meta: class Meta:
""" Schema config """ """ Schema config """
render_module = RenderYAML render_module = RenderYAML
ordered = True ordered = True
order = ['config', 'domains', 'users', 'aliases', 'relays'] order = ['config', 'domains', 'users', 'aliases', 'relays']
@post_dump(pass_many=True) def __init__(self, *args, **kwargs):
def _order(self, data : OrderedDict, many : bool, **kwargs): # pylint: disable=unused-argument super().__init__(*args, **kwargs)
for key in reversed(self.Meta.order): # order fields
for field_list in self.load_fields, self.dump_fields, self.fields:
for section in reversed(self.Meta.order):
try: try:
data.move_to_end(key, False) field_list.move_to_end(section, False)
except KeyError: except KeyError:
pass pass
@pre_load
def _clear_config(self, data, many, **kwargs): # pylint: disable=unused-argument
""" create config object in context if missing
and clear it if requested
"""
if 'config' not in self.context:
self.context['config'] = models.MailuConfig()
if self.context.get('clear'):
self.context['config'].clear(
models = {field.nested.opts.model for field in self.fields.values()}
)
return data return data
@post_load
def _make_config(self, data, many, **kwargs): # pylint: disable=unused-argument
""" update and return config object """
config = self.context['config']
for section in self.Meta.order:
if section in data:
config.update(data[section], section)
return config
config = fields.Nested(ConfigSchema, many=True) config = fields.Nested(ConfigSchema, many=True)
domains = fields.Nested(DomainSchema, many=True) domains = fields.Nested(DomainSchema, many=True)
users = fields.Nested(UserSchema, many=True) users = fields.Nested(UserSchema, many=True)

Loading…
Cancel
Save