Skip to content

Commit 208a668

Browse files
superbobrycopybara-github
authored andcommitted
Forward *args/**kwargs from functools.partial to the wrapper
PiperOrigin-RevId: 821720894
1 parent 268c04b commit 208a668

4 files changed

Lines changed: 124 additions & 8 deletions

File tree

pytype/abstract/_function_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def call(
192192
self.func.__self__ # pytype: disable=attribute-error
193193
)
194194
args = args.simplify(node, self.ctx, match_signature=sig)
195+
del sig
195196
posargs = [u.AssignToNewVariable(node) for u in args.posargs]
196197
namedargs = {
197198
k: u.AssignToNewVariable(node) for k, u in args.namedargs.items()

pytype/overlays/functools_overlay.py

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,12 @@ def new_slot(
5555
self.ctx,
5656
node,
5757
new.to_variable(node),
58-
function.Args((cls, *args), kwargs),
58+
function.Args(
59+
(cls, *args),
60+
kwargs,
61+
call_context.starargs,
62+
call_context.starstarargs,
63+
),
5964
fallback_to_unsolvable=False,
6065
)
6166
[specialized_obj] = specialized_obj.data
@@ -66,7 +71,7 @@ def new_slot(
6671
return node, obj.to_variable(node)
6772

6873
def get_own_new(self, node, value) -> tuple[cfg.CFGNode, cfg.Variable]:
69-
new = abstract.NativeFunction("__new__", self.new_slot, self.ctx)
74+
new = NativeFunction("__new__", self.new_slot, self.ctx)
7075
return node, new.to_variable(node)
7176

7277

@@ -76,6 +81,8 @@ def bind_partial(node, cls, args, kwargs, ctx) -> BoundPartial:
7681
obj.underlying = args[0]
7782
obj.args = args[1:]
7883
obj.kwargs = kwargs
84+
obj.starargs = call_context.starargs
85+
obj.starstarargs = call_context.starstarargs
7986
return obj
8087

8188

@@ -115,6 +122,21 @@ def call(
115122
) -> tuple[cfg.CFGNode, cfg.Variable]:
116123
# ``NativeFunction.call`` does not forward *args and **kwargs to the
117124
# underlying function, so we do it here to avoid changing core pytype APIs.
125+
#
126+
# The simplification below ensures that the *args/**kwargs cannot in fact
127+
# be split into individual arguments. This logic follow the implementation
128+
# in the base class.
129+
sig = None
130+
if isinstance(
131+
self.func.__self__, # pytype: disable=attribute-error
132+
abstract.CallableClass,
133+
):
134+
sig = function.Signature.from_callable(
135+
self.func.__self__ # pytype: disable=attribute-error
136+
)
137+
args = args.simplify(node, self.ctx, match_signature=sig)
138+
del sig
139+
118140
starargs = args.starargs
119141
starstarargs = args.starstarargs
120142
if starargs is not None:
@@ -131,6 +153,8 @@ class BoundPartial(abstract.Instance, mixin.HasSlots):
131153
underlying: cfg.Variable
132154
args: tuple[cfg.Variable, ...]
133155
kwargs: dict[str, cfg.Variable]
156+
starargs: cfg.Variable | None
157+
starstarargs: cfg.Variable | None
134158

135159
def __init__(self, ctx, cls, container=None):
136160
super().__init__(cls, ctx, container)
@@ -141,28 +165,72 @@ def __init__(self, ctx, cls, container=None):
141165

142166
def get_signatures(self) -> Sequence[function.Signature]:
143167
sigs = []
144-
args = function.Args(posargs=self.args, namedargs=self.kwargs)
168+
args = function.Args(
169+
self.args, self.kwargs, self.starargs, self.starstarargs
170+
)
145171
for data in self.underlying.data:
146172
for sig in function.get_signatures(data):
147173
# Use the partial arguments as defaults in the signature, making them
148174
# optional but overwritable.
149175
defaults = sig.defaults.copy()
176+
kwonly_params = [*sig.kwonly_params]
177+
bound_param_names = set()
178+
pos_only_count = sig.posonly_count
150179
for name, value, _ in sig.iter_args(args):
151-
if value is not None:
152-
defaults[name] = value
153-
sigs.append(sig._replace(defaults=defaults))
180+
if value is None:
181+
continue
182+
if sig.param_names.index(name) < sig.posonly_count:
183+
# The parameter is positional-only, meaning that it cannot be
184+
# overwritten via a keyword argument. Remove it.
185+
bound_param_names.add(name)
186+
sig.posonly_count -= 1
187+
continue
188+
if name not in sig.kwonly_params:
189+
# The parameter can be overwritten via a keyword argument. Note
190+
# that we still have to remove it from ``param_names`` to make
191+
# sure it cannot be bound by position.
192+
bound_param_names.add(name)
193+
kwonly_params.append(name)
194+
195+
defaults[name] = value
196+
197+
sigs.append(
198+
sig._replace(
199+
param_names=tuple(
200+
n for n in sig.param_names if n not in bound_param_names
201+
),
202+
posonly_count=pos_only_count,
203+
kwonly_params=tuple(kwonly_params),
204+
defaults=defaults,
205+
)
206+
)
154207
return sigs
155208

156209
def call_slot(self, node: cfg.CFGNode, *args, **kwargs):
210+
if self.starargs and call_context.starargs:
211+
combined_starargs = self.ctx.convert.build_tuple(
212+
node, (self.starargs, call_context.starargs)
213+
)
214+
else:
215+
combined_starargs = call_context.starargs or self.starargs
216+
217+
if self.starstarargs and call_context.starstarargs:
218+
d = abstract.Dict(self.ctx)
219+
d.update(node, self.starstarargs.data[0])
220+
d.update(node, call_context.starstarargs.data[0])
221+
combined_starstarargs = d.to_variable(node)
222+
else:
223+
combined_starstarargs = call_context.starstarargs or self.starstarargs
224+
157225
return function.call_function(
158226
self.ctx,
159227
node,
160228
self.underlying,
161229
function.Args(
162230
(*self.args, *args),
163231
{**self.kwargs, **kwargs},
164-
call_context.starargs,
165-
call_context.starstarargs,
232+
combined_starargs,
233+
combined_starstarargs,
166234
),
167235
fallback_to_unsolvable=False,
168236
)

pytype/tests/test_attr2.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,20 @@ class Foo:
230230
assert_type(foo.x, str)
231231
""")
232232

233+
def test_partial_with_positional_args_as_converter(self):
234+
self.Check("""
235+
import attr
236+
import functools
237+
def f(x: str, y: int) -> int:
238+
del x
239+
return y
240+
@attr.s
241+
class Foo:
242+
x = attr.ib(converter=functools.partial(f, "foo"))
243+
foo = Foo(x=0)
244+
assert_type(foo.x, int)
245+
""")
246+
233247
def test_partial_overloaded_as_converter(self):
234248
self.Check("""
235249
import attr

pytype/tests/test_functions1.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,39 @@ def f(a, b=None):
10791079
partial_f(0)
10801080
""")
10811081

1082+
def test_functools_partial_starstar(self):
1083+
self.Check("""
1084+
import functools
1085+
def f(*, a: str, b: int):
1086+
pass
1087+
1088+
def test(**kwargs):
1089+
partial_f = functools.partial(f, **kwargs)
1090+
partial_f()
1091+
""")
1092+
1093+
def test_functools_partial_called_with_starstar(self):
1094+
self.Check("""
1095+
import functools
1096+
def f(a: str, b: int, c: list):
1097+
pass
1098+
partial_f = functools.partial(f, "foo")
1099+
1100+
def test(**kwargs):
1101+
partial_f(42, **kwargs)
1102+
""")
1103+
1104+
def test_functools_starstar_everywhere(self):
1105+
self.Check("""
1106+
import functools
1107+
def f(*, a: str, b: int):
1108+
pass
1109+
1110+
def test(**kwargs):
1111+
partial_f = functools.partial(f, **kwargs)
1112+
partial_f(**kwargs)
1113+
""")
1114+
10821115
def test_functools_partial_overloaded(self):
10831116
self.Check("""
10841117
import functools

0 commit comments

Comments
 (0)