""" Mailu marshmallow fields and schema
"""

from copy import deepcopy
from collections import Counter
from datetime import timezone

import json
import logging
import yaml

import sqlalchemy

from marshmallow import pre_load, post_load, post_dump, fields, Schema
from marshmallow.utils import ensure_text_type
from marshmallow.exceptions import ValidationError
from marshmallow_sqlalchemy import SQLAlchemyAutoSchemaOpts
from marshmallow_sqlalchemy.fields import RelatedList

from flask_marshmallow import Marshmallow

from OpenSSL import crypto

from pygments import highlight
from pygments.token import Token
from pygments.lexers import get_lexer_by_name
from pygments.lexers.data import YamlLexer
from pygments.formatters import get_formatter_by_name

from mailu import models, dkim


ma = Marshmallow()


### import logging and schema colorization ###

_model2schema = {}

def get_schema(cls=None):
    """ return schema class for model """
    if cls is None:
        return _model2schema.values()
    return _model2schema.get(cls)

def mapped(cls):
    """ register schema in model2schema map """
    _model2schema[cls.Meta.model] = cls
    return cls

class Logger:
    """ helps with counting and colorizing
        imported and exported data
    """

    class MyYamlLexer(YamlLexer):
        """ colorize yaml constants and integers """
        def get_tokens(self, text, unfiltered=False):
            for typ, value in super().get_tokens(text, unfiltered):
                if typ is Token.Literal.Scalar.Plain:
                    if value in {'true', 'false', 'null'}:
                        typ = Token.Keyword.Constant
                    elif value == HIDDEN:
                        typ = Token.Error
                    else:
                        try:
                            int(value, 10)
                        except ValueError:
                            try:
                                float(value)
                            except ValueError:
                                pass
                            else:
                                typ = Token.Literal.Number.Float
                        else:
                            typ = Token.Literal.Number.Integer
                yield typ, value

    def __init__(self, want_color=None, can_color=False, debug=False, secrets=False):

        self.lexer = 'yaml'
        self.formatter = 'terminal'
        self.strip = False
        self.verbose = 0
        self.quiet = False
        self.secrets = secrets
        self.debug = debug
        self.print = print

        self.color = want_color or can_color

        self._counter = Counter()
        self._schemas = {}

        # log contexts
        self._diff_context = {
            'full': True,
            'secrets': secrets,
        }
        log_context = {
            'secrets': secrets,
        }

        # register listeners
        for schema in get_schema():
            model = schema.Meta.model
            self._schemas[model] = schema(context=log_context)
            sqlalchemy.event.listen(model, 'after_insert', self._listen_insert)
            sqlalchemy.event.listen(model, 'after_update', self._listen_update)
            sqlalchemy.event.listen(model, 'after_delete', self._listen_delete)

        # special listener for dkim_key changes
        # TODO: _listen_dkim can be removed when dkim keys are stored in database
        self._dedupe_dkim = set()
        sqlalchemy.event.listen(models.db.session, 'after_flush', self._listen_dkim)

        # register debug logger for sqlalchemy
        if self.debug:
            logging.basicConfig()
            logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)

    def _log(self, action, target, message=None):
        if message is None:
            try:
                message = self._schemas[target.__class__].dump(target)
            except KeyError:
                message = target
        if not isinstance(message, str):
            message = repr(message)
        self.print(f'{action} {target.__table__}: {self.colorize(message)}')

    def _listen_insert(self, mapper, connection, target): # pylint: disable=unused-argument
        """ callback method to track import """
        self._counter.update([('Created', target.__table__.name)])
        if self.verbose:
            self._log('Created', target)

    def _listen_update(self, mapper, connection, target): # pylint: disable=unused-argument
        """ callback method to track import """

        changes = {}
        inspection = sqlalchemy.inspect(target)
        for attr in sqlalchemy.orm.class_mapper(target.__class__).column_attrs:
            history = getattr(inspection.attrs, attr.key).history
            if history.has_changes() and history.deleted:
                before = history.deleted[-1]
                after = getattr(target, attr.key)
                # TODO: this can be removed when comment is not nullable in model
                if attr.key == 'comment' and not before and not after:
                    pass
                # only remember changed keys
                elif before != after:
                    if self.verbose:
                        changes[str(attr.key)] = (before, after)
                    else:
                        break

        if self.verbose:
            # use schema to log changed attributes
            schema = get_schema(target.__class__)
            only = set(changes.keys()) & set(schema().fields.keys())
            if only:
                for key, value in schema(
                    only=only,
                    context=self._diff_context
                ).dump(target).items():
                    before, after = changes[key]
                    if value == HIDDEN:
                        before = HIDDEN if before else before
                        after = HIDDEN if after else after
                    else:
                        # also hide this
                        after = value
                    self._log('Modified', target, f'{str(target)!r} {key}: {before!r} -> {after!r}')

        if changes:
            self._counter.update([('Modified', target.__table__.name)])

    def _listen_delete(self, mapper, connection, target): # pylint: disable=unused-argument
        """ callback method to track import """
        self._counter.update([('Deleted', target.__table__.name)])
        if self.verbose:
            self._log('Deleted', target)

    # TODO: _listen_dkim can be removed when dkim keys are stored in database
    def _listen_dkim(self, session, flush_context): # pylint: disable=unused-argument
        """ callback method to track import """
        for target in session.identity_map.values():
            # look at Domains originally loaded from db
            if not isinstance(target, models.Domain) or not target._sa_instance_state.load_path:
                continue
            before = target._dkim_key_on_disk
            after = target._dkim_key
            # "de-dupe" messages; this event is fired at every flush
            if before == after or (target, before, after) in self._dedupe_dkim:
                continue
            self._dedupe_dkim.add((target, before, after))
            self._counter.update([('Modified', target.__table__.name)])
            if self.verbose:
                if self.secrets:
                    before = before.decode('ascii', 'ignore')
                    after = after.decode('ascii', 'ignore')
                else:
                    before = HIDDEN if before else ''
                    after = HIDDEN if after else ''
                self._log('Modified', target, f'{str(target)!r} dkim_key: {before!r} -> {after!r}')

    def track_serialize(self, obj, item, backref=None):
        """ callback method to track import """
        # called for backref modification?
        if backref is not None:
            self._log(
                'Modified', item, '{target!r} {key}: {before!r} -> {after!r}'.format_map(backref))
            return
        # show input data?
        if self.verbose < 2:
            return
        # hide secrets in data
        if not self.secrets:
            item = self._schemas[obj.opts.model].hide(item)
            if 'hash_password' in item:
                item['password'] = HIDDEN
            if 'fetches' in item:
                for fetch in item['fetches']:
                    fetch['password'] = HIDDEN
        self._log('Handling', obj.opts.model, item)

    def changes(self, *messages, **kwargs):
        """ show changes gathered in counter """
        if self.quiet:
            return
        if self._counter:
            changes = []
            last = None
            for (action, what), count in sorted(self._counter.items()):
                if action != last:
                    if last:
                        changes.append('/')
                    changes.append(f'{action}:')
                    last = action
                changes.append(f'{what}({count})')
        else:
            changes = ['No changes.']
        self.print(*messages, *changes, **kwargs)

    def _format_errors(self, store, path=None):

        res = []
        if path is None:
            path = []
        for key in sorted(store):
            location = path + [str(key)]
            value = store[key]
            if isinstance(value, dict):
                res.extend(self._format_errors(value, location))
            else:
                for message in value:
                    res.append((".".join(location), message))

        if path:
            return res

        maxlen = max(len(loc) for loc, msg in res)
        res = [f'     - {loc.ljust(maxlen)} : {msg}' for loc, msg in res]
        errors = f'{len(res)} error{["s",""][len(res)==1]}'
        res.insert(0, f'[ValidationError] {errors} occurred during input validation')

        return '\n'.join(res)

    def _is_validation_error(self, exc):
        """ walk traceback to extract invalid field from marshmallow """
        path = []
        trace = exc.__traceback__
        while trace:
            if trace.tb_frame.f_code.co_name == '_serialize':
                if 'attr' in trace.tb_frame.f_locals:
                    path.append(trace.tb_frame.f_locals['attr'])
            elif trace.tb_frame.f_code.co_name == '_init_fields':
                spec = ', '.join(
                    '.'.join(path + [key])
                    for key in trace.tb_frame.f_locals['invalid_fields'])
                return f'Invalid filter: {spec}'
            trace = trace.tb_next
        return None

    def format_exception(self, exc):
        """ format ValidationErrors and other exceptions when not debugging """
        if isinstance(exc, ValidationError):
            return self._format_errors(exc.messages)
        if isinstance(exc, ValueError):
            if msg := self._is_validation_error(exc):
                return msg
        if self.debug:
            return None
        msg = ' '.join(str(exc).split())
        return f'[{exc.__class__.__name__}] {msg}'

    colorscheme = {
        Token:                  ('',        ''),
        Token.Name.Tag:         ('cyan',    'cyan'),
        Token.Literal.Scalar:   ('green',   'green'),
        Token.Literal.String:   ('green',   'green'),
        Token.Name.Constant:    ('green',   'green'), # multiline strings
        Token.Keyword.Constant: ('magenta', 'magenta'),
        Token.Literal.Number:   ('magenta', 'magenta'),
        Token.Error:            ('red',     'red'),
        Token.Name:             ('red',     'red'),
        Token.Operator:         ('red',     'red'),
    }

    def colorize(self, data, lexer=None, formatter=None, color=None, strip=None):
        """ add ANSI color to data """

        if color is False or not self.color:
            return data

        lexer = lexer or self.lexer
        lexer = Logger.MyYamlLexer() if lexer == 'yaml' else get_lexer_by_name(lexer)
        formatter = get_formatter_by_name(formatter or self.formatter, colorscheme=self.colorscheme)
        if strip is None:
            strip = self.strip

        res = highlight(data, lexer, formatter)
        if strip:
            return res.rstrip('\n')
        return res


