-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[ty] Enable narrowing for unions of TypedDict #25188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
31c7fd4
a1893c3
bac745f
b8e85df
b27ecd1
8802453
08145e0
356d7aa
9a9d64b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,8 +10,9 @@ use crate::types::typed_dict::{ | |
| }; | ||
| use crate::types::{ | ||
| CallableType, ClassLiteral, ClassType, IntersectionBuilder, IntersectionType, KnownClass, | ||
| KnownInstanceType, LiteralValueTypeKind, SpecialFormType, SubclassOfInner, SubclassOfType, | ||
| Truthiness, Type, TypeContext, TypeVarBoundOrConstraints, UnionBuilder, infer_expression_types, | ||
| KnownInstanceType, LiteralValueTypeKind, Parameter, Parameters, Signature, SpecialFormType, | ||
| SubclassOfInner, SubclassOfType, Truthiness, Type, TypeContext, TypeVarBoundOrConstraints, | ||
| UnionBuilder, infer_expression_types, | ||
| }; | ||
| use ty_python_core::expression::Expression; | ||
| use ty_python_core::place::{PlaceExpr, PlaceTable, ScopedPlaceId}; | ||
|
|
@@ -1395,8 +1396,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { | |
| } | ||
| } | ||
|
|
||
| // Narrow unions and intersections of `TypedDict` in cases where required keys are | ||
| // excluded: | ||
| // Narrow types when a key membership test proves that a key is present, and narrow unions | ||
| // and intersections of `TypedDict` when a key membership test proves that a required key is | ||
| // absent: | ||
| // | ||
| // class Foo(TypedDict): | ||
| // foo: int | ||
|
|
@@ -1412,11 +1414,38 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { | |
| && let rhs_type = inference.expression_type(&comparators[0]) | ||
| && is_or_contains_typeddict(self.db, rhs_type) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't necessary, but retaining it to avoid protocol in the type display where it isn't required.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can obviously gate the logic further (e.g., avoid intersecting with a
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, I did that part too because it was bothering me :)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, worryingly, adding this
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should try broadening this, it would be much more principled. But let's see about improving our subtyping and disjointness for these synthesized protocols first. And it could also have performance implications, so best done as a followup. |
||
| { | ||
| let key = key.value(self.db); | ||
| let apply_constraint = | ||
| |constraints: &mut NarrowingConstraints<'db>, | ||
| constraint: NarrowingConstraint<'db>| { | ||
| let comparator_place = PlaceExpr::try_from_expr(&comparators[0]) | ||
| .and_then(|place_expr| self.places().place_id(&place_expr)); | ||
| if let Some(place) = comparator_place { | ||
| constraints.insert(place, constraint.clone()); | ||
| } | ||
|
|
||
| let value_place = PlaceExpr::try_from_expr(rhs_expr) | ||
| .and_then(|place_expr| self.places().place_id(&place_expr)); | ||
| if value_place != comparator_place | ||
| && let Some(place) = value_place | ||
| { | ||
| constraints.insert(place, constraint); | ||
| } | ||
| }; | ||
|
|
||
| let is_positive_key_membership = is_positive == (ops[0] == ast::CmpOp::In); | ||
| if is_positive_key_membership { | ||
| let narrowed = self.narrow_with_present_key(rhs_type, key); | ||
| if narrowed != rhs_type.resolve_type_alias(self.db) { | ||
| apply_constraint(&mut constraints, NarrowingConstraint::replacement(narrowed)); | ||
| } | ||
| } | ||
|
|
||
| let is_negative_check = is_positive == (ops[0] == ast::CmpOp::NotIn); | ||
| if is_negative_check { | ||
| let requires_key = |td: TypedDictType<'db>| -> bool { | ||
| td.items(self.db) | ||
| .get(key.value(self.db)) | ||
| .get(key) | ||
| .is_some_and(TypedDictField::is_required) | ||
| }; | ||
|
|
||
|
|
@@ -1460,21 +1489,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { | |
| }; | ||
|
|
||
| if narrowed != resolved_rhs_type { | ||
| let constraint = NarrowingConstraint::replacement(narrowed); | ||
|
|
||
| let comparator_place = PlaceExpr::try_from_expr(&comparators[0]) | ||
| .and_then(|place_expr| self.places().place_id(&place_expr)); | ||
| if let Some(place) = comparator_place { | ||
| constraints.insert(place, constraint.clone()); | ||
| } | ||
|
|
||
| let value_place = PlaceExpr::try_from_expr(rhs_expr) | ||
| .and_then(|place_expr| self.places().place_id(&place_expr)); | ||
| if value_place != comparator_place | ||
| && let Some(place) = value_place | ||
| { | ||
| constraints.insert(place, constraint); | ||
| } | ||
| apply_constraint(&mut constraints, NarrowingConstraint::replacement(narrowed)); | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -2031,6 +2046,27 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { | |
| Some((place, NarrowingConstraint::intersection(intersection))) | ||
| } | ||
|
|
||
| fn narrow_with_present_key(&self, ty: Type<'db>, key: &str) -> Type<'db> { | ||
| let db = self.db; | ||
| let constrain = |ty, key_presence_constraint| { | ||
| IntersectionType::from_two_elements(db, ty, key_presence_constraint) | ||
| }; | ||
|
|
||
| match ty.resolve_type_alias(self.db) { | ||
| Type::Union(union) => UnionType::from_elements( | ||
| self.db, | ||
| union | ||
| .elements(self.db) | ||
| .iter() | ||
| .map(|element| self.narrow_with_present_key(*element, key)), | ||
| ), | ||
| resolved if is_or_contains_typeddict(self.db, resolved) => { | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (I believe |
||
| constrain(ty, typeddict_key_getitem_protocol(self.db, key)) | ||
| } | ||
| _ => constrain(ty, key_membership_contains_protocol(self.db, key)), | ||
| } | ||
| } | ||
|
Comment on lines
+2049
to
+2068
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can simplify this. You don't need to manually map over the union elements (the diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs
index bfcdc0c2a0..acccb1b7b7 100644
--- a/crates/ty_python_semantic/src/types/narrow.rs
+++ b/crates/ty_python_semantic/src/types/narrow.rs
@@ -2047,21 +2047,27 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
}
fn narrow_with_present_key(&self, ty: Type<'db>, key: &str) -> Type<'db> {
- let key_presence_constraint = typeddict_key_getitem_protocol(self.db, key);
-
- let db = self.db;
- let constrain = |ty| IntersectionType::from_two_elements(db, ty, key_presence_constraint);
-
- match ty.resolve_type_alias(self.db) {
- Type::Union(union) => UnionType::from_elements(
+ let signature = Signature::new(
+ Parameters::new(
self.db,
- union
- .elements(self.db)
- .iter()
- .map(|element| self.narrow_with_present_key(*element, key)),
+ [
+ Parameter::positional_only(Some(Name::new_static("self"))),
+ Parameter::positional_only(Some(Name::new_static("key")))
+ .with_annotated_type(Type::string_literal(self.db, key)),
+ ],
),
- _ => constrain(ty),
- }
+ Type::object(),
+ );
+
+ let key_presence_constraint = Type::protocol_with_methods(
+ self.db,
+ [(
+ "__getitem__",
+ CallableType::function_like(self.db, signature),
+ )],
+ );
+
+ IntersectionType::from_two_elements(self.db, ty, key_presence_constraint)
}
/// Narrow tagged unions of tuples with `Literal` elements.
@@ -2187,25 +2193,6 @@ fn is_or_contains_typeddict<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool {
}
}
-fn typeddict_key_getitem_protocol<'db>(db: &'db dyn Db, key: &str) -> Type<'db> {
- let signature = Signature::new(
- Parameters::new(
- db,
- [
- Parameter::positional_only(Some(Name::new_static("self"))),
- Parameter::positional_only(Some(Name::new_static("key")))
- .with_annotated_type(Type::string_literal(db, key)),
- ],
- ),
- Type::object(),
- );
-
- Type::protocol_with_methods(
- db,
- [("__getitem__", CallableType::function_like(db, signature))],
- )
-}
-
fn is_supported_tag_literal(ty: Type) -> bool {
matches!(
ty.as_literal_value_kind(),
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. having realised #25188 (comment), I no longer think that applying this patch directly would be a good idea :-) |
||
|
|
||
| /// Narrow tagged unions of tuples with `Literal` elements. | ||
| /// | ||
| /// Given a subscript expression like `t[0]` where `t` is a union of tuple types, and a | ||
|
|
@@ -2154,6 +2190,76 @@ fn is_or_contains_typeddict<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool { | |
| } | ||
| } | ||
|
|
||
| /// Return a synthesized protocol that represents safe subscript access for a present key on a | ||
| /// `TypedDict`-containing type. | ||
| /// | ||
| /// For `TypedDict`s, a positive key-membership test proves more than containment: it also makes | ||
| /// string-literal subscript access with that key valid. In the `if` branch below, the `Bar` arm | ||
| /// keeps its original shape but is intersected with this protocol so `u["foo"]` is accepted: | ||
| /// | ||
| /// ```python | ||
| /// class Foo(TypedDict): | ||
| /// foo: int | ||
| /// | ||
| /// class Bar(TypedDict): | ||
| /// bar: int | ||
| /// | ||
| /// def f(u: Foo | Bar): | ||
| /// if "foo" in u: | ||
| /// reveal_type(u["foo"]) # object | ||
| /// ``` | ||
| fn typeddict_key_getitem_protocol<'db>(db: &'db dyn Db, key: &str) -> Type<'db> { | ||
| let signature = Signature::new( | ||
| Parameters::new( | ||
| db, | ||
| [ | ||
| Parameter::positional_only(Some(Name::new_static("self"))), | ||
| Parameter::positional_only(Some(Name::new_static("key"))) | ||
| .with_annotated_type(Type::string_literal(db, key)), | ||
| ], | ||
| ), | ||
| Type::object(), | ||
| ); | ||
|
|
||
| Type::protocol_with_methods( | ||
| db, | ||
| [("__getitem__", CallableType::function_like(db, signature))], | ||
| ) | ||
| } | ||
|
|
||
| /// Return a synthesized protocol that records a true key-membership test without implying | ||
| /// subscript access. | ||
| /// | ||
| /// For non-`TypedDict` types, `"key" in value` only proves that membership is true. It does not | ||
| /// prove that `value["key"]` is valid: | ||
| /// | ||
| /// ```python | ||
| /// def f(s: Literal["abc"]): | ||
| /// if "a" in s: | ||
| /// s["a"] # Runtime `TypeError` | ||
| /// ``` | ||
| /// | ||
| /// Non-`TypedDict` union arms therefore receive this `__contains__` protocol instead of the | ||
| /// `__getitem__` protocol used for `TypedDict` arms. | ||
| fn key_membership_contains_protocol<'db>(db: &'db dyn Db, key: &str) -> Type<'db> { | ||
| let signature = Signature::new( | ||
| Parameters::new( | ||
| db, | ||
| [ | ||
| Parameter::positional_only(Some(Name::new_static("self"))), | ||
| Parameter::positional_only(Some(Name::new_static("key"))) | ||
| .with_annotated_type(Type::string_literal(db, key)), | ||
| ], | ||
| ), | ||
| Type::bool_literal(true), | ||
| ); | ||
|
|
||
| Type::protocol_with_methods( | ||
| db, | ||
| [("__contains__", CallableType::function_like(db, signature))], | ||
| ) | ||
| } | ||
|
|
||
| fn is_supported_tag_literal(ty: Type) -> bool { | ||
| matches!( | ||
| ty.as_literal_value_kind(), | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, this does reveal a problem with the approach I was suggesting... just because
"a" in objresolves toTruedoesn't necessarily mean thatobj["a"]is a valid operation."abc"["a"]fails at runtime even though"a" in "abc"resolves toTrue.I think maybe that does suggest that something more similar to one of your earlier approaches might be better... for a
TypedDict(as a standalone type or a union member), we want to intersect with a protocol that has a__getitem__(self, key: Literal["a"]) -> objectmethod. But for any non-TypedDict, we probably just want to intersect with a protocol that has a__contains__(self, key: Literal["a"]) -> Literal[True]. Intersecting with the__getitem__protocol is only safe forTypedDicttypes, not for other types.In retrospect this feels obvious. Sorry for sending you round the houses on this, that only occurred to me now :-(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lol how did we not notice this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's also unsafe for
dictand other types due to subclassing?