Skip to content

Commit 88eb9bc

Browse files
superbobrycopybara-github
authored andcommitted
Added an overlay for functools.partial
The goal of the overlay is to delay return type inference until the partial object is called, allowing pytype to use all available arguments instead of just the ones provided to `functools.partial`. PiperOrigin-RevId: 748656406
1 parent 762b86d commit 88eb9bc

3 files changed

Lines changed: 120 additions & 0 deletions

File tree

pytype/overlays/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ py_library(
161161
DEPS
162162
.overlay
163163
.special_builtins
164+
pytype.abstract.abstract
165+
pytype.typegraph.cfg
164166
)
165167

166168
py_library(

pytype/overlays/functools_overlay.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
11
"""Overlay for functools."""
22

3+
from __future__ import annotations
4+
5+
from collections.abc import Mapping, Sequence
6+
from typing import TYPE_CHECKING
7+
8+
from pytype.abstract import abstract
9+
from pytype.abstract import function
10+
from pytype.abstract import mixin
311
from pytype.overlays import overlay
412
from pytype.overlays import special_builtins
13+
from pytype.typegraph import cfg
14+
15+
if TYPE_CHECKING:
16+
from pytype import context # pylint: disable=g-import-not-at-top
17+
518

619
_MODULE_NAME = "functools"
720

@@ -14,6 +27,80 @@ def __init__(self, ctx):
1427
"cached_property": overlay.add_name(
1528
"cached_property", special_builtins.Property.make_alias
1629
),
30+
"partial": Partial,
1731
}
1832
ast = ctx.loader.import_name(_MODULE_NAME)
1933
super().__init__(ctx, _MODULE_NAME, member_map, ast)
34+
35+
36+
class Partial(abstract.PyTDClass, mixin.HasSlots):
37+
"""Implementation of functools.partial."""
38+
39+
def __init__(self, ctx: "context.Context", module: str):
40+
pytd_cls = ctx.loader.lookup_pytd(module, "partial")
41+
super().__init__("partial", pytd_cls, ctx)
42+
mixin.HasSlots.init_mixin(self)
43+
44+
self._pytd_new = self.pytd_cls.Lookup("__new__")
45+
46+
def new_slot(
47+
self, node, cls, *args, **kwargs
48+
) -> tuple[cfg.CFGNode, cfg.Variable]:
49+
# Make sure the call is well typed before binding the partial
50+
new = self.ctx.convert.convert_pytd_function(self._pytd_new)
51+
_, specialized_obj = function.call_function(
52+
self.ctx,
53+
node,
54+
new.to_variable(node),
55+
function.Args(posargs=(cls, *args), namedargs=kwargs),
56+
fallback_to_unsolvable=False,
57+
)
58+
[specialized_obj] = specialized_obj.data
59+
type_arg = specialized_obj.get_formal_type_parameter("_T")
60+
[cls] = cls.data
61+
cls = abstract.ParameterizedClass(cls, {"_T": type_arg}, self.ctx)
62+
obj = bind_partial(node, cls, args, kwargs, self.ctx)
63+
return node, obj.to_variable(node)
64+
65+
def get_own_new(self, node, value) -> tuple[cfg.CFGNode, cfg.Variable]:
66+
new = abstract.NativeFunction("__new__", self.new_slot, self.ctx)
67+
return node, new.to_variable(node)
68+
69+
70+
def bind_partial(node, cls, args, kwargs, ctx) -> BoundPartial:
71+
del node # Unused.
72+
obj = BoundPartial(ctx, cls)
73+
obj.underlying = args[0]
74+
obj.args = args[1:]
75+
obj.kwargs = kwargs
76+
return obj
77+
78+
79+
class BoundPartial(abstract.Instance, mixin.HasSlots):
80+
"""An instance of functools.partial."""
81+
82+
underlying: cfg.Variable
83+
args: Sequence[cfg.Variable]
84+
kwargs: Mapping[str, cfg.Variable]
85+
86+
def __init__(self, ctx, cls, container=None):
87+
super().__init__(cls, ctx, container)
88+
mixin.HasSlots.init_mixin(self)
89+
self.set_native_slot("__call__", self.call_slot)
90+
91+
@property
92+
def func(self) -> cfg.Variable:
93+
# The ``func`` attribute marks this class as a wrapper for
94+
# ``maybe_unwrap_decorated_function``.
95+
return self.underlying
96+
97+
def call_slot(self, node: cfg.CFGNode, *args, **kwargs):
98+
return function.call_function(
99+
self.ctx,
100+
node,
101+
self.underlying,
102+
function.Args(
103+
posargs=(*self.args, *args), namedargs={**self.kwargs, **kwargs}
104+
),
105+
fallback_to_unsolvable=False,
106+
)

pytype/tests/test_functions1.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,37 @@ def f(a, b=None):
10791079
partial_f(0)
10801080
""")
10811081

1082+
def test_functools_partial_overloaded1(self):
1083+
self.Check("""
1084+
import functools
1085+
from typing import overload
1086+
@overload
1087+
def f(x: int) -> int: ...
1088+
@overload
1089+
def f(x: str) -> str: ...
1090+
def f(x):
1091+
return x
1092+
partial_f = functools.partial(f)
1093+
# This should be functools.partial[int | str], but pytype does not
1094+
# seem to infer that.
1095+
assert_type(partial_f, functools.partial)
1096+
assert_type(partial_f(1), int)
1097+
assert_type(partial_f("s"), str)
1098+
""")
1099+
1100+
def test_functools_partial_overloaded2(self):
1101+
self.Check("""
1102+
import functools
1103+
from typing import overload
1104+
@overload
1105+
def f(x: int, y: int) -> int: ...
1106+
@overload
1107+
def f(x: str, y: str) -> str: ...
1108+
def f(x, y):
1109+
return x
1110+
partial_f = functools.partial(f, 42)
1111+
""")
1112+
10821113
def test_functools_partial_class(self):
10831114
self.Check("""
10841115
import functools

0 commit comments

Comments
 (0)