diff --git a/core/admin/mailu/models.py b/core/admin/mailu/models.py index 73d05801..3187f597 100644 --- a/core/admin/mailu/models.py +++ b/core/admin/mailu/models.py @@ -8,6 +8,7 @@ import json from datetime import date from email.mime import text +from itertools import chain import flask_sqlalchemy import sqlalchemy @@ -30,11 +31,12 @@ class IdnaDomain(db.TypeDecorator): """ Stores a Unicode string in it's IDNA representation (ASCII only) """ + # TODO: String(80) is too small? impl = db.String(80) def process_bind_param(self, value, dialect): """ encode unicode domain name to punycode """ - return idna.encode(value).decode('ascii').lower() + return idna.encode(value.lower()).decode('ascii') def process_result_value(self, value, dialect): """ decode punycode domain name to unicode """ @@ -46,26 +48,21 @@ class IdnaEmail(db.TypeDecorator): """ Stores a Unicode string in it's IDNA representation (ASCII only) """ + # TODO: String(255) is too small? impl = db.String(255) def process_bind_param(self, value, dialect): """ encode unicode domain part of email address to punycode """ - try: - localpart, domain_name = value.split('@') - return '{0}@{1}'.format( - localpart, - idna.encode(domain_name).decode('ascii'), - ).lower() - except ValueError: - pass + localpart, domain_name = value.rsplit('@', 1) + if '@' in localpart: + raise ValueError('email local part must not contain "@"') + domain_name = domain_name.lower() + return f'{localpart}@{idna.encode(domain_name).decode("ascii")}' def process_result_value(self, value, dialect): """ decode punycode domain part of email to unicode """ - localpart, domain_name = value.split('@') - return '{0}@{1}'.format( - localpart, - idna.decode(domain_name), - ) + localpart, domain_name = value.rsplit('@', 1) + return f'{localpart}@{idna.decode(domain_name)}' python_type = str @@ -81,7 +78,7 @@ class CommaSeparatedList(db.TypeDecorator): raise TypeError('Must be a list of strings') for item in value: if ',' in item: - raise ValueError('Item must not contain a comma') + raise ValueError('list item must not contain ","') return ','.join(sorted(value)) def process_result_value(self, value, dialect): @@ -123,173 +120,6 @@ class Base(db.Model): updated_at = db.Column(db.Date, nullable=True, onupdate=date.today) comment = db.Column(db.String(255), nullable=True, default='') - # @classmethod - # def from_dict(cls, data, delete=False): - - # changed = [] - - # pkey = cls._dict_pkey() - - # # handle "primary key" only - # if not isinstance(data, dict): - # data = {pkey: data} - - # # modify input data - # if hasattr(cls, '_dict_input'): - # try: - # cls._dict_input(data) - # except Exception as exc: - # raise ValueError(f'{exc}', cls, None, data) from exc - - # # check for primary key (if not recursed) - # if not getattr(cls, '_dict_recurse', False): - # if not pkey in data: - # raise KeyError(f'primary key {cls.__table__}.{pkey} is missing', cls, pkey, data) - - # # check data keys and values - # for key in list(data.keys()): - - # # check key - # if not hasattr(cls, key) and not key in cls.__mapper__.relationships: - # raise KeyError(f'unknown key {cls.__table__}.{key}', cls, key, data) - - # # check value type - # value = data[key] - # col = cls.__mapper__.columns.get(key) - # if col is not None: - # if not ((value is None and col.nullable) or (isinstance(value, col.type.python_type))): - # raise TypeError(f'{cls.__table__}.{key} {value!r} has invalid type {type(value).__name__!r}', cls, key, data) - # else: - # rel = cls.__mapper__.relationships.get(key) - # if rel is None: - # itype = getattr(cls, '_dict_types', {}).get(key) - # if itype is not None: - # if itype is False: # ignore value. TODO: emit warning? - # del data[key] - # continue - # elif not isinstance(value, itype): - # raise TypeError(f'{cls.__table__}.{key} {value!r} has invalid type {type(value).__name__!r}', cls, key, data) - # else: - # raise NotImplementedError(f'type not defined for {cls.__table__}.{key}') - - # # handle relationships - # if key in cls.__mapper__.relationships: - # rel_model = cls.__mapper__.relationships[key].argument - # if not isinstance(rel_model, sqlalchemy.orm.Mapper): - # add = rel_model.from_dict(value, delete) - # assert len(add) == 1 - # rel_item, updated = add[0] - # changed.append((rel_item, updated)) - # data[key] = rel_item - - # # create item if necessary - # created = False - # item = cls.query.get(data[pkey]) if pkey in data else None - # if item is None: - - # # check for mandatory keys - # missing = getattr(cls, '_dict_mandatory', set()) - set(data.keys()) - # if missing: - # raise ValueError(f'mandatory key(s) {", ".join(sorted(missing))} for {cls.__table__} missing', cls, missing, data) - - # # remove mapped relationships from data - # mapped = {} - # for key in list(data.keys()): - # if key in cls.__mapper__.relationships: - # if isinstance(cls.__mapper__.relationships[key].argument, sqlalchemy.orm.Mapper): - # mapped[key] = data[key] - # del data[key] - - # # create new item - # item = cls(**data) - # created = True - - # # and update mapped relationships (below) - # data = mapped - - # # update item - # updated = [] - # for key, value in data.items(): - - # # skip primary key - # if key == pkey: - # continue - - # if key in cls.__mapper__.relationships: - # # update relationship - # rel_model = cls.__mapper__.relationships[key].argument - # if isinstance(rel_model, sqlalchemy.orm.Mapper): - # rel_model = rel_model.class_ - # # add (and create) referenced items - # cur = getattr(item, key) - # old = sorted(cur, key=id) - # new = [] - # for rel_data in value: - # # get or create related item - # add = rel_model.from_dict(rel_data, delete) - # assert len(add) == 1 - # rel_item, rel_updated = add[0] - # changed.append((rel_item, rel_updated)) - # if rel_item not in cur: - # cur.append(rel_item) - # new.append(rel_item) - - # # delete referenced items missing in yaml - # rel_pkey = rel_model._dict_pkey() - # new_data = list([i.to_dict(True, True, None, True, [rel_pkey]) for i in new]) - # for rel_item in old: - # if rel_item not in new: - # # check if item with same data exists to stabilze import without primary key - # rel_data = rel_item.to_dict(True, True, None, True, [rel_pkey]) - # try: - # same_idx = new_data.index(rel_data) - # except ValueError: - # same = None - # else: - # same = new[same_idx] - - # if same is None: - # # delete items missing in new - # if delete: - # cur.remove(rel_item) - # else: - # new.append(rel_item) - # else: - # # swap found item with same data with newly created item - # new.append(rel_item) - # new_data.append(rel_data) - # new.remove(same) - # del new_data[same_idx] - # for i, (ch_item, _) in enumerate(changed): - # if ch_item is same: - # changed[i] = (rel_item, []) - # db.session.flush() - # db.session.delete(ch_item) - # break - - # # remember changes - # new = sorted(new, key=id) - # if new != old: - # updated.append((key, old, new)) - - # else: - # # update key - # old = getattr(item, key) - # if isinstance(old, list): - # # deduplicate list value - # assert isinstance(value, list) - # value = set(value) - # old = set(old) - # if not delete: - # value = old | value - # if value != old: - # updated.append((key, old, value)) - # setattr(item, key, value) - - # changed.append((item, created if created else updated)) - - # return changed - # Many-to-many association table for domain managers managers = db.Table('manager', Base.metadata, @@ -309,9 +139,7 @@ class Config(Base): # TODO: use sqlalchemy.event.listen() on a store method of object? @sqlalchemy.event.listens_for(db.session, 'after_commit') def store_dkim_key(session): - """ Store DKIM key on commit - """ - + """ Store DKIM key on commit """ for obj in session.identity_map.values(): if isinstance(obj, Domain): if obj._dkim_key_changed: @@ -340,21 +168,27 @@ class Domain(Base): _dkim_key_changed = False def _dkim_file(self): + """ return filename for active DKIM key """ return app.config['DKIM_PATH'].format( - domain=self.name, selector=app.config['DKIM_SELECTOR']) + domain=self.name, + selector=app.config['DKIM_SELECTOR'] + ) @property def dns_mx(self): - hostname = app.config['HOSTNAMES'].split(',')[0] + """ return MX record for domain """ + hostname = app.config['HOSTNAMES'].split(',', 1)[0] return f'{self.name}. 600 IN MX 10 {hostname}.' @property def dns_spf(self): - hostname = app.config['HOSTNAMES'].split(',')[0] + """ return SPF record for domain """ + hostname = app.config['HOSTNAMES'].split(',', 1)[0] return f'{self.name}. 600 IN TXT "v=spf1 mx a:{hostname} ~all"' @property def dns_dkim(self): + """ return DKIM record for domain """ if os.path.exists(self._dkim_file()): selector = app.config['DKIM_SELECTOR'] return ( @@ -364,6 +198,7 @@ class Domain(Base): @property def dns_dmarc(self): + """ return DMARC record for domain """ if os.path.exists(self._dkim_file()): domain = app.config['DOMAIN'] rua = app.config['DMARC_RUA'] @@ -374,6 +209,7 @@ class Domain(Base): @property def dkim_key(self): + """ return private DKIM key """ if self._dkim_key is None: file_path = self._dkim_file() if os.path.exists(file_path): @@ -385,6 +221,7 @@ class Domain(Base): @dkim_key.setter def dkim_key(self, value): + """ set private DKIM key """ old_key = self.dkim_key if value is None: value = b'' @@ -393,36 +230,40 @@ class Domain(Base): @property def dkim_publickey(self): + """ return public part of DKIM key """ dkim_key = self.dkim_key if dkim_key: return dkim.strip_key(dkim_key).decode('utf8') def generate_dkim_key(self): + """ generate and activate new DKIM key """ self.dkim_key = dkim.gen_key() def has_email(self, localpart): - for email in self.users + self.aliases: + """ checks if localpart is configured for domain """ + for email in chain(self.users, self.aliases): if email.localpart == localpart: return True return False def check_mx(self): + """ checks if MX record for domain points to mailu host """ try: - hostnames = app.config['HOSTNAMES'].split(',') + hostnames = set(app.config['HOSTNAMES'].split(',')) return any( - str(rset).split()[-1][:-1] in hostnames + rset.exchange.to_text().rstrip('.') in hostnames for rset in dns.resolver.query(self.name, 'MX') ) - except Exception: + except dns.exception.DNSException: return False def __str__(self): return str(self.name) def __eq__(self, other): - try: - return self.name == other.name - except AttributeError: + if isinstance(other, self.__class__): + return str(self.name) == str(other.name) + else: return NotImplemented def __hash__(self): @@ -432,7 +273,7 @@ class Domain(Base): class Alternative(Base): """ Alternative name for a served domain. - The name "domain alias" was avoided to prevent some confusion. + The name "domain alias" was avoided to prevent some confusion. """ __tablename__ = 'alternative' @@ -454,6 +295,7 @@ class Relay(Base): __tablename__ = 'relay' name = db.Column(IdnaDomain, primary_key=True, nullable=False) + # TODO: String(80) is too small? smtp = db.Column(db.String(80), nullable=True) def __str__(self): @@ -464,10 +306,14 @@ class Email(object): """ Abstraction for an email address (localpart and domain). """ + # TODO: validate max. total length of address (<=254) + + # TODO: String(80) is too large (>64)? localpart = db.Column(db.String(80), nullable=False) @declarative.declared_attr def domain_name(self): + """ the domain part of the email address """ return db.Column(IdnaDomain, db.ForeignKey(Domain.name), nullable=False, default=IdnaDomain) @@ -476,26 +322,18 @@ class Email(object): # especially when the mail server is reading the database. @declarative.declared_attr def email(self): - updater = lambda context: '{0}@{1}'.format( - context.current_parameters['localpart'], - context.current_parameters['domain_name'], - ) + """ the complete email address (localpart@domain) """ + updater = lambda ctx: '{localpart}@{domain_name}'.format(**ctx.current_parameters) return db.Column(IdnaEmail, primary_key=True, nullable=False, - default=updater) + default=updater + ) def sendmail(self, subject, body): - """ Send an email to the address. - """ - from_address = '{0}@{1}'.format( - app.config['POSTMASTER'], - idna.encode(app.config['DOMAIN']).decode('ascii'), - ) + """ send an email to the address """ + from_address = f'{app.config["POSTMASTER"]}@{idna.encode(app.config["DOMAIN"]).decode("ascii")}' with smtplib.SMTP(app.config['HOST_AUTHSMTP'], port=10025) as smtp: - to_address = '{0}@{1}'.format( - self.localpart, - idna.encode(self.domain_name).decode('ascii'), - ) + to_address = f'{self.localpart}@{idna.encode(self.domain_name).decode("ascii")}' msg = text.MIMEText(body) msg['Subject'] = subject msg['From'] = from_address @@ -504,7 +342,8 @@ class Email(object): @classmethod def resolve_domain(cls, email): - localpart, domain_name = email.split('@', 1) if '@' in email else (None, email) + """ resolves domain alternative to real domain """ + localpart, domain_name = email.rsplit('@', 1) if '@' in email else (None, email) alternative = Alternative.query.get(domain_name) if alternative: domain_name = alternative.domain_name @@ -512,17 +351,19 @@ class Email(object): @classmethod def resolve_destination(cls, localpart, domain_name, ignore_forward_keep=False): + """ return destination for email address localpart@domain_name """ + localpart_stripped = None stripped_alias = None if os.environ.get('RECIPIENT_DELIMITER') in localpart: localpart_stripped = localpart.rsplit(os.environ.get('RECIPIENT_DELIMITER'), 1)[0] - user = User.query.get('{}@{}'.format(localpart, domain_name)) + user = User.query.get(f'{localpart}@{domain_name}') if not user and localpart_stripped: - user = User.query.get('{}@{}'.format(localpart_stripped, domain_name)) + user = User.query.get(f'{localpart_stripped}@{domain_name}') if user: - email = '{}@{}'.format(localpart, domain_name) + email = f'{localpart}@{domain_name}' if user.forward_enabled: destination = user.forward_destination @@ -537,11 +378,15 @@ class Email(object): if pure_alias and not pure_alias.wildcard: return pure_alias.destination - elif stripped_alias: + + if stripped_alias: return stripped_alias.destination - elif pure_alias: + + if pure_alias: return pure_alias.destination + return None + def __str__(self): return str(self.email) @@ -586,11 +431,15 @@ class User(Base, Email): is_active = True is_anonymous = False + # TODO: remove unused user.get_id() def get_id(self): + """ return users email address """ return self.email + # TODO: remove unused user.destination @property def destination(self): + """ returns comma separated string of destinations """ if self.forward_enabled: result = list(self.forward_destination) if self.forward_keep: @@ -601,6 +450,7 @@ class User(Base, Email): @property def reply_active(self): + """ returns status of autoreply function """ now = date.today() return ( self.reply_enabled and @@ -608,49 +458,56 @@ class User(Base, Email): self.reply_enddate > now ) - scheme_dict = {'PBKDF2': 'pbkdf2_sha512', - 'BLF-CRYPT': 'bcrypt', - 'SHA512-CRYPT': 'sha512_crypt', - 'SHA256-CRYPT': 'sha256_crypt', - 'MD5-CRYPT': 'md5_crypt', - 'CRYPT': 'des_crypt'} + scheme_dict = { + 'PBKDF2': 'pbkdf2_sha512', + 'BLF-CRYPT': 'bcrypt', + 'SHA512-CRYPT': 'sha512_crypt', + 'SHA256-CRYPT': 'sha256_crypt', + 'MD5-CRYPT': 'md5_crypt', + 'CRYPT': 'des_crypt', + } - def get_password_context(self): + def _get_password_context(self): return passlib.context.CryptContext( schemes=self.scheme_dict.values(), default=self.scheme_dict[app.config['PASSWORD_SCHEME']], ) - def check_password(self, password): - context = self.get_password_context() - reference = re.match('({[^}]+})?(.*)', self.password).group(2) - result = context.verify(password, reference) - if result and context.identify(reference) != context.default_scheme(): - self.set_password(password) + def check_password(self, plain): + """ Check password against stored hash + Update hash when default scheme has changed + """ + context = self._get_password_context() + hashed = re.match('^({[^}]+})?(.*)$', self.password).group(2) + result = context.verify(plain, hashed) + if result and context.identify(hashed) != context.default_scheme(): + self.set_password(plain) db.session.add(self) db.session.commit() return result - def set_password(self, password, hash_scheme=None, raw=False): - """Set password for user with specified encryption scheme - @password: plain text password to encrypt (if raw == True the hash itself) + # TODO: remove kwarg hash_scheme - there is no point in setting a scheme, + # when the next check updates the password to the default scheme. + def set_password(self, new, hash_scheme=None, raw=False): + """ Set password for user with specified encryption scheme + @new: plain text password to encrypt (or, if raw is True: the hash itself) """ + # for the list of hash schemes see https://wiki2.dovecot.org/Authentication/PasswordSchemes if hash_scheme is None: hash_scheme = app.config['PASSWORD_SCHEME'] - # for the list of hash schemes see https://wiki2.dovecot.org/Authentication/PasswordSchemes - if raw: - self.password = '{'+hash_scheme+'}' + password - else: - self.password = '{'+hash_scheme+'}' + \ - self.get_password_context().encrypt(password, self.scheme_dict[hash_scheme]) + if not raw: + new = self._get_password_context().encrypt(new, self.scheme_dict[hash_scheme]) + self.password = f'{{{hash_scheme}}}{new}' def get_managed_domains(self): + """ return list of domains this user can manage """ if self.global_admin: return Domain.query.all() else: return self.manager_of def get_managed_emails(self, include_aliases=True): + """ returns list of email addresses this user can manage """ emails = [] for domain in self.get_managed_domains(): emails.extend(domain.users) @@ -659,16 +516,18 @@ class User(Base, Email): return emails def send_welcome(self): + """ send welcome email to user """ if app.config['WELCOME']: - self.sendmail(app.config['WELCOME_SUBJECT'], - app.config['WELCOME_BODY']) + self.sendmail(app.config['WELCOME_SUBJECT'], app.config['WELCOME_BODY']) @classmethod def get(cls, email): + """ find user object for email address """ return cls.query.get(email) @classmethod def login(cls, email, password): + """ login user when enabled and password is valid """ user = cls.query.get(email) return user if (user and user.enabled and user.check_password(password)) else None @@ -686,6 +545,8 @@ class Alias(Base, Email): @classmethod def resolve(cls, localpart, domain_name): + """ find aliases matching email address localpart@domain_name """ + alias_preserve_case = cls.query.filter( sqlalchemy.and_(cls.domain_name == domain_name, sqlalchemy.or_( @@ -709,24 +570,27 @@ class Alias(Base, Email): sqlalchemy.func.lower(cls.localpart) == localpart_lower ), sqlalchemy.and_( cls.wildcard is True, - sqlalchemy.bindparam('l', localpart_lower).like(sqlalchemy.func.lower(cls.localpart)) + sqlalchemy.bindparam('l', localpart_lower).like( + sqlalchemy.func.lower(cls.localpart)) ) ) ) - ).order_by(cls.wildcard, sqlalchemy.func.char_length(sqlalchemy.func.lower(cls.localpart)).desc()).first() + ).order_by(cls.wildcard, sqlalchemy.func.char_length( + sqlalchemy.func.lower(cls.localpart)).desc()).first() if alias_preserve_case and alias_lower_case: - if alias_preserve_case.wildcard: - return alias_lower_case - else: - return alias_preserve_case - elif alias_preserve_case and not alias_lower_case: - return alias_preserve_case - elif alias_lower_case and not alias_preserve_case: - return alias_lower_case - else: - return None + return alias_lower_case if alias_preserve_case.wildcard else alias_preserve_case + if alias_preserve_case and not alias_lower_case: + return alias_preserve_case + + if alias_lower_case and not alias_preserve_case: + return alias_lower_case + + return None + +# TODO: where are Tokens used / validated? +# TODO: what about API tokens? class Token(Base): """ A token is an application password for a given user. """ @@ -739,16 +603,20 @@ class Token(Base): user = db.relationship(User, backref=db.backref('tokens', cascade='all, delete-orphan')) password = db.Column(db.String(255), nullable=False) + # TODO: String(80) is too large? ip = db.Column(db.String(255)) def check_password(self, password): + """ verifies password against stored hash """ return passlib.hash.sha256_crypt.verify(password, self.password) + # TODO: use crypt context and default scheme from config? def set_password(self, password): + """ sets password using sha256_crypt(rounds=1000) """ self.password = passlib.hash.sha256_crypt.using(rounds=1000).hash(password) def __str__(self): - return self.comment or self.ip + return str(self.comment or self.ip) class Fetch(Base):