diff --git a/crates/ty_python_semantic/resources/mdtest/call/methods.md b/crates/ty_python_semantic/resources/mdtest/call/methods.md index 986f9ce21d7c37..3d531f55dc8e62 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/methods.md +++ b/crates/ty_python_semantic/resources/mdtest/call/methods.md @@ -659,6 +659,18 @@ class Base: class Valid(Base, arg=5, metaclass=object): ... ``` +Class keyword arguments are inferred with type context from the corresponding `__init_subclass__` +parameters: + +```py +class Base: + def __init_subclass__(cls, *, xs: list[int | None]) -> None: + pass + +# No error here: +class Sub(Base, xs=[1, 2]): ... +``` + ## `@staticmethod` ### Basic diff --git a/crates/ty_python_semantic/src/types/infer/builder/class.rs b/crates/ty_python_semantic/src/types/infer/builder/class.rs index a5563ddda4a035..7203604dcdb9c1 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/class.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/class.rs @@ -1,8 +1,8 @@ use crate::{ semantic_index::{definition::Definition, scope::NodeWithScopeRef}, types::{ - CallArguments, DataclassParams, KnownClass, KnownInstanceType, SpecialFormType, - StaticClassLiteral, Type, TypeContext, + CallArguments, ClassLiteral, DataclassParams, KnownClass, KnownInstanceType, + MemberLookupPolicy, SpecialFormType, StaticClassLiteral, Type, TypeContext, call::CallError, function::KnownFunction, infer::{ @@ -14,6 +14,7 @@ use crate::{ }, }; use ruff_python_ast::{self as ast, helpers::any_over_expr}; +use rustc_hash::FxHashMap; use ty_module_resolver::{KnownModule, file_to_module}; impl<'db> TypeInferenceBuilder<'db, '_> { @@ -214,15 +215,6 @@ impl<'db> TypeInferenceBuilder<'db, '_> { // if there are type parameters, then the keywords and bases are within that scope // and we don't need to run inference here if type_params.is_none() { - // In stub files, keyword values may reference names that are defined later in the file. - let in_stub = self.in_stub(); - let previous_deferred_state = - std::mem::replace(&mut self.deferred_state, in_stub.into()); - for keyword in class_node.keywords() { - self.infer_expression(&keyword.value, TypeContext::default()); - } - self.deferred_state = previous_deferred_state; - // Inference of bases deferred in stubs, or if any are string literals. if self.in_stub() || class_node @@ -239,6 +231,8 @@ impl<'db> TypeInferenceBuilder<'db, '_> { } self.typevar_binding_context = previous_typevar_binding_context; } + + self.infer_class_keyword_arguments(inferred_ty, class_node); } } @@ -261,4 +255,63 @@ impl<'db> TypeInferenceBuilder<'db, '_> { } self.typevar_binding_context = previous_typevar_binding_context; } + + /// Infer class keyword argument types with type context from `__init_subclass__` parameters. + fn infer_class_keyword_arguments( + &mut self, + class_ty: Type<'db>, + class_node: &ast::StmtClassDef, + ) { + if class_node.keywords().is_empty() { + return; + } + + let db = self.db(); + + // Build a map from keyword name to the parameter's annotated type. + let mut param_types: FxHashMap> = FxHashMap::default(); + if let Type::ClassLiteral(ClassLiteral::Static(class)) = class_ty { + let init_subclass_type = class + .class_member_from_mro( + db, + "__init_subclass__", + MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK, + // skip(1) to skip the current class and only consider base classes. + class.iter_mro(db, None).skip(1), + ) + .ignore_possibly_undefined(); + + if let Some(init_subclass) = init_subclass_type { + let bindings = init_subclass.bindings(db); + for binding in bindings.iter_flat() { + for overload in binding.overloads() { + let parameters = overload.signature.parameters(); + for param in parameters { + if let Some(name) = param.name() { + param_types + .entry(name.to_string()) + .or_insert(param.annotated_type()); + } + } + } + } + } + } + + // In stub files, keyword values may reference names that are defined later in the file. + let in_stub = self.in_stub(); + let previous_deferred_state = std::mem::replace(&mut self.deferred_state, in_stub.into()); + + for keyword in class_node.keywords() { + let tcx = keyword + .arg + .as_ref() + .and_then(|name| param_types.get(name.id.as_str()).copied()) + .map(|param_ty| TypeContext::new(Some(param_ty))) + .unwrap_or_default(); + self.infer_expression(&keyword.value, tcx); + } + + self.deferred_state = previous_deferred_state; + } }