diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/in.md b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/in.md index ec071bb6879a3..4d9247f72f919 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/in.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/in.md @@ -218,11 +218,10 @@ def mutable_global_rhs(x: str | None, unavailable: set[str | None]) -> None: reveal_type(x) # revealed: str | None ``` -## No narrowing for the right-hand side (currently) +## No present-key narrowing without a `TypedDict` -No narrowing is done for the right-hand side currently, even if the right-hand side is a valid -"target" (name/attribute/subscript) that could potentially be narrowed. We may change this in the -future: +We only synthesize a key-access protocol for string membership tests on right-hand-side values that +include a `TypedDict`. Other membership tests can mean substring or element containment instead: ```py from typing import Literal diff --git a/crates/ty_python_semantic/resources/mdtest/typed_dict.md b/crates/ty_python_semantic/resources/mdtest/typed_dict.md index b33d786a00406..d36de9046d3b3 100644 --- a/crates/ty_python_semantic/resources/mdtest/typed_dict.md +++ b/crates/ty_python_semantic/resources/mdtest/typed_dict.md @@ -4509,10 +4509,10 @@ def _(u: Foo | Bar | NotADict): reveal_type(u) # revealed: Bar | NotADict ``` -It would be nice if we could also narrow `TypedDict` unions by checking whether a key (which only -shows up in a subset of the union members) is present, but that isn't generally correct, because -"extra items" are allowed by default. For example, even though `Bar` here doesn't define a `"foo"` -field, it could be _assigned to_ with another `TypedDict` that does: +We can also narrow `TypedDict` unions by checking whether a key (which only shows up in a subset of +the union members) is present. We can't filter the union down to just the `TypedDict`s that declare +the key, because "extra items" are allowed by default. For example, even though `Bar` here doesn't +define a `"foo"` field, it could be _assigned to_ with another `TypedDict` that does: ```py from typing_extensions import Literal @@ -4525,14 +4525,16 @@ class Bar(TypedDict): def disappointment(u: Foo | Bar, v: Literal["foo"]): if "foo" in u: - # We can't narrow the union here... - reveal_type(u) # revealed: Foo | Bar + # We don't narrow to just `Foo` here... + reveal_type(u) # revealed: Foo | (Bar & ) + reveal_type(u["foo"]) # revealed: object else: # ...(even though we *can* narrow it here)... reveal_type(u) # revealed: Bar if v in u: - reveal_type(u) # revealed: Foo | Bar + reveal_type(u) # revealed: Foo | (Bar & ) + reveal_type(u["foo"]) # revealed: object else: reveal_type(u) # revealed: Bar @@ -4543,6 +4545,42 @@ class FooBar(TypedDict): static_assert(is_assignable_to(FooBar, Foo)) static_assert(is_assignable_to(FooBar, Bar)) + +def dictionary_union(u: Foo | dict[Literal["a", "b"], int]): + if "c" in u: + # TODO: This should stop erroring if we prove that the `dict` arm cannot contain `"c"`. + # error: [invalid-argument-type] + reveal_type(u["c"]) # revealed: object + +def literal_union(u: Foo | Literal["abc"]): + if "a" in u: + # revealed: (Foo & ) | (Literal["abc"] & ) + reveal_type(u) + +def literal_union_key_access(obj: Foo | Literal["a"]): + if "a" in obj: + # Membership in a string does not imply that the string supports subscripting with that key. + # error: [invalid-argument-type] + reveal_type(obj["a"]) # revealed: object +``` + +This still accepts guarded key access in the branch, without pretending that an open `TypedDict` +must be one of the union members that explicitly declares the key: + +```py +from typing import TypedDict + +class FileWithBytes(TypedDict): + bytes: bytes + +class FileWithUri(TypedDict): + uri: str + +def get_bytes(file_content: FileWithBytes | FileWithUri) -> object: + if "bytes" in file_content: + reveal_type(file_content["bytes"]) # revealed: object + return file_content["bytes"] + raise ValueError ``` `not in` works in the opposite way to `in`: we can narrow in the positive case, but we cannot narrow @@ -4565,7 +4603,9 @@ def _(t: Bar, u: Foo | Intersection[Bar, Any], v: Intersection[Bar, Any], w: Lit if "bar" not in u: reveal_type(u) # revealed: Foo else: - reveal_type(u) # revealed: Foo | (Bar & Any) + # TODO: This should simplify to `Foo | (Bar & Any)`, since `Foo` is a + # subtype of the synthesized protocol. + reveal_type(u) # revealed: (Foo & ) | (Bar & Any) if "bar" not in v: reveal_type(v) # revealed: Never @@ -4575,12 +4615,12 @@ def _(t: Bar, u: Foo | Intersection[Bar, Any], v: Intersection[Bar, Any], w: Lit if w not in u: reveal_type(u) # revealed: Foo else: - reveal_type(u) # revealed: Foo | (Bar & Any) + reveal_type(u) # revealed: (Foo & ) | (Bar & Any) if "bar" not in (u2 := u): reveal_type(u2) # revealed: Foo else: - reveal_type(u2) # revealed: Foo | (Bar & Any) + reveal_type(u2) # revealed: (Foo & ) | (Bar & Any) ``` With `closed=True`, the narrowing that we couldn't do above becomes possible, because a [closed] @@ -4598,13 +4638,13 @@ class ClosedBar(TypedDict, closed=True): def _(u: ClosedFoo | ClosedBar, v: Literal["foo"]): if "foo" in u: # TODO: should be `ClosedFoo` - reveal_type(u) # revealed: ClosedFoo | ClosedBar + reveal_type(u) # revealed: ClosedFoo | (ClosedBar & ) else: reveal_type(u) # revealed: ClosedBar if v in u: # TODO: should be `ClosedFoo` - reveal_type(u) # revealed: ClosedFoo | ClosedBar + reveal_type(u) # revealed: ClosedFoo | (ClosedBar & ) else: reveal_type(u) # revealed: ClosedBar ``` @@ -4625,7 +4665,7 @@ def _( reveal_type(u) # revealed: ClosedFoo else: # TODO: should be `ClosedBar & Any` - reveal_type(u) # revealed: ClosedFoo | (ClosedBar & Any) + reveal_type(u) # revealed: (ClosedFoo & ) | (ClosedBar & Any) if "bar" not in v: reveal_type(v) # revealed: Never @@ -4636,7 +4676,7 @@ def _( reveal_type(u) # revealed: ClosedFoo else: # TODO: should be `ClosedBar & Any` - reveal_type(u) # revealed: ClosedFoo | (ClosedBar & Any) + reveal_type(u) # revealed: (ClosedFoo & ) | (ClosedBar & Any) ``` ## Narrowing tagged unions of `TypedDict`s with `match` statements @@ -4792,7 +4832,7 @@ def test_in(x: ThingWithBaz): if "baz" not in x: reveal_type(x) # revealed: Foo else: - reveal_type(x) # revealed: Foo | Baz + reveal_type(x) # revealed: (Foo & ) | Baz ``` Nested PEP 695 type aliases (an alias referring to another alias) also work: @@ -4821,7 +4861,7 @@ def test_nested_in(x: OuterWithBaz): if "baz" not in x: reveal_type(x) # revealed: Foo else: - reveal_type(x) # revealed: Foo | Baz + reveal_type(x) # revealed: (Foo & ) | Baz ``` ## Only annotated declarations are allowed in the class body diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 8ff4549aa4748..6d11f79bb8fa5 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -10,8 +10,9 @@ use crate::types::typed_dict::{ }; use crate::types::{ CallableType, ClassLiteral, ClassType, IntersectionBuilder, IntersectionType, KnownClass, - KnownInstanceType, LiteralValueTypeKind, SpecialFormType, SubclassOfInner, SubclassOfType, - Truthiness, Type, TypeContext, TypeVarBoundOrConstraints, UnionBuilder, infer_expression_types, + KnownInstanceType, LiteralValueTypeKind, Parameter, Parameters, Signature, SpecialFormType, + SubclassOfInner, SubclassOfType, Truthiness, Type, TypeContext, TypeVarBoundOrConstraints, + UnionBuilder, infer_expression_types, }; use ty_python_core::expression::Expression; use ty_python_core::place::{PlaceExpr, PlaceTable, ScopedPlaceId}; @@ -1395,8 +1396,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { } } - // Narrow unions and intersections of `TypedDict` in cases where required keys are - // excluded: + // Narrow types when a key membership test proves that a key is present, and narrow unions + // and intersections of `TypedDict` when a key membership test proves that a required key is + // absent: // // class Foo(TypedDict): // foo: int @@ -1412,11 +1414,38 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { && let rhs_type = inference.expression_type(&comparators[0]) && is_or_contains_typeddict(self.db, rhs_type) { + let key = key.value(self.db); + let apply_constraint = + |constraints: &mut NarrowingConstraints<'db>, + constraint: NarrowingConstraint<'db>| { + let comparator_place = PlaceExpr::try_from_expr(&comparators[0]) + .and_then(|place_expr| self.places().place_id(&place_expr)); + if let Some(place) = comparator_place { + constraints.insert(place, constraint.clone()); + } + + let value_place = PlaceExpr::try_from_expr(rhs_expr) + .and_then(|place_expr| self.places().place_id(&place_expr)); + if value_place != comparator_place + && let Some(place) = value_place + { + constraints.insert(place, constraint); + } + }; + + let is_positive_key_membership = is_positive == (ops[0] == ast::CmpOp::In); + if is_positive_key_membership { + let narrowed = self.narrow_with_present_key(rhs_type, key); + if narrowed != rhs_type.resolve_type_alias(self.db) { + apply_constraint(&mut constraints, NarrowingConstraint::replacement(narrowed)); + } + } + let is_negative_check = is_positive == (ops[0] == ast::CmpOp::NotIn); if is_negative_check { let requires_key = |td: TypedDictType<'db>| -> bool { td.items(self.db) - .get(key.value(self.db)) + .get(key) .is_some_and(TypedDictField::is_required) }; @@ -1460,21 +1489,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { }; if narrowed != resolved_rhs_type { - let constraint = NarrowingConstraint::replacement(narrowed); - - let comparator_place = PlaceExpr::try_from_expr(&comparators[0]) - .and_then(|place_expr| self.places().place_id(&place_expr)); - if let Some(place) = comparator_place { - constraints.insert(place, constraint.clone()); - } - - let value_place = PlaceExpr::try_from_expr(rhs_expr) - .and_then(|place_expr| self.places().place_id(&place_expr)); - if value_place != comparator_place - && let Some(place) = value_place - { - constraints.insert(place, constraint); - } + apply_constraint(&mut constraints, NarrowingConstraint::replacement(narrowed)); } } } @@ -2031,6 +2046,27 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { Some((place, NarrowingConstraint::intersection(intersection))) } + fn narrow_with_present_key(&self, ty: Type<'db>, key: &str) -> Type<'db> { + let db = self.db; + let constrain = |ty, key_presence_constraint| { + IntersectionType::from_two_elements(db, ty, key_presence_constraint) + }; + + match ty.resolve_type_alias(self.db) { + Type::Union(union) => UnionType::from_elements( + self.db, + union + .elements(self.db) + .iter() + .map(|element| self.narrow_with_present_key(*element, key)), + ), + resolved if is_or_contains_typeddict(self.db, resolved) => { + constrain(ty, typeddict_key_getitem_protocol(self.db, key)) + } + _ => constrain(ty, key_membership_contains_protocol(self.db, key)), + } + } + /// Narrow tagged unions of tuples with `Literal` elements. /// /// Given a subscript expression like `t[0]` where `t` is a union of tuple types, and a @@ -2154,6 +2190,76 @@ fn is_or_contains_typeddict<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool { } } +/// Return a synthesized protocol that represents safe subscript access for a present key on a +/// `TypedDict`-containing type. +/// +/// For `TypedDict`s, a positive key-membership test proves more than containment: it also makes +/// string-literal subscript access with that key valid. In the `if` branch below, the `Bar` arm +/// keeps its original shape but is intersected with this protocol so `u["foo"]` is accepted: +/// +/// ```python +/// class Foo(TypedDict): +/// foo: int +/// +/// class Bar(TypedDict): +/// bar: int +/// +/// def f(u: Foo | Bar): +/// if "foo" in u: +/// reveal_type(u["foo"]) # object +/// ``` +fn typeddict_key_getitem_protocol<'db>(db: &'db dyn Db, key: &str) -> Type<'db> { + let signature = Signature::new( + Parameters::new( + db, + [ + Parameter::positional_only(Some(Name::new_static("self"))), + Parameter::positional_only(Some(Name::new_static("key"))) + .with_annotated_type(Type::string_literal(db, key)), + ], + ), + Type::object(), + ); + + Type::protocol_with_methods( + db, + [("__getitem__", CallableType::function_like(db, signature))], + ) +} + +/// Return a synthesized protocol that records a true key-membership test without implying +/// subscript access. +/// +/// For non-`TypedDict` types, `"key" in value` only proves that membership is true. It does not +/// prove that `value["key"]` is valid: +/// +/// ```python +/// def f(s: Literal["abc"]): +/// if "a" in s: +/// s["a"] # Runtime `TypeError` +/// ``` +/// +/// Non-`TypedDict` union arms therefore receive this `__contains__` protocol instead of the +/// `__getitem__` protocol used for `TypedDict` arms. +fn key_membership_contains_protocol<'db>(db: &'db dyn Db, key: &str) -> Type<'db> { + let signature = Signature::new( + Parameters::new( + db, + [ + Parameter::positional_only(Some(Name::new_static("self"))), + Parameter::positional_only(Some(Name::new_static("key"))) + .with_annotated_type(Type::string_literal(db, key)), + ], + ), + Type::bool_literal(true), + ); + + Type::protocol_with_methods( + db, + [("__contains__", CallableType::function_like(db, signature))], + ) +} + fn is_supported_tag_literal(ty: Type) -> bool { matches!( ty.as_literal_value_kind(),