### marshmallow render modules ###

# hidden attributes
class _Hidden:
    def __bool__(self):
        return False
    def __copy__(self):
        return self
    def __deepcopy__(self, _):
        return self
    def __eq__(self, other):
        return str(other) == '<hidden>'
    def __repr__(self):
        return '<hidden>'
    __str__ = __repr__

yaml.add_representer(
    _Hidden,
    lambda dumper, data: dumper.represent_data(str(data))
)

HIDDEN = _Hidden()

# multiline attributes
class _Multiline(str):
    pass

yaml.add_representer(
    _Multiline,
    lambda dumper, data: dumper.represent_scalar(u'tag:yaml.org,2002:str', data, style='|')

)

# yaml render module
class RenderYAML:
    """ Marshmallow YAML Render Module
    """

    class SpacedDumper(yaml.Dumper):
        """ YAML Dumper to add a newline between main sections
            and double the indent used
        """

        def write_line_break(self, data=None):
            super().write_line_break(data)
            if len(self.indents) == 1:
                super().write_line_break()

        def increase_indent(self, flow=False, indentless=False):
            return super().increase_indent(flow, False)

    @staticmethod
    def _augment(kwargs, defaults):
        """ add defaults to kwargs if missing
        """
        for key, value in defaults.items():
            if key not in kwargs:
                kwargs[key] = value

    _load_defaults = {}
    @classmethod
    def loads(cls, *args, **kwargs):
        """ load yaml data from string
        """
        cls._augment(kwargs, cls._load_defaults)
        return yaml.safe_load(*args, **kwargs)

    _dump_defaults = {
        'Dumper': SpacedDumper,
        'default_flow_style': False,
        'allow_unicode': True,
        'sort_keys': False,
    }
    @classmethod
    def dumps(cls, *args, **kwargs):
        """ dump data to yaml string
        """
        cls._augment(kwargs, cls._dump_defaults)
        return yaml.dump(*args, **kwargs)

