Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions pytype/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,7 +1516,12 @@ def _match_heterogeneous_tuple_instance(
# accidentally violate _satisfies_common_superclass.
new_substs = []
for instance_param in instance.pyval:
if copy_params_directly and instance_param.bindings:
if abstract_utils.is_var_splat(instance_param):
instance_param = abstract_utils.unwrap_splat(instance_param)
new_subst = self._match_all_bindings(
instance_param, class_param, subst, view
)
elif copy_params_directly and instance_param.bindings:
new_subst = {
class_param.full_name: view[instance_param].AssignToNewVariable(
self._node
Expand All @@ -1528,7 +1533,8 @@ def _match_heterogeneous_tuple_instance(
)
if new_subst is None:
return None
new_substs.append(new_subst)
if new_subst is not None:
new_substs.append(new_subst)
if new_substs:
subst = self._merge_substs(subst, new_substs)
if not instance.pyval:
Expand Down
5 changes: 4 additions & 1 deletion pytype/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,12 @@ def _value_to_parameter_types(self, node, v, instance, template, seen, view):
type_arguments = []
for t in template:
if isinstance(instance, abstract.Tuple):
elem_var = instance.pyval[t]
if abstract_utils.is_var_splat(elem_var):
elem_var = abstract_utils.unwrap_splat(elem_var)
param_values = {
val: view
for val in self._get_values(node, instance.pyval[t], view)
for val in self._get_values(node, elem_var, view)
}
elif instance.has_instance_type_parameter(t):
param_values = {
Expand Down
31 changes: 31 additions & 0 deletions pytype/tests/test_functions2.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,37 @@ def test_unpack_str(self):
""",
)

def test_unpack_tuple(self):
# The **kwargs unpacking in the wrapper seems to prevent pytype from
# eagerly expanding the splat in the tuple literal.
ty = self.Infer("""
def f(*, xs: tuple[int, ...], **kwargs: object):
def wrapper():
out = f(
xs=(42, *kwargs.pop("xs", ())),
**kwargs,
)()
return wrapper
""")
self.assertTypesMatchPytd(
ty,
"""
from typing import Any, Callable
def f(*, xs: tuple[int, ...], **kwargs: object) -> Callable[[], Any]: ...
""",
)

def test_unpack_tuple_invalid(self):
self.InferWithErrors("""
def f(*, xs: tuple[int, ...], **kwargs: object):
def wrapper():
out = f( # wrong-arg-types
xs=(object(), *kwargs.pop("xs", ())),
**kwargs,
)()
return wrapper
""")

def test_unpack_nonliteral(self):
ty = self.Infer("""
def f(x, **kwargs):
Expand Down
Loading