Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion fastapi_rest_jsonapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,12 @@
from fastapi_rest_jsonapi.data import DataLayer, SQLAlchemyDataLayer


__all__ = ["RestAPI", "Schema", "Resource", "ResourceList", "ResourceDetail", "DataLayer", "SQLAlchemyDataLayer"]
__all__ = [
"RestAPI",
"Schema",
"Resource",
"ResourceList",
"ResourceDetail",
"DataLayer",
"SQLAlchemyDataLayer",
]
6 changes: 6 additions & 0 deletions fastapi_rest_jsonapi/common/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def __init__(self, relationship: str):
self.message = f"Unknown relationship: {relationship}"


class UnprocessableEntityException(RestAPIException):
def __init__(self, entity: str):
self.status = status.HTTP_422_UNPROCESSABLE_ENTITY
self.message = f"Unprocessable entity: {entity}"


class UnknownTypeException(RestAPIException):
def __init__(self, type_: str):
self.status = status.HTTP_400_BAD_REQUEST
Expand Down
10 changes: 8 additions & 2 deletions fastapi_rest_jsonapi/data/data_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@
class DataLayer(metaclass=ABCMeta):
@abstractmethod
def get(
self, sorts: list[Sort] = None, fields: list[Field] = None, page: Page = None, includes: list[Include] = None
self,
sorts: list[Sort] = None,
fields: list[Field] = None,
page: Page = None,
includes: list[Include] = None,
) -> list:
raise NotImplementedError

@abstractmethod
def get_one(self, id_: int, fields: list[Field] = None, includes: list[Include] = None) -> object:
def get_one(
self, id_: int, fields: list[Field] = None, includes: list[Include] = None
) -> object:
raise NotImplementedError

@abstractmethod
Expand Down
57 changes: 44 additions & 13 deletions fastapi_rest_jsonapi/data/sqlachemy_data_layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from math import ceil
from fastapi_rest_jsonapi.common.exceptions import UnprocessableEntityException

