From 9121672c8d31e3a0696f6d209c7a6f03f7f6c626 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 30 Sep 2025 01:36:08 -0700 Subject: [PATCH] Added an overlay for `functools.partial` The goal of the overlay is to delay return type inference until the partial object is called, allowing pytype to use all available arguments instead of just the ones provided to `functools.partial`. The overlay is currently gated by a feature flag to avoid breaking existing (but ill-typed) code. PiperOrigin-RevId: 813149022 --- pytype/config.py | 5 + pytype/overlays/CMakeLists.txt | 2 + pytype/overlays/functools_overlay.py | 140 +++++++++++++++++++++++++++ pytype/tests/test_base.py | 1 + pytype/tests/test_functions1.py | 34 +++++++ 5 files changed, 182 insertions(+) diff --git a/pytype/config.py b/pytype/config.py index 71c3c6ddb..297412508 100644 --- a/pytype/config.py +++ b/pytype/config.py @@ -295,6 +295,11 @@ def add_options(o, arglist): _flag( "--use-fiddle-overlay", False, "Support the third-party fiddle library." ), + _flag( + "--use-functools-partial-overlay", + False, + "Enable precise checks when calling functools.partial objects.", + ), ] + _OPT_IN_FEATURES diff --git a/pytype/overlays/CMakeLists.txt b/pytype/overlays/CMakeLists.txt index b1bd1118b..259ec624b 100644 --- a/pytype/overlays/CMakeLists.txt +++ b/pytype/overlays/CMakeLists.txt @@ -161,6 +161,8 @@ py_library( DEPS .overlay .special_builtins + pytype.abstract.abstract + pytype.typegraph.cfg ) py_library( diff --git a/pytype/overlays/functools_overlay.py b/pytype/overlays/functools_overlay.py index eb2cc0430..891258336 100644 --- a/pytype/overlays/functools_overlay.py +++ b/pytype/overlays/functools_overlay.py @@ -1,7 +1,21 @@ """Overlay for functools.""" +from __future__ import annotations + +from collections.abc import Mapping, Sequence +import threading +from typing import Any, Self, TYPE_CHECKING + +from pytype.abstract import abstract +from pytype.abstract import function +from pytype.abstract import mixin from pytype.overlays import overlay from pytype.overlays import special_builtins +from pytype.typegraph import cfg + +if TYPE_CHECKING: + from pytype import context # pylint: disable=g-import-not-at-top + _MODULE_NAME = "functools" @@ -15,5 +29,131 @@ def __init__(self, ctx): "cached_property", special_builtins.Property.make_alias ), } + if ctx.options.use_functools_partial_overlay: + member_map["partial"] = Partial ast = ctx.loader.import_name(_MODULE_NAME) super().__init__(ctx, _MODULE_NAME, member_map, ast) + + +class Partial(abstract.PyTDClass, mixin.HasSlots): + """Implementation of functools.partial.""" + + def __init__(self, ctx: "context.Context", module: str): + pytd_cls = ctx.loader.lookup_pytd(module, "partial") + super().__init__("partial", pytd_cls, ctx) + mixin.HasSlots.init_mixin(self) + + self._pytd_new = self.pytd_cls.Lookup("__new__") + + def new_slot( + self, node, cls, *args, **kwargs + ) -> tuple[cfg.CFGNode, cfg.Variable]: + # Make sure the call is well typed before binding the partial + new = self.ctx.convert.convert_pytd_function(self._pytd_new) + _, specialized_obj = function.call_function( + self.ctx, + node, + new.to_variable(node), + function.Args((cls, *args), kwargs), + fallback_to_unsolvable=False, + ) + [specialized_obj] = specialized_obj.data + type_arg = specialized_obj.get_formal_type_parameter("_T") + [cls] = cls.data + cls = abstract.ParameterizedClass(cls, {"_T": type_arg}, self.ctx) + obj = bind_partial(node, cls, args, kwargs, self.ctx) + return node, obj.to_variable(node) + + def get_own_new(self, node, value) -> tuple[cfg.CFGNode, cfg.Variable]: + new = abstract.NativeFunction("__new__", self.new_slot, self.ctx) + return node, new.to_variable(node) + + +def bind_partial(node, cls, args, kwargs, ctx) -> BoundPartial: + del node # Unused. + obj = BoundPartial(ctx, cls) + obj.underlying = args[0] + obj.args = args[1:] + obj.kwargs = kwargs + return obj + + +class CallContext(threading.local): + """A thread-local context for ``NativeFunction.call``.""" + + starargs: cfg.Variable | None = None + starstarargs: cfg.Variable | None = None + + def forward( + self, starargs: cfg.Variable | None, starstarargs: cfg.Variable | None + ) -> Self: + self.starargs = starargs + self.starstarargs = starstarargs + return self + + def __enter__(self) -> Self: + return self + + def __exit__(self, *exc_info) -> None: + self.starargs = None + self.starstarargs = None + + +call_context = CallContext() + + +class NativeFunction(abstract.NativeFunction): + """A native function that forwards *args and **kwargs to the underlying function.""" + + def call( + self, + node: cfg.CFGNode, + func: cfg.Binding, + args: function.Args, + alias_map: Any | None = None, + ) -> tuple[cfg.CFGNode, cfg.Variable]: + # ``NativeFunction.call`` does not forward *args and **kwargs to the + # underlying function, so we do it here to avoid changing core pytype APIs. + starargs = args.starargs + starstarargs = args.starstarargs + if starargs is not None: + starargs = starargs.AssignToNewVariable(node) + if starstarargs is not None: + starstarargs = starstarargs.AssignToNewVariable(node) + with call_context.forward(starargs, starstarargs): + return super().call(node, func, args, alias_map) + + +class BoundPartial(abstract.Instance, mixin.HasSlots): + """An instance of functools.partial.""" + + underlying: cfg.Variable + args: Sequence[cfg.Variable] + kwargs: Mapping[str, cfg.Variable] + + def __init__(self, ctx, cls, container=None): + super().__init__(cls, ctx, container) + mixin.HasSlots.init_mixin(self) + self.set_slot( + "__call__", NativeFunction("__call__", self.call_slot, self.ctx) + ) + + @property + def func(self) -> cfg.Variable: + # The ``func`` attribute marks this class as a wrapper for + # ``maybe_unwrap_decorated_function``. + return self.underlying + + def call_slot(self, node: cfg.CFGNode, *args, **kwargs): + return function.call_function( + self.ctx, + node, + self.underlying, + function.Args( + (*self.args, *args), + {**self.kwargs, **kwargs}, + call_context.starargs, + call_context.starstarargs, + ), + fallback_to_unsolvable=False, + ) diff --git a/pytype/tests/test_base.py b/pytype/tests/test_base.py index 48a1a9a5c..29b48902a 100644 --- a/pytype/tests/test_base.py +++ b/pytype/tests/test_base.py @@ -93,6 +93,7 @@ def setUp(self): strict_primitive_comparisons=True, strict_none_binding=True, use_fiddle_overlay=True, + use_functools_partial_overlay=True, use_rewrite=_USE_REWRITE, validate_version=False, ) diff --git a/pytype/tests/test_functions1.py b/pytype/tests/test_functions1.py index 03c81e2c3..e2fa18a75 100644 --- a/pytype/tests/test_functions1.py +++ b/pytype/tests/test_functions1.py @@ -1079,6 +1079,40 @@ def f(a, b=None): partial_f(0) """) + def test_functools_partial_overloaded(self): + self.Check(""" + import functools + from typing import overload + @overload + def f(x: int) -> int: ... + @overload + def f(x: str) -> str: ... + def f(x): + return x + partial_f = functools.partial(f) + # TODO(slebedev): This should be functools.partial[int | str]. + assert_type(partial_f, functools.partial) + assert_type(partial_f(1), int) + assert_type(partial_f("s"), str) + """) + + def test_functools_partial_overloaded_with_star(self): + self.Check(""" + import functools + from typing import overload + @overload + def f(x: int, y: int) -> int: ... + @overload + def f(x: str, y: str) -> str: ... + def f(x, y): + return x + partial_f = functools.partial(f, 42) + def test(*args): + # TODO(slebedev): This should be functools.partial[int]. + assert_type(partial_f, functools.partial) + assert_type(partial_f(*args), int) + """) + def test_functools_partial_class(self): self.Check(""" import functools