diff --git a/ereuse_devicehub/resources/lot/__init__.py b/ereuse_devicehub/resources/lot/__init__.py index 176ed4d5..46d17120 100644 --- a/ereuse_devicehub/resources/lot/__init__.py +++ b/ereuse_devicehub/resources/lot/__init__.py @@ -1,10 +1,12 @@ import pathlib +from typing import Callable, Iterable, Tuple from teal.resource import Converters, Resource from ereuse_devicehub.db import db from ereuse_devicehub.resources.lot import schemas -from ereuse_devicehub.resources.lot.views import LotView +from ereuse_devicehub.resources.lot.views import LotBaseChildrenView, LotChildrenView, \ + LotDeviceView, LotView class LotDef(Resource): @@ -13,6 +15,20 @@ class LotDef(Resource): AUTH = True ID_CONVERTER = Converters.uuid + def __init__(self, app, import_name=__package__, static_folder=None, static_url_path=None, + template_folder=None, url_prefix=None, subdomain=None, url_defaults=None, + root_path=None, cli_commands: Iterable[Tuple[Callable, str or None]] = tuple()): + super().__init__(app, import_name, static_folder, static_url_path, template_folder, + url_prefix, subdomain, url_defaults, root_path, cli_commands) + children = LotChildrenView.as_view('lot-children', definition=self, auth=app.auth) + self.add_url_rule('/<{}:{}>/children'.format(self.ID_CONVERTER.value, self.ID_NAME), + view_func=children, + methods={'POST', 'DELETE'}) + children = LotDeviceView.as_view('lot-device', definition=self, auth=app.auth) + self.add_url_rule('/<{}:{}>/devices'.format(self.ID_CONVERTER.value, self.ID_NAME), + view_func=children, + methods={'POST', 'DELETE'}) + def init_db(self, db: 'db.SQLAlchemy'): # Create functions with pathlib.Path(__file__).parent.joinpath('dag.sql').open() as f: diff --git a/ereuse_devicehub/resources/lot/models.py b/ereuse_devicehub/resources/lot/models.py index b11ac785..9da9420c 100644 --- a/ereuse_devicehub/resources/lot/models.py +++ b/ereuse_devicehub/resources/lot/models.py @@ -1,9 +1,9 @@ import uuid from datetime import datetime -from typing import Set from flask import g from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import aliased from sqlalchemy.sql import expression from sqlalchemy_utils import LtreeType from sqlalchemy_utils.types.ltree import LQUERY @@ -26,6 +26,12 @@ class Lot(Thing): backref=db.backref('parents', lazy=True, collection_class=set), secondary=lambda: LotDevice.__table__, collection_class=set) + """ + The **children** devices that the lot has. + + Note that the lot can have more devices, if they are inside + descendant lots. + """ def __init__(self, name: str, closed: bool = closed.default.arg) -> None: """ @@ -36,15 +42,34 @@ class Lot(Thing): super().__init__(id=uuid.uuid4(), name=name, closed=closed) Path(self) # Lots have always one edge per default. - def add_child(self, child: 'Lot'): + def add_child(self, child): """Adds a child to this lot.""" - Path.add(self.id, child.id) - db.session.refresh(self) # todo is this useful? - db.session.refresh(child) + if isinstance(child, Lot): + Path.add(self.id, child.id) + db.session.refresh(self) # todo is this useful? + db.session.refresh(child) + else: + assert isinstance(child, uuid.UUID) + Path.add(self.id, child) + db.session.refresh(self) # todo is this useful? def remove_child(self, child: 'Lot'): Path.delete(self.id, child.id) + @property + def children(self): + """The children lots.""" + # From https://stackoverflow.com/a/41158890 + # todo test + cls = self.__class__ + exp = '*.{}.*{{1}}'.format(UUIDLtree.convert(self.id)) + child_lots = aliased(Lot) + + return self.query \ + .join(cls.paths) \ + .filter(Path.path.lquery(expression.cast(exp, LQUERY))) \ + .join(child_lots, Path.lot) + def __contains__(self, child: 'Lot'): return Path.has_lot(self.id, child.id) @@ -96,16 +121,6 @@ class Path(db.Model): super().__init__(lot=lot) self.path = UUIDLtree(lot.id) - def children(self) -> Set['Path']: - """Get the children edges.""" - # todo is it useful? test it when first usage - # From https://stackoverflow.com/a/41158890 - exp = '*.{}.*{{1}}'.format(self.lot_id) - return set(self.query - .filter(self.path.lquery(expression.cast(exp, LQUERY))) - .distinct(self.__class__.lot_id) - .all()) - @classmethod def add(cls, parent_id: uuid.UUID, child_id: uuid.UUID): """Creates an edge between parent and child.""" @@ -118,7 +133,10 @@ class Path(db.Model): @classmethod def has_lot(cls, parent_id: uuid.UUID, child_id: uuid.UUID) -> bool: - return bool(db.session.execute( - "SELECT 1 from path where path ~ '*.{}.*.{}.*'".format( - str(parent_id).replace('-', '_'), str(child_id).replace('-', '_')) - ).first()) + parent_id = UUIDLtree.convert(parent_id) + child_id = UUIDLtree.convert(child_id) + return bool( + db.session.execute( + "SELECT 1 from path where path ~ '*.{}.*.{}.*'".format(parent_id, child_id) + ).first() + ) diff --git a/ereuse_devicehub/resources/lot/models.pyi b/ereuse_devicehub/resources/lot/models.pyi index 21aa00b3..386e8b35 100644 --- a/ereuse_devicehub/resources/lot/models.pyi +++ b/ereuse_devicehub/resources/lot/models.pyi @@ -1,5 +1,6 @@ +import uuid from datetime import datetime -from typing import Set +from typing import Set, Union from uuid import UUID from sqlalchemy import Column @@ -25,7 +26,7 @@ class Lot(Thing): self.devices = ... # type: Set[Device] self.paths = ... # type: Set[Path] - def add_child(self, child: 'Lot'): + def add_child(self, child: Union['Lot', uuid.UUID]): pass def remove_child(self, child: 'Lot'): @@ -35,6 +36,10 @@ class Lot(Thing): def roots(cls): pass + @property + def children(self) -> Set['Lot']: + pass + class Path: id = ... # type: Column diff --git a/ereuse_devicehub/resources/lot/schemas.py b/ereuse_devicehub/resources/lot/schemas.py index 59529954..c148e537 100644 --- a/ereuse_devicehub/resources/lot/schemas.py +++ b/ereuse_devicehub/resources/lot/schemas.py @@ -9,6 +9,9 @@ from ereuse_devicehub.resources.schemas import Thing class Lot(Thing): id = f.UUID(dump_only=True) - name = f.String(validate=f.validate.Length(max=STR_SIZE)) - closed = f.String(required=True, missing=False, description=m.Lot.closed.comment) - devices = f.String(NestedOn(Device, many=True, collection_class=set, only_query='id')) + name = f.String(validate=f.validate.Length(max=STR_SIZE), required=True) + closed = f.Boolean(missing=False, description=m.Lot.closed.comment) + devices = NestedOn(Device, many=True, dump_only=True) + children = NestedOn('Lot', + many=True, + dump_only=True) diff --git a/ereuse_devicehub/resources/lot/views.py b/ereuse_devicehub/resources/lot/views.py index f6ebfb9f..d8cb8f81 100644 --- a/ereuse_devicehub/resources/lot/views.py +++ b/ereuse_devicehub/resources/lot/views.py @@ -1,6 +1,8 @@ import uuid +from typing import Set -from flask import current_app as app, request +import marshmallow as ma +from flask import request from teal.resource import View from ereuse_devicehub.db import db @@ -9,10 +11,8 @@ from ereuse_devicehub.resources.lot.models import Lot class LotView(View): def post(self): - json = request.get_json(validate=False) - e = app.resources[json['type']].schema.load(json) - Model = db.Model._decl_class_registry.data[json['type']]() - lot = Model(**e) + l = request.get_json() + lot = Lot(**l) db.session.add(lot) db.session.commit() ret = self.schema.jsonify(lot) @@ -21,5 +21,74 @@ class LotView(View): def one(self, id: uuid.UUID): """Gets one event.""" - event = Lot.query.filter_by(id=id).one() - return self.schema.jsonify(event) + lot = Lot.query.filter_by(id=id).one() # type: Lot + return self.schema.jsonify(lot) + + +class LotBaseChildrenView(View): + """Base class for adding / removing children devices and + lots from a lot. + """ + + class ListArgs(ma.Schema): + id = ma.fields.List(ma.fields.UUID()) + + def __init__(self, definition: 'Resource', **kw) -> None: + super().__init__(definition, **kw) + self.list_args = self.ListArgs() + + def get_ids(self) -> Set[uuid.UUID]: + args = self.QUERY_PARSER.parse(self.list_args, request, locations=('querystring',)) + return set(args['id']) + + def get_lot(self, id: uuid.UUID) -> Lot: + return Lot.query.filter_by(id=id).one() + + # noinspection PyMethodOverriding + def post(self, id: uuid.UUID): + lot = self.get_lot(id) + self._post(lot, self.get_ids()) + db.session.commit() + ret = self.schema.jsonify(lot) + ret.status_code = 201 + return ret + + def delete(self, id: uuid.UUID): + lot = self.get_lot(id) + self._delete(lot, self.get_ids()) + db.session.commit() + return self.schema.jsonify(lot) + + def _post(self, lot: Lot, ids: Set[uuid.UUID]): + raise NotImplementedError + + def _delete(self, lot: Lot, ids: Set[uuid.UUID]): + raise NotImplementedError + + +class LotChildrenView(LotBaseChildrenView): + """View for adding and removing child lots from a lot. + + Ex. ``lot//children/id=X&id=Y``. + """ + + def _post(self, lot: Lot, ids: Set[uuid.UUID]): + for id in ids: + lot.add_child(id) # todo what to do if child exists already? + + def _delete(self, lot: Lot, ids: Set[uuid.UUID]): + for id in ids: + lot.remove_child(id) + + +class LotDeviceView(LotBaseChildrenView): + """View for adding and removing child devices from a lot. + + Ex. ``lot//devices/id=X&id=Y``. + """ + + def _post(self, lot: Lot, ids: Set[uuid.UUID]): + lot.devices |= self.get_ids() + + def _delete(self, lot: Lot, ids: Set[uuid.UUID]): + lot.devices -= self.get_ids() diff --git a/tests/test_basic.py b/tests/test_basic.py index 616f9af7..fb3ec0e4 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -25,7 +25,9 @@ def test_api_docs(client: Client): '/snapshots/', '/users/login', '/events/', - '/lots/' + '/lots/', + '/lots/{id}/children', + '/lots/{id}/devices' } assert docs['info'] == {'title': 'Devicehub', 'version': '0.2'} assert docs['components']['securitySchemes']['bearerAuth'] == { diff --git a/tests/test_lot.py b/tests/test_lot.py index 1fa50a93..384bc787 100644 --- a/tests/test_lot.py +++ b/tests/test_lot.py @@ -1,6 +1,7 @@ import pytest from flask import g +from ereuse_devicehub.client import UserClient from ereuse_devicehub.db import db from ereuse_devicehub.resources.device.models import Desktop from ereuse_devicehub.resources.enums import ComputerChassis @@ -181,3 +182,40 @@ def test_lot_roots(): assert Lot.roots() == {l1, l2, l3} l1.add_child(l2) assert Lot.roots() == {l1, l3} + + +@pytest.mark.usefixtures(conftest.auth_app_context.__name__) +def test_lot_model_children(): + """Tests the property Lot.children""" + lots = Lot('1'), Lot('2'), Lot('3') + l1, l2, l3 = lots + db.session.add_all(lots) + db.session.flush() + + l1.add_child(l2) + db.session.flush() + + children = l1.children + assert list(children) == [l2] + + +def test_post_get_lot(user: UserClient): + """Tests submitting and retreiving a basic lot.""" + l, _ = user.post({'name': 'Foo'}, res=Lot) + assert l['name'] == 'Foo' + l, _ = user.get(res=Lot, item=l['id']) + assert l['name'] == 'Foo' + assert not l['children'] + + +def test_post_add_children_view(user: UserClient): + """Tests adding children lots to a lot through the view.""" + l, _ = user.post(({'name': 'Parent'}), res=Lot) + child, _ = user.post(({'name': 'Child'}), res=Lot) + l, _ = user.post({}, res=Lot, item='{}/children'.format(l['id']), query=[('id', child['id'])]) + assert l['children'][0]['id'] == child['id'] + + +@pytest.mark.xfail(reason='Just develop the test') +def test_post_add_device_view(user: UserClient): + pass