Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions back/boxtribute_server/graph_ql/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from flask import current_app, jsonify, request

from ..exceptions import format_database_errors
from ..utils import activate_logging
from .loaders import (
BaseLoader,
BoxLoader,
Expand Down Expand Up @@ -33,6 +34,7 @@ def execute_async(*, schema, introspection=None):
managing the asyncio event loop, finalizing asynchronous generators, and closing
the threadpool.
"""
logger = activate_logging()

async def run():
# Create DataLoaders and persist them for the time of processing the request.
Expand Down Expand Up @@ -73,5 +75,6 @@ async def run():

success, result = asyncio.run(run())

logger.warning("done")
status_code = 200 if success else 400
return jsonify(result), status_code
98 changes: 66 additions & 32 deletions back/boxtribute_server/graph_ql/loaders.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from functools import partial

Expand Down Expand Up @@ -38,6 +40,27 @@ def load(self, key):
return super().load(key)


executor = ThreadPoolExecutor(max_workers=5)


async def select(
model, /, *conditions, fields=None, join_kwargs=None, group_field=None
):
def utility_function():
query = model.select()
if fields is not None:
query = model.select(*fields)
if join_kwargs is not None:
query = query.join(**join_kwargs)
if conditions:
query = query.where(*conditions)
if group_field is not None:
query = query.group_by(group_field)
return list(query.iterator())

return await asyncio.get_running_loop().run_in_executor(executor, utility_function)


class SimpleDataLoader(DataLoader):
"""Custom implementation that batch-loads all requested rows of the specified data
model, optionally enforcing authorization for the resource.
Expand All @@ -55,7 +78,7 @@ async def batch_load_fn(self, ids):
permission = f"{resource}:read"
authorize(permission=permission)

rows = {r.id: r for r in self.model.select().where(self.model.id << ids)}
rows = {r.id: r for r in await select(self.model, self.model.id << ids)}
return [rows.get(i) for i in ids]


Expand Down Expand Up @@ -118,9 +141,12 @@ class ShipmentLoader(DataLoader):
async def batch_load_fn(self, keys):
shipments = {
s.id: s
for s in Shipment.select().orwhere(
authorized_bases_filter(Shipment, base_fk_field_name="source_base_id"),
authorized_bases_filter(Shipment, base_fk_field_name="target_base_id"),
for s in await select(
Shipment,
authorized_bases_filter(Shipment, base_fk_field_name="source_base_id")
| authorized_bases_filter(
Shipment, base_fk_field_name="target_base_id"
),
)
}
return [shipments.get(i) for i in keys]
Expand All @@ -131,7 +157,8 @@ async def batch_load_fn(self, agreement_ids):
# Select all shipments with given agreement IDs that the user is authorized for,
# and group them by agreement ID
shipments = defaultdict(list)
for shipment in Shipment.select().where(
for shipment in await select(
Shipment,
Shipment.transfer_agreement << agreement_ids,
authorized_bases_filter(Shipment, base_fk_field_name="source_base")
| authorized_bases_filter(Shipment, base_fk_field_name="target_base"),
Expand All @@ -145,16 +172,18 @@ class TagsForBoxLoader(DataLoader):
async def batch_load_fn(self, keys):
tags = defaultdict(list)
# maybe need different join type
for relation in TagsRelation.select(
TagsRelation.object_type, TagsRelation.object_id, Tag
).join(
Tag,
on=(
(TagsRelation.tag == Tag.id)
& (TagsRelation.object_type == TaggableObjectType.Box)
& (TagsRelation.object_id << keys)
& (TagsRelation.deleted_on.is_null())
& (authorized_bases_filter(Tag))
for relation in await select(
TagsRelation,
fields=(TagsRelation.object_type, TagsRelation.object_id, Tag),
join_kwargs=dict(
dest=Tag,
on=(
(TagsRelation.tag == Tag.id)
& (TagsRelation.object_type == TaggableObjectType.Box)
& (TagsRelation.object_id << keys)
& (TagsRelation.deleted_on.is_null())
& (authorized_bases_filter(Tag))
),
),
):
tags[relation.object_id].append(relation.tag)
Expand Down Expand Up @@ -462,10 +491,11 @@ async def batch_load_fn(self, shipment_ids):
# Join with Shipment model, such that authorization in ShipmentDetail resolvers
# (detail.shipment.source_base_id) don't create additional DB queries
details = defaultdict(list)
for detail in (
ShipmentDetail.select(ShipmentDetail, Shipment)
.join(Shipment)
.where(ShipmentDetail.shipment << shipment_ids)
for detail in await select(
ShipmentDetail,
ShipmentDetail.shipment << shipment_ids,
fields=[ShipmentDetail, Shipment],
join_kwargs={"dest": Shipment},
):
details[detail.shipment_id].append(detail)
# Return empty list if shipment has no details attached
Expand All @@ -476,7 +506,8 @@ class ShipmentDetailForBoxLoader(DataLoader):
async def batch_load_fn(self, keys):
details = {
detail.box_id: detail
for detail in ShipmentDetail.select().where(
for detail in await select(
ShipmentDetail,
ShipmentDetail.box << keys,
ShipmentDetail.removed_on.is_null(),
ShipmentDetail.lost_on.is_null(),
Expand All @@ -492,7 +523,7 @@ async def batch_load_fn(self, keys):
authorize(permission="size:read")
# Mapping of size range ID to list of sizes
sizes = defaultdict(list)
for size in Size.select():
for size in await select(Size):
sizes[size.size_range_id].append(size)
# Keys are in fact size range IDs. Return empty list if size range has no sizes
return [sizes.get(i, []) for i in keys]
Expand All @@ -502,23 +533,26 @@ class UnitsForDimensionLoader(DataLoader):
async def batch_load_fn(self, keys):
# Mapping of size range ID (dimension) to list of units
units = defaultdict(list)
for unit in Unit.select().iterator():
for unit in await select(Unit):
units[unit.dimension_id].append(unit)
return [units.get(i, []) for i in keys]


class EnabledBasesForStandardProductLoader(DataLoader):
async def batch_load_fn(self, standard_product_ids):
result = Product.select(Product.standard_product, Base).join(
Base,
on=(
(Product.base == Base.id)
& authorized_bases_filter(model=Product)
& (Product.standard_product << standard_product_ids)
& (Product.deleted_on.is_null())
),
)
standard_products = defaultdict(list)
for row in result:
for row in await select(
Product,
fields=(Product.standard_product, Base),
join_kwargs=dict(
dest=Base,
on=(
(Product.base == Base.id)
& authorized_bases_filter(model=Product)
& (Product.standard_product << standard_product_ids)
& (Product.deleted_on.is_null())
),
),
):
standard_products[row.standard_product_id].append(row.base)
return [standard_products.get(i, []) for i in standard_product_ids]
6 changes: 6 additions & 0 deletions back/boxtribute_server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,9 @@ def activate_logging(): # pragma: no cover
from flask import current_app

logger.parent = current_app.logger
handler = logging.FileHandler("back/peewee.log")
formatter = logging.Formatter("%(created)f | %(message)s")
handler.setFormatter(formatter)
if len(logger.handlers) == 1:
logger.addHandler(handler)
return logger