# json encoder
class JSONEncoder(json.JSONEncoder):
    """ JSONEncoder supporting serialization of HIDDEN """
    def default(self, o):
        """ serialize HIDDEN """
        if isinstance(o, _Hidden):
            return str(o)
        return json.JSONEncoder.default(self, o)

# json render module
class RenderJSON:
    """ Marshmallow JSON Render Module
    """

    @staticmethod
    def _augment(kwargs, defaults):
        """ add defaults to kwargs if missing
        """
        for key, value in defaults.items():
            if key not in kwargs:
                kwargs[key] = value

    _load_defaults = {}
    @classmethod
    def loads(cls, *args, **kwargs):
        """ load json data from string
        """
        cls._augment(kwargs, cls._load_defaults)
        return json.loads(*args, **kwargs)

    _dump_defaults = {
        'separators': (',',':'),
        'cls': JSONEncoder,
    }
    @classmethod
    def dumps(cls, *args, **kwargs):
        """ dump data to json string
        """
        cls._augment(kwargs, cls._dump_defaults)
        return json.dumps(*args, **kwargs)


### marshmallow: custom fields ###

def _rfc3339(datetime):
    """ dump datetime according to rfc3339 """
    if datetime.tzinfo is None:
        datetime = datetime.astimezone(timezone.utc)
    res = datetime.isoformat()
    if res.endswith('+00:00'):
        return f'{res[:-6]}Z'
    return res

fields.DateTime.SERIALIZATION_FUNCS['rfc3339'] = _rfc3339
fields.DateTime.DESERIALIZATION_FUNCS['rfc3339'] = fields.DateTime.DESERIALIZATION_FUNCS['iso']
fields.DateTime.DEFAULT_FORMAT = 'rfc3339'

