implemented config_update and config_dump

enhanced data model with to_dict and from_dict methods
added config_dump function to manage command
config_update now uses new data model methods
master
Alexander Graf 4 years ago
parent c26ddd3c68
commit 5c0efe82cf

@ -8,6 +8,8 @@ import os
import socket import socket
import uuid import uuid
import click import click
import yaml
import sys
db = models.db db = models.db
@ -169,148 +171,149 @@ def user_import(localpart, domain_name, password_hash, hash_scheme = None):
db.session.commit() db.session.commit()
yaml_sections = [
('domains', models.Domain),
('relays', models.Relay),
('users', models.User),
('aliases', models.Alias),
# ('config', models.Config),
]
@mailu.command() @mailu.command()
@click.option('-v', '--verbose') @click.option('-v', '--verbose', is_flag=True)
@click.option('-d', '--delete-objects') @click.option('-d', '--delete-objects', is_flag=True)
@click.option('-n', '--dry-run', is_flag=True)
@flask_cli.with_appcontext @flask_cli.with_appcontext
def config_update(verbose=False, delete_objects=False): def config_update(verbose=False, delete_objects=False, dry_run=False, file=None):
"""sync configuration with data from YAML-formatted stdin""" """sync configuration with data from YAML-formatted stdin"""
import yaml
import sys out = (lambda *args: print('(DRY RUN)', *args)) if dry_run else print
try:
new_config = yaml.safe_load(sys.stdin) new_config = yaml.safe_load(sys.stdin)
# print new_config except (yaml.scanner.ScannerError, yaml.parser.ParserError) as reason:
domains = new_config.get('domains', []) out(f'[ERROR] Invalid yaml: {reason}')
tracked_domains = set() sys.exit(1)
for domain_config in domains:
if verbose:
print(str(domain_config))
domain_name = domain_config['name']
max_users = domain_config.get('max_users', -1)
max_aliases = domain_config.get('max_aliases', -1)
max_quota_bytes = domain_config.get('max_quota_bytes', 0)
tracked_domains.add(domain_name)
domain = models.Domain.query.get(domain_name)
if not domain:
domain = models.Domain(name=domain_name,
max_users=max_users,
max_aliases=max_aliases,
max_quota_bytes=max_quota_bytes)
db.session.add(domain)
print("Added " + str(domain_config))
else: else:
domain.max_users = max_users if type(new_config) is str:
domain.max_aliases = max_aliases out(f'[ERROR] Invalid yaml: {new_config!r}')
domain.max_quota_bytes = max_quota_bytes sys.exit(1)
db.session.add(domain) elif new_config is None or not len(new_config):
print("Updated " + str(domain_config)) out('[ERROR] Empty yaml: Please pipe yaml into stdin')
sys.exit(1)
users = new_config.get('users', []) error = False
tracked_users = set() tracked = {}
user_optional_params = ('comment', 'quota_bytes', 'global_admin', for section, model in yaml_sections:
'enable_imap', 'enable_pop', 'forward_enabled',
'forward_destination', 'reply_enabled',
'reply_subject', 'reply_body', 'displayed_name',
'spam_enabled', 'email', 'spam_threshold')
for user_config in users:
if verbose:
print(str(user_config))
localpart = user_config['localpart']
domain_name = user_config['domain']
password_hash = user_config.get('password_hash', None)
hash_scheme = user_config.get('hash_scheme', None)
domain = models.Domain.query.get(domain_name)
email = '{0}@{1}'.format(localpart, domain_name)
optional_params = {}
for k in user_optional_params:
if k in user_config:
optional_params[k] = user_config[k]
if not domain:
domain = models.Domain(name=domain_name)
db.session.add(domain)
user = models.User.query.get(email)
tracked_users.add(email)
tracked_domains.add(domain_name)
if not user:
user = models.User(
localpart=localpart,
domain=domain,
**optional_params
)
else:
for k in optional_params:
setattr(user, k, optional_params[k])
user.set_password(password_hash, hash_scheme=hash_scheme, raw=True)
db.session.add(user)
aliases = new_config.get('aliases', [])
tracked_aliases = set()
for alias_config in aliases:
if verbose:
print(str(alias_config))
localpart = alias_config['localpart']
domain_name = alias_config['domain']
if type(alias_config['destination']) is str:
destination = alias_config['destination'].split(',')
else:
destination = alias_config['destination']
wildcard = alias_config.get('wildcard', False)
domain = models.Domain.query.get(domain_name)
email = '{0}@{1}'.format(localpart, domain_name)
if not domain:
domain = models.Domain(name=domain_name)
db.session.add(domain)
alias = models.Alias.query.get(email)
tracked_aliases.add(email)
tracked_domains.add(domain_name)
if not alias:
alias = models.Alias(
localpart=localpart,
domain=domain,
wildcard=wildcard,
destination=destination,
email=email
)
else:
alias.destination = destination
alias.wildcard = wildcard
db.session.add(alias)
db.session.commit()
managers = new_config.get('managers', [])
# tracked_managers=set()
for manager_config in managers:
if verbose:
print(str(manager_config))
domain_name = manager_config['domain']
user_name = manager_config['user']
domain = models.Domain.query.get(domain_name)
manageruser = models.User.query.get(user_name + '@' + domain_name)
if manageruser not in domain.managers:
domain.managers.append(manageruser)
db.session.add(domain)
db.session.commit()
items = new_config.get(section)
if items is None:
if delete_objects: if delete_objects:
for user in db.session.query(models.User).all(): out(f'[ERROR] Invalid yaml: Section "{section}" is missing')
if not (user.email in tracked_users): error = True
break
else:
continue
del new_config[section]
if type(items) is not list:
out(f'[ERROR] Section "{section}" must be a list, not {items.__class__.__name__}')
error = True
break
elif not items:
continue
# create items
for data in items:
if verbose: if verbose:
print("Deleting user: " + str(user.email)) out(f'Handling {model.__table__} data: {data!r}')
db.session.delete(user)
for alias in db.session.query(models.Alias).all(): try:
if not (alias.email in tracked_aliases): changed = model.from_dict(data, delete_objects)
except Exception as reason:
out(f'[ERROR] {reason.args[0]} in data: {data}')
error = True
break
for item, created in changed:
if created is True:
# flush newly created item
db.session.add(item)
db.session.flush()
if verbose: if verbose:
print("Deleting alias: " + str(alias.email)) out(f'Added {item!r}: {item.to_dict()}')
db.session.delete(alias) else:
for domain in db.session.query(models.Domain).all(): out(f'Added {item!r}')
if not (domain.name in tracked_domains):
elif len(created):
# modified instance
if verbose: if verbose:
print("Deleting domain: " + str(domain.name)) for key, old, new in created:
db.session.delete(domain) out(f'Updated {key!r} of {item!r}: {old!r} -> {new!r}')
else:
out(f'Updated {item!r}: {", ".join(sorted([kon[0] for kon in created]))}')
# track primary key of all items
tracked.setdefault(item.__class__, set()).update(set([item._dict_pval()]))
if error:
break
# on error: stop early
if error:
out('An error occured. Not committing changes.')
db.session.rollback()
sys.exit(1)
# are there sections left in new_config?
if new_config:
out(f'[ERROR] Unknown section(s) in yaml: {", ".join(sorted(new_config.keys()))}')
error = True
# test for conflicting domains
domains = set()
for model, items in tracked.items():
if model in (models.Domain, models.Alternative, models.Relay):
if domains & items:
for domain in domains & items:
out(f'[ERROR] Duplicate domain name used: {domain}')
error = True
domains.update(items)
# delete items not tracked
if delete_objects:
for model, items in tracked.items():
for item in model.query.all():
if not item._dict_pval() in items:
out(f'Deleted {item!r} {item}')
db.session.delete(item)
# don't commit when running dry
if dry_run:
db.session.rollback()
else:
db.session.commit() db.session.commit()
@mailu.command()
@click.option('-v', '--verbose', is_flag=True)
@click.option('-s', '--secrets', is_flag=True)
@flask_cli.with_appcontext
def config_dump(verbose=False, secrets=False):
"""dump configuration as YAML-formatted data to stdout"""
config = {}
for section, model in yaml_sections:
dump = [item.to_dict(verbose, secrets) for item in model.query.all()]
if len(dump):
config[section] = dump
yaml.dump(config, sys.stdout, default_flow_style=False, allow_unicode=True)
@mailu.command() @mailu.command()
@click.argument('email') @click.argument('email')
@flask_cli.with_appcontext @flask_cli.with_appcontext

@ -5,6 +5,7 @@ from passlib import context, hash
from datetime import datetime, date from datetime import datetime, date
from email.mime import text from email.mime import text
from flask import current_app as app from flask import current_app as app
from textwrap import wrap
import flask_sqlalchemy import flask_sqlalchemy
import sqlalchemy import sqlalchemy
@ -15,6 +16,8 @@ import glob
import smtplib import smtplib
import idna import idna
import dns import dns
import json
import itertools
db = flask_sqlalchemy.SQLAlchemy() db = flask_sqlalchemy.SQLAlchemy()
@ -32,6 +35,7 @@ class IdnaDomain(db.TypeDecorator):
def process_result_value(self, value, dialect): def process_result_value(self, value, dialect):
return idna.decode(value) return idna.decode(value)
python_type = str
class IdnaEmail(db.TypeDecorator): class IdnaEmail(db.TypeDecorator):
""" Stores a Unicode string in it's IDNA representation (ASCII only) """ Stores a Unicode string in it's IDNA representation (ASCII only)
@ -56,6 +60,7 @@ class IdnaEmail(db.TypeDecorator):
idna.decode(domain_name), idna.decode(domain_name),
) )
python_type = str
class CommaSeparatedList(db.TypeDecorator): class CommaSeparatedList(db.TypeDecorator):
""" Stores a list as a comma-separated string, compatible with Postfix. """ Stores a list as a comma-separated string, compatible with Postfix.
@ -74,6 +79,7 @@ class CommaSeparatedList(db.TypeDecorator):
def process_result_value(self, value, dialect): def process_result_value(self, value, dialect):
return list(filter(bool, value.split(","))) if value else [] return list(filter(bool, value.split(","))) if value else []
python_type = list
class JSONEncoded(db.TypeDecorator): class JSONEncoded(db.TypeDecorator):
""" Represents an immutable structure as a json-encoded string. """ Represents an immutable structure as a json-encoded string.
@ -87,6 +93,7 @@ class JSONEncoded(db.TypeDecorator):
def process_result_value(self, value, dialect): def process_result_value(self, value, dialect):
return json.loads(value) if value else None return json.loads(value) if value else None
python_type = str
class Base(db.Model): class Base(db.Model):
""" Base class for all models """ Base class for all models
@ -105,6 +112,219 @@ class Base(db.Model):
updated_at = db.Column(db.Date, nullable=True, onupdate=date.today) updated_at = db.Column(db.Date, nullable=True, onupdate=date.today)
comment = db.Column(db.String(255), nullable=True) comment = db.Column(db.String(255), nullable=True)
@classmethod
def _dict_pkey(model):
return model.__mapper__.primary_key[0].name
def _dict_pval(self):
return getattr(self, self._dict_pkey())
def to_dict(self, full=False, include_secrets=False, recursed=False, hide=None):
""" Return a dictionary representation of this model.
"""
if recursed and not getattr(self, '_dict_recurse', False):
return str(self)
hide = set(hide or []) | {'created_at', 'updated_at'}
if hasattr(self, '_dict_hide'):
hide |= self._dict_hide
secret = set()
if not include_secrets and hasattr(self, '_dict_secret'):
secret |= self._dict_secret
convert = getattr(self, '_dict_output', {})
res = {}
for key in itertools.chain(self.__table__.columns.keys(), getattr(self, '_dict_show', [])):
if key in hide:
continue
if key in self.__table__.columns:
default = self.__table__.columns[key].default
if isinstance(default, sqlalchemy.sql.schema.ColumnDefault):
default = default.arg
else:
default = None
value = getattr(self, key)
if full or ((default or value) and value != default):
if key in secret:
value = '<hidden>'
elif value is not None and key in convert:
value = convert[key](value)
res[key] = value
for key in self.__mapper__.relationships.keys():
if key in hide:
continue
if self.__mapper__.relationships[key].uselist:
items = getattr(self, key)
if self.__mapper__.relationships[key].query_class is not None:
if hasattr(items, 'all'):
items = items.all()
if full or len(items):
if key in secret:
res[key] = '<hidden>'
else:
res[key] = [item.to_dict(full, include_secrets, True) for item in items]
else:
value = getattr(self, key)
if full or value is not None:
if key in secret:
res[key] = '<hidden>'
else:
res[key] = value.to_dict(full, include_secrets, True)
return res
@classmethod
def from_dict(model, data, delete=False):
changed = []
pkey = model._dict_pkey()
# handle "primary key" only
if type(data) is not dict:
data = {pkey: data}
# modify input data
if hasattr(model, '_dict_input'):
try:
model._dict_input(data)
except Exception as reason:
raise ValueError(f'{reason}', model, None, data)
# check for primary key (if not recursed)
if not getattr(model, '_dict_recurse', False):
if not pkey in data:
raise KeyError(f'primary key {model.__table__}.{pkey} is missing', model, pkey, data)
# check data keys and values
for key, value in data.items():
# check key
if not hasattr(model, key):
raise KeyError(f'unknown key {model.__table__}.{key}', model, key, data)
# check value type
col = model.__mapper__.columns.get(key)
if col is not None:
if not type(value) is col.type.python_type:
raise TypeError(f'{model.__table__}.{key} {value!r} has invalid type {type(value).__name__!r}', model, key, data)
else:
rel = model.__mapper__.relationships.get(key)
if rel is None:
itype = getattr(model, '_dict_types', {}).get(key)
if itype is not None:
if type(value) is not itype:
raise TypeError(f'{model.__table__}.{key} {value!r} has invalid type {type(value).__name__!r}', model, key, data)
else:
raise NotImplementedError(f'type not defined for {model.__table__}.{key}')
# handle relationships
if key in model.__mapper__.relationships:
rel_model = model.__mapper__.relationships[key].argument
if not isinstance(rel_model, sqlalchemy.orm.Mapper):
add = rel_model.from_dict(value, delete)
assert len(add) == 1
item, updated = add[0]
changed.append((item, updated))
data[key] = item
# create or update item?
item = model.query.get(data[pkey]) if pkey in data else None
if item is None:
# create item
# check for mandatory keys
missing = getattr(model, '_dict_mandatory', set()) - set(data.keys())
if missing:
raise ValueError(f'mandatory key(s) {", ".join(sorted(missing))} for {model.__table__} missing', model, missing, data)
changed.append((model(**data), True))
else:
# update item
updated = []
for key, value in data.items():
# skip primary key
if key == pkey:
continue
if key in model.__mapper__.relationships:
# update relationship
rel_model = model.__mapper__.relationships[key].argument
if isinstance(rel_model, sqlalchemy.orm.Mapper):
rel_model = rel_model.class_
# add (and create) referenced items
cur = getattr(item, key)
old = sorted(cur, key=lambda i:id(i))
new = []
for rel_data in value:
# get or create related item
add = rel_model.from_dict(rel_data, delete)
assert len(add) == 1
rel_item, rel_updated = add[0]
changed.append((rel_item, rel_updated))
if rel_item not in cur:
cur.append(rel_item)
new.append(rel_item)
# delete referenced items missing in yaml
rel_pkey = rel_model._dict_pkey()
new_data = list([i.to_dict(True, True, True, [rel_pkey]) for i in new])
for rel_item in old:
if rel_item not in new:
# check if item with same data exists to stabilze import without primary key
rel_data = rel_item.to_dict(True, True, True, [rel_pkey])
try:
same_idx = new_data.index(rel_data)
except ValueError:
same = None
else:
same = new[same_idx]
if same is None:
# delete items missing in new
if delete:
cur.remove(rel_item)
else:
new.append(rel_item)
else:
# swap found item with same data with newly created item
new.append(rel_item)
new_data.append(rel_data)
new.remove(same)
del new_data[same_idx]
for i, (ch_item, ch_update) in enumerate(changed):
if ch_item is same:
changed[i] = (rel_item, [])
db.session.flush()
db.session.delete(ch_item)
break
# remember changes
new = sorted(new, key=lambda i:id(i))
if new != old:
updated.append((key, old, new))
else:
# update key
old = getattr(item, key)
if type(old) is list and not delete:
value = old + value
if value != old:
updated.append((key, old, value))
setattr(item, key, value)
changed.append((item, updated))
return changed
# Many-to-many association table for domain managers # Many-to-many association table for domain managers
managers = db.Table('manager', Base.metadata, managers = db.Table('manager', Base.metadata,
@ -126,6 +346,27 @@ class Domain(Base):
""" """
__tablename__ = "domain" __tablename__ = "domain"
_dict_hide = {'users', 'managers', 'aliases'}
_dict_show = {'dkim_key'}
_dict_secret = {'dkim_key'}
_dict_types = {'dkim_key': bytes}
_dict_output = {'dkim_key': lambda v: v.decode('utf-8').strip().split('\n')[1:-1]}
@staticmethod
def _dict_input(data):
key = data.get('dkim_key')
if key is not None:
key = data['dkim_key']
if type(key) is list:
key = ''.join(key)
if type(key) is str:
key = ''.join(key.strip().split())
if key.startswith('-----BEGIN PRIVATE KEY-----'):
key = key[25:]
if key.endswith('-----END PRIVATE KEY-----'):
key = key[:-23]
key = '\n'.join(wrap(key, 64))
data['dkim_key'] = f'-----BEGIN PRIVATE KEY-----\n{key}\n-----END PRIVATE KEY-----\n'.encode('ascii')
name = db.Column(IdnaDomain, primary_key=True, nullable=False) name = db.Column(IdnaDomain, primary_key=True, nullable=False)
managers = db.relationship('User', secondary=managers, managers = db.relationship('User', secondary=managers,
backref=db.backref('manager_of'), lazy='dynamic') backref=db.backref('manager_of'), lazy='dynamic')
@ -208,6 +449,8 @@ class Relay(Base):
__tablename__ = "relay" __tablename__ = "relay"
_dict_mandatory = {'smtp'}
name = db.Column(IdnaDomain, primary_key=True, nullable=False) name = db.Column(IdnaDomain, primary_key=True, nullable=False)
smtp = db.Column(db.String(80), nullable=True) smtp = db.Column(db.String(80), nullable=True)
@ -221,6 +464,16 @@ class Email(object):
localpart = db.Column(db.String(80), nullable=False) localpart = db.Column(db.String(80), nullable=False)
@staticmethod
def _dict_input(data):
if 'email' in data:
if 'localpart' in data or 'domain' in data:
raise ValueError('ambigous key email and localpart/domain')
elif type(data['email']) is str:
data['localpart'], data['domain'] = data['email'].rsplit('@', 1)
else:
data['email'] = f"{data['localpart']}@{data['domain']}"
@declarative.declared_attr @declarative.declared_attr
def domain_name(cls): def domain_name(cls):
return db.Column(IdnaDomain, db.ForeignKey(Domain.name), return db.Column(IdnaDomain, db.ForeignKey(Domain.name),
@ -306,6 +559,28 @@ class User(Base, Email):
""" """
__tablename__ = "user" __tablename__ = "user"
_dict_hide = {'domain_name', 'domain', 'localpart', 'quota_bytes_used'}
_dict_mandatory = {'localpart', 'domain', 'password'}
@classmethod
def _dict_input(cls, data):
Email._dict_input(data)
# handle password
if 'password' in data:
if 'password_hash' in data or 'hash_scheme' in data:
raise ValueError('ambigous key password and password_hash/hash_scheme')
# check (hashed) password
password = data['password']
if password.startswith('{') and '}' in password:
scheme = password[1:password.index('}')]
if scheme not in cls.scheme_dict:
raise ValueError(f'invalid password scheme {scheme!r}')
else:
raise ValueError(f'invalid hashed password {password!r}')
elif 'password_hash' in data and 'hash_scheme' in data:
if data['hash_scheme'] not in cls.scheme_dict:
raise ValueError(f'invalid password scheme {scheme!r}')
data['password'] = '{'+data['hash_scheme']+'}'+ data['password_hash']
domain = db.relationship(Domain, domain = db.relationship(Domain,
backref=db.backref('users', cascade='all, delete-orphan')) backref=db.backref('users', cascade='all, delete-orphan'))
password = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=False)
@ -431,6 +706,14 @@ class Alias(Base, Email):
""" """
__tablename__ = "alias" __tablename__ = "alias"
_dict_hide = {'domain_name', 'domain', 'localpart'}
@staticmethod
def _dict_input(data):
# handle comma delimited string for backwards compability
dst = data.get('destination')
if type(dst) is str:
data['destination'] = list([adr.strip() for adr in dst.split(',')])
domain = db.relationship(Domain, domain = db.relationship(Domain,
backref=db.backref('aliases', cascade='all, delete-orphan')) backref=db.backref('aliases', cascade='all, delete-orphan'))
wildcard = db.Column(db.Boolean(), nullable=False, default=False) wildcard = db.Column(db.Boolean(), nullable=False, default=False)
@ -484,6 +767,10 @@ class Token(Base):
""" """
__tablename__ = "token" __tablename__ = "token"
_dict_recurse = True
_dict_hide = {'user', 'user_email'}
_dict_mandatory = {'password'}
id = db.Column(db.Integer(), primary_key=True) id = db.Column(db.Integer(), primary_key=True)
user_email = db.Column(db.String(255), db.ForeignKey(User.email), user_email = db.Column(db.String(255), db.ForeignKey(User.email),
nullable=False) nullable=False)
@ -499,7 +786,7 @@ class Token(Base):
self.password = hash.sha256_crypt.using(rounds=1000).hash(password) self.password = hash.sha256_crypt.using(rounds=1000).hash(password)
def __str__(self): def __str__(self):
return self.comment return self.comment or self.ip
class Fetch(Base): class Fetch(Base):
@ -508,6 +795,11 @@ class Fetch(Base):
""" """
__tablename__ = "fetch" __tablename__ = "fetch"
_dict_recurse = True
_dict_hide = {'user_email', 'user', 'last_check', 'error'}
_dict_mandatory = {'protocol', 'host', 'port', 'username', 'password'}
_dict_secret = {'password'}
id = db.Column(db.Integer(), primary_key=True) id = db.Column(db.Integer(), primary_key=True)
user_email = db.Column(db.String(255), db.ForeignKey(User.email), user_email = db.Column(db.String(255), db.ForeignKey(User.email),
nullable=False) nullable=False)
@ -516,9 +808,12 @@ class Fetch(Base):
protocol = db.Column(db.Enum('imap', 'pop3'), nullable=False) protocol = db.Column(db.Enum('imap', 'pop3'), nullable=False)
host = db.Column(db.String(255), nullable=False) host = db.Column(db.String(255), nullable=False)
port = db.Column(db.Integer(), nullable=False) port = db.Column(db.Integer(), nullable=False)
tls = db.Column(db.Boolean(), nullable=False) tls = db.Column(db.Boolean(), nullable=False, default=False)
username = db.Column(db.String(255), nullable=False) username = db.Column(db.String(255), nullable=False)
password = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=False)
keep = db.Column(db.Boolean(), nullable=False) keep = db.Column(db.Boolean(), nullable=False, default=False)
last_check = db.Column(db.DateTime, nullable=True) last_check = db.Column(db.DateTime, nullable=True)
error = db.Column(db.String(1023), nullable=True) error = db.Column(db.String(1023), nullable=True)
def __str__(self):
return f'{self.protocol}{"s" if self.tls else ""}://{self.username}@{self.host}:{self.port}'

Loading…
Cancel
Save