Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,10 @@ def mutable_global_rhs(x: str | None, unavailable: set[str | None]) -> None:
reveal_type(x) # revealed: str | None
```

## No narrowing for the right-hand side (currently)
## No present-key narrowing without a `TypedDict`

No narrowing is done for the right-hand side currently, even if the right-hand side is a valid
"target" (name/attribute/subscript) that could potentially be narrowed. We may change this in the
future:
We only synthesize a key-access protocol for string membership tests on right-hand-side values that
include a `TypedDict`. Other membership tests can mean substring or element containment instead:

```py
from typing import Literal
Expand Down
72 changes: 56 additions & 16 deletions crates/ty_python_semantic/resources/mdtest/typed_dict.md
Original file line number Diff line number Diff line change
Expand Up @@ -4509,10 +4509,10 @@ def _(u: Foo | Bar | NotADict):
reveal_type(u) # revealed: Bar | NotADict
```

It would be nice if we could also narrow `TypedDict` unions by checking whether a key (which only
shows up in a subset of the union members) is present, but that isn't generally correct, because
"extra items" are allowed by default. For example, even though `Bar` here doesn't define a `"foo"`
field, it could be _assigned to_ with another `TypedDict` that does:
We can also narrow `TypedDict` unions by checking whether a key (which only shows up in a subset of
the union members) is present. We can't filter the union down to just the `TypedDict`s that declare
the key, because "extra items" are allowed by default. For example, even though `Bar` here doesn't
define a `"foo"` field, it could be _assigned to_ with another `TypedDict` that does:

```py
from typing_extensions import Literal
Expand All @@ -4525,14 +4525,16 @@ class Bar(TypedDict):

def disappointment(u: Foo | Bar, v: Literal["foo"]):
if "foo" in u:
# We can't narrow the union here...
reveal_type(u) # revealed: Foo | Bar
# We don't narrow to just `Foo` here...
reveal_type(u) # revealed: Foo | (Bar & <Protocol with members '__getitem__'>)
reveal_type(u["foo"]) # revealed: object
else:
# ...(even though we *can* narrow it here)...
reveal_type(u) # revealed: Bar

if v in u:
reveal_type(u) # revealed: Foo | Bar
reveal_type(u) # revealed: Foo | (Bar & <Protocol with members '__getitem__'>)
reveal_type(u["foo"]) # revealed: object
else:
reveal_type(u) # revealed: Bar

Expand All @@ -4543,6 +4545,42 @@ class FooBar(TypedDict):

static_assert(is_assignable_to(FooBar, Foo))
static_assert(is_assignable_to(FooBar, Bar))

def dictionary_union(u: Foo | dict[Literal["a", "b"], int]):
if "c" in u:
# TODO: This should stop erroring if we prove that the `dict` arm cannot contain `"c"`.
# error: [invalid-argument-type]
reveal_type(u["c"]) # revealed: object

def literal_union(u: Foo | Literal["abc"]):
if "a" in u:
# revealed: (Foo & <Protocol with members '__getitem__'>) | (Literal["abc"] & <Protocol with members '__contains__'>)
reveal_type(u)
Comment on lines +4555 to +4558
Copy link
Copy Markdown
Member

@AlexWaygood AlexWaygood May 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, this does reveal a problem with the approach I was suggesting... just because "a" in obj resolves to True doesn't necessarily mean that obj["a"] is a valid operation. "abc"["a"] fails at runtime even though "a" in "abc" resolves to True.

I think maybe that does suggest that something more similar to one of your earlier approaches might be better... for a TypedDict (as a standalone type or a union member), we want to intersect with a protocol that has a __getitem__(self, key: Literal["a"]) -> object method. But for any non-TypedDict, we probably just want to intersect with a protocol that has a __contains__(self, key: Literal["a"]) -> Literal[True]. Intersecting with the __getitem__ protocol is only safe for TypedDict types, not for other types.

In retrospect this feels obvious. Sorry for sending you round the houses on this, that only occurred to me now :-(

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol how did we not notice this

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's also unsafe for dict and other types due to subclassing?


def literal_union_key_access(obj: Foo | Literal["a"]):
if "a" in obj:
# Membership in a string does not imply that the string supports subscripting with that key.
# error: [invalid-argument-type]
reveal_type(obj["a"]) # revealed: object
```

This still accepts guarded key access in the branch, without pretending that an open `TypedDict`
must be one of the union members that explicitly declares the key:

```py
from typing import TypedDict

class FileWithBytes(TypedDict):
bytes: bytes

class FileWithUri(TypedDict):
uri: str

def get_bytes(file_content: FileWithBytes | FileWithUri) -> object:
if "bytes" in file_content:
reveal_type(file_content["bytes"]) # revealed: object
return file_content["bytes"]
raise ValueError
```

`not in` works in the opposite way to `in`: we can narrow in the positive case, but we cannot narrow
Expand All @@ -4565,7 +4603,9 @@ def _(t: Bar, u: Foo | Intersection[Bar, Any], v: Intersection[Bar, Any], w: Lit
if "bar" not in u:
reveal_type(u) # revealed: Foo
else:
reveal_type(u) # revealed: Foo | (Bar & Any)
# TODO: This should simplify to `Foo | (Bar & Any)`, since `Foo` is a
# subtype of the synthesized protocol.
reveal_type(u) # revealed: (Foo & <Protocol with members '__getitem__'>) | (Bar & Any)
Comment thread
charliermarsh marked this conversation as resolved.

if "bar" not in v:
reveal_type(v) # revealed: Never
Expand All @@ -4575,12 +4615,12 @@ def _(t: Bar, u: Foo | Intersection[Bar, Any], v: Intersection[Bar, Any], w: Lit
if w not in u:
reveal_type(u) # revealed: Foo
else:
reveal_type(u) # revealed: Foo | (Bar & Any)
reveal_type(u) # revealed: (Foo & <Protocol with members '__getitem__'>) | (Bar & Any)

if "bar" not in (u2 := u):
reveal_type(u2) # revealed: Foo
else:
reveal_type(u2) # revealed: Foo | (Bar & Any)
reveal_type(u2) # revealed: (Foo & <Protocol with members '__getitem__'>) | (Bar & Any)
```

With `closed=True`, the narrowing that we couldn't do above becomes possible, because a [closed]
Expand All @@ -4598,13 +4638,13 @@ class ClosedBar(TypedDict, closed=True):
def _(u: ClosedFoo | ClosedBar, v: Literal["foo"]):
if "foo" in u:
# TODO: should be `ClosedFoo`
reveal_type(u) # revealed: ClosedFoo | ClosedBar
reveal_type(u) # revealed: ClosedFoo | (ClosedBar & <Protocol with members '__getitem__'>)
else:
reveal_type(u) # revealed: ClosedBar

if v in u:
# TODO: should be `ClosedFoo`
reveal_type(u) # revealed: ClosedFoo | ClosedBar
reveal_type(u) # revealed: ClosedFoo | (ClosedBar & <Protocol with members '__getitem__'>)
else:
reveal_type(u) # revealed: ClosedBar
```
Expand All @@ -4625,7 +4665,7 @@ def _(
reveal_type(u) # revealed: ClosedFoo
else:
# TODO: should be `ClosedBar & Any`
reveal_type(u) # revealed: ClosedFoo | (ClosedBar & Any)
reveal_type(u) # revealed: (ClosedFoo & <Protocol with members '__getitem__'>) | (ClosedBar & Any)

if "bar" not in v:
reveal_type(v) # revealed: Never
Expand All @@ -4636,7 +4676,7 @@ def _(
reveal_type(u) # revealed: ClosedFoo
else:
# TODO: should be `ClosedBar & Any`
reveal_type(u) # revealed: ClosedFoo | (ClosedBar & Any)
reveal_type(u) # revealed: (ClosedFoo & <Protocol with members '__getitem__'>) | (ClosedBar & Any)
```

## Narrowing tagged unions of `TypedDict`s with `match` statements
Expand Down Expand Up @@ -4792,7 +4832,7 @@ def test_in(x: ThingWithBaz):
if "baz" not in x:
reveal_type(x) # revealed: Foo
else:
reveal_type(x) # revealed: Foo | Baz
reveal_type(x) # revealed: (Foo & <Protocol with members '__getitem__'>) | Baz
```

Nested PEP 695 type aliases (an alias referring to another alias) also work:
Expand Down Expand Up @@ -4821,7 +4861,7 @@ def test_nested_in(x: OuterWithBaz):
if "baz" not in x:
reveal_type(x) # revealed: Foo
else:
reveal_type(x) # revealed: Foo | Baz
reveal_type(x) # revealed: (Foo & <Protocol with members '__getitem__'>) | Baz
```

## Only annotated declarations are allowed in the class body
Expand Down
146 changes: 126 additions & 20 deletions crates/ty_python_semantic/src/types/narrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ use crate::types::typed_dict::{
};
use crate::types::{
CallableType, ClassLiteral, ClassType, IntersectionBuilder, IntersectionType, KnownClass,
KnownInstanceType, LiteralValueTypeKind, SpecialFormType, SubclassOfInner, SubclassOfType,
Truthiness, Type, TypeContext, TypeVarBoundOrConstraints, UnionBuilder, infer_expression_types,
KnownInstanceType, LiteralValueTypeKind, Parameter, Parameters, Signature, SpecialFormType,
SubclassOfInner, SubclassOfType, Truthiness, Type, TypeContext, TypeVarBoundOrConstraints,
UnionBuilder, infer_expression_types,
};
use ty_python_core::expression::Expression;
use ty_python_core::place::{PlaceExpr, PlaceTable, ScopedPlaceId};
Expand Down Expand Up @@ -1395,8 +1396,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
}
}

