diff --git a/back/boxtribute_server/graph_ql/execution.py b/back/boxtribute_server/graph_ql/execution.py index 5fe7c52e3e..f05dae60a2 100644 --- a/back/boxtribute_server/graph_ql/execution.py +++ b/back/boxtribute_server/graph_ql/execution.py @@ -4,6 +4,7 @@ from flask import current_app, jsonify, request from ..exceptions import format_database_errors +from ..utils import activate_logging from .loaders import ( BaseLoader, BoxLoader, @@ -33,6 +34,7 @@ def execute_async(*, schema, introspection=None): managing the asyncio event loop, finalizing asynchronous generators, and closing the threadpool. """ + logger = activate_logging() async def run(): # Create DataLoaders and persist them for the time of processing the request. @@ -73,5 +75,6 @@ async def run(): success, result = asyncio.run(run()) + logger.warning("done") status_code = 200 if success else 400 return jsonify(result), status_code diff --git a/back/boxtribute_server/graph_ql/loaders.py b/back/boxtribute_server/graph_ql/loaders.py index b8f7f5e938..99929587b7 100644 --- a/back/boxtribute_server/graph_ql/loaders.py +++ b/back/boxtribute_server/graph_ql/loaders.py @@ -1,4 +1,6 @@ +import asyncio from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor from datetime import datetime from functools import partial @@ -38,6 +40,27 @@ def load(self, key): return super().load(key) +executor = ThreadPoolExecutor(max_workers=5) + + +async def select( + model, /, *conditions, fields=None, join_kwargs=None, group_field=None +): + def utility_function(): + query = model.select() + if fields is not None: + query = model.select(*fields) + if join_kwargs is not None: + query = query.join(**join_kwargs) + if conditions: + query = query.where(*conditions) + if group_field is not None: + query = query.group_by(group_field) + return list(query.iterator()) + + return await asyncio.get_running_loop().run_in_executor(executor, utility_function) + + class SimpleDataLoader(DataLoader): """Custom implementation that batch-loads all requested rows of the specified data model, optionally enforcing authorization for the resource. @@ -55,7 +78,7 @@ async def batch_load_fn(self, ids): permission = f"{resource}:read" authorize(permission=permission) - rows = {r.id: r for r in self.model.select().where(self.model.id << ids)} + rows = {r.id: r for r in await select(self.model, self.model.id << ids)} return [rows.get(i) for i in ids] @@ -118,9 +141,12 @@ class ShipmentLoader(DataLoader): async def batch_load_fn(self, keys): shipments = { s.id: s - for s in Shipment.select().orwhere( - authorized_bases_filter(Shipment, base_fk_field_name="source_base_id"), - authorized_bases_filter(Shipment, base_fk_field_name="target_base_id"), + for s in await select( + Shipment, + authorized_bases_filter(Shipment, base_fk_field_name="source_base_id") + | authorized_bases_filter( + Shipment, base_fk_field_name="target_base_id" + ), ) } return [shipments.get(i) for i in keys] @@ -131,7 +157,8 @@ async def batch_load_fn(self, agreement_ids): # Select all shipments with given agreement IDs that the user is authorized for, # and group them by agreement ID shipments = defaultdict(list) - for shipment in Shipment.select().where( + for shipment in await select( + Shipment, Shipment.transfer_agreement << agreement_ids, authorized_bases_filter(Shipment, base_fk_field_name="source_base") | authorized_bases_filter(Shipment, base_fk_field_name="target_base"), @@ -145,16 +172,18 @@ class TagsForBoxLoader(DataLoader): async def batch_load_fn(self, keys): tags = defaultdict(list) # maybe need different join type - for relation in TagsRelation.select( - TagsRelation.object_type, TagsRelation.object_id, Tag - ).join( - Tag, - on=( - (TagsRelation.tag == Tag.id) - & (TagsRelation.object_type == TaggableObjectType.Box) - & (TagsRelation.object_id << keys) - & (TagsRelation.deleted_on.is_null()) - & (authorized_bases_filter(Tag)) + for relation in await select( + TagsRelation, + fields=(TagsRelation.object_type, TagsRelation.object_id, Tag), + join_kwargs=dict( + dest=Tag, + on=( + (TagsRelation.tag == Tag.id) + & (TagsRelation.object_type == TaggableObjectType.Box) + & (TagsRelation.object_id << keys) + & (TagsRelation.deleted_on.is_null()) + & (authorized_bases_filter(Tag)) + ), ), ): tags[relation.object_id].append(relation.tag) @@ -462,10 +491,11 @@ async def batch_load_fn(self, shipment_ids): # Join with Shipment model, such that authorization in ShipmentDetail resolvers # (detail.shipment.source_base_id) don't create additional DB queries details = defaultdict(list) - for detail in ( - ShipmentDetail.select(ShipmentDetail, Shipment) - .join(Shipment) - .where(ShipmentDetail.shipment << shipment_ids) + for detail in await select( + ShipmentDetail, + ShipmentDetail.shipment << shipment_ids, + fields=[ShipmentDetail, Shipment], + join_kwargs={"dest": Shipment}, ): details[detail.shipment_id].append(detail) # Return empty list if shipment has no details attached @@ -476,7 +506,8 @@ class ShipmentDetailForBoxLoader(DataLoader): async def batch_load_fn(self, keys): details = { detail.box_id: detail - for detail in ShipmentDetail.select().where( + for detail in await select( + ShipmentDetail, ShipmentDetail.box << keys, ShipmentDetail.removed_on.is_null(), ShipmentDetail.lost_on.is_null(), @@ -492,7 +523,7 @@ async def batch_load_fn(self, keys): authorize(permission="size:read") # Mapping of size range ID to list of sizes sizes = defaultdict(list) - for size in Size.select(): + for size in await select(Size): sizes[size.size_range_id].append(size) # Keys are in fact size range IDs. Return empty list if size range has no sizes return [sizes.get(i, []) for i in keys] @@ -502,23 +533,26 @@ class UnitsForDimensionLoader(DataLoader): async def batch_load_fn(self, keys): # Mapping of size range ID (dimension) to list of units units = defaultdict(list) - for unit in Unit.select().iterator(): + for unit in await select(Unit): units[unit.dimension_id].append(unit) return [units.get(i, []) for i in keys] class EnabledBasesForStandardProductLoader(DataLoader): async def batch_load_fn(self, standard_product_ids): - result = Product.select(Product.standard_product, Base).join( - Base, - on=( - (Product.base == Base.id) - & authorized_bases_filter(model=Product) - & (Product.standard_product << standard_product_ids) - & (Product.deleted_on.is_null()) - ), - ) standard_products = defaultdict(list) - for row in result: + for row in await select( + Product, + fields=(Product.standard_product, Base), + join_kwargs=dict( + dest=Base, + on=( + (Product.base == Base.id) + & authorized_bases_filter(model=Product) + & (Product.standard_product << standard_product_ids) + & (Product.deleted_on.is_null()) + ), + ), + ): standard_products[row.standard_product_id].append(row.base) return [standard_products.get(i, []) for i in standard_product_ids] diff --git a/back/boxtribute_server/utils.py b/back/boxtribute_server/utils.py index 38943d409e..4e1052ef1f 100644 --- a/back/boxtribute_server/utils.py +++ b/back/boxtribute_server/utils.py @@ -52,3 +52,9 @@ def activate_logging(): # pragma: no cover from flask import current_app logger.parent = current_app.logger + handler = logging.FileHandler("back/peewee.log") + formatter = logging.Formatter("%(created)f | %(message)s") + handler.setFormatter(formatter) + if len(logger.handlers) == 1: + logger.addHandler(handler) + return logger