class LazyStringField(fields.String):
    """ Field that serializes a "false" value to the empty string
    """

    def _serialize(self, value, attr, obj, **kwargs):
        """ serialize None to the empty string
        """
        return value if value else ''

class CommaSeparatedListField(fields.Raw):
    """ Deserialize a string containing comma-separated values to
        a list of strings
    """

    default_error_messages = {
        "invalid": "Not a valid string or list.",
        "invalid_utf8": "Not a valid utf-8 string or list.",
    }

    def _deserialize(self, value, attr, data, **kwargs):
        """ deserialize comma separated string to list of strings
        """

        # empty
        if not value:
            return []

        # handle list
        if isinstance(value, list):
            try:
                value = [ensure_text_type(item) for item in value]
            except UnicodeDecodeError as exc:
                raise self.make_error("invalid_utf8") from exc

        # handle text
        else:
            if not isinstance(value, (str, bytes)):
                raise self.make_error("invalid")
            try:
                value = ensure_text_type(value)
            except UnicodeDecodeError as exc:
                raise self.make_error("invalid_utf8") from exc
            else:
                value = filter(bool, (item.strip() for item in value.split(',')))

        return list(value)


class DkimKeyField(fields.String):
    """ Serialize a dkim key to a multiline string and
        deserialize a dkim key data as string or list of strings
        to a valid dkim key
    """

    default_error_messages = {
        "invalid": "Not a valid string or list.",
        "invalid_utf8": "Not a valid utf-8 string or list.",
    }

    def _serialize(self, value, attr, obj, **kwargs):
        """ serialize dkim key as multiline string
        """

        # map empty string and None to None
        if not value:
            return ''

        # return multiline string
        return _Multiline(value.decode('utf-8'))

    def _wrap_key(self, begin, data, end):
        """ generator to wrap key into RFC 7468 format """
        yield begin
        pos = 0
        while pos < len(data):
            yield data[pos:pos+64]
            pos += 64
        yield end
        yield ''

    def _deserialize(self, value, attr, data, **kwargs):
        """ deserialize a string or list of strings to dkim key data
            with verification
        """

        # convert list to str
        if isinstance(value, list):
            try:
                value = ''.join(ensure_text_type(item) for item in value).strip()
            except UnicodeDecodeError as exc:
                raise self.make_error("invalid_utf8") from exc

        # only text is allowed
        else:
            if not isinstance(value, (str, bytes)):
                raise self.make_error("invalid")
            try:
                value = ensure_text_type(value).strip()
            except UnicodeDecodeError as exc:
                raise self.make_error("invalid_utf8") from exc

        # generate new key?
        if value.lower() == '-generate-':
            return dkim.gen_key()

        # no key?
        if not value:
            return None

        # remember part of value for ValidationError
        bad_key = value

        # strip header and footer, clean whitespace and wrap to 64 characters
        try:
            if value.startswith('-----BEGIN '):
                end = value.index('-----', 11) + 5
                header = value[:end]
                value = value[end:]
            else:
                header = '-----BEGIN PRIVATE KEY-----'

            if (pos := value.find('-----END ')) >= 0:
                end = value.index('-----', pos+9) + 5
                footer = value[pos:end]
                value = value[:pos]
            else:
                footer = '-----END PRIVATE KEY-----'
        except ValueError as exc:
            raise ValidationError(f'invalid dkim key {bad_key!r}') from exc

        # remove whitespace from key data
        value = ''.join(value.split())

        # remember part of value for ValidationError
        bad_key = f'{value[:25]}...{value[-10:]}' if len(value) > 40 else value

        # wrap key according to RFC 7468
        value = ('\n'.join(self._wrap_key(header, value, footer))).encode('ascii')

        # check key validity
        try:
            crypto.load_privatekey(crypto.FILETYPE_PEM, value)
        except crypto.Error as exc:
            raise ValidationError(f'invalid dkim key {bad_key!r}') from exc
        else:
            return value

