diff --git a/core/admin/mailu/manage.py b/core/admin/mailu/manage.py index 62f214d3..bf0148df 100644 --- a/core/admin/mailu/manage.py +++ b/core/admin/mailu/manage.py @@ -8,6 +8,8 @@ import os import socket import uuid import click +import yaml +import sys db = models.db @@ -169,146 +171,147 @@ def user_import(localpart, domain_name, password_hash, hash_scheme = None): db.session.commit() +yaml_sections = [ + ('domains', models.Domain), + ('relays', models.Relay), + ('users', models.User), + ('aliases', models.Alias), +# ('config', models.Config), +] + @mailu.command() -@click.option('-v', '--verbose') -@click.option('-d', '--delete-objects') +@click.option('-v', '--verbose', is_flag=True) +@click.option('-d', '--delete-objects', is_flag=True) +@click.option('-n', '--dry-run', is_flag=True) @flask_cli.with_appcontext -def config_update(verbose=False, delete_objects=False): +def config_update(verbose=False, delete_objects=False, dry_run=False, file=None): """sync configuration with data from YAML-formatted stdin""" - import yaml - import sys - new_config = yaml.safe_load(sys.stdin) - # print new_config - domains = new_config.get('domains', []) - tracked_domains = set() - for domain_config in domains: - if verbose: - print(str(domain_config)) - domain_name = domain_config['name'] - max_users = domain_config.get('max_users', -1) - max_aliases = domain_config.get('max_aliases', -1) - max_quota_bytes = domain_config.get('max_quota_bytes', 0) - tracked_domains.add(domain_name) - domain = models.Domain.query.get(domain_name) - if not domain: - domain = models.Domain(name=domain_name, - max_users=max_users, - max_aliases=max_aliases, - max_quota_bytes=max_quota_bytes) - db.session.add(domain) - print("Added " + str(domain_config)) - else: - domain.max_users = max_users - domain.max_aliases = max_aliases - domain.max_quota_bytes = max_quota_bytes - db.session.add(domain) - print("Updated " + str(domain_config)) - users = new_config.get('users', []) - tracked_users = set() - user_optional_params = ('comment', 'quota_bytes', 'global_admin', - 'enable_imap', 'enable_pop', 'forward_enabled', - 'forward_destination', 'reply_enabled', - 'reply_subject', 'reply_body', 'displayed_name', - 'spam_enabled', 'email', 'spam_threshold') - for user_config in users: - if verbose: - print(str(user_config)) - localpart = user_config['localpart'] - domain_name = user_config['domain'] - password_hash = user_config.get('password_hash', None) - hash_scheme = user_config.get('hash_scheme', None) - domain = models.Domain.query.get(domain_name) - email = '{0}@{1}'.format(localpart, domain_name) - optional_params = {} - for k in user_optional_params: - if k in user_config: - optional_params[k] = user_config[k] - if not domain: - domain = models.Domain(name=domain_name) - db.session.add(domain) - user = models.User.query.get(email) - tracked_users.add(email) - tracked_domains.add(domain_name) - if not user: - user = models.User( - localpart=localpart, - domain=domain, - **optional_params - ) - else: - for k in optional_params: - setattr(user, k, optional_params[k]) - user.set_password(password_hash, hash_scheme=hash_scheme, raw=True) - db.session.add(user) + out = (lambda *args: print('(DRY RUN)', *args)) if dry_run else print - aliases = new_config.get('aliases', []) - tracked_aliases = set() - for alias_config in aliases: - if verbose: - print(str(alias_config)) - localpart = alias_config['localpart'] - domain_name = alias_config['domain'] - if type(alias_config['destination']) is str: - destination = alias_config['destination'].split(',') - else: - destination = alias_config['destination'] - wildcard = alias_config.get('wildcard', False) - domain = models.Domain.query.get(domain_name) - email = '{0}@{1}'.format(localpart, domain_name) - if not domain: - domain = models.Domain(name=domain_name) - db.session.add(domain) - alias = models.Alias.query.get(email) - tracked_aliases.add(email) - tracked_domains.add(domain_name) - if not alias: - alias = models.Alias( - localpart=localpart, - domain=domain, - wildcard=wildcard, - destination=destination, - email=email - ) - else: - alias.destination = destination - alias.wildcard = wildcard - db.session.add(alias) + try: + new_config = yaml.safe_load(sys.stdin) + except (yaml.scanner.ScannerError, yaml.parser.ParserError) as reason: + out(f'[ERROR] Invalid yaml: {reason}') + sys.exit(1) + else: + if type(new_config) is str: + out(f'[ERROR] Invalid yaml: {new_config!r}') + sys.exit(1) + elif new_config is None or not len(new_config): + out('[ERROR] Empty yaml: Please pipe yaml into stdin') + sys.exit(1) - db.session.commit() + error = False + tracked = {} + for section, model in yaml_sections: - managers = new_config.get('managers', []) - # tracked_managers=set() - for manager_config in managers: - if verbose: - print(str(manager_config)) - domain_name = manager_config['domain'] - user_name = manager_config['user'] - domain = models.Domain.query.get(domain_name) - manageruser = models.User.query.get(user_name + '@' + domain_name) - if manageruser not in domain.managers: - domain.managers.append(manageruser) - db.session.add(domain) + items = new_config.get(section) + if items is None: + if delete_objects: + out(f'[ERROR] Invalid yaml: Section "{section}" is missing') + error = True + break + else: + continue - db.session.commit() + del new_config[section] + if type(items) is not list: + out(f'[ERROR] Section "{section}" must be a list, not {items.__class__.__name__}') + error = True + break + elif not items: + continue + + # create items + for data in items: + + if verbose: + out(f'Handling {model.__table__} data: {data!r}') + + try: + changed = model.from_dict(data, delete_objects) + except Exception as reason: + out(f'[ERROR] {reason.args[0]} in data: {data}') + error = True + break + + for item, created in changed: + + if created is True: + # flush newly created item + db.session.add(item) + db.session.flush() + if verbose: + out(f'Added {item!r}: {item.to_dict()}') + else: + out(f'Added {item!r}') + + elif len(created): + # modified instance + if verbose: + for key, old, new in created: + out(f'Updated {key!r} of {item!r}: {old!r} -> {new!r}') + else: + out(f'Updated {item!r}: {", ".join(sorted([kon[0] for kon in created]))}') + + # track primary key of all items + tracked.setdefault(item.__class__, set()).update(set([item._dict_pval()])) + + if error: + break + + # on error: stop early + if error: + out('An error occured. Not committing changes.') + db.session.rollback() + sys.exit(1) + + # are there sections left in new_config? + if new_config: + out(f'[ERROR] Unknown section(s) in yaml: {", ".join(sorted(new_config.keys()))}') + error = True + + # test for conflicting domains + domains = set() + for model, items in tracked.items(): + if model in (models.Domain, models.Alternative, models.Relay): + if domains & items: + for domain in domains & items: + out(f'[ERROR] Duplicate domain name used: {domain}') + error = True + domains.update(items) + + # delete items not tracked if delete_objects: - for user in db.session.query(models.User).all(): - if not (user.email in tracked_users): - if verbose: - print("Deleting user: " + str(user.email)) - db.session.delete(user) - for alias in db.session.query(models.Alias).all(): - if not (alias.email in tracked_aliases): - if verbose: - print("Deleting alias: " + str(alias.email)) - db.session.delete(alias) - for domain in db.session.query(models.Domain).all(): - if not (domain.name in tracked_domains): - if verbose: - print("Deleting domain: " + str(domain.name)) - db.session.delete(domain) - db.session.commit() + for model, items in tracked.items(): + for item in model.query.all(): + if not item._dict_pval() in items: + out(f'Deleted {item!r} {item}') + db.session.delete(item) + + # don't commit when running dry + if dry_run: + db.session.rollback() + else: + db.session.commit() + + +@mailu.command() +@click.option('-v', '--verbose', is_flag=True) +@click.option('-s', '--secrets', is_flag=True) +@flask_cli.with_appcontext +def config_dump(verbose=False, secrets=False): + """dump configuration as YAML-formatted data to stdout""" + + config = {} + for section, model in yaml_sections: + dump = [item.to_dict(verbose, secrets) for item in model.query.all()] + if len(dump): + config[section] = dump + + yaml.dump(config, sys.stdout, default_flow_style=False, allow_unicode=True) @mailu.command() diff --git a/core/admin/mailu/models.py b/core/admin/mailu/models.py index 0a447758..fde4d6f1 100644 --- a/core/admin/mailu/models.py +++ b/core/admin/mailu/models.py @@ -5,6 +5,7 @@ from passlib import context, hash from datetime import datetime, date from email.mime import text from flask import current_app as app +from textwrap import wrap import flask_sqlalchemy import sqlalchemy @@ -15,6 +16,8 @@ import glob import smtplib import idna import dns +import json +import itertools db = flask_sqlalchemy.SQLAlchemy() @@ -32,6 +35,7 @@ class IdnaDomain(db.TypeDecorator): def process_result_value(self, value, dialect): return idna.decode(value) + python_type = str class IdnaEmail(db.TypeDecorator): """ Stores a Unicode string in it's IDNA representation (ASCII only) @@ -56,6 +60,7 @@ class IdnaEmail(db.TypeDecorator): idna.decode(domain_name), ) + python_type = str class CommaSeparatedList(db.TypeDecorator): """ Stores a list as a comma-separated string, compatible with Postfix. @@ -74,6 +79,7 @@ class CommaSeparatedList(db.TypeDecorator): def process_result_value(self, value, dialect): return list(filter(bool, value.split(","))) if value else [] + python_type = list class JSONEncoded(db.TypeDecorator): """ Represents an immutable structure as a json-encoded string. @@ -87,6 +93,7 @@ class JSONEncoded(db.TypeDecorator): def process_result_value(self, value, dialect): return json.loads(value) if value else None + python_type = str class Base(db.Model): """ Base class for all models @@ -105,6 +112,219 @@ class Base(db.Model): updated_at = db.Column(db.Date, nullable=True, onupdate=date.today) comment = db.Column(db.String(255), nullable=True) + @classmethod + def _dict_pkey(model): + return model.__mapper__.primary_key[0].name + + def _dict_pval(self): + return getattr(self, self._dict_pkey()) + + def to_dict(self, full=False, include_secrets=False, recursed=False, hide=None): + """ Return a dictionary representation of this model. + """ + + if recursed and not getattr(self, '_dict_recurse', False): + return str(self) + + hide = set(hide or []) | {'created_at', 'updated_at'} + if hasattr(self, '_dict_hide'): + hide |= self._dict_hide + + secret = set() + if not include_secrets and hasattr(self, '_dict_secret'): + secret |= self._dict_secret + + convert = getattr(self, '_dict_output', {}) + + res = {} + + for key in itertools.chain(self.__table__.columns.keys(), getattr(self, '_dict_show', [])): + if key in hide: + continue + if key in self.__table__.columns: + default = self.__table__.columns[key].default + if isinstance(default, sqlalchemy.sql.schema.ColumnDefault): + default = default.arg + else: + default = None + value = getattr(self, key) + if full or ((default or value) and value != default): + if key in secret: + value = '' + elif value is not None and key in convert: + value = convert[key](value) + res[key] = value + + for key in self.__mapper__.relationships.keys(): + if key in hide: + continue + if self.__mapper__.relationships[key].uselist: + items = getattr(self, key) + if self.__mapper__.relationships[key].query_class is not None: + if hasattr(items, 'all'): + items = items.all() + if full or len(items): + if key in secret: + res[key] = '' + else: + res[key] = [item.to_dict(full, include_secrets, True) for item in items] + else: + value = getattr(self, key) + if full or value is not None: + if key in secret: + res[key] = '' + else: + res[key] = value.to_dict(full, include_secrets, True) + + return res + + @classmethod + def from_dict(model, data, delete=False): + + changed = [] + + pkey = model._dict_pkey() + + # handle "primary key" only + if type(data) is not dict: + data = {pkey: data} + + # modify input data + if hasattr(model, '_dict_input'): + try: + model._dict_input(data) + except Exception as reason: + raise ValueError(f'{reason}', model, None, data) + + # check for primary key (if not recursed) + if not getattr(model, '_dict_recurse', False): + if not pkey in data: + raise KeyError(f'primary key {model.__table__}.{pkey} is missing', model, pkey, data) + + # check data keys and values + for key, value in data.items(): + + # check key + if not hasattr(model, key): + raise KeyError(f'unknown key {model.__table__}.{key}', model, key, data) + + # check value type + col = model.__mapper__.columns.get(key) + if col is not None: + if not type(value) is col.type.python_type: + raise TypeError(f'{model.__table__}.{key} {value!r} has invalid type {type(value).__name__!r}', model, key, data) + else: + rel = model.__mapper__.relationships.get(key) + if rel is None: + itype = getattr(model, '_dict_types', {}).get(key) + if itype is not None: + if type(value) is not itype: + raise TypeError(f'{model.__table__}.{key} {value!r} has invalid type {type(value).__name__!r}', model, key, data) + else: + raise NotImplementedError(f'type not defined for {model.__table__}.{key}') + + # handle relationships + if key in model.__mapper__.relationships: + rel_model = model.__mapper__.relationships[key].argument + if not isinstance(rel_model, sqlalchemy.orm.Mapper): + add = rel_model.from_dict(value, delete) + assert len(add) == 1 + item, updated = add[0] + changed.append((item, updated)) + data[key] = item + + # create or update item? + item = model.query.get(data[pkey]) if pkey in data else None + if item is None: + # create item + + # check for mandatory keys + missing = getattr(model, '_dict_mandatory', set()) - set(data.keys()) + if missing: + raise ValueError(f'mandatory key(s) {", ".join(sorted(missing))} for {model.__table__} missing', model, missing, data) + + changed.append((model(**data), True)) + + else: + # update item + + updated = [] + for key, value in data.items(): + + # skip primary key + if key == pkey: + continue + + if key in model.__mapper__.relationships: + # update relationship + rel_model = model.__mapper__.relationships[key].argument + if isinstance(rel_model, sqlalchemy.orm.Mapper): + rel_model = rel_model.class_ + # add (and create) referenced items + cur = getattr(item, key) + old = sorted(cur, key=lambda i:id(i)) + new = [] + for rel_data in value: + # get or create related item + add = rel_model.from_dict(rel_data, delete) + assert len(add) == 1 + rel_item, rel_updated = add[0] + changed.append((rel_item, rel_updated)) + if rel_item not in cur: + cur.append(rel_item) + new.append(rel_item) + + # delete referenced items missing in yaml + rel_pkey = rel_model._dict_pkey() + new_data = list([i.to_dict(True, True, True, [rel_pkey]) for i in new]) + for rel_item in old: + if rel_item not in new: + # check if item with same data exists to stabilze import without primary key + rel_data = rel_item.to_dict(True, True, True, [rel_pkey]) + try: + same_idx = new_data.index(rel_data) + except ValueError: + same = None + else: + same = new[same_idx] + + if same is None: + # delete items missing in new + if delete: + cur.remove(rel_item) + else: + new.append(rel_item) + else: + # swap found item with same data with newly created item + new.append(rel_item) + new_data.append(rel_data) + new.remove(same) + del new_data[same_idx] + for i, (ch_item, ch_update) in enumerate(changed): + if ch_item is same: + changed[i] = (rel_item, []) + db.session.flush() + db.session.delete(ch_item) + break + + # remember changes + new = sorted(new, key=lambda i:id(i)) + if new != old: + updated.append((key, old, new)) + + else: + # update key + old = getattr(item, key) + if type(old) is list and not delete: + value = old + value + if value != old: + updated.append((key, old, value)) + setattr(item, key, value) + + changed.append((item, updated)) + + return changed + # Many-to-many association table for domain managers managers = db.Table('manager', Base.metadata, @@ -126,6 +346,27 @@ class Domain(Base): """ __tablename__ = "domain" + _dict_hide = {'users', 'managers', 'aliases'} + _dict_show = {'dkim_key'} + _dict_secret = {'dkim_key'} + _dict_types = {'dkim_key': bytes} + _dict_output = {'dkim_key': lambda v: v.decode('utf-8').strip().split('\n')[1:-1]} + @staticmethod + def _dict_input(data): + key = data.get('dkim_key') + if key is not None: + key = data['dkim_key'] + if type(key) is list: + key = ''.join(key) + if type(key) is str: + key = ''.join(key.strip().split()) + if key.startswith('-----BEGIN PRIVATE KEY-----'): + key = key[25:] + if key.endswith('-----END PRIVATE KEY-----'): + key = key[:-23] + key = '\n'.join(wrap(key, 64)) + data['dkim_key'] = f'-----BEGIN PRIVATE KEY-----\n{key}\n-----END PRIVATE KEY-----\n'.encode('ascii') + name = db.Column(IdnaDomain, primary_key=True, nullable=False) managers = db.relationship('User', secondary=managers, backref=db.backref('manager_of'), lazy='dynamic') @@ -208,6 +449,8 @@ class Relay(Base): __tablename__ = "relay" + _dict_mandatory = {'smtp'} + name = db.Column(IdnaDomain, primary_key=True, nullable=False) smtp = db.Column(db.String(80), nullable=True) @@ -221,6 +464,16 @@ class Email(object): localpart = db.Column(db.String(80), nullable=False) + @staticmethod + def _dict_input(data): + if 'email' in data: + if 'localpart' in data or 'domain' in data: + raise ValueError('ambigous key email and localpart/domain') + elif type(data['email']) is str: + data['localpart'], data['domain'] = data['email'].rsplit('@', 1) + else: + data['email'] = f"{data['localpart']}@{data['domain']}" + @declarative.declared_attr def domain_name(cls): return db.Column(IdnaDomain, db.ForeignKey(Domain.name), @@ -306,6 +559,28 @@ class User(Base, Email): """ __tablename__ = "user" + _dict_hide = {'domain_name', 'domain', 'localpart', 'quota_bytes_used'} + _dict_mandatory = {'localpart', 'domain', 'password'} + @classmethod + def _dict_input(cls, data): + Email._dict_input(data) + # handle password + if 'password' in data: + if 'password_hash' in data or 'hash_scheme' in data: + raise ValueError('ambigous key password and password_hash/hash_scheme') + # check (hashed) password + password = data['password'] + if password.startswith('{') and '}' in password: + scheme = password[1:password.index('}')] + if scheme not in cls.scheme_dict: + raise ValueError(f'invalid password scheme {scheme!r}') + else: + raise ValueError(f'invalid hashed password {password!r}') + elif 'password_hash' in data and 'hash_scheme' in data: + if data['hash_scheme'] not in cls.scheme_dict: + raise ValueError(f'invalid password scheme {scheme!r}') + data['password'] = '{'+data['hash_scheme']+'}'+ data['password_hash'] + domain = db.relationship(Domain, backref=db.backref('users', cascade='all, delete-orphan')) password = db.Column(db.String(255), nullable=False) @@ -431,6 +706,14 @@ class Alias(Base, Email): """ __tablename__ = "alias" + _dict_hide = {'domain_name', 'domain', 'localpart'} + @staticmethod + def _dict_input(data): + # handle comma delimited string for backwards compability + dst = data.get('destination') + if type(dst) is str: + data['destination'] = list([adr.strip() for adr in dst.split(',')]) + domain = db.relationship(Domain, backref=db.backref('aliases', cascade='all, delete-orphan')) wildcard = db.Column(db.Boolean(), nullable=False, default=False) @@ -484,6 +767,10 @@ class Token(Base): """ __tablename__ = "token" + _dict_recurse = True + _dict_hide = {'user', 'user_email'} + _dict_mandatory = {'password'} + id = db.Column(db.Integer(), primary_key=True) user_email = db.Column(db.String(255), db.ForeignKey(User.email), nullable=False) @@ -499,7 +786,7 @@ class Token(Base): self.password = hash.sha256_crypt.using(rounds=1000).hash(password) def __str__(self): - return self.comment + return self.comment or self.ip class Fetch(Base): @@ -508,6 +795,11 @@ class Fetch(Base): """ __tablename__ = "fetch" + _dict_recurse = True + _dict_hide = {'user_email', 'user', 'last_check', 'error'} + _dict_mandatory = {'protocol', 'host', 'port', 'username', 'password'} + _dict_secret = {'password'} + id = db.Column(db.Integer(), primary_key=True) user_email = db.Column(db.String(255), db.ForeignKey(User.email), nullable=False) @@ -516,9 +808,12 @@ class Fetch(Base): protocol = db.Column(db.Enum('imap', 'pop3'), nullable=False) host = db.Column(db.String(255), nullable=False) port = db.Column(db.Integer(), nullable=False) - tls = db.Column(db.Boolean(), nullable=False) + tls = db.Column(db.Boolean(), nullable=False, default=False) username = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=False) - keep = db.Column(db.Boolean(), nullable=False) + keep = db.Column(db.Boolean(), nullable=False, default=False) last_check = db.Column(db.DateTime, nullable=True) error = db.Column(db.String(1023), nullable=True) + + def __str__(self): + return f'{self.protocol}{"s" if self.tls else ""}://{self.username}@{self.host}:{self.port}'