Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions taskiq_dependencies/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,15 @@ def __eq__(self, rhs: object) -> bool:
if not isinstance(rhs, Dependency):
return False
return self._id == rhs._id

def __repr__(self) -> str:
func_name = str(self.dependency)
if self.dependency is not None and hasattr(self.dependency, "__name__"):
func_name = self.dependency.__name__
return (
f"Dependency({func_name}, "
f"use_cache={self.use_cache}, "
f"kwargs={self.kwargs}, "
f"parent={self.parent}"
")"
)
57 changes: 52 additions & 5 deletions taskiq_dependencies/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import inspect
import sys
import warnings
from collections import defaultdict, deque
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, TypeVar, get_type_hints

from graphlib import TopologicalSorter
Expand Down Expand Up @@ -171,19 +173,64 @@ def _build_graph(self) -> None: # noqa: C901
if inspect.isclass(origin):
# If this is a class, we need to get signature of
# an __init__ method.
hints = get_type_hints(origin.__init__)
try:
hints = get_type_hints(origin.__init__)
except NameError:
_, src_lineno = inspect.getsourcelines(origin)
src_file = Path(inspect.getfile(origin)).relative_to(
Path.cwd(),
)
warnings.warn(
"Cannot resolve type hints for "
f"a class {origin.__name__} defined "
f"at {src_file}:{src_lineno}.",
RuntimeWarning,
stacklevel=2,
)
continue
sign = inspect.signature(
origin.__init__,
**signature_kwargs,
)
elif inspect.isfunction(dep.dependency):
# If this is function or an instance of a class, we get it's type hints.
hints = get_type_hints(dep.dependency)
try:
hints = get_type_hints(dep.dependency)
except NameError:
_, src_lineno = inspect.getsourcelines(dep.dependency) # type: ignore
src_file = Path(inspect.getfile(dep.dependency)).relative_to(
Path.cwd(),
)
warnings.warn(
"Cannot resolve type hints for "
f"a function {dep.dependency.__name__} defined "
f"at {src_file}:{src_lineno}.",
RuntimeWarning,
stacklevel=2,
)
continue
sign = inspect.signature(origin, **signature_kwargs) # type: ignore
else:
hints = get_type_hints(
dep.dependency.__call__, # type: ignore
)
try:
hints = get_type_hints(
dep.dependency.__call__, # type: ignore
)
except NameError:
_, src_lineno = inspect.getsourcelines(dep.dependency.__class__)
src_file = Path(
inspect.getfile(dep.dependency.__class__),
).relative_to(
Path.cwd(),
)
cls_name = dep.dependency.__class__.__name__
warnings.warn(
"Cannot resolve type hints for "
f"an object of class {cls_name} defined "
f"at {src_file}:{src_lineno}.",
RuntimeWarning,
stacklevel=2,
)
continue
sign = inspect.signature(origin, **signature_kwargs) # type: ignore

# Now we need to iterate over parameters, to
Expand Down
63 changes: 62 additions & 1 deletion tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import re
import uuid
from contextlib import asynccontextmanager, contextmanager
from typing import Any, AsyncGenerator, Generator, Generic, Tuple, TypeVar
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Generator,
Generic,
Tuple,
TypeVar,
)

import pytest

Expand Down Expand Up @@ -891,3 +899,56 @@ def target(info: ParamInfo = Depends(inner_dep, use_cache=False)) -> None:
assert info.name == ""
assert info.definition is None
assert info.graph == graph


def test_skip_type_checking_function() -> None:
"""Test if we can skip type only for type checking for the function."""
if TYPE_CHECKING:

class A:
pass

def target(unknown: "A") -> None:
pass

with pytest.warns(RuntimeWarning, match=r"Cannot resolve.*function target.*"):
graph = DependencyGraph(target=target)
with graph.sync_ctx() as ctx:
assert "unknown" not in ctx.resolve_kwargs()


def test_skip_type_checking_class() -> None:
"""Test if we can skip type only for type checking for the function."""
if TYPE_CHECKING:

class A:
pass

class Target:
def __init__(self, unknown: "A") -> None:
pass

with pytest.warns(RuntimeWarning, match=r"Cannot resolve.*class Target.*"):
graph = DependencyGraph(target=Target)
with graph.sync_ctx() as ctx:
assert "unknown" not in ctx.resolve_kwargs()


def test_skip_type_checking_object() -> None:
"""Test if we can skip type only for type checking for the function."""
if TYPE_CHECKING:

class A:
pass

class Target:
def __call__(self, unknown: "A") -> None:
pass

with pytest.warns(
RuntimeWarning,
match=r"Cannot resolve.*object of class Target.*",
):
graph = DependencyGraph(target=Target())
with graph.sync_ctx() as ctx:
assert "unknown" not in ctx.resolve_kwargs()
Loading