class PasswordField(fields.Str):
    """ Serialize a hashed password hash by stripping the obsolete {SCHEME}
        Deserialize a plain password or hashed password into a hashed password
    """

    _hashes = {'PBKDF2', 'BLF-CRYPT', 'SHA512-CRYPT', 'SHA256-CRYPT', 'MD5-CRYPT', 'CRYPT'}

    def _serialize(self, value, attr, obj, **kwargs):
        """ strip obsolete {password-hash} when serializing """
        # strip scheme spec if in database - it's obsolete
        if value.startswith('{') and (end := value.find('}', 1)) >= 0:
            if value[1:end] in self._hashes:
                return value[end+1:]
        return value

    def _deserialize(self, value, attr, data, **kwargs):
        """ hashes plain password or checks hashed password
            also strips obsolete {password-hash} when deserializing
        """

        # when hashing is requested: use model instance to hash plain password
        if data.get('hash_password'):
            # hash password using model instance
            inst = self.metadata['model']()
            inst.set_password(value)
            value = inst.password
            del inst

        # strip scheme spec when specified - it's obsolete
        if value.startswith('{') and (end := value.find('}', 1)) >= 0:
            if value[1:end] in self._hashes:
                value = value[end+1:]

        # check if algorithm is supported
        inst = self.metadata['model'](password=value)
        try:
            # just check against empty string to see if hash is valid
            inst.check_password('')
        except ValueError as exc:
            # ValueError: hash could not be identified
            raise ValidationError(f'invalid password hash {value!r}') from exc
        del inst

        return value


### base schema ###

class Storage:
    """ Storage class to save information in context
    """

    context = {}

    def _bind(self, key, bind):
        if bind is True:
            return (self.__class__, key)
        if isinstance(bind, str):
            return (get_schema(self.recall(bind).__class__), key)
        return (bind, key)

    def store(self, key, value, bind=None):
        """ store value under key """
        self.context.setdefault('_track', {})[self._bind(key, bind)]= value

    def recall(self, key, bind=None):
        """ recall value from key """
        return self.context['_track'][self._bind(key, bind)]

class BaseOpts(SQLAlchemyAutoSchemaOpts):
    """ Option class with sqla session
    """
    def __init__(self, meta, ordered=False):
        if not hasattr(meta, 'sqla_session'):
            meta.sqla_session = models.db.session
        if not hasattr(meta, 'sibling'):
            meta.sibling = False
        super(BaseOpts, self).__init__(meta, ordered=ordered)