import sqlalchemy
from fastapi_rest_jsonapi.common.exceptions import (
Expand Down Expand Up @@ -30,23 +31,35 @@ def __get_model_for_type(self, type_: str):
return class_
raise UnknownTypeException(type_)

def __get_relationships_and_properties_for_model(self, model) -> tuple[list[str], list[str]]:
def __get_relationships_and_properties_for_model(
self, model
) -> tuple[list[str], list[str]]:
relationships = []
properties = []
for field_name, field_attr in model.__dict__.items():
if type(getattr(field_attr, "property", False)) is sqlalchemy.orm.relationships.RelationshipProperty:
if (
type(getattr(field_attr, "property", False))
is sqlalchemy.orm.relationships.RelationshipProperty
):
relationships.append(field_name)
elif type(getattr(field_attr, "property", False)):
properties.append(field_name)

return relationships, properties

def __get_fields_for_type(self, type_: str, fields: list[Field]) -> tuple[list[str], list[str]]:
def __get_fields_for_type(
self, type_: str, fields: list[Field]
) -> tuple[list[str], list[str]]:
type_model = self.__get_model_for_type(type_)
type_relationship_fields, type_properties_fields = self.__get_relationships_and_properties_for_model(type_model)
(
type_relationship_fields,
type_properties_fields,
) = self.__get_relationships_and_properties_for_model(type_model)
fields_ = list(filter(lambda f: f.type == type_, fields or []))
fields_properties = filter(lambda f: f.field in type_properties_fields, fields_)
fields_relationships = filter(lambda f: f.field in type_relationship_fields, fields_)
fields_relationships = filter(
lambda f: f.field in type_relationship_fields, fields_
)
fields_properties = map(lambda f: f.field, fields_properties)
fields_relationships = map(lambda f: f.field, fields_relationships)
return list(fields_properties), list(fields_relationships)
Expand Down Expand Up @@ -75,10 +88,16 @@ def __paginate_query(self, query: Query, page: Page) -> Query:
page.max_number = ceil(total / page.size)
return query.offset(page.size * (page.number - 1)).limit(page.size)

def __include_and_field_query(self, query: Query, includes: list[Include], fields: list[Field]) -> Query:
def __include_and_field_query(
self, query: Query, includes: list[Include], fields: list[Field]
) -> Query:
processed_fields = []
fields_properties, fields_relationship = self.__get_fields_for_type(self.current_tablename, fields)
query = query.options(load_only(*fields_properties)) if fields_properties else query
fields_properties, fields_relationship = self.__get_fields_for_type(
self.current_tablename, fields
)
query = (
query.options(load_only(*fields_properties)) if fields_properties else query
)
processed_fields.extend(fields_properties)
processed_fields.extend(fields_relationship)

Expand All @@ -91,20 +110,28 @@ def __include_and_field_query(self, query: Query, includes: list[Include], field
joined_load_func = joined_load_func.load_only(*fields_)
query = query.options(joined_load_func)

if processed_diff := set(processed_fields) ^ set([x.field for x in fields or []]):
if processed_diff := set(processed_fields) ^ set(
[x.field for x in fields or []]
):
raise UnknownRelationshipException(", ".join(processed_diff))
return query

def get(
self, sorts: list[Sort] = None, fields: list[Field] = None, page: Page = None, includes: list[Include] = None
self,
sorts: list[Sort] = None,
fields: list[Field] = None,
page: Page = None,
includes: list[Include] = None,
) -> list:
query: Query = self.session.query(self.model)
query = self.__include_and_field_query(query, includes, fields)
query = self.__sort_query(query, sorts)
query = self.__paginate_query(query, page)
return query.all()

def get_one(self, id_: int, fields: list[Field] = None, includes: list[Include] = None) -> object:
def get_one(
self, id_: int, fields: list[Field] = None, includes: list[Include] = None
) -> object:
query: Query = self.session.query(self.model)
query = self.__include_and_field_query(query, includes, fields)
query = query.filter(self.model.id == id_)
Expand All @@ -130,5 +157,9 @@ def update_one(self, id_: int, **kwargs) -> object:
def create_one(self, **kwargs) -> object:
obj = self.model(**kwargs)
self.session.add(obj)
self.session.commit()
return obj
try:
self.session.commit()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that someone could use def update_one(self, id_: int, **kwargs) -> object: and set some unique field to a value that already exists in another row which would throw up the same issue?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, I need to fix the update_one method as well :)

return obj
except Exception:
self.session.rollback()
raise UnprocessableEntityException(str(kwargs))
23 changes: 19 additions & 4 deletions fastapi_rest_jsonapi/request/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@


class Page:
def __init__(self, url: URL, query_params: Optional[dict], number: int, size: Optional[int] = None) -> None:
def __init__(
self,
url: URL,
query_params: Optional[dict],
number: int,
size: Optional[int] = None,
) -> None:
self.url = url
self.query_params = query_params
self.number = number
Expand All @@ -19,9 +25,18 @@ def is_paginated(self):
def __get_query_params_as_dict(self) -> dict:
if not self.query_params:
return {}
non_iterable_query_params = {k: v for k, v in self.query_params.items() if v and not isinstance(v, list)}
iterable_query_params = {k: v for k, v in self.query_params.items() if v and k not in non_iterable_query_params}
iterable_query_params = {v.split("=")[0]: v.split("=")[1] for v in chain(*iterable_query_params.values())}
non_iterable_query_params = {
k: v for k, v in self.query_params.items() if v and not isinstance(v, list)
}
iterable_query_params = {
k: v
for k, v in self.query_params.items()
if v and k not in non_iterable_query_params
}
iterable_query_params = {
v.split("=")[0]: v.split("=")[1]
for v in chain(*iterable_query_params.values())
}
return non_iterable_query_params | iterable_query_params

def get_self_link(self) -> str:
Expand Down
23 changes: 20 additions & 3 deletions fastapi_rest_jsonapi/resource/resource_detail.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,31 @@

from fastapi import status
from fastapi_rest_jsonapi.common import Methods
from fastapi_rest_jsonapi.common.exceptions import UnprocessableEntityException
from fastapi_rest_jsonapi.resource import Resource
from fastapi_rest_jsonapi.request.request_context import RequestContext
from fastapi_rest_jsonapi.response.response import Response
from marshmallow import ValidationError


class ResourceDetail(Resource, ABC):
methods = [Methods.GET.value, Methods.PATCH.value, Methods.DELETE.value]

@staticmethod
def get(cls: Resource, request_ctx: RequestContext):
obj = cls.data_layer.get_one(request_ctx.path_parameters.id, request_ctx.fields, request_ctx.includes)
obj = cls.data_layer.get_one(
request_ctx.path_parameters.id, request_ctx.fields, request_ctx.includes
)
if obj is None:
return Response(request_ctx, status_code=status.HTTP_404_NOT_FOUND)
return Response(
request_ctx,
content=cls.schema().dump(includes=request_ctx.includes, fields=request_ctx.fields, obj=obj, many=False),
content=cls.schema().dump(
includes=request_ctx.includes,
fields=request_ctx.fields,
obj=obj,
many=False,
),
)

@staticmethod
Expand All @@ -30,7 +39,15 @@ def delete(cls: Resource, request_ctx: RequestContext):

@staticmethod
def patch(cls: Resource, request_ctx: RequestContext):
is_updated = cls.data_layer.update_one(request_ctx.path_parameters.id, **request_ctx.body) is not None
try:
data = cls.schema().load(request_ctx.body)
except ValidationError:
raise UnprocessableEntityException(str(request_ctx.body))

is_updated = (
cls.data_layer.update_one(request_ctx.path_parameters.id, **data)
is not None
)
if is_updated:
return Response(request_ctx, status_code=status.HTTP_204_NO_CONTENT)
return Response(request_ctx, status_code=status.HTTP_404_NOT_FOUND)
40 changes: 33 additions & 7 deletions fastapi_rest_jsonapi/resource/resource_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from fastapi import status
from fastapi_rest_jsonapi.resource import Resource
from fastapi_rest_jsonapi.common.methods import Methods
from fastapi_rest_jsonapi.common.exceptions import UnprocessableEntityException
from fastapi_rest_jsonapi.request.request_context import RequestContext
from fastapi_rest_jsonapi.response.response import Response
from marshmallow import ValidationError


DEFAULT_PAGE_SIZE = 30
Expand All @@ -20,15 +22,39 @@ def get(cls: Resource, request_ctx: RequestContext):
if request_ctx_page.size is None:
request_ctx_page.size = cls.page_size

objects = cls.data_layer.get(request_ctx.sorts, request_ctx.fields, request_ctx_page, request_ctx.includes)
content = cls.schema().dump(includes=request_ctx.includes, fields=request_ctx.fields, obj=objects, many=True)
objects = cls.data_layer.get(
request_ctx.sorts,
request_ctx.fields,
request_ctx_page,
request_ctx.includes,
)
content = cls.schema().dump(
includes=request_ctx.includes,
fields=request_ctx.fields,
obj=objects,
many=True,
)
return Response(request_ctx, content=content)

@staticmethod
def post(cls: Resource, request_ctx: RequestContext):
created = cls.data_layer.create_one(**request_ctx.body)
if created is None:
return Response(request_ctx, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
try:
data = cls.schema().load(request_ctx.body)
except ValidationError:
raise UnprocessableEntityException(str(request_ctx.body))

content = cls.schema().dump(includes=request_ctx.fields, fields=request_ctx.fields, obj=created, many=False)
return Response(request_ctx, content=content, status_code=status.HTTP_201_CREATED)
created = cls.data_layer.create_one(**data)
if created is None:
return Response(
request_ctx, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)

content = cls.schema().dump(
includes=request_ctx.fields,
fields=request_ctx.fields,
obj=created,
many=False,
)
return Response(
request_ctx, content=content, status_code=status.HTTP_201_CREATED
)
36 changes: 26 additions & 10 deletions fastapi_rest_jsonapi/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def __get_response_model(self, resource: Resource, method: str) -> BaseModel:
is_detail_resource_ = is_detail_resource(resource)
model_suffix = "detail" if is_detail_resource_ else "list"
# For some reasons, FastAPI does not allow to use the same name for the response model
response_model = create_model(f"{schema.__type__}-{method}-{model_suffix}", **fields)
response_model = create_model(
f"{schema.__type__}-{method}-{model_suffix}", **fields
)
if is_detail_resource_:
return response_model
return List[response_model]
Expand All @@ -74,9 +76,7 @@ def __get_path_parameters_model(self, resource: Resource, method: str) -> BaseMo
def __get_endpoint_summary(self, resource: Resource, method: str) -> str:
is_detail_resource_ = is_detail_resource(resource)
schema_type = resource.schema.__type__
return (
f"{method} {'a' if is_detail_resource_ else 'multiple'} {schema_type}{'' if is_detail_resource_ else 's'}"
)
return f"{method} {'a' if is_detail_resource_ else 'multiple'} {schema_type}{'' if is_detail_resource_ else 's'}"

def __get_query_parameters_dict(self, request: Request) -> dict:
request_query_params_dict = request.query_params._dict
Expand All @@ -89,9 +89,16 @@ def __get_query_params_with_brackets(query_param_name: str):
if request_query_param := request_query_params_dict.get(query_param_name):
return request_query_param.split("&")

return [f"{k}={v}" for k, v in request_query_params_dict.items() if query_param_name in k]
return [
f"{k}={v}"
for k, v in request_query_params_dict.items()
if query_param_name in k
]

return {parameter: __get_query_params_with_brackets(parameter) for parameter in RestAPI.QUERY_PARAMETER_KEYS}
return {
parameter: __get_query_params_with_brackets(parameter)
for parameter in RestAPI.QUERY_PARAMETER_KEYS
}

def __override_swagger_doc(self):
def __generate_field_parameter(field_name: str) -> dict:
Expand All @@ -104,10 +111,15 @@ def __generate_field_parameter(field_name: str) -> dict:

openapi = self.app.openapi()
for resource, resource_url in self.registered_resources:
if Methods.GET.value not in resource.methods or is_detail_resource(resource):
if Methods.GET.value not in resource.methods or is_detail_resource(
resource
):
continue
openapi["paths"][resource_url]["get"]["parameters"] = [
*[__generate_field_parameter(field_name) for field_name in RestAPI.QUERY_PARAMETER_KEYS]
*[
__generate_field_parameter(field_name)
for field_name in RestAPI.QUERY_PARAMETER_KEYS
]
]

def endpoint_wrapper(self, resource: Resource, method: str):
Expand All @@ -131,7 +143,9 @@ def endpoint(request: Request, path_parameters, body):

def wrapper(
request: Request,
path_parameters: self.__get_path_parameters_model(resource, method) = Depends(),
path_parameters: self.__get_path_parameters_model(
resource, method
) = Depends(),
body: Optional[dict] = Body(default=None),
):
return endpoint(request, path_parameters, body)
Expand All @@ -140,7 +154,9 @@ def wrapper(

def wrapper(
request: Request,
path_parameters: self.__get_path_parameters_model(resource, method) = Depends(),
path_parameters: self.__get_path_parameters_model(
resource, method
) = Depends(),
):
return endpoint(request, path_parameters, None)

Expand Down
Loading