diff --git a/pytype/config.py b/pytype/config.py index 71c3c6ddb..297412508 100644 --- a/pytype/config.py +++ b/pytype/config.py @@ -295,6 +295,11 @@ def add_options(o, arglist): _flag( "--use-fiddle-overlay", False, "Support the third-party fiddle library." ), + _flag( + "--use-functools-partial-overlay", + False, + "Enable precise checks when calling functools.partial objects.", + ), ] + _OPT_IN_FEATURES diff --git a/pytype/overlays/CMakeLists.txt b/pytype/overlays/CMakeLists.txt index b1bd1118b..259ec624b 100644 --- a/pytype/overlays/CMakeLists.txt +++ b/pytype/overlays/CMakeLists.txt @@ -161,6 +161,8 @@ py_library( DEPS .overlay .special_builtins + pytype.abstract.abstract + pytype.typegraph.cfg ) py_library( diff --git a/pytype/overlays/functools_overlay.py b/pytype/overlays/functools_overlay.py index eb2cc0430..891258336 100644 --- a/pytype/overlays/functools_overlay.py +++ b/pytype/overlays/functools_overlay.py @@ -1,7 +1,21 @@ """Overlay for functools.""" +from __future__ import annotations + +from collections.abc import Mapping, Sequence +import threading +from typing import Any, Self, TYPE_CHECKING + +from pytype.abstract import abstract +from pytype.abstract import function +from pytype.abstract import mixin from pytype.overlays import overlay from pytype.overlays import special_builtins +from pytype.typegraph import cfg + +if TYPE_CHECKING: + from pytype import context # pylint: disable=g-import-not-at-top + _MODULE_NAME = "functools" @@ -15,5 +29,131 @@ def __init__(self, ctx): "cached_property", special_builtins.Property.make_alias ), } + if ctx.options.use_functools_partial_overlay: + member_map["partial"] = Partial ast = ctx.loader.import_name(_MODULE_NAME) super().__init__(ctx, _MODULE_NAME, member_map, ast) + + +class Partial(abstract.PyTDClass, mixin.HasSlots): + """Implementation of functools.partial.""" + + def __init__(self, ctx: "context.Context", module: str): + pytd_cls = ctx.loader.lookup_pytd(module, "partial") + super().__init__("partial", pytd_cls, ctx) + mixin.HasSlots.init_mixin(self) + + self._pytd_new = self.pytd_cls.Lookup("__new__") + + def new_slot( + self, node, cls, *args, **kwargs + ) -> tuple[cfg.CFGNode, cfg.Variable]: + # Make sure the call is well typed before binding the partial + new = self.ctx.convert.convert_pytd_function(self._pytd_new) + _, specialized_obj = function.call_function( + self.ctx, + node, + new.to_variable(node), + function.Args((cls, *args), kwargs), + fallback_to_unsolvable=False, + ) + [specialized_obj] = specialized_obj.data + type_arg = specialized_obj.get_formal_type_parameter("_T") + [cls] = cls.data + cls = abstract.ParameterizedClass(cls, {"_T": type_arg}, self.ctx) + obj = bind_partial(node, cls, args, kwargs, self.ctx) + return node, obj.to_variable(node) + + def get_own_new(self, node, value) -> tuple[cfg.CFGNode, cfg.Variable]: + new = abstract.NativeFunction("__new__", self.new_slot, self.ctx) + return node, new.to_variable(node) + + +def bind_partial(node, cls, args, kwargs, ctx) -> BoundPartial: + del node # Unused. + obj = BoundPartial(ctx, cls) + obj.underlying = args[0] + obj.args = args[1:] + obj.kwargs = kwargs + return obj + + +class CallContext(threading.local): + """A thread-local context for ``NativeFunction.call``.""" + + starargs: cfg.Variable | None = None + starstarargs: cfg.Variable | None = None + + def forward( + self, starargs: cfg.Variable | None, starstarargs: cfg.Variable | None + ) -> Self: + self.starargs = starargs + self.starstarargs = starstarargs + return self + + def __enter__(self) -> Self: + return self + + def __exit__(self, *exc_info) -> None: + self.starargs = None + self.starstarargs = None + + +call_context = CallContext() + + +class NativeFunction(abstract.NativeFunction): + """A native function that forwards *args and **kwargs to the underlying function.""" + + def call( + self, + node: cfg.CFGNode, + func: cfg.Binding, + args: function.Args, + alias_map: Any | None = None, + ) -> tuple[cfg.CFGNode, cfg.Variable]: + # ``NativeFunction.call`` does not forward *args and **kwargs to the + # underlying function, so we do it here to avoid changing core pytype APIs. + starargs = args.starargs + starstarargs = args.starstarargs + if starargs is not None: + starargs = starargs.AssignToNewVariable(node) + if starstarargs is not None: + starstarargs = starstarargs.AssignToNewVariable(node) + with call_context.forward(starargs, starstarargs): + return super().call(node, func, args, alias_map) + + +class BoundPartial(abstract.Instance, mixin.HasSlots): + """An instance of functools.partial.""" + + underlying: cfg.Variable + args: Sequence[cfg.Variable] + kwargs: Mapping[str, cfg.Variable] + + def __init__(self, ctx, cls, container=None): + super().__init__(cls, ctx, container) + mixin.HasSlots.init_mixin(self) + self.set_slot( + "__call__", NativeFunction("__call__", self.call_slot, self.ctx) + ) + + @property + def func(self) -> cfg.Variable: + # The ``func`` attribute marks this class as a wrapper for + # ``maybe_unwrap_decorated_function``. + return self.underlying + + def call_slot(self, node: cfg.CFGNode, *args, **kwargs): + return function.call_function( + self.ctx, + node, + self.underlying, + function.Args( + (*self.args, *args), + {**self.kwargs, **kwargs}, + call_context.starargs, + call_context.starstarargs, + ), + fallback_to_unsolvable=False, + ) diff --git a/pytype/tests/test_base.py b/pytype/tests/test_base.py index 48a1a9a5c..29b48902a 100644 --- a/pytype/tests/test_base.py +++ b/pytype/tests/test_base.py @@ -93,6 +93,7 @@ def setUp(self): strict_primitive_comparisons=True, strict_none_binding=True, use_fiddle_overlay=True, + use_functools_partial_overlay=True, use_rewrite=_USE_REWRITE, validate_version=False, ) diff --git a/pytype/tests/test_functions1.py b/pytype/tests/test_functions1.py index 03c81e2c3..e2fa18a75 100644 --- a/pytype/tests/test_functions1.py +++ b/pytype/tests/test_functions1.py @@ -1079,6 +1079,40 @@ def f(a, b=None): partial_f(0) """) + def test_functools_partial_overloaded(self): + self.Check(""" + import functools + from typing import overload + @overload + def f(x: int) -> int: ... + @overload + def f(x: str) -> str: ... + def f(x): + return x + partial_f = functools.partial(f) + # TODO(slebedev): This should be functools.partial[int | str]. + assert_type(partial_f, functools.partial) + assert_type(partial_f(1), int) + assert_type(partial_f("s"), str) + """) + + def test_functools_partial_overloaded_with_star(self): + self.Check(""" + import functools + from typing import overload + @overload + def f(x: int, y: int) -> int: ... + @overload + def f(x: str, y: str) -> str: ... + def f(x, y): + return x + partial_f = functools.partial(f, 42) + def test(*args): + # TODO(slebedev): This should be functools.partial[int]. + assert_type(partial_f, functools.partial) + assert_type(partial_f(*args), int) + """) + def test_functools_partial_class(self): self.Check(""" import functools