class BaseSchema(ma.SQLAlchemyAutoSchema, Storage):
    """ Marshmallow base schema with custom exclude logic
        and option to hide sqla defaults
    """

    OPTIONS_CLASS = BaseOpts

    class Meta:
        """ Schema config """
        include_by_context = {}
        exclude_by_value = {}
        hide_by_context = {}
        order = []
        sibling = False

    def __init__(self, *args, **kwargs):

        # prepare only to auto-include explicitly specified attributes
        only = set(kwargs.get('only') or [])

        # get context
        context = kwargs.get('context', {})
        flags = {key for key, value in context.items() if value is True}

        # compile excludes
        exclude = set(kwargs.get('exclude', []))

        # always exclude
        exclude.update({'created_at', 'updated_at'} - only)

        # add include_by_context
        if context is not None:
            for need, what in getattr(self.Meta, 'include_by_context', {}).items():
                if not flags & set(need):
                    exclude |= what - only

        # update excludes
        kwargs['exclude'] = exclude

        # init SQLAlchemyAutoSchema
        super().__init__(*args, **kwargs)

        # exclude_by_value
        self._exclude_by_value = {
            key: values for key, values in getattr(self.Meta, 'exclude_by_value', {}).items()
            if key not in only
        }

        # exclude default values
        if not context.get('full'):
            for column in self.opts.model.__table__.columns:
                if column.name not in exclude and column.name not in only:
                    self._exclude_by_value.setdefault(column.name, []).append(
                        None if column.default is None else column.default.arg
                    )

        # hide by context
        self._hide_by_context = set()
        if context is not None:
            for need, what in getattr(self.Meta, 'hide_by_context', {}).items():
                if not flags & set(need):
                    self._hide_by_context |= what - only

        # remember primary keys
        self._primary = str(self.opts.model.__table__.primary_key.columns.values()[0].name)

        # determine attribute order
        if hasattr(self.Meta, 'order'):
            # use user-defined order
            order = self.Meta.order
        else:
            # default order is: primary_key + other keys alphabetically
            order = list(sorted(self.fields.keys()))
            if self._primary in order:
                order.remove(self._primary)
                order.insert(0, self._primary)

        # order fieldlists
        for fieldlist in (self.fields, self.load_fields, self.dump_fields):
            for field in order:
                if field in fieldlist:
                    fieldlist[field] = fieldlist.pop(field)

        # move post_load hook "_add_instance" to the end (after load_instance mixin)
        hooks = self._hooks[('post_load', False)]
        hooks.remove('_add_instance')
        hooks.append('_add_instance')

    def hide(self, data):
        """ helper method to hide input data for logging """
        # always returns a copy of data
        return {
            key: HIDDEN if key in self._hide_by_context else deepcopy(value)
            for key, value in data.items()
        }

    def _call_and_store(self, *args, **kwargs):
        """ track current parent field for pruning """
        self.store('field', kwargs['field_name'], True)
        return super()._call_and_store(*args, **kwargs)

    # this is only needed to work around the declared attr "email" primary key in model
    def get_instance(self, data):
        """ lookup item by defined primary key instead of key(s) from model """
        if self.transient:
            return None
        if keys := getattr(self.Meta, 'primary_keys', None):
            filters = {key: data.get(key) for key in keys}
            if None not in filters.values():
                res= self.session.query(self.opts.model).filter_by(**filters).first()
                return res
        res= super().get_instance(data)
        return res

    @pre_load(pass_many=True)
    def _patch_many(self, items, many, **kwargs): # pylint: disable=unused-argument
        """ - flush sqla session before serializing a section when requested
              (make sure all objects that could be referred to later are created)
            - when in update mode: patch input data before deserialization
              - handle "prune" and "delete" items
              - replace values in keys starting with '-' with default
        """

        # flush sqla session
        if not self.Meta.sibling:
            self.opts.sqla_session.flush()

        # stop early when not updating
        if not self.context.get('update'):
            return items

        # patch "delete", "prune" and "default"
        want_prune = []
        def patch(count, data):

            # don't allow __delete__ coming from input
            if '__delete__' in data:
                raise ValidationError('Unknown field.', f'{count}.__delete__')

            # fail when hash_password is specified without password
            if 'hash_password' in data and not 'password' in data:
                raise ValidationError(
                    'Nothing to hash. Field "password" is missing.',
                    field_name = f'{count}.hash_password',
                )

            # handle "prune list" and "delete item" (-pkey: none and -pkey: id)
            for key in data:
                if key.startswith('-'):
                    if key[1:] == self._primary:
                        # delete or prune
                        if data[key] is None:
                            # prune
                            want_prune.append(True)
                            return None
                        # mark item for deletion
                        return {key[1:]: data[key], '__delete__': count}

            # handle "set to default value" (-key: none)
            def set_default(key, value):
                if not key.startswith('-'):
                    return (key, value)
                key = key[1:]
                if not key in self.opts.model.__table__.columns:
                    return (key, None)
                if value is not None:
                    raise ValidationError(
                        'Value must be "null" when resetting to default.',
                        f'{count}.{key}'
                    )
                value = self.opts.model.__table__.columns[key].default
                if value is None:
                    raise ValidationError(
                        'Field has no default value.',
                        f'{count}.{key}'
                    )
                return (key, value.arg)

            return dict(set_default(key, value) for key, value in data.items())

        # convert items to "delete" and filter "prune" item
        items = [
            item for item in [
                patch(count, item) for count, item in enumerate(items)
            ] if item
        ]

        # remember if prune was requested for _prune_items@post_load
        self.store('prune', bool(want_prune), True)

        # remember original items to stabilize password-changes in _add_instance@post_load
        self.store('original', items, True)

        return items

    @pre_load
    def _patch_item(self, data, many, **kwargs): # pylint: disable=unused-argument
        """ - call callback function to track import
            - stabilize import of items with auto-increment primary key
            - delete items
            - delete/prune list attributes
            - add missing required attributes
        """

        # callback
        if callback := self.context.get('callback'):
            callback(self, data)

        # stop early when not updating
        if not self.opts.load_instance or not self.context.get('update'):
            return data

        # stabilize import of auto-increment primary keys (not required),
        # by matching import data to existing items and setting primary key
        if not self._primary in data:
            for item in getattr(self.recall('parent'), self.recall('field', 'parent')):
                existing = self.dump(item, many=False)
                this = existing.pop(self._primary)
                if data == existing:
                    instance = item
                    data[self._primary] = this
                    break

        # try to load instance
        instance = self.instance or self.get_instance(data)
        if instance is None:

            if '__delete__' in data:
                # deletion of non-existent item requested
                raise ValidationError(
                    f'Item to delete not found: {data[self._primary]!r}.',
                    field_name = f'{data["__delete__"]}.{self._primary}',
                )

        else:

            if self.context.get('update'):
                # remember instance as parent for pruning siblings
                if not self.Meta.sibling:
                    self.store('parent', instance)
                # delete instance from session when marked
                if '__delete__' in data:
                    self.opts.sqla_session.delete(instance)
                # delete item from lists or prune lists
                # currently: domain.alternatives, user.forward_destination,
                # user.manager_of, aliases.destination
                for key, value in data.items():
                    if not isinstance(self.fields.get(key), (
                        RelatedList, CommaSeparatedListField, fields.Raw)
                    ) or not isinstance(value, list):
                        continue
                    # deduplicate new value
                    new_value = set(value)
                    # handle list pruning
                    if '-prune-' in value:
                        value.remove('-prune-')
                        new_value.remove('-prune-')
                    else:
                        for old in getattr(instance, key):
                            # using str() is okay for now (see above)
                            new_value.add(str(old))
                    # handle item deletion
                    for item in value:
                        if item.startswith('-'):
                            new_value.remove(item)
                            try:
                                new_value.remove(item[1:])
                            except KeyError as exc:
                                raise ValidationError(
                                    f'Item to delete not found: {item[1:]!r}.',
                                    field_name=f'?.{key}',
                                ) from exc
                    # sort list of new values
                    data[key] = sorted(new_value)
                    # log backref modification not catched by modify hook
                    if isinstance(self.fields[key], RelatedList):
                        if callback := self.context.get('callback'):
                            before = {str(v) for v in getattr(instance, key)}
                            after = set(data[key])
                            if before != after:
                                callback(self, instance, {
                                    'key': key,
                                    'target': str(instance),
                                    'before': before,
                                    'after': after,
                                })

            # add attributes required for validation from db
            for attr_name, field_obj in self.load_fields.items():
                if field_obj.required and attr_name not in data:
                    data[attr_name] = getattr(instance, attr_name)

        return data

    @post_load(pass_many=True)
    def _prune_items(self, items, many, **kwargs): # pylint: disable=unused-argument
        """ handle list pruning """

        # stop early when not updating
        if not self.context.get('update'):
            return items

        # get prune flag from _patch_many@pre_load
        want_prune = self.recall('prune', True)

        # prune: determine if existing items in db need to be added or marked for deletion
        add_items = False
        del_items = False
        if self.Meta.sibling:
            # parent prunes automatically
            if not want_prune:
                # no prune requested => add old items
                add_items = True
        else:
            # parent does not prune automatically
            if want_prune:
                # prune requested => mark old items for deletion
                del_items = True

        if add_items or del_items:
            existing = {item[self._primary] for item in items if self._primary in item}
            for item in getattr(self.recall('parent'), self.recall('field', 'parent')):
                key = getattr(item, self._primary)
                if key not in existing:
                    if add_items:
                        items.append({self._primary: key})
                    else:
                        items.append({self._primary: key, '__delete__': '?'})

        return items

    @post_load
    def _add_instance(self, item, many, **kwargs): # pylint: disable=unused-argument
        """ - undo password change in existing instances when plain password did not change
            - add new instances to sqla session
        """

        if not item in self.opts.sqla_session:
            self.opts.sqla_session.add(item)
            return item

        # stop early when not updating or item has no password attribute
        if not self.context.get('update') or not hasattr(item, 'password'):
            return item

        # did we hash a new plaintext password?
        original = None
        pkey = getattr(item, self._primary)
        for data in self.recall('original', True):
            if 'hash_password' in data and data.get(self._primary) == pkey:
                original = data['password']
                break
        if original is None:
            # password was hashed by us
            return item

        # reset hash if plain password matches hash from db
        if attr := getattr(sqlalchemy.inspect(item).attrs, 'password', None):
            if attr.history.has_changes() and attr.history.deleted:
                try:
                    # reset password hash
                    inst = type(item)(password=attr.history.deleted[-1])
                    if inst.check_password(original):
                        item.password = inst.password
                except ValueError:
                    # hash in db is invalid
                    pass
                else:
                    del inst

        return item

    @post_dump
    def _hide_values(self, data, many, **kwargs): # pylint: disable=unused-argument
        """ hide secrets """

        # stop early when not excluding/hiding
        if not self._exclude_by_value and not self._hide_by_context:
            return data

        # exclude or hide values
        full = self.context.get('full')
        return type(data)(
            (key, HIDDEN if key in self._hide_by_context else value)
            for key, value in data.items()
            if full or key not in self._exclude_by_value or value not in self._exclude_by_value[key]
        )

    # this field is used to mark items for deletion
    mark_delete = fields.Boolean(data_key='__delete__', load_only=True)

    # TODO: this can be removed when comment is not nullable in model
    comment = LazyStringField()


