Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pytype/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,11 @@ def add_options(o, arglist):
_flag(
"--use-fiddle-overlay", False, "Support the third-party fiddle library."
),
_flag(
"--use-functools-partial-overlay",
False,
"Enable precise checks when calling functools.partial objects.",
),
] + _OPT_IN_FEATURES


Expand Down
2 changes: 2 additions & 0 deletions pytype/overlays/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ py_library(
DEPS
.overlay
.special_builtins
pytype.abstract.abstract
pytype.typegraph.cfg
)

py_library(
Expand Down
140 changes: 140 additions & 0 deletions pytype/overlays/functools_overlay.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
"""Overlay for functools."""

from __future__ import annotations

from collections.abc import Mapping, Sequence
import threading
from typing import Any, Self, TYPE_CHECKING

from pytype.abstract import abstract
from pytype.abstract import function
from pytype.abstract import mixin
from pytype.overlays import overlay
from pytype.overlays import special_builtins
from pytype.typegraph import cfg

if TYPE_CHECKING:
from pytype import context # pylint: disable=g-import-not-at-top


_MODULE_NAME = "functools"

Expand All @@ -15,5 +29,131 @@ def __init__(self, ctx):
"cached_property", special_builtins.Property.make_alias
),
}
if ctx.options.use_functools_partial_overlay:
member_map["partial"] = Partial
ast = ctx.loader.import_name(_MODULE_NAME)
super().__init__(ctx, _MODULE_NAME, member_map, ast)


class Partial(abstract.PyTDClass, mixin.HasSlots):
"""Implementation of functools.partial."""

def __init__(self, ctx: "context.Context", module: str):
pytd_cls = ctx.loader.lookup_pytd(module, "partial")
super().__init__("partial", pytd_cls, ctx)
mixin.HasSlots.init_mixin(self)

self._pytd_new = self.pytd_cls.Lookup("__new__")

def new_slot(
self, node, cls, *args, **kwargs
) -> tuple[cfg.CFGNode, cfg.Variable]:
# Make sure the call is well typed before binding the partial
new = self.ctx.convert.convert_pytd_function(self._pytd_new)
_, specialized_obj = function.call_function(
self.ctx,
node,
new.to_variable(node),
function.Args((cls, *args), kwargs),
fallback_to_unsolvable=False,
)
[specialized_obj] = specialized_obj.data
type_arg = specialized_obj.get_formal_type_parameter("_T")
[cls] = cls.data
cls = abstract.ParameterizedClass(cls, {"_T": type_arg}, self.ctx)
obj = bind_partial(node, cls, args, kwargs, self.ctx)
return node, obj.to_variable(node)

def get_own_new(self, node, value) -> tuple[cfg.CFGNode, cfg.Variable]:
new = abstract.NativeFunction("__new__", self.new_slot, self.ctx)
return node, new.to_variable(node)


def bind_partial(node, cls, args, kwargs, ctx) -> BoundPartial:
del node # Unused.
obj = BoundPartial(ctx, cls)
obj.underlying = args[0]
obj.args = args[1:]
obj.kwargs = kwargs
return obj


class CallContext(threading.local):
"""A thread-local context for ``NativeFunction.call``."""

starargs: cfg.Variable | None = None
starstarargs: cfg.Variable | None = None

def forward(
self, starargs: cfg.Variable | None, starstarargs: cfg.Variable | None
) -> Self:
self.starargs = starargs
self.starstarargs = starstarargs
return self

def __enter__(self) -> Self:
return self

def __exit__(self, *exc_info) -> None:
self.starargs = None
self.starstarargs = None


call_context = CallContext()


class NativeFunction(abstract.NativeFunction):
"""A native function that forwards *args and **kwargs to the underlying function."""

def call(
self,
node: cfg.CFGNode,
func: cfg.Binding,
args: function.Args,
alias_map: Any | None = None,
) -> tuple[cfg.CFGNode, cfg.Variable]:
# ``NativeFunction.call`` does not forward *args and **kwargs to the
# underlying function, so we do it here to avoid changing core pytype APIs.
starargs = args.starargs
starstarargs = args.starstarargs
if starargs is not None:
starargs = starargs.AssignToNewVariable(node)
if starstarargs is not None:
starstarargs = starstarargs.AssignToNewVariable(node)
with call_context.forward(starargs, starstarargs):
return super().call(node, func, args, alias_map)


class BoundPartial(abstract.Instance, mixin.HasSlots):
"""An instance of functools.partial."""

underlying: cfg.Variable
args: Sequence[cfg.Variable]
kwargs: Mapping[str, cfg.Variable]

def __init__(self, ctx, cls, container=None):
super().__init__(cls, ctx, container)
mixin.HasSlots.init_mixin(self)
self.set_slot(
"__call__", NativeFunction("__call__", self.call_slot, self.ctx)
)

@property
def func(self) -> cfg.Variable:
# The ``func`` attribute marks this class as a wrapper for
# ``maybe_unwrap_decorated_function``.
return self.underlying

def call_slot(self, node: cfg.CFGNode, *args, **kwargs):
return function.call_function(
self.ctx,
node,
self.underlying,
function.Args(
(*self.args, *args),
{**self.kwargs, **kwargs},
call_context.starargs,
call_context.starstarargs,
),
fallback_to_unsolvable=False,
)
1 change: 1 addition & 0 deletions pytype/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def setUp(self):
strict_primitive_comparisons=True,
strict_none_binding=True,
use_fiddle_overlay=True,
use_functools_partial_overlay=True,
use_rewrite=_USE_REWRITE,
validate_version=False,
)
Expand Down
34 changes: 34 additions & 0 deletions pytype/tests/test_functions1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,40 @@ def f(a, b=None):
partial_f(0)
""")

def test_functools_partial_overloaded(self):
self.Check("""
import functools
from typing import overload
@overload
def f(x: int) -> int: ...
@overload
def f(x: str) -> str: ...
def f(x):
return x
partial_f = functools.partial(f)
# TODO(slebedev): This should be functools.partial[int | str].
assert_type(partial_f, functools.partial)
assert_type(partial_f(1), int)
assert_type(partial_f("s"), str)
""")

def test_functools_partial_overloaded_with_star(self):
self.Check("""
import functools
from typing import overload
@overload
def f(x: int, y: int) -> int: ...
@overload
def f(x: str, y: str) -> str: ...
def f(x, y):
return x
partial_f = functools.partial(f, 42)
def test(*args):
# TODO(slebedev): This should be functools.partial[int].
assert_type(partial_f, functools.partial)
assert_type(partial_f(*args), int)
""")

def test_functools_partial_class(self):
self.Check("""
import functools
Expand Down
Loading