added config_import using marshmallow

master
Alexander Graf 4 years ago
parent 7413f9b7b4
commit c24bff1c1b

@ -26,7 +26,7 @@ def register(app):
# add redirect to current api version
@app.route(f'{ROOT}/')
def redir():
def _redirect_to_active_api():
return redirect(url_for(f'{ACTIVE.blueprint.name}.root'))
# swagger ui config

@ -1,21 +1,25 @@
from mailu import models
from .schemas import MailuConfig, MailuSchema
from flask import current_app as app
from flask.cli import FlaskGroup, with_appcontext
""" Mailu command line interface
"""
import sys
import os
import socket
import uuid
import click
import yaml
import sys
from flask import current_app as app
from flask.cli import FlaskGroup, with_appcontext
from marshmallow.exceptions import ValidationError
from . import models
from .schemas import MailuSchema
db = models.db
@click.group(cls=FlaskGroup)
@click.group(cls=FlaskGroup, context_settings={'help_option_names': ['-?', '-h', '--help']})
def mailu():
""" Mailu command line
"""
@ -26,17 +30,17 @@ def mailu():
def advertise():
""" Advertise this server against statistic services.
"""
if os.path.isfile(app.config["INSTANCE_ID_PATH"]):
with open(app.config["INSTANCE_ID_PATH"], "r") as handle:
if os.path.isfile(app.config['INSTANCE_ID_PATH']):
with open(app.config['INSTANCE_ID_PATH'], 'r') as handle:
instance_id = handle.read()
else:
instance_id = str(uuid.uuid4())
with open(app.config["INSTANCE_ID_PATH"], "w") as handle:
with open(app.config['INSTANCE_ID_PATH'], 'w') as handle:
handle.write(instance_id)
if not app.config["DISABLE_STATISTICS"]:
if not app.config['DISABLE_STATISTICS']:
try:
socket.gethostbyname(app.config["STATS_ENDPOINT"].format(instance_id))
except:
socket.gethostbyname(app.config['STATS_ENDPOINT'].format(instance_id))
except OSError:
pass
@ -171,156 +175,196 @@ def user_import(localpart, domain_name, password_hash, hash_scheme = None):
db.session.commit()
yaml_sections = [
('domains', models.Domain),
('relays', models.Relay),
('users', models.User),
('aliases', models.Alias),
# ('config', models.Config),
]
# @mailu.command()
# @click.option('-v', '--verbose', is_flag=True, help='Increase verbosity')
# @click.option('-d', '--delete-objects', is_flag=True, help='Remove objects not included in yaml')
# @click.option('-n', '--dry-run', is_flag=True, help='Perform a trial run with no changes made')
# @click.argument('source', metavar='[FILENAME|-]', type=click.File(mode='r'), default=sys.stdin)
# @with_appcontext
# def config_update(verbose=False, delete_objects=False, dry_run=False, source=None):
# """ Update configuration with data from YAML-formatted input
# """
# try:
# new_config = yaml.safe_load(source)
# except (yaml.scanner.ScannerError, yaml.parser.ParserError) as exc:
# print(f'[ERROR] Invalid yaml: {exc}')
# sys.exit(1)
# else:
# if isinstance(new_config, str):
# print(f'[ERROR] Invalid yaml: {new_config!r}')
# sys.exit(1)
# elif new_config is None or not new_config:
# print('[ERROR] Empty yaml: Please pipe yaml into stdin')
# sys.exit(1)
# error = False
# tracked = {}
# for section, model in yaml_sections:
# items = new_config.get(section)
# if items is None:
# if delete_objects:
# print(f'[ERROR] Invalid yaml: Section "{section}" is missing')
# error = True
# break
# else:
# continue
# del new_config[section]
# if not isinstance(items, list):
# print(f'[ERROR] Section "{section}" must be a list, not {items.__class__.__name__}')
# error = True
# break
# elif not items:
# continue
# # create items
# for data in items:
# if verbose:
# print(f'Handling {model.__table__} data: {data!r}')
# try:
# changed = model.from_dict(data, delete_objects)
# except Exception as exc:
# print(f'[ERROR] {exc.args[0]} in data: {data}')
# error = True
# break
# for item, created in changed:
# if created is True:
# # flush newly created item
# db.session.add(item)
# db.session.flush()
# if verbose:
# print(f'Added {item!r}: {item.to_dict()}')
# else:
# print(f'Added {item!r}')
# elif created:
# # modified instance
# if verbose:
# for key, old, new in created:
# print(f'Updated {key!r} of {item!r}: {old!r} -> {new!r}')
# else:
# print(f'Updated {item!r}: {", ".join(sorted([kon[0] for kon in created]))}')
# # track primary key of all items
# tracked.setdefault(item.__class__, set()).update(set([item._dict_pval()]))
# if error:
# break
# # on error: stop early
# if error:
# print('[ERROR] An error occured. Not committing changes.')
# db.session.rollback()
# sys.exit(1)
# # are there sections left in new_config?
# if new_config:
# print(f'[ERROR] Unknown section(s) in yaml: {", ".join(sorted(new_config.keys()))}')
# error = True
# # test for conflicting domains
# domains = set()
# for model, items in tracked.items():
# if model in (models.Domain, models.Alternative, models.Relay):
# if domains & items:
# for fqdn in domains & items:
# print(f'[ERROR] Duplicate domain name used: {fqdn}')
# error = True
# domains.update(items)
# # delete items not tracked
# if delete_objects:
# for model, items in tracked.items():
# for item in model.query.all():
# if not item._dict_pval() in items:
# print(f'Deleted {item!r} {item}')
# db.session.delete(item)
# # don't commit when running dry
# if dry_run:
# print('Dry run. Not commiting changes.')
# db.session.rollback()
# else:
# db.session.commit()
SECTIONS = {'domains', 'relays', 'users', 'aliases'}
@mailu.command()
@click.option('-v', '--verbose', is_flag=True, help='Increase verbosity')
@click.option('-d', '--delete-objects', is_flag=True, help='Remove objects not included in yaml')
@click.option('-n', '--dry-run', is_flag=True, help='Perform a trial run with no changes made')
@click.argument('source', metavar='[FILENAME|-]', type=click.File(mode='r'), default=sys.stdin)
@with_appcontext
def config_update(verbose=False, delete_objects=False, dry_run=False, file=None):
"""sync configuration with data from YAML-formatted stdin"""
def config_import(verbose=False, dry_run=False, source=None):
""" Import configuration YAML
"""
out = (lambda *args: print('(DRY RUN)', *args)) if dry_run else print
context = {
'verbose': verbose, # TODO: use callback function to be verbose?
'import': True,
}
try:
new_config = yaml.safe_load(sys.stdin)
except (yaml.scanner.ScannerError, yaml.parser.ParserError) as reason:
out(f'[ERROR] Invalid yaml: {reason}')
config = MailuSchema(context=context).loads(source)
except ValidationError as exc:
print(f'[ERROR] {exc}')
# TODO: show nice errors
from pprint import pprint
pprint(exc.messages)
sys.exit(1)
else:
if type(new_config) is str:
out(f'[ERROR] Invalid yaml: {new_config!r}')
sys.exit(1)
elif new_config is None or not len(new_config):
out('[ERROR] Empty yaml: Please pipe yaml into stdin')
sys.exit(1)
error = False
tracked = {}
for section, model in yaml_sections:
items = new_config.get(section)
if items is None:
if delete_objects:
out(f'[ERROR] Invalid yaml: Section "{section}" is missing')
error = True
break
else:
continue
del new_config[section]
if type(items) is not list:
out(f'[ERROR] Section "{section}" must be a list, not {items.__class__.__name__}')
error = True
break
elif not items:
continue
# create items
for data in items:
if verbose:
out(f'Handling {model.__table__} data: {data!r}')
try:
changed = model.from_dict(data, delete_objects)
except Exception as reason:
out(f'[ERROR] {reason.args[0]} in data: {data}')
error = True
break
for item, created in changed:
if created is True:
# flush newly created item
db.session.add(item)
db.session.flush()
if verbose:
out(f'Added {item!r}: {item.to_dict()}')
else:
out(f'Added {item!r}')
elif len(created):
# modified instance
if verbose:
for key, old, new in created:
out(f'Updated {key!r} of {item!r}: {old!r} -> {new!r}')
else:
out(f'Updated {item!r}: {", ".join(sorted([kon[0] for kon in created]))}')
# track primary key of all items
tracked.setdefault(item.__class__, set()).update(set([item._dict_pval()]))
if error:
break
# on error: stop early
if error:
out('An error occured. Not committing changes.')
db.session.rollback()
sys.exit(1)
# are there sections left in new_config?
if new_config:
out(f'[ERROR] Unknown section(s) in yaml: {", ".join(sorted(new_config.keys()))}')
error = True
# test for conflicting domains
domains = set()
for model, items in tracked.items():
if model in (models.Domain, models.Alternative, models.Relay):
if domains & items:
for domain in domains & items:
out(f'[ERROR] Duplicate domain name used: {domain}')
error = True
domains.update(items)
# delete items not tracked
if delete_objects:
for model, items in tracked.items():
for item in model.query.all():
if not item._dict_pval() in items:
out(f'Deleted {item!r} {item}')
db.session.delete(item)
print(config)
print(MailuSchema().dumps(config))
# TODO: does not commit yet.
# TODO: delete other entries?
# don't commit when running dry
if dry_run:
if True: #dry_run:
print('Dry run. Not commiting changes.')
db.session.rollback()
else:
db.session.commit()
@mailu.command()
@click.option('-f', '--full', is_flag=True, help='Include default attributes')
@click.option('-s', '--secrets', is_flag=True, help='Include secrets (dkim-key, plain-text / not hashed)')
@click.option('-f', '--full', is_flag=True, help='Include attributes with default value')
@click.option('-s', '--secrets', is_flag=True,
help='Include secret attributes (dkim-key, passwords)')
@click.option('-d', '--dns', is_flag=True, help='Include dns records')
@click.option('-o', '--output-file', 'output', default=sys.stdout, type=click.File(mode='w'),
help='save yaml to file')
@click.argument('sections', nargs=-1)
@with_appcontext
def config_dump(full=False, secrets=False, dns=False, sections=None):
"""dump configuration as YAML-formatted data to stdout
def config_dump(full=False, secrets=False, dns=False, output=None, sections=None):
""" Dump configuration as YAML to stdout or file
SECTIONS can be: domains, relays, users, aliases
"""
try:
config = MailuConfig(sections)
except ValueError as reason:
print(f'[ERROR] {reason}')
return 1
if sections:
for section in sections:
if section not in SECTIONS:
print(f'[ERROR] Unknown section: {section!r}')
sys.exit(1)
sections = set(sections)
else:
sections = SECTIONS
MailuSchema(context={
context={
'full': full,
'secrets': secrets,
'dns': dns,
}).dumps(config, sys.stdout)
}
MailuSchema(only=sections, context=context).dumps(models.MailuConfig(), output)
@mailu.command()

@ -1,23 +1,26 @@
from mailu import dkim
""" Mailu config storage model
"""
from sqlalchemy.ext import declarative
from datetime import datetime, date
import re
import os
import smtplib
import json
from datetime import date
from email.mime import text
from flask import current_app as app
from textwrap import wrap
import flask_sqlalchemy
import sqlalchemy
import re
import time
import os
import passlib
import glob
import smtplib
import idna
import dns
import json
import itertools
from flask import current_app as app
from sqlalchemy.ext import declarative
from sqlalchemy.inspection import inspect
from werkzeug.utils import cached_property
from . import dkim
db = flask_sqlalchemy.SQLAlchemy()
@ -30,9 +33,11 @@ class IdnaDomain(db.TypeDecorator):
impl = db.String(80)
def process_bind_param(self, value, dialect):
""" encode unicode domain name to punycode """
return idna.encode(value).decode('ascii').lower()
def process_result_value(self, value, dialect):
""" decode punycode domain name to unicode """
return idna.decode(value)
python_type = str
@ -44,6 +49,7 @@ class IdnaEmail(db.TypeDecorator):
impl = db.String(255)
def process_bind_param(self, value, dialect):
""" encode unicode domain part of email address to punycode """
try:
localpart, domain_name = value.split('@')
return '{0}@{1}'.format(
@ -54,6 +60,7 @@ class IdnaEmail(db.TypeDecorator):
pass
def process_result_value(self, value, dialect):
""" decode punycode domain part of email to unicode """
localpart, domain_name = value.split('@')
return '{0}@{1}'.format(
localpart,
@ -69,14 +76,16 @@ class CommaSeparatedList(db.TypeDecorator):
impl = db.String
def process_bind_param(self, value, dialect):
if not isinstance(value, (list, set)):
raise TypeError('Must be a list')
""" join list of items to comma separated string """
if not isinstance(value, (list, tuple, set)):
raise TypeError('Must be a list of strings')
for item in value:
if ',' in item:
raise ValueError('Item must not contain a comma')
return ','.join(sorted(value))
def process_result_value(self, value, dialect):
""" split comma separated string to list """
return list(filter(bool, value.split(','))) if value else []
python_type = list
@ -88,9 +97,11 @@ class JSONEncoded(db.TypeDecorator):
impl = db.String
def process_bind_param(self, value, dialect):
""" encode data as json """
return json.dumps(value) if value else None
def process_result_value(self, value, dialect):
""" decode json to data """
return json.loads(value) if value else None
python_type = str
@ -112,246 +123,172 @@ class Base(db.Model):
updated_at = db.Column(db.Date, nullable=True, onupdate=date.today)
comment = db.Column(db.String(255), nullable=True, default='')
@classmethod
def _dict_pkey(cls):
return cls.__mapper__.primary_key[0].name
# @classmethod
# def from_dict(cls, data, delete=False):
def _dict_pval(self):
return getattr(self, self._dict_pkey())
# changed = []
def to_dict(self, full=False, include_secrets=False, include_extra=None, recursed=False, hide=None):
""" Return a dictionary representation of this model.
"""
# pkey = cls._dict_pkey()
if recursed and not getattr(self, '_dict_recurse', False):
return str(self)
# # handle "primary key" only
# if not isinstance(data, dict):
# data = {pkey: data}
hide = set(hide or []) | {'created_at', 'updated_at'}
if hasattr(self, '_dict_hide'):
hide |= self._dict_hide
# # modify input data
# if hasattr(cls, '_dict_input'):
# try:
# cls._dict_input(data)
# except Exception as exc:
# raise ValueError(f'{exc}', cls, None, data) from exc
secret = set()
if not include_secrets and hasattr(self, '_dict_secret'):
secret |= self._dict_secret
# # check for primary key (if not recursed)
# if not getattr(cls, '_dict_recurse', False):
# if not pkey in data:
# raise KeyError(f'primary key {cls.__table__}.{pkey} is missing', cls, pkey, data)
convert = getattr(self, '_dict_output', {})
# # check data keys and values
# for key in list(data.keys()):
extra_keys = getattr(self, '_dict_extra', {})
if include_extra is None:
include_extra = []
# # check key
# if not hasattr(cls, key) and not key in cls.__mapper__.relationships:
# raise KeyError(f'unknown key {cls.__table__}.{key}', cls, key, data)
res = {}
# # check value type
# value = data[key]
# col = cls.__mapper__.columns.get(key)
# if col is not None:
# if not ((value is None and col.nullable) or (isinstance(value, col.type.python_type))):
# raise TypeError(f'{cls.__table__}.{key} {value!r} has invalid type {type(value).__name__!r}', cls, key, data)
# else:
# rel = cls.__mapper__.relationships.get(key)
# if rel is None:
# itype = getattr(cls, '_dict_types', {}).get(key)
# if itype is not None:
# if itype is False: # ignore value. TODO: emit warning?
# del data[key]
# continue
# elif not isinstance(value, itype):
# raise TypeError(f'{cls.__table__}.{key} {value!r} has invalid type {type(value).__name__!r}', cls, key, data)
# else:
# raise NotImplementedError(f'type not defined for {cls.__table__}.{key}')
for key in itertools.chain(
self.__table__.columns.keys(),
getattr(self, '_dict_show', []),
*[extra_keys.get(extra, []) for extra in include_extra]
):
if key in hide:
continue
if key in self.__table__.columns:
default = self.__table__.columns[key].default
if isinstance(default, sqlalchemy.sql.schema.ColumnDefault):
default = default.arg
else:
default = None
value = getattr(self, key)
if full or ((default or value) and value != default):
if key in secret:
value = '<hidden>'
elif value is not None and key in convert:
value = convert[key](value)
res[key] = value
# # handle relationships
# if key in cls.__mapper__.relationships:
# rel_model = cls.__mapper__.relationships[key].argument
# if not isinstance(rel_model, sqlalchemy.orm.Mapper):
# add = rel_model.from_dict(value, delete)
# assert len(add) == 1
# rel_item, updated = add[0]
# changed.append((rel_item, updated))
# data[key] = rel_item
for key in self.__mapper__.relationships.keys():
if key in hide:
continue
if self.__mapper__.relationships[key].uselist:
items = getattr(self, key)
if self.__mapper__.relationships[key].query_class is not None:
if hasattr(items, 'all'):
items = items.all()
if full or items:
if key in secret:
res[key] = '<hidden>'
else:
res[key] = [item.to_dict(full, include_secrets, include_extra, True) for item in items]
else:
value = getattr(self, key)
if full or value is not None:
if key in secret:
res[key] = '<hidden>'
else:
res[key] = value.to_dict(full, include_secrets, include_extra, True)
# # create item if necessary
# created = False
# item = cls.query.get(data[pkey]) if pkey in data else None
# if item is None:
return res
# # check for mandatory keys
# missing = getattr(cls, '_dict_mandatory', set()) - set(data.keys())
# if missing:
# raise ValueError(f'mandatory key(s) {", ".join(sorted(missing))} for {cls.__table__} missing', cls, missing, data)
@classmethod
def from_dict(cls, data, delete=False):
# # remove mapped relationships from data
# mapped = {}
# for key in list(data.keys()):
# if key in cls.__mapper__.relationships:
# if isinstance(cls.__mapper__.relationships[key].argument, sqlalchemy.orm.Mapper):
# mapped[key] = data[key]
# del data[key]
changed = []
# # create new item
# item = cls(**data)
# created = True
pkey = cls._dict_pkey()
# # and update mapped relationships (below)
# data = mapped
# handle "primary key" only
if isinstance(data, dict):
data = {pkey: data}
# # update item
# updated = []
# for key, value in data.items():
# modify input data
if hasattr(cls, '_dict_input'):
try:
cls._dict_input(data)
except Exception as reason:
raise ValueError(f'{reason}', cls, None, data)
# # skip primary key
# if key == pkey:
# continue
# check for primary key (if not recursed)
if not getattr(cls, '_dict_recurse', False):
if not pkey in data:
raise KeyError(f'primary key {cls.__table__}.{pkey} is missing', cls, pkey, data)
# if key in cls.__mapper__.relationships:
# # update relationship
# rel_model = cls.__mapper__.relationships[key].argument
# if isinstance(rel_model, sqlalchemy.orm.Mapper):
# rel_model = rel_model.class_
# # add (and create) referenced items
# cur = getattr(item, key)
# old = sorted(cur, key=id)
# new = []
# for rel_data in value:
# # get or create related item
# add = rel_model.from_dict(rel_data, delete)
# assert len(add) == 1
# rel_item, rel_updated = add[0]
# changed.append((rel_item, rel_updated))
# if rel_item not in cur:
# cur.append(rel_item)
# new.append(rel_item)
# check data keys and values
for key in list(data.keys()):
# # delete referenced items missing in yaml
# rel_pkey = rel_model._dict_pkey()
# new_data = list([i.to_dict(True, True, None, True, [rel_pkey]) for i in new])
# for rel_item in old:
# if rel_item not in new:
# # check if item with same data exists to stabilze import without primary key
# rel_data = rel_item.to_dict(True, True, None, True, [rel_pkey])
# try:
# same_idx = new_data.index(rel_data)
# except ValueError:
# same = None
# else:
# same = new[same_idx]
# check key
if not hasattr(cls, key) and not key in cls.__mapper__.relationships:
raise KeyError(f'unknown key {cls.__table__}.{key}', cls, key, data)
# if same is None:
# # delete items missing in new
# if delete:
# cur.remove(rel_item)
# else:
# new.append(rel_item)
# else:
# # swap found item with same data with newly created item
# new.append(rel_item)
# new_data.append(rel_data)
# new.remove(same)
# del new_data[same_idx]
# for i, (ch_item, _) in enumerate(changed):
# if ch_item is same:
# changed[i] = (rel_item, [])
# db.session.flush()
# db.session.delete(ch_item)
# break
# check value type
value = data[key]
col = cls.__mapper__.columns.get(key)
if col is not None:
if not ((value is None and col.nullable) or (isinstance(value, col.type.python_type))):
raise TypeError(f'{cls.__table__}.{key} {value!r} has invalid type {type(value).__name__!r}', cls, key, data)
else:
rel = cls.__mapper__.relationships.get(key)
if rel is None:
itype = getattr(cls, '_dict_types', {}).get(key)
if itype is not None:
if itype is False: # ignore value. TODO: emit warning?
del data[key]
continue
elif not isinstance(value, itype):
raise TypeError(f'{cls.__table__}.{key} {value!r} has invalid type {type(value).__name__!r}', cls, key, data)
else:
raise NotImplementedError(f'type not defined for {cls.__table__}.{key}')
# # remember changes
# new = sorted(new, key=id)
# if new != old:
# updated.append((key, old, new))
# handle relationships
if key in cls.__mapper__.relationships:
rel_model = cls.__mapper__.relationships[key].argument
if not isinstance(rel_model, sqlalchemy.orm.Mapper):
add = rel_model.from_dict(value, delete)
assert len(add) == 1
rel_item, updated = add[0]
changed.append((rel_item, updated))
data[key] = rel_item
# else:
# # update key
# old = getattr(item, key)
# if isinstance(old, list):
# # deduplicate list value
# assert isinstance(value, list)
# value = set(value)
# old = set(old)
# if not delete:
# value = old | value
# if value != old:
# updated.append((key, old, value))
# setattr(item, key, value)
# create item if necessary
created = False
item = cls.query.get(data[pkey]) if pkey in data else None
if item is None:
# changed.append((item, created if created else updated))
# check for mandatory keys
missing = getattr(cls, '_dict_mandatory', set()) - set(data.keys())
if missing:
raise ValueError(f'mandatory key(s) {", ".join(sorted(missing))} for {cls.__table__} missing', cls, missing, data)
# remove mapped relationships from data
mapped = {}
for key in list(data.keys()):
if key in cls.__mapper__.relationships:
if isinstance(cls.__mapper__.relationships[key].argument, sqlalchemy.orm.Mapper):
mapped[key] = data[key]
del data[key]
# create new item
item = cls(**data)
created = True
# and update mapped relationships (below)
data = mapped
# update item
updated = []
for key, value in data.items():
# skip primary key
if key == pkey:
continue
if key in cls.__mapper__.relationships:
# update relationship
rel_model = cls.__mapper__.relationships[key].argument
if isinstance(rel_model, sqlalchemy.orm.Mapper):
rel_model = rel_model.class_
# add (and create) referenced items
cur = getattr(item, key)
old = sorted(cur, key=id)
new = []
for rel_data in value:
# get or create related item
add = rel_model.from_dict(rel_data, delete)
assert len(add) == 1
rel_item, rel_updated = add[0]
changed.append((rel_item, rel_updated))
if rel_item not in cur:
cur.append(rel_item)
new.append(rel_item)
# delete referenced items missing in yaml
rel_pkey = rel_model._dict_pkey()
new_data = list([i.to_dict(True, True, None, True, [rel_pkey]) for i in new])
for rel_item in old:
if rel_item not in new:
# check if item with same data exists to stabilze import without primary key
rel_data = rel_item.to_dict(True, True, None, True, [rel_pkey])
try:
same_idx = new_data.index(rel_data)
except ValueError:
same = None
else:
same = new[same_idx]
if same is None:
# delete items missing in new
if delete:
cur.remove(rel_item)
else:
new.append(rel_item)
else:
# swap found item with same data with newly created item
new.append(rel_item)
new_data.append(rel_data)
new.remove(same)
del new_data[same_idx]
for i, (ch_item, _) in enumerate(changed):
if ch_item is same:
changed[i] = (rel_item, [])
db.session.flush()
db.session.delete(ch_item)
break
# remember changes
new = sorted(new, key=id)
if new != old:
updated.append((key, old, new))
else:
# update key
old = getattr(item, key)
if isinstance(old, list):
# deduplicate list value
assert isinstance(value, list)
value = set(value)
old = set(old)
if not delete:
value = old | value
if value != old:
updated.append((key, old, value))
setattr(item, key, value)
changed.append((item, created if created else updated))
return changed
# return changed
# Many-to-many association table for domain managers
@ -391,48 +328,6 @@ class Domain(Base):
__tablename__ = 'domain'
_dict_hide = {'users', 'managers', 'aliases'}
_dict_show = {'dkim_key'}
_dict_extra = {'dns':{'dkim_publickey', 'dns_mx', 'dns_spf', 'dns_dkim', 'dns_dmarc'}}
_dict_secret = {'dkim_key'}
_dict_types = {
'dkim_key': (bytes, type(None)),
'dkim_publickey': False,
'dns_mx': False,
'dns_spf': False,
'dns_dkim': False,
'dns_dmarc': False,
}
_dict_output = {'dkim_key': lambda key: key.decode('utf-8').strip().split('\n')[1:-1]}
@staticmethod
def _dict_input(data):
if 'dkim_key' in data:
key = data['dkim_key']
if key is not None:
if isinstance(key, list):
key = ''.join(key)
if isinstance(key, str):
key = ''.join(key.strip().split()) # removes all whitespace
if key == 'generate':
data['dkim_key'] = dkim.gen_key()
elif key:
match = re.match('^-----BEGIN (RSA )?PRIVATE KEY-----', key)
if match is not None:
key = key[match.end():]
match = re.search('-----END (RSA )?PRIVATE KEY-----$', key)
if match is not None:
key = key[:match.start()]
key = '\n'.join(wrap(key, 64))
key = f'-----BEGIN PRIVATE KEY-----\n{key}\n-----END PRIVATE KEY-----\n'.encode('ascii')
try:
dkim.strip_key(key)
except:
raise ValueError('invalid dkim key')
else:
data['dkim_key'] = key
else:
data['dkim_key'] = None
name = db.Column(IdnaDomain, primary_key=True, nullable=False)
managers = db.relationship('User', secondary=managers,
backref=db.backref('manager_of'), lazy='dynamic')
@ -462,7 +357,10 @@ class Domain(Base):
def dns_dkim(self):
if os.path.exists(self._dkim_file()):
selector = app.config['DKIM_SELECTOR']
return f'{selector}._domainkey.{self.name}. 600 IN TXT "v=DKIM1; k=rsa; p={self.dkim_publickey}"'
return (
f'{selector}._domainkey.{self.name}. 600 IN TXT'
f'"v=DKIM1; k=rsa; p={self.dkim_publickey}"'
)
@property
def dns_dmarc(self):
@ -525,7 +423,11 @@ class Domain(Base):
try:
return self.name == other.name
except AttributeError:
return False
return NotImplemented
def __hash__(self):
return hash(str(self.name))
class Alternative(Base):
@ -551,8 +453,6 @@ class Relay(Base):
__tablename__ = 'relay'
_dict_mandatory = {'smtp'}
name = db.Column(IdnaDomain, primary_key=True, nullable=False)
smtp = db.Column(db.String(80), nullable=True)
@ -566,18 +466,8 @@ class Email(object):
localpart = db.Column(db.String(80), nullable=False)
@staticmethod
def _dict_input(data):
if 'email' in data:
if 'localpart' in data or 'domain' in data:
raise ValueError('ambigous key email and localpart/domain')
elif isinstance(data['email'], str):
data['localpart'], data['domain'] = data['email'].rsplit('@', 1)
else:
data['email'] = f'{data["localpart"]}@{data["domain"]}'
@declarative.declared_attr
def domain_name(cls):
def domain_name(self):
return db.Column(IdnaDomain, db.ForeignKey(Domain.name),
nullable=False, default=IdnaDomain)
@ -585,7 +475,7 @@ class Email(object):
# It is however very useful for quick lookups without joining tables,
# especially when the mail server is reading the database.
@declarative.declared_attr
def email(cls):
def email(self):
updater = lambda context: '{0}@{1}'.format(
context.current_parameters['localpart'],
context.current_parameters['domain_name'],
@ -662,30 +552,6 @@ class User(Base, Email):
__tablename__ = 'user'
_dict_hide = {'domain_name', 'domain', 'localpart', 'quota_bytes_used'}
_dict_mandatory = {'localpart', 'domain', 'password'}
@classmethod
def _dict_input(cls, data):
Email._dict_input(data)
# handle password
if 'password' in data:
if 'password_hash' in data or 'hash_scheme' in data:
raise ValueError('ambigous key password and password_hash/hash_scheme')
# check (hashed) password
password = data['password']
if password.startswith('{') and '}' in password:
scheme = password[1:password.index('}')]
if scheme not in cls.scheme_dict:
raise ValueError(f'invalid password scheme {scheme!r}')
else:
raise ValueError(f'invalid hashed password {password!r}')
elif 'password_hash' in data and 'hash_scheme' in data:
if data['hash_scheme'] not in cls.scheme_dict:
raise ValueError(f'invalid password scheme {scheme!r}')
data['password'] = '{'+data['hash_scheme']+'}'+ data['password_hash']
del data['hash_scheme']
del data['password_hash']
domain = db.relationship(Domain,
backref=db.backref('users', cascade='all, delete-orphan'))
password = db.Column(db.String(255), nullable=False)
@ -775,7 +641,8 @@ class User(Base, Email):
if raw:
self.password = '{'+hash_scheme+'}' + password
else:
self.password = '{'+hash_scheme+'}' + self.get_password_context().encrypt(password, self.scheme_dict[hash_scheme])
self.password = '{'+hash_scheme+'}' + \
self.get_password_context().encrypt(password, self.scheme_dict[hash_scheme])
def get_managed_domains(self):
if self.global_admin:
@ -812,15 +679,6 @@ class Alias(Base, Email):
__tablename__ = 'alias'
_dict_hide = {'domain_name', 'domain', 'localpart'}
@staticmethod
def _dict_input(data):
Email._dict_input(data)
# handle comma delimited string for backwards compability
dst = data.get('destination')
if isinstance(dst, str):
data['destination'] = list([adr.strip() for adr in dst.split(',')])
domain = db.relationship(Domain,
backref=db.backref('aliases', cascade='all, delete-orphan'))
wildcard = db.Column(db.Boolean, nullable=False, default=False)
@ -832,10 +690,10 @@ class Alias(Base, Email):
sqlalchemy.and_(cls.domain_name == domain_name,
sqlalchemy.or_(
sqlalchemy.and_(
cls.wildcard == False,
cls.wildcard is False,
cls.localpart == localpart
), sqlalchemy.and_(
cls.wildcard == True,
cls.wildcard is True,
sqlalchemy.bindparam('l', localpart).like(cls.localpart)
)
)
@ -847,10 +705,10 @@ class Alias(Base, Email):
sqlalchemy.and_(cls.domain_name == domain_name,
sqlalchemy.or_(
sqlalchemy.and_(
cls.wildcard == False,
cls.wildcard is False,
sqlalchemy.func.lower(cls.localpart) == localpart_lower
), sqlalchemy.and_(
cls.wildcard == True,
cls.wildcard is True,
sqlalchemy.bindparam('l', localpart_lower).like(sqlalchemy.func.lower(cls.localpart))
)
)
@ -875,10 +733,6 @@ class Token(Base):
__tablename__ = 'token'
_dict_recurse = True
_dict_hide = {'user', 'user_email'}
_dict_mandatory = {'password'}
id = db.Column(db.Integer, primary_key=True)
user_email = db.Column(db.String(255), db.ForeignKey(User.email),
nullable=False)
@ -904,11 +758,6 @@ class Fetch(Base):
__tablename__ = 'fetch'
_dict_recurse = True
_dict_hide = {'user_email', 'user', 'last_check', 'error'}
_dict_mandatory = {'protocol', 'host', 'port', 'username', 'password'}
_dict_secret = {'password'}
id = db.Column(db.Integer, primary_key=True)
user_email = db.Column(db.String(255), db.ForeignKey(User.email),
nullable=False)
@ -926,3 +775,124 @@ class Fetch(Base):
def __str__(self):
return f'{self.protocol}{"s" if self.tls else ""}://{self.username}@{self.host}:{self.port}'
class MailuConfig:
""" Class which joins whole Mailu config for dumping
and loading
"""
# TODO: add sqlalchemy session updating (.add & .del)
class MailuCollection:
""" Provides dict- and list-like access to all instances
of a sqlalchemy model
"""
def __init__(self, model : db.Model):
self._model = model
@cached_property
def _items(self):
return {
inspect(item).identity: item
for item in self._model.query.all()
}
def __len__(self):
return len(self._items)
def __iter__(self):
return iter(self._items.values())
def __getitem__(self, key):
return self._items[key]
def __setitem__(self, key, item):
if not isinstance(item, self._model):
raise TypeError(f'expected {self._model.name}')
if key != inspect(item).identity:
raise ValueError(f'item identity != key {key!r}')
self._items[key] = item
def __delitem__(self, key):
del self._items[key]
def append(self, item):
""" list-like append """
if not isinstance(item, self._model):
raise TypeError(f'expected {self._model.name}')
key = inspect(item).identity
if key in self._items:
raise ValueError(f'item {key!r} already present in collection')
self._items[key] = item
def extend(self, items):
""" list-like extend """
add = {}
for item in items:
if not isinstance(item, self._model):
raise TypeError(f'expected {self._model.name}')
key = inspect(item).identity
if key in self._items:
raise ValueError(f'item {key!r} already present in collection')
add[key] = item
self._items.update(add)
def pop(self, *args):
""" list-like (no args) and dict-like (1 or 2 args) pop """
if args:
if len(args) > 2:
raise TypeError(f'pop expected at most 2 arguments, got {len(args)}')
return self._items.pop(*args)
else:
return self._items.popitem()[1]
def popitem(self):
""" dict-like popitem """
return self._items.popitem()
def remove(self, item):
""" list-like remove """
if not isinstance(item, self._model):
raise TypeError(f'expected {self._model.name}')
key = inspect(item).identity
if not key in self._items:
raise ValueError(f'item {key!r} not found in collection')
del self._items[key]
def clear(self):
""" dict-like clear """
while True:
try:
self.pop()
except IndexError:
break
def update(self, items):
""" dict-like update """
for key, item in items:
if not isinstance(item, self._model):
raise TypeError(f'expected {self._model.name}')
if key != inspect(item).identity:
raise ValueError(f'item identity != key {key!r}')
if key in self._items:
raise ValueError(f'item {key!r} already present in collection')
def setdefault(self, key, item=None):
""" dict-like setdefault """
if key in self._items:
return self._items[key]
if item is None:
return None
if not isinstance(item, self._model):
raise TypeError(f'expected {self._model.name}')
if key != inspect(item).identity:
raise ValueError(f'item identity != key {key!r}')
self._items[key] = item
return item
domains = MailuCollection(Domain)
relays = MailuCollection(Relay)
users = MailuCollection(User)
aliases = MailuCollection(Alias)
config = MailuCollection(Config)

@ -1,13 +1,15 @@
""" Mailu marshmallow fields and schema
"""
Mailu marshmallow schema
"""
import re
from textwrap import wrap
import re
import yaml
from marshmallow import post_dump, fields, Schema
from marshmallow import pre_load, post_dump, fields, Schema
from marshmallow.exceptions import ValidationError
from marshmallow_sqlalchemy import SQLAlchemyAutoSchemaOpts
from flask_marshmallow import Marshmallow
from OpenSSL import crypto
@ -15,9 +17,9 @@ from . import models, dkim
ma = Marshmallow()
# TODO:
# how to mark keys as "required" while unserializing (in certain use cases/API)?
# - fields withoud default => required
# TODO: how and where to mark keys as "required" while unserializing (on commandline, in api)?
# - fields without default => required
# - fields which are the primary key => unchangeable when updating
@ -41,7 +43,7 @@ class RenderYAML:
return super().increase_indent(flow, False)
@staticmethod
def _update_dict(dict1, dict2):
def _update_items(dict1, dict2):
""" sets missing keys in dict1 to values of dict2
"""
for key, value in dict2.items():
@ -53,8 +55,8 @@ class RenderYAML:
def loads(cls, *args, **kwargs):
""" load yaml data from string
"""
cls._update_dict(kwargs, cls._load_defaults)
return yaml.load(*args, **kwargs)
cls._update_items(kwargs, cls._load_defaults)
return yaml.safe_load(*args, **kwargs)
_dump_defaults = {
'Dumper': SpacedDumper,
@ -65,13 +67,33 @@ class RenderYAML:
def dumps(cls, *args, **kwargs):
""" dump yaml data to string
"""
cls._update_dict(kwargs, cls._dump_defaults)
cls._update_items(kwargs, cls._dump_defaults)
return yaml.dump(*args, **kwargs)
### functions ###
def handle_email(data):
""" merge separate localpart and domain to email
"""
localpart = 'localpart' in data
domain = 'domain' in data
if 'email' in data:
if localpart or domain:
raise ValidationError('duplicate email and localpart/domain')
elif localpart and domain:
data['email'] = f'{data["localpart"]}@{data["domain"]}'
elif localpart or domain:
raise ValidationError('incomplete localpart/domain')
return data
### field definitions ###
class LazyString(fields.String):
class LazyStringField(fields.String):
""" Field that serializes a "false" value to the empty string
"""
@ -81,14 +103,27 @@ class LazyString(fields.String):
return value if value else ''
class CommaSeparatedList(fields.Raw):
class CommaSeparatedListField(fields.Raw):
""" Field that deserializes a string containing comma-separated values to
a list of strings
"""
# TODO: implement this
def _deserialize(self, value, attr, data, **kwargs):
""" deserialize comma separated string to list of strings
"""
# empty
if not value:
return []
# split string
if isinstance(value, str):
return list([item.strip() for item in value.split(',') if item.strip()])
else:
return value
class DkimKey(fields.String):
class DkimKeyField(fields.String):
""" Field that serializes a dkim key to a list of strings (lines) and
deserializes a string or list of strings.
"""
@ -120,7 +155,7 @@ class DkimKey(fields.String):
# only strings are allowed
if not isinstance(value, str):
raise TypeError(f'invalid type: {type(value).__name__!r}')
raise ValidationError(f'invalid type {type(value).__name__!r}')
# clean value (remove whitespace and header/footer)
value = self._clean_re.sub('', value.strip())
@ -133,6 +168,11 @@ class DkimKey(fields.String):
elif value == 'generate':
return dkim.gen_key()
# remember some keydata for error message
keydata = value
if len(keydata) > 40:
keydata = keydata[:25] + '...' + keydata[-10:]
# wrap value into valid pem layout and check validity
value = (
'-----BEGIN PRIVATE KEY-----\n' +
@ -142,26 +182,37 @@ class DkimKey(fields.String):
try:
crypto.load_privatekey(crypto.FILETYPE_PEM, value)
except crypto.Error as exc:
raise ValueError('invalid dkim key') from exc
raise ValidationError(f'invalid dkim key {keydata!r}') from exc
else:
return value
### schema definitions ###
### base definitions ###
class BaseOpts(SQLAlchemyAutoSchemaOpts):
""" Option class with sqla session
"""
def __init__(self, meta, ordered=False):
if not hasattr(meta, 'sqla_session'):
meta.sqla_session = models.db.session
super(BaseOpts, self).__init__(meta, ordered=ordered)
class BaseSchema(ma.SQLAlchemyAutoSchema):
""" Marshmallow base schema with custom exclude logic
and option to hide sqla defaults
"""
OPTIONS_CLASS = BaseOpts
class Meta:
""" Schema config """
model = None
def __init__(self, *args, **kwargs):
# get and remove config from kwargs
# context?
context = kwargs.get('context', {})
flags = set([key for key, value in context.items() if value is True])
# compile excludes
exclude = set(kwargs.get('exclude', []))
@ -171,8 +222,8 @@ class BaseSchema(ma.SQLAlchemyAutoSchema):
# add include_by_context
if context is not None:
for ctx, what in getattr(self.Meta, 'include_by_context', {}).items():
if not context.get(ctx):
for need, what in getattr(self.Meta, 'include_by_context', {}).items():
if not flags & set(need):
exclude |= set(what)
# update excludes
@ -192,8 +243,8 @@ class BaseSchema(ma.SQLAlchemyAutoSchema):
# hide by context
self._hide_by_context = set()
if context is not None:
for ctx, what in getattr(self.Meta, 'hide_by_context', {}).items():
if not context.get(ctx):
for need, what in getattr(self.Meta, 'hide_by_context', {}).items():
if not flags & set(need):
self._hide_by_context |= set(what)
# init SQLAlchemyAutoSchema
@ -212,23 +263,26 @@ class BaseSchema(ma.SQLAlchemyAutoSchema):
if full or key not in self._exclude_by_value or value not in self._exclude_by_value[key]
}
# TODO: remove LazyString and fix model definition (comment should not be nullable)
comment = LazyString()
# TODO: remove LazyString and change model (IMHO comment should not be nullable)
comment = LazyStringField()
### schema definitions ###
class DomainSchema(BaseSchema):
""" Marshmallow schema for Domain model """
class Meta:
""" Schema config """
model = models.Domain
load_instance = True
include_relationships = True
#include_fk = True
exclude = ['users', 'managers', 'aliases']
include_by_context = {
'dns': {'dkim_publickey', 'dns_mx', 'dns_spf', 'dns_dkim', 'dns_dmarc'},
('dns',): {'dkim_publickey', 'dns_mx', 'dns_spf', 'dns_dkim', 'dns_dmarc'},
}
hide_by_context = {
'secrets': {'dkim_key'},
('secrets',): {'dkim_key'},
}
exclude_by_value = {
'alternatives': [[]],
@ -240,40 +294,20 @@ class DomainSchema(BaseSchema):
'dns_dmarc': [None],
}
dkim_key = DkimKey()
dkim_key = DkimKeyField(allow_none=True)
dkim_publickey = fields.String(dump_only=True)
dns_mx = fields.String(dump_only=True)
dns_spf = fields.String(dump_only=True)
dns_dkim = fields.String(dump_only=True)
dns_dmarc = fields.String(dump_only=True)
# _dict_types = {
# 'dkim_key': (bytes, type(None)),
# 'dkim_publickey': False,
# 'dns_mx': False,
# 'dns_spf': False,
# 'dns_dkim': False,
# 'dns_dmarc': False,
# }
class TokenSchema(BaseSchema):
""" Marshmallow schema for Token model """
class Meta:
""" Schema config """
model = models.Token
# _dict_recurse = True
# _dict_hide = {'user', 'user_email'}
# _dict_mandatory = {'password'}
# id = db.Column(db.Integer(), primary_key=True)
# user_email = db.Column(db.String(255), db.ForeignKey(User.email),
# nullable=False)
# user = db.relationship(User,
# backref=db.backref('tokens', cascade='all, delete-orphan'))
# password = db.Column(db.String(255), nullable=False)
# ip = db.Column(db.String(255))
load_instance = True
class FetchSchema(BaseSchema):
@ -281,58 +315,57 @@ class FetchSchema(BaseSchema):
class Meta:
""" Schema config """
model = models.Fetch
load_instance = True
include_by_context = {
'full': {'last_check', 'error'},
('full', 'import'): {'last_check', 'error'},
}
hide_by_context = {
'secrets': {'password'},
('secrets',): {'password'},
}
# TODO: What about mandatory keys?
# _dict_mandatory = {'protocol', 'host', 'port', 'username', 'password'}
class UserSchema(BaseSchema):
""" Marshmallow schema for User model """
class Meta:
""" Schema config """
model = models.User
load_instance = True
include_relationships = True
exclude = ['localpart', 'domain', 'quota_bytes_used']
exclude_by_value = {
'forward_destination': [[]],
'tokens': [[]],
'manager_of': [[]],
'reply_enddate': ['2999-12-31'],
'reply_startdate': ['1900-01-01'],
}
@pre_load
def _handle_password(self, data, many, **kwargs): # pylint: disable=unused-argument
data = handle_email(data)
if 'password' in data:
if 'password_hash' in data or 'hash_scheme' in data:
raise ValidationError('ambigous key password and password_hash/hash_scheme')
# check (hashed) password
password = data['password']
if password.startswith('{') and '}' in password:
scheme = password[1:password.index('}')]
if scheme not in self.Meta.model.scheme_dict:
raise ValidationError(f'invalid password scheme {scheme!r}')
else:
raise ValidationError(f'invalid hashed password {password!r}')
elif 'password_hash' in data and 'hash_scheme' in data:
if data['hash_scheme'] not in self.Meta.model.scheme_dict:
raise ValidationError(f'invalid password scheme {scheme!r}')
data['password'] = '{'+data['hash_scheme']+'}'+ data['password_hash']
del data['hash_scheme']
del data['password_hash']
return data
tokens = fields.Nested(TokenSchema, many=True)
fetches = fields.Nested(FetchSchema, many=True)
# TODO: deserialize password/password_hash! What about mandatory keys?
# _dict_mandatory = {'localpart', 'domain', 'password'}
# @classmethod
# def _dict_input(cls, data):
# Email._dict_input(data)
# # handle password
# if 'password' in data:
# if 'password_hash' in data or 'hash_scheme' in data:
# raise ValueError('ambigous key password and password_hash/hash_scheme')
# # check (hashed) password
# password = data['password']
# if password.startswith('{') and '}' in password:
# scheme = password[1:password.index('}')]
# if scheme not in cls.scheme_dict:
# raise ValueError(f'invalid password scheme {scheme!r}')
# else:
# raise ValueError(f'invalid hashed password {password!r}')
# elif 'password_hash' in data and 'hash_scheme' in data:
# if data['hash_scheme'] not in cls.scheme_dict:
# raise ValueError(f'invalid password scheme {scheme!r}')
# data['password'] = '{'+data['hash_scheme']+'}'+ data['password_hash']
# del data['hash_scheme']
# del data['password_hash']
class AliasSchema(BaseSchema):
@ -340,20 +373,18 @@ class AliasSchema(BaseSchema):
class Meta:
""" Schema config """
model = models.Alias
load_instance = True
exclude = ['localpart']
exclude_by_value = {
'destination': [[]],
}
# TODO: deserialize destination!
# @staticmethod
# def _dict_input(data):
# Email._dict_input(data)
# # handle comma delimited string for backwards compability
# dst = data.get('destination')
# if type(dst) is str:
# data['destination'] = list([adr.strip() for adr in dst.split(',')])
@pre_load
def _handle_password(self, data, many, **kwargs): # pylint: disable=unused-argument
return handle_email(data)
destination = CommaSeparatedListField()
class ConfigSchema(BaseSchema):
@ -361,6 +392,7 @@ class ConfigSchema(BaseSchema):
class Meta:
""" Schema config """
model = models.Config
load_instance = True
class RelaySchema(BaseSchema):
@ -368,45 +400,17 @@ class RelaySchema(BaseSchema):
class Meta:
""" Schema config """
model = models.Relay
load_instance = True
class MailuSchema(Schema):
""" Marshmallow schema for Mailu config """
""" Marshmallow schema for complete Mailu config """
class Meta:
""" Schema config """
render_module = RenderYAML
domains = fields.Nested(DomainSchema, many=True)
relays = fields.Nested(RelaySchema, many=True)
users = fields.Nested(UserSchema, many=True)
aliases = fields.Nested(AliasSchema, many=True)
config = fields.Nested(ConfigSchema, many=True)
### config class ###
class MailuConfig:
""" Class which joins whole Mailu config for dumping
"""
_models = {
'domains': models.Domain,
'relays': models.Relay,
'users': models.User,
'aliases': models.Alias,
# 'config': models.Config,
}
def __init__(self, sections):
if sections:
for section in sections:
if section not in self._models:
raise ValueError(f'Unknown section: {section!r}')
self._sections = set(sections)
else:
self._sections = set(self._models.keys())
def __getattr__(self, section):
if section in self._sections:
return self._models[section].query.all()
else:
raise AttributeError

Loading…
Cancel
Save