### schema definitions ###

@mapped
class DomainSchema(BaseSchema):
    """ Marshmallow schema for Domain model """
    class Meta:
        """ Schema config """
        model = models.Domain
        load_instance = True
        include_relationships = True
        exclude = ['users', 'managers', 'aliases']

        include_by_context = {
            ('dns',): {'dkim_publickey', 'dns_mx', 'dns_spf', 'dns_dkim', 'dns_dmarc'},
        }
        hide_by_context = {
            ('secrets',): {'dkim_key'},
        }
        exclude_by_value = {
            'alternatives': [[]],
            'dkim_key': [None],
            'dkim_publickey': [None],
            'dns_mx': [None],
            'dns_spf': [None],
            'dns_dkim': [None],
            'dns_dmarc': [None],
        }

    dkim_key = DkimKeyField(allow_none=True)
    dkim_publickey = fields.String(dump_only=True)
    dns_mx = fields.String(dump_only=True)
    dns_spf = fields.String(dump_only=True)
    dns_dkim = fields.String(dump_only=True)
    dns_dmarc = fields.String(dump_only=True)


@mapped
class TokenSchema(BaseSchema):
    """ Marshmallow schema for Token model """
    class Meta:
        """ Schema config """
        model = models.Token
        load_instance = True

        sibling = True

    password = PasswordField(required=True, metadata={'model': models.User})
    hash_password = fields.Boolean(load_only=True, missing=False)


