fixed query lots

This commit is contained in:
Cayo Puigdefabregas 2022-05-31 15:14:56 +02:00
parent c36bf77ab7
commit 01f2996c48

View file

@ -1,20 +1,22 @@
import uuid import uuid
from sqlalchemy.util import OrderedSet
from collections import deque from collections import deque
from enum import Enum from enum import Enum
from typing import Dict, List, Set, Union from typing import Dict, List, Set, Union
import marshmallow as ma import marshmallow as ma
from flask import Response, jsonify, request, g from flask import Response, g, jsonify, request
from marshmallow import Schema as MarshmallowSchema, fields as f from marshmallow import Schema as MarshmallowSchema
from marshmallow import fields as f
from sqlalchemy import or_ from sqlalchemy import or_
from sqlalchemy.util import OrderedSet
from teal.marshmallow import EnumField from teal.marshmallow import EnumField
from teal.resource import View from teal.resource import View
from ereuse_devicehub.db import db from ereuse_devicehub.db import db
from ereuse_devicehub.inventory.models import Transfer
from ereuse_devicehub.query import things_response from ereuse_devicehub.query import things_response
from ereuse_devicehub.resources.device.models import Device, Computer from ereuse_devicehub.resources.action.models import Confirm, Revoke, Trade
from ereuse_devicehub.resources.action.models import Trade, Confirm, Revoke from ereuse_devicehub.resources.device.models import Computer, Device
from ereuse_devicehub.resources.lot.models import Lot, Path from ereuse_devicehub.resources.lot.models import Lot, Path
@ -27,6 +29,7 @@ class LotView(View):
"""Allowed arguments for the ``find`` """Allowed arguments for the ``find``
method (GET collection) endpoint method (GET collection) endpoint
""" """
format = EnumField(LotFormat, missing=None) format = EnumField(LotFormat, missing=None)
search = f.Str(missing=None) search = f.Str(missing=None)
type = f.Str(missing=None) type = f.Str(missing=None)
@ -42,12 +45,26 @@ class LotView(View):
return ret return ret
def patch(self, id): def patch(self, id):
patch_schema = self.resource_def.SCHEMA(only=( patch_schema = self.resource_def.SCHEMA(
'name', 'description', 'transfer_state', 'receiver_address', 'amount', 'devices', only=(
'owner_address'), partial=True) 'name',
'description',
'transfer_state',
'receiver_address',
'amount',
'devices',
'owner_address',
),
partial=True,
)
l = request.get_json(schema=patch_schema) l = request.get_json(schema=patch_schema)
lot = Lot.query.filter_by(id=id).one() lot = Lot.query.filter_by(id=id).one()
device_fields = ['transfer_state', 'receiver_address', 'amount', 'owner_address'] device_fields = [
'transfer_state',
'receiver_address',
'amount',
'owner_address',
]
computers = [x for x in lot.all_devices if isinstance(x, Computer)] computers = [x for x in lot.all_devices if isinstance(x, Computer)]
for key, value in l.items(): for key, value in l.items():
setattr(lot, key, value) setattr(lot, key, value)
@ -84,7 +101,7 @@ class LotView(View):
ret = { ret = {
'items': {l['id']: l for l in lots}, 'items': {l['id']: l for l in lots},
'tree': self.ui_tree(), 'tree': self.ui_tree(),
'url': request.path 'url': request.path,
} }
else: else:
query = Lot.query query = Lot.query
@ -95,15 +112,28 @@ class LotView(View):
lots = query.paginate(per_page=6 if args['search'] else query.count()) lots = query.paginate(per_page=6 if args['search'] else query.count())
return things_response( return things_response(
self.schema.dump(lots.items, many=True, nested=2), self.schema.dump(lots.items, many=True, nested=2),
lots.page, lots.per_page, lots.total, lots.prev_num, lots.next_num lots.page,
lots.per_page,
lots.total,
lots.prev_num,
lots.next_num,
) )
return jsonify(ret) return jsonify(ret)
def visibility_filter(self, query): def visibility_filter(self, query):
query = query.outerjoin(Trade) \ query = (
.filter(or_(Trade.user_from == g.user, query.outerjoin(Trade)
Trade.user_to == g.user, .outerjoin(Transfer)
Lot.owner_id == g.user.id)) .filter(
or_(
Trade.user_from == g.user,
Trade.user_to == g.user,
Lot.owner_id == g.user.id,
Transfer.user_from == g.user,
Transfer.user_to == g.user,
)
)
)
return query return query
def type_filter(self, query, args): def type_filter(self, query, args):
@ -111,13 +141,23 @@ class LotView(View):
# temporary # temporary
if lot_type == "temporary": if lot_type == "temporary":
return query.filter(Lot.trade == None) return query.filter(Lot.trade == None).filter(Lot.transfer == None)
if lot_type == "incoming": if lot_type == "incoming":
return query.filter(Lot.trade and Trade.user_to == g.user) return query.filter(
or_(
Lot.trade and Trade.user_to == g.user,
Lot.transfer and Transfer.user_to == g.user,
)
).all()
if lot_type == "outgoing": if lot_type == "outgoing":
return query.filter(Lot.trade and Trade.user_from == g.user) return query.filter(
or_(
Lot.trade and Trade.user_from == g.user,
Lot.transfer and Transfer.user_from == g.user,
)
).all()
return query return query
@ -152,10 +192,7 @@ class LotView(View):
# does lot_id exist already in node? # does lot_id exist already in node?
node = next(part for part in nodes if lot_id == part['id']) node = next(part for part in nodes if lot_id == part['id'])
except StopIteration: except StopIteration:
node = { node = {'id': lot_id, 'nodes': []}
'id': lot_id,
'nodes': []
}
nodes.append(node) nodes.append(node)
if path: if path:
cls._p(node['nodes'], path) cls._p(node['nodes'], path)
@ -175,15 +212,17 @@ class LotView(View):
class LotBaseChildrenView(View): class LotBaseChildrenView(View):
"""Base class for adding / removing children devices and """Base class for adding / removing children devices and
lots from a lot. lots from a lot.
""" """
def __init__(self, definition: 'Resource', **kw) -> None: def __init__(self, definition: 'Resource', **kw) -> None:
super().__init__(definition, **kw) super().__init__(definition, **kw)
self.list_args = self.ListArgs() self.list_args = self.ListArgs()
def get_ids(self) -> Set[uuid.UUID]: def get_ids(self) -> Set[uuid.UUID]:
args = self.QUERY_PARSER.parse(self.list_args, request, locations=('querystring',)) args = self.QUERY_PARSER.parse(
self.list_args, request, locations=('querystring',)
)
return set(args['id']) return set(args['id'])
def get_lot(self, id: uuid.UUID) -> Lot: def get_lot(self, id: uuid.UUID) -> Lot:
@ -247,8 +286,9 @@ class LotDeviceView(LotBaseChildrenView):
if not ids: if not ids:
return return
devices = set(Device.query.filter(Device.id.in_(ids)).filter( devices = set(
Device.owner == g.user)) Device.query.filter(Device.id.in_(ids)).filter(Device.owner == g.user)
)
lot.devices.update(devices) lot.devices.update(devices)
@ -271,8 +311,9 @@ class LotDeviceView(LotBaseChildrenView):
txt = 'This is not your lot' txt = 'This is not your lot'
raise ma.ValidationError(txt) raise ma.ValidationError(txt)
devices = set(Device.query.filter(Device.id.in_(ids)).filter( devices = set(
Device.owner_id == g.user.id)) Device.query.filter(Device.id.in_(ids)).filter(Device.owner_id == g.user.id)
)
lot.devices.difference_update(devices) lot.devices.difference_update(devices)
@ -311,9 +352,7 @@ def delete_from_trade(lot: Lot, devices: List):
phantom = lot.trade.user_from phantom = lot.trade.user_from
phantom_revoke = Revoke( phantom_revoke = Revoke(
action=lot.trade, action=lot.trade, user=phantom, devices=set(without_confirms)
user=phantom,
devices=set(without_confirms)
) )
db.session.add(phantom_revoke) db.session.add(phantom_revoke)