Skip to content

Commit d7255e3

Browse files
committed
Refactor SDL generation and type handling
- Moved core type extraction and input type checking to a new `schema_helpers` module for better organization and reusability. - Removed the `TypeService` class, consolidating its functionality into the `TypeConverter` class. - Updated SDLGenerator to utilize the new helper functions for core type extraction and input type checking. - Added comprehensive tests for the new helper functions and updated existing tests to ensure compatibility with the refactored code. - Improved type handling in the `type_utils` module and added tests for various type utilities.
1 parent bcddf04 commit d7255e3

8 files changed

Lines changed: 946 additions & 377 deletions

File tree

src/sqlmodel_graphql/introspection.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sqlmodel import SQLModel
1111

1212
from sqlmodel_graphql.type_converter import TypeConverter
13+
from sqlmodel_graphql.utils.schema_helpers import get_core_types, is_input_type
1314

1415
if TYPE_CHECKING:
1516
pass
@@ -447,44 +448,9 @@ def _collect_enum_types(self) -> dict[str, type[Enum]]:
447448

448449
def _collect_input_types(self) -> dict[str, type]:
449450
"""Collect all Input types from query and mutation parameters."""
450-
import types as types_module
451-
from typing import Union
452-
453-
from pydantic import BaseModel
454-
455451
input_types: dict[str, type] = {}
456452
visited: set[str] = set()
457453

458-
def get_core_types(python_type: Any) -> list[type]:
459-
"""Extract core types from a type hint."""
460-
origin = get_origin(python_type)
461-
if origin is Union or origin is types_module.UnionType:
462-
args = get_args(python_type)
463-
result = []
464-
for arg in args:
465-
if arg is not type(None):
466-
result.extend(get_core_types(arg))
467-
return result
468-
if origin is list:
469-
args = get_args(python_type)
470-
if args:
471-
return get_core_types(args[0])
472-
return []
473-
if isinstance(python_type, type):
474-
return [python_type]
475-
return []
476-
477-
def is_input_type(python_type: type) -> bool:
478-
"""Check if a type should be treated as a GraphQL Input type."""
479-
if not isinstance(python_type, type):
480-
return False
481-
try:
482-
if issubclass(python_type, SQLModel) or issubclass(python_type, BaseModel):
483-
return True
484-
except TypeError:
485-
pass
486-
return False
487-
488454
def collect_from_type(param_type: Any) -> None:
489455
"""Recursively collect Input types from a type hint."""
490456
core_types = get_core_types(param_type)

src/sqlmodel_graphql/sdl_generator.py

Lines changed: 5 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -11,58 +11,12 @@
1111

1212
from sqlmodel_graphql.type_converter import TypeConverter
1313
from sqlmodel_graphql.utils.naming import to_graphql_field_name
14+
from sqlmodel_graphql.utils.schema_helpers import get_core_types, is_input_type
1415

1516
if TYPE_CHECKING:
1617
pass
1718

1819

19-
def _get_core_types(python_type: Any) -> list[type]:
20-
"""Extract core types from a type hint, unwrapping Optional, Union, list, etc."""
21-
import types
22-
from typing import Union
23-
24-
origin = get_origin(python_type)
25-
26-
# Handle Union (including Optional)
27-
if origin is Union or origin is types.UnionType:
28-
args = get_args(python_type)
29-
result = []
30-
for arg in args:
31-
if arg is not type(None):
32-
result.extend(_get_core_types(arg))
33-
return result
34-
35-
# Handle list
36-
if origin is list:
37-
args = get_args(python_type)
38-
if args:
39-
return _get_core_types(args[0])
40-
return []
41-
42-
# Base type
43-
if isinstance(python_type, type):
44-
return [python_type]
45-
46-
return []
47-
48-
49-
def _is_input_type(python_type: type) -> bool:
50-
"""Check if a type should be treated as a GraphQL Input type.
51-
52-
Input types are SQLModel or BaseModel subclasses that are NOT in the entity list
53-
(i.e., they are used as mutation parameters, not as entity types).
54-
"""
55-
if not isinstance(python_type, type):
56-
return False
57-
# Check if it's a SQLModel or Pydantic BaseModel
58-
try:
59-
if issubclass(python_type, SQLModel) or issubclass(python_type, BaseModel):
60-
return True
61-
except TypeError:
62-
pass
63-
return False
64-
65-
6620
def _python_type_to_graphql(
6721
python_type: Any, converter: TypeConverter, entity_names: set[str] | None = None
6822
) -> str:
@@ -100,7 +54,7 @@ def _python_type_to_graphql_inner(
10054
return f"{entity_name}{'!' if not nullable else ''}"
10155

10256
# Check if it's an Input type (SQLModel or BaseModel not in entities)
103-
if entity_names is not None and _is_input_type(python_type) and python_type.__name__ not in entity_names:
57+
if entity_names is not None and is_input_type(python_type) and python_type.__name__ not in entity_names:
10458
return f"{python_type.__name__}{'!' if not nullable else ''}"
10559

10660
# Handle basic Python types
@@ -194,10 +148,10 @@ def _collect_input_types(self) -> set[type]:
194148

195149
def collect_from_type(param_type: Any) -> None:
196150
"""Recursively collect Input types from a type hint."""
197-
core_types = _get_core_types(param_type)
151+
core_types = get_core_types(param_type)
198152

199153
for core_type in core_types:
200-
if _is_input_type(core_type) and core_type.__name__ not in self._entity_names:
154+
if is_input_type(core_type) and core_type.__name__ not in self._entity_names:
201155
type_name = core_type.__name__
202156
if type_name not in visited:
203157
visited.add(type_name)
@@ -297,7 +251,7 @@ def _input_type_to_graphql(self, python_type: Any, field_info: Any = None, is_op
297251
return python_type.__name__
298252

299253
# Check if it's another Input type (SQLModel or BaseModel not in entities)
300-
if _is_input_type(python_type) and python_type.__name__ not in self._entity_names:
254+
if is_input_type(python_type) and python_type.__name__ not in self._entity_names:
301255
return f"{python_type.__name__}" if is_optional else f"{python_type.__name__}!"
302256

303257
# Check if it's an entity type

src/sqlmodel_graphql/type_service.py

Lines changed: 0 additions & 250 deletions
This file was deleted.

0 commit comments

Comments
 (0)