@mapped
class FetchSchema(BaseSchema):
    """ Marshmallow schema for Fetch model """
    class Meta:
        """ Schema config """
        model = models.Fetch
        load_instance = True

        sibling = True
        include_by_context = {
            ('full', 'import'): {'last_check', 'error'},
        }
        hide_by_context = {
            ('secrets',): {'password'},
        }


@mapped
class UserSchema(BaseSchema):
    """ Marshmallow schema for User model """
    class Meta:
        """ Schema config """
        model = models.User
        load_instance = True
        include_relationships = True
        exclude = ['_email', 'domain', 'localpart', 'domain_name', 'quota_bytes_used']

        primary_keys = ['email']
        exclude_by_value = {
            'forward_destination': [[]],
            'tokens':              [[]],
            'fetches':             [[]],
            'manager_of':          [[]],
            'reply_enddate':       ['2999-12-31'],
            'reply_startdate':     ['1900-01-01'],
        }

    email = fields.String(required=True)
    tokens = fields.Nested(TokenSchema, many=True)
    fetches = fields.Nested(FetchSchema, many=True)

    password = PasswordField(required=True, metadata={'model': models.User})
    hash_password = fields.Boolean(load_only=True, missing=False)


@mapped
class AliasSchema(BaseSchema):
    """ Marshmallow schema for Alias model """
    class Meta:
        """ Schema config """
        model = models.Alias
        load_instance = True
        exclude = ['_email', 'domain', 'localpart', 'domain_name']

        primary_keys = ['email']
        exclude_by_value = {
            'destination': [[]],
        }

    email = fields.String(required=True)
    destination = CommaSeparatedListField()


@mapped
class ConfigSchema(BaseSchema):
    """ Marshmallow schema for Config model """
    class Meta:
        """ Schema config """
        model = models.Config
        load_instance = True


@mapped
class RelaySchema(BaseSchema):
    """ Marshmallow schema for Relay model """
    class Meta:
        """ Schema config """
        model = models.Relay
        load_instance = True


@mapped
class MailuSchema(Schema, Storage):
    """ Marshmallow schema for complete Mailu config """
    class Meta:
        """ Schema config """
        model = models.MailuConfig
        render_module = RenderYAML

        order = ['domain', 'user', 'alias', 'relay'] # 'config'

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # order fieldlists
        for fieldlist in (self.fields, self.load_fields, self.dump_fields):
            for field in self.Meta.order:
                if field in fieldlist:
                    fieldlist[field] = fieldlist.pop(field)

    def _call_and_store(self, *args, **kwargs):
        """ track current parent and field for pruning """
        self.store('field', kwargs['field_name'], True)
        self.store('parent', self.context.get('config'))
        return super()._call_and_store(*args, **kwargs)

    @pre_load
    def _clear_config(self, data, many, **kwargs): # pylint: disable=unused-argument
        """ create config object in context if missing
            and clear it if requested
        """
        if 'config' not in self.context:
            self.context['config'] = models.MailuConfig()
        if self.context.get('clear'):
            self.context['config'].clear(
                models = {field.nested.opts.model for field in self.fields.values()}
            )
        return data

    @post_load
    def _make_config(self, data, many, **kwargs): # pylint: disable=unused-argument
        """ update and return config object """
        config = self.context['config']
        for section in self.Meta.order:
            if section in data:
                config.update(data[section], section)

        return config

    domain = fields.Nested(DomainSchema, many=True)
    user = fields.Nested(UserSchema, many=True)
    alias = fields.Nested(AliasSchema, many=True)
    relay = fields.Nested(RelaySchema, many=True)
#    config = fields.Nested(ConfigSchema, many=True)