diff --git a/core/admin/mailu/models.py b/core/admin/mailu/models.py index fde4d6f1..cbafc6a4 100644 --- a/core/admin/mailu/models.py +++ b/core/admin/mailu/models.py @@ -69,12 +69,12 @@ class CommaSeparatedList(db.TypeDecorator): impl = db.String def process_bind_param(self, value, dialect): - if type(value) is not list: + if not isinstance(value, (list, set)): raise TypeError("Should be a list") for item in value: if "," in item: raise ValueError("Item must not contain a comma") - return ",".join(value) + return ",".join(sorted(value)) def process_result_value(self, value, dialect): return list(filter(bool, value.split(","))) if value else [] @@ -205,13 +205,13 @@ class Base(db.Model): for key, value in data.items(): # check key - if not hasattr(model, key): + if not hasattr(model, key) and not key in model.__mapper__.relationships: raise KeyError(f'unknown key {model.__table__}.{key}', model, key, data) # check value type col = model.__mapper__.columns.get(key) if col is not None: - if not type(value) is col.type.python_type: + if not ((value is None and col.nullable) or (type(value) is col.type.python_type)): raise TypeError(f'{model.__table__}.{key} {value!r} has invalid type {type(value).__name__!r}', model, key, data) else: rel = model.__mapper__.relationships.get(key) @@ -229,99 +229,115 @@ class Base(db.Model): if not isinstance(rel_model, sqlalchemy.orm.Mapper): add = rel_model.from_dict(value, delete) assert len(add) == 1 - item, updated = add[0] - changed.append((item, updated)) - data[key] = item + rel_item, updated = add[0] + changed.append((rel_item, updated)) + data[key] = rel_item - # create or update item? + # create item if necessary + created = False item = model.query.get(data[pkey]) if pkey in data else None if item is None: - # create item # check for mandatory keys missing = getattr(model, '_dict_mandatory', set()) - set(data.keys()) if missing: raise ValueError(f'mandatory key(s) {", ".join(sorted(missing))} for {model.__table__} missing', model, missing, data) - changed.append((model(**data), True)) - - else: - # update item - - updated = [] - for key, value in data.items(): - - # skip primary key - if key == pkey: - continue - + # remove mapped relationships from data + mapped = {} + for key in list(data.keys()): if key in model.__mapper__.relationships: - # update relationship - rel_model = model.__mapper__.relationships[key].argument - if isinstance(rel_model, sqlalchemy.orm.Mapper): - rel_model = rel_model.class_ - # add (and create) referenced items - cur = getattr(item, key) - old = sorted(cur, key=lambda i:id(i)) - new = [] - for rel_data in value: - # get or create related item - add = rel_model.from_dict(rel_data, delete) - assert len(add) == 1 - rel_item, rel_updated = add[0] - changed.append((rel_item, rel_updated)) - if rel_item not in cur: - cur.append(rel_item) - new.append(rel_item) + if isinstance(model.__mapper__.relationships[key].argument, sqlalchemy.orm.Mapper): + mapped[key] = data[key] + del data[key] - # delete referenced items missing in yaml - rel_pkey = rel_model._dict_pkey() - new_data = list([i.to_dict(True, True, True, [rel_pkey]) for i in new]) - for rel_item in old: - if rel_item not in new: - # check if item with same data exists to stabilze import without primary key - rel_data = rel_item.to_dict(True, True, True, [rel_pkey]) - try: - same_idx = new_data.index(rel_data) - except ValueError: - same = None - else: - same = new[same_idx] + # create new item + item = model(**data) + created = True - if same is None: - # delete items missing in new - if delete: - cur.remove(rel_item) - else: - new.append(rel_item) + # and update mapped relationships (below) + data = mapped + + # update item + updated = [] + for key, value in data.items(): + + # skip primary key + if key == pkey: + continue + + if key in model.__mapper__.relationships: + # update relationship + rel_model = model.__mapper__.relationships[key].argument + if isinstance(rel_model, sqlalchemy.orm.Mapper): + rel_model = rel_model.class_ + # add (and create) referenced items + cur = getattr(item, key) + old = sorted(cur, key=lambda i:id(i)) + new = [] + for rel_data in value: + # get or create related item + add = rel_model.from_dict(rel_data, delete) + assert len(add) == 1 + rel_item, rel_updated = add[0] + changed.append((rel_item, rel_updated)) + if rel_item not in cur: + cur.append(rel_item) + new.append(rel_item) + + # delete referenced items missing in yaml + rel_pkey = rel_model._dict_pkey() + new_data = list([i.to_dict(True, True, True, [rel_pkey]) for i in new]) + for rel_item in old: + if rel_item not in new: + # check if item with same data exists to stabilze import without primary key + rel_data = rel_item.to_dict(True, True, True, [rel_pkey]) + try: + same_idx = new_data.index(rel_data) + except ValueError: + same = None + else: + same = new[same_idx] + + if same is None: + # delete items missing in new + if delete: + cur.remove(rel_item) else: - # swap found item with same data with newly created item new.append(rel_item) - new_data.append(rel_data) - new.remove(same) - del new_data[same_idx] - for i, (ch_item, ch_update) in enumerate(changed): - if ch_item is same: - changed[i] = (rel_item, []) - db.session.flush() - db.session.delete(ch_item) - break + else: + # swap found item with same data with newly created item + new.append(rel_item) + new_data.append(rel_data) + new.remove(same) + del new_data[same_idx] + for i, (ch_item, ch_update) in enumerate(changed): + if ch_item is same: + changed[i] = (rel_item, []) + db.session.flush() + db.session.delete(ch_item) + break - # remember changes - new = sorted(new, key=lambda i:id(i)) - if new != old: - updated.append((key, old, new)) + # remember changes + new = sorted(new, key=lambda i:id(i)) + if new != old: + updated.append((key, old, new)) - else: - # update key - old = getattr(item, key) - if type(old) is list and not delete: - value = old + value - if value != old: - updated.append((key, old, value)) - setattr(item, key, value) + else: + # update key + old = getattr(item, key) + if type(old) is list: + # deduplicate list value + assert type(value) is list + value = set(value) + old = set(old) + if not delete: + value = old | value + if value != old: + updated.append((key, old, value)) + setattr(item, key, value) - changed.append((item, updated)) + changed.append((item, created if created else updated)) return changed @@ -353,19 +369,21 @@ class Domain(Base): _dict_output = {'dkim_key': lambda v: v.decode('utf-8').strip().split('\n')[1:-1]} @staticmethod def _dict_input(data): - key = data.get('dkim_key') - if key is not None: + if 'dkim_key' in data: key = data['dkim_key'] - if type(key) is list: - key = ''.join(key) - if type(key) is str: - key = ''.join(key.strip().split()) - if key.startswith('-----BEGIN PRIVATE KEY-----'): - key = key[25:] - if key.endswith('-----END PRIVATE KEY-----'): - key = key[:-23] - key = '\n'.join(wrap(key, 64)) - data['dkim_key'] = f'-----BEGIN PRIVATE KEY-----\n{key}\n-----END PRIVATE KEY-----\n'.encode('ascii') + if key is None: + del data['dkim_key'] + else: + if type(key) is list: + key = ''.join(key) + if type(key) is str: + key = ''.join(key.strip().split()) + if key.startswith('-----BEGIN PRIVATE KEY-----'): + key = key[25:] + if key.endswith('-----END PRIVATE KEY-----'): + key = key[:-23] + key = '\n'.join(wrap(key, 64)) + data['dkim_key'] = f'-----BEGIN PRIVATE KEY-----\n{key}\n-----END PRIVATE KEY-----\n'.encode('ascii') name = db.Column(IdnaDomain, primary_key=True, nullable=False) managers = db.relationship('User', secondary=managers, @@ -580,6 +598,8 @@ class User(Base, Email): if data['hash_scheme'] not in cls.scheme_dict: raise ValueError(f'invalid password scheme {scheme!r}') data['password'] = '{'+data['hash_scheme']+'}'+ data['password_hash'] + del data['hash_scheme'] + del data['password_hash'] domain = db.relationship(Domain, backref=db.backref('users', cascade='all, delete-orphan')) @@ -709,6 +729,7 @@ class Alias(Base, Email): _dict_hide = {'domain_name', 'domain', 'localpart'} @staticmethod def _dict_input(data): + Email._dict_input(data) # handle comma delimited string for backwards compability dst = data.get('destination') if type(dst) is str: