From aea4ed756b068ad40d38125a5aff16d31863594f Mon Sep 17 00:00:00 2001 From: David Peter Date: Thu, 2 Apr 2026 14:49:03 +0200 Subject: [PATCH 1/2] [ty] Initial attempt --- .../resources/mdtest/call/methods.md | 12 ++++++++ .../src/types/infer/builder.rs | 4 +++ .../src/types/infer/deferred/static_class.rs | 29 +++++++++++++++++-- 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/call/methods.md b/crates/ty_python_semantic/resources/mdtest/call/methods.md index 986f9ce21d7c3..3d531f55dc8e6 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/methods.md +++ b/crates/ty_python_semantic/resources/mdtest/call/methods.md @@ -659,6 +659,18 @@ class Base: class Valid(Base, arg=5, metaclass=object): ... ``` +Class keyword arguments are inferred with type context from the corresponding `__init_subclass__` +parameters: + +```py +class Base: + def __init_subclass__(cls, *, xs: list[int | None]) -> None: + pass + +# No error here: +class Sub(Base, xs=[1, 2]): ... +``` + ## `@staticmethod` ### Basic diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index a78fc100caa1e..2e2ce6319e4e6 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -706,6 +706,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { class_node.node(self.module()), self.index, &|expr| self.file_expression_type(expr), + &|expr, tcx| { + let mut speculative = self.speculate(); + speculative.infer_expression(expr, tcx) + }, ); } _ => {} diff --git a/crates/ty_python_semantic/src/types/infer/deferred/static_class.rs b/crates/ty_python_semantic/src/types/infer/deferred/static_class.rs index fe45e34c9cb85..9d06e6953f89a 100644 --- a/crates/ty_python_semantic/src/types/infer/deferred/static_class.rs +++ b/crates/ty_python_semantic/src/types/infer/deferred/static_class.rs @@ -19,7 +19,7 @@ use crate::{ types::{ CallArguments, ClassBase, ClassLiteral, ClassType, GenericAlias, KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, Parameters, Signature, SpecialFormType, - StaticClassLiteral, Type, + StaticClassLiteral, Type, TypeContext, call::{Argument, CallError, CallErrorKind}, class::{AbstractMethod, CodeGeneratorKind, FieldKind, MetaclassErrorKind}, context::InferContext, @@ -71,6 +71,7 @@ pub(crate) fn check_static_class_definitions<'db>( class_node: &ast::StmtClassDef, index: &SemanticIndex<'db>, file_expression_type: &impl Fn(&ast::Expr) -> Type<'db>, + infer_expression_with_context: &impl Fn(&ast::Expr, TypeContext<'db>) -> Type<'db>, ) { let db = context.db(); @@ -711,7 +712,7 @@ pub(crate) fn check_static_class_definitions<'db>( } } } else { - let call_args: CallArguments = args + let mut call_args: CallArguments = args .keywords .iter() .filter_map(|keyword| match keyword.arg.as_ref() { @@ -739,6 +740,30 @@ pub(crate) fn check_static_class_definitions<'db>( .ignore_possibly_undefined(); if let Some(init_subclass) = init_subclass_type { + // Re-infer keyword arguments with type context from corresponding `__init_subclass__` parameters. + for binding in init_subclass.bindings(db).iter_flat() { + for overload in binding.overloads() { + let parameters = overload.signature.parameters(); + for (keyword, (argument, argument_types)) in args + .keywords + .iter() + .filter(|kw| kw.arg.as_ref().is_some_and(|name| name != "metaclass")) + .zip(call_args.iter_mut()) + { + let Argument::Keyword(name) = argument else { + continue; + }; + let Some((_, param)) = parameters.keyword_by_name(name) else { + continue; + }; + let param_ty = param.annotated_type(); + let tcx = TypeContext::new(Some(param_ty)); + let inferred = infer_expression_with_context(&keyword.value, tcx); + argument_types.insert(param_ty, inferred); + } + } + } + let call_args = call_args.with_self(Some(Type::from(class))); if let Err(CallError(CallErrorKind::BindingError, bindings)) = init_subclass.try_call(db, &call_args) From 94f9ec77353523fa3f4fd680c8b9343a8fa872bf Mon Sep 17 00:00:00 2001 From: David Peter Date: Thu, 2 Apr 2026 15:20:46 +0200 Subject: [PATCH 2/2] Defer inference of class keyword arguments --- .../src/types/infer/builder.rs | 4 - .../src/types/infer/builder/class.rs | 75 ++++++++++++++++--- .../src/types/infer/deferred/static_class.rs | 29 +------ 3 files changed, 66 insertions(+), 42 deletions(-) diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 2e2ce6319e4e6..a78fc100caa1e 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -706,10 +706,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { class_node.node(self.module()), self.index, &|expr| self.file_expression_type(expr), - &|expr, tcx| { - let mut speculative = self.speculate(); - speculative.infer_expression(expr, tcx) - }, ); } _ => {} diff --git a/crates/ty_python_semantic/src/types/infer/builder/class.rs b/crates/ty_python_semantic/src/types/infer/builder/class.rs index a5563ddda4a03..7203604dcdb9c 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/class.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/class.rs @@ -1,8 +1,8 @@ use crate::{ semantic_index::{definition::Definition, scope::NodeWithScopeRef}, types::{ - CallArguments, DataclassParams, KnownClass, KnownInstanceType, SpecialFormType, - StaticClassLiteral, Type, TypeContext, + CallArguments, ClassLiteral, DataclassParams, KnownClass, KnownInstanceType, + MemberLookupPolicy, SpecialFormType, StaticClassLiteral, Type, TypeContext, call::CallError, function::KnownFunction, infer::{ @@ -14,6 +14,7 @@ use crate::{ }, }; use ruff_python_ast::{self as ast, helpers::any_over_expr}; +use rustc_hash::FxHashMap; use ty_module_resolver::{KnownModule, file_to_module}; impl<'db> TypeInferenceBuilder<'db, '_> { @@ -214,15 +215,6 @@ impl<'db> TypeInferenceBuilder<'db, '_> { // if there are type parameters, then the keywords and bases are within that scope // and we don't need to run inference here if type_params.is_none() { - // In stub files, keyword values may reference names that are defined later in the file. - let in_stub = self.in_stub(); - let previous_deferred_state = - std::mem::replace(&mut self.deferred_state, in_stub.into()); - for keyword in class_node.keywords() { - self.infer_expression(&keyword.value, TypeContext::default()); - } - self.deferred_state = previous_deferred_state; - // Inference of bases deferred in stubs, or if any are string literals. if self.in_stub() || class_node @@ -239,6 +231,8 @@ impl<'db> TypeInferenceBuilder<'db, '_> { } self.typevar_binding_context = previous_typevar_binding_context; } + + self.infer_class_keyword_arguments(inferred_ty, class_node); } } @@ -261,4 +255,63 @@ impl<'db> TypeInferenceBuilder<'db, '_> { } self.typevar_binding_context = previous_typevar_binding_context; } + + /// Infer class keyword argument types with type context from `__init_subclass__` parameters. + fn infer_class_keyword_arguments( + &mut self, + class_ty: Type<'db>, + class_node: &ast::StmtClassDef, + ) { + if class_node.keywords().is_empty() { + return; + } + + let db = self.db(); + + // Build a map from keyword name to the parameter's annotated type. + let mut param_types: FxHashMap> = FxHashMap::default(); + if let Type::ClassLiteral(ClassLiteral::Static(class)) = class_ty { + let init_subclass_type = class + .class_member_from_mro( + db, + "__init_subclass__", + MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK, + // skip(1) to skip the current class and only consider base classes. + class.iter_mro(db, None).skip(1), + ) + .ignore_possibly_undefined(); + + if let Some(init_subclass) = init_subclass_type { + let bindings = init_subclass.bindings(db); + for binding in bindings.iter_flat() { + for overload in binding.overloads() { + let parameters = overload.signature.parameters(); + for param in parameters { + if let Some(name) = param.name() { + param_types + .entry(name.to_string()) + .or_insert(param.annotated_type()); + } + } + } + } + } + } + + // In stub files, keyword values may reference names that are defined later in the file. + let in_stub = self.in_stub(); + let previous_deferred_state = std::mem::replace(&mut self.deferred_state, in_stub.into()); + + for keyword in class_node.keywords() { + let tcx = keyword + .arg + .as_ref() + .and_then(|name| param_types.get(name.id.as_str()).copied()) + .map(|param_ty| TypeContext::new(Some(param_ty))) + .unwrap_or_default(); + self.infer_expression(&keyword.value, tcx); + } + + self.deferred_state = previous_deferred_state; + } } diff --git a/crates/ty_python_semantic/src/types/infer/deferred/static_class.rs b/crates/ty_python_semantic/src/types/infer/deferred/static_class.rs index 9d06e6953f89a..fe45e34c9cb85 100644 --- a/crates/ty_python_semantic/src/types/infer/deferred/static_class.rs +++ b/crates/ty_python_semantic/src/types/infer/deferred/static_class.rs @@ -19,7 +19,7 @@ use crate::{ types::{ CallArguments, ClassBase, ClassLiteral, ClassType, GenericAlias, KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, Parameters, Signature, SpecialFormType, - StaticClassLiteral, Type, TypeContext, + StaticClassLiteral, Type, call::{Argument, CallError, CallErrorKind}, class::{AbstractMethod, CodeGeneratorKind, FieldKind, MetaclassErrorKind}, context::InferContext, @@ -71,7 +71,6 @@ pub(crate) fn check_static_class_definitions<'db>( class_node: &ast::StmtClassDef, index: &SemanticIndex<'db>, file_expression_type: &impl Fn(&ast::Expr) -> Type<'db>, - infer_expression_with_context: &impl Fn(&ast::Expr, TypeContext<'db>) -> Type<'db>, ) { let db = context.db(); @@ -712,7 +711,7 @@ pub(crate) fn check_static_class_definitions<'db>( } } } else { - let mut call_args: CallArguments = args + let call_args: CallArguments = args .keywords .iter() .filter_map(|keyword| match keyword.arg.as_ref() { @@ -740,30 +739,6 @@ pub(crate) fn check_static_class_definitions<'db>( .ignore_possibly_undefined(); if let Some(init_subclass) = init_subclass_type { - // Re-infer keyword arguments with type context from corresponding `__init_subclass__` parameters. - for binding in init_subclass.bindings(db).iter_flat() { - for overload in binding.overloads() { - let parameters = overload.signature.parameters(); - for (keyword, (argument, argument_types)) in args - .keywords - .iter() - .filter(|kw| kw.arg.as_ref().is_some_and(|name| name != "metaclass")) - .zip(call_args.iter_mut()) - { - let Argument::Keyword(name) = argument else { - continue; - }; - let Some((_, param)) = parameters.keyword_by_name(name) else { - continue; - }; - let param_ty = param.annotated_type(); - let tcx = TypeContext::new(Some(param_ty)); - let inferred = infer_expression_with_context(&keyword.value, tcx); - argument_types.insert(param_ty, inferred); - } - } - } - let call_args = call_args.with_self(Some(Type::from(class))); if let Err(CallError(CallErrorKind::BindingError, bindings)) = init_subclass.try_call(db, &call_args)