// Narrow unions and intersections of `TypedDict` in cases where required keys are
// excluded:
// Narrow types when a key membership test proves that a key is present, and narrow unions
// and intersections of `TypedDict` when a key membership test proves that a required key is
// absent:
//
// class Foo(TypedDict):
// foo: int
Expand All @@ -1412,11 +1414,38 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
&& let rhs_type = inference.expression_type(&comparators[0])
&& is_or_contains_typeddict(self.db, rhs_type)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't necessary, but retaining it to avoid protocol in the type display where it isn't required.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can obviously gate the logic further (e.g., avoid intersecting with a Literal within a Union, etc.) if we want.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I did that part too because it was bothering me :)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, worryingly, adding this is_or_contains_typeddict gate (I removed it in a prior commit) materially changed the diagnostics:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should try broadening this, it would be much more principled. But let's see about improving our subtyping and disjointness for these synthesized protocols first. And it could also have performance implications, so best done as a followup.

{
let key = key.value(self.db);
let apply_constraint =
|constraints: &mut NarrowingConstraints<'db>,
constraint: NarrowingConstraint<'db>| {
let comparator_place = PlaceExpr::try_from_expr(&comparators[0])
.and_then(|place_expr| self.places().place_id(&place_expr));
if let Some(place) = comparator_place {
constraints.insert(place, constraint.clone());
}

let value_place = PlaceExpr::try_from_expr(rhs_expr)
.and_then(|place_expr| self.places().place_id(&place_expr));
if value_place != comparator_place
&& let Some(place) = value_place
{
constraints.insert(place, constraint);
}
};

let is_positive_key_membership = is_positive == (ops[0] == ast::CmpOp::In);
if is_positive_key_membership {
let narrowed = self.narrow_with_present_key(rhs_type, key);
if narrowed != rhs_type.resolve_type_alias(self.db) {
apply_constraint(&mut constraints, NarrowingConstraint::replacement(narrowed));
}
}

let is_negative_check = is_positive == (ops[0] == ast::CmpOp::NotIn);
if is_negative_check {
let requires_key = |td: TypedDictType<'db>| -> bool {
td.items(self.db)
.get(key.value(self.db))
.get(key)
.is_some_and(TypedDictField::is_required)
};

Expand Down Expand Up @@ -1460,21 +1489,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
};

if narrowed != resolved_rhs_type {
let constraint = NarrowingConstraint::replacement(narrowed);

let comparator_place = PlaceExpr::try_from_expr(&comparators[0])
.and_then(|place_expr| self.places().place_id(&place_expr));
if let Some(place) = comparator_place {
constraints.insert(place, constraint.clone());
}

let value_place = PlaceExpr::try_from_expr(rhs_expr)
.and_then(|place_expr| self.places().place_id(&place_expr));
if value_place != comparator_place
&& let Some(place) = value_place
{
constraints.insert(place, constraint);
}
apply_constraint(&mut constraints, NarrowingConstraint::replacement(narrowed));
}
}
}
Expand Down Expand Up @@ -2031,6 +2046,27 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
Some((place, NarrowingConstraint::intersection(intersection)))
}

fn narrow_with_present_key(&self, ty: Type<'db>, key: &str) -> Type<'db> {
let db = self.db;
let constrain = |ty, key_presence_constraint| {
IntersectionType::from_two_elements(db, ty, key_presence_constraint)
};

match ty.resolve_type_alias(self.db) {
Type::Union(union) => UnionType::from_elements(
self.db,
union
.elements(self.db)
.iter()
.map(|element| self.narrow_with_present_key(*element, key)),
),
resolved if is_or_contains_typeddict(self.db, resolved) => {
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I believe is_or_contains_typeddict in this case is just checking for Type::TypedDict and intersections, since unions should be covered by the above. We could use a separate check but doesn't feel critical.)

constrain(ty, typeddict_key_getitem_protocol(self.db, key))
}
_ => constrain(ty, key_membership_contains_protocol(self.db, key)),
}
}
Comment on lines +2049 to +2068
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can simplify this. You don't need to manually map over the union elements (the IntersectionBuilder already does that). And I'd just inline the typeddict_key_getitem_protocol function, which is only called from this single callsite:

diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs
index bfcdc0c2a0..acccb1b7b7 100644
--- a/crates/ty_python_semantic/src/types/narrow.rs
+++ b/crates/ty_python_semantic/src/types/narrow.rs
@@ -2047,21 +2047,27 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
     }

     fn narrow_with_present_key(&self, ty: Type<'db>, key: &str) -> Type<'db> {
-        let key_presence_constraint = typeddict_key_getitem_protocol(self.db, key);
-
-        let db = self.db;
-        let constrain = |ty| IntersectionType::from_two_elements(db, ty, key_presence_constraint);
-
-        match ty.resolve_type_alias(self.db) {
-            Type::Union(union) => UnionType::from_elements(
+        let signature = Signature::new(
+            Parameters::new(
                 self.db,
-                union
-                    .elements(self.db)
-                    .iter()
-                    .map(|element| self.narrow_with_present_key(*element, key)),
+                [
+                    Parameter::positional_only(Some(Name::new_static("self"))),
+                    Parameter::positional_only(Some(Name::new_static("key")))
+                        .with_annotated_type(Type::string_literal(self.db, key)),
+                ],
             ),
-            _ => constrain(ty),
-        }
+            Type::object(),
+        );
+
+        let key_presence_constraint = Type::protocol_with_methods(
+            self.db,
+            [(
+                "__getitem__",
+                CallableType::function_like(self.db, signature),
+            )],
+        );
+
+        IntersectionType::from_two_elements(self.db, ty, key_presence_constraint)
     }

     /// Narrow tagged unions of tuples with `Literal` elements.
@@ -2187,25 +2193,6 @@ fn is_or_contains_typeddict<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool {
     }
 }

-fn typeddict_key_getitem_protocol<'db>(db: &'db dyn Db, key: &str) -> Type<'db> {
-    let signature = Signature::new(
-        Parameters::new(
-            db,
-            [
-                Parameter::positional_only(Some(Name::new_static("self"))),
-                Parameter::positional_only(Some(Name::new_static("key")))
-                    .with_annotated_type(Type::string_literal(db, key)),
-            ],
-        ),
-        Type::object(),
-    );
-
-    Type::protocol_with_methods(
-        db,
-        [("__getitem__", CallableType::function_like(db, signature))],
-    )
-}
-
 fn is_supported_tag_literal(ty: Type) -> bool {
     matches!(
         ty.as_literal_value_kind(),

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

having realised #25188 (comment), I no longer think that applying this patch directly would be a good idea :-)


/// Narrow tagged unions of tuples with `Literal` elements.
///
/// Given a subscript expression like `t[0]` where `t` is a union of tuple types, and a
Expand Down Expand Up @@ -2154,6 +2190,76 @@ fn is_or_contains_typeddict<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool {
}
}

/// Return a synthesized protocol that represents safe subscript access for a present key on a
/// `TypedDict`-containing type.
///
/// For `TypedDict`s, a positive key-membership test proves more than containment: it also makes
/// string-literal subscript access with that key valid. In the `if` branch below, the `Bar` arm
/// keeps its original shape but is intersected with this protocol so `u["foo"]` is accepted:
///
/// ```python
/// class Foo(TypedDict):
/// foo: int
///
/// class Bar(TypedDict):
/// bar: int
///
/// def f(u: Foo | Bar):
/// if "foo" in u:
/// reveal_type(u["foo"]) # object
/// ```
fn typeddict_key_getitem_protocol<'db>(db: &'db dyn Db, key: &str) -> Type<'db> {
let signature = Signature::new(
Parameters::new(
db,
[
Parameter::positional_only(Some(Name::new_static("self"))),
Parameter::positional_only(Some(Name::new_static("key")))
.with_annotated_type(Type::string_literal(db, key)),
],
),
Type::object(),
);

Type::protocol_with_methods(
db,
[("__getitem__", CallableType::function_like(db, signature))],
)
}

/// Return a synthesized protocol that records a true key-membership test without implying
/// subscript access.
///
/// For non-`TypedDict` types, `"key" in value` only proves that membership is true. It does not
/// prove that `value["key"]` is valid:
///
/// ```python
/// def f(s: Literal["abc"]):
/// if "a" in s:
/// s["a"] # Runtime `TypeError`
/// ```
///
/// Non-`TypedDict` union arms therefore receive this `__contains__` protocol instead of the
/// `__getitem__` protocol used for `TypedDict` arms.
fn key_membership_contains_protocol<'db>(db: &'db dyn Db, key: &str) -> Type<'db> {
let signature = Signature::new(
Parameters::new(
db,
[
Parameter::positional_only(Some(Name::new_static("self"))),
Parameter::positional_only(Some(Name::new_static("key")))
.with_annotated_type(Type::string_literal(db, key)),
],
),
Type::bool_literal(true),
);

Type::protocol_with_methods(
db,
[("__contains__", CallableType::function_like(db, signature))],
)
}

fn is_supported_tag_literal(ty: Type) -> bool {
matches!(
ty.as_literal_value_kind(),
Expand Down
Loading