Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions crates/ty_python_semantic/resources/mdtest/call/methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,18 @@ class Base:
class Valid(Base, arg=5, metaclass=object): ...
```

Class keyword arguments are inferred with type context from the corresponding `__init_subclass__`
parameters:

```py
class Base:
def __init_subclass__(cls, *, xs: list[int | None]) -> None:
pass

# No error here:
class Sub(Base, xs=[1, 2]): ...
```

## `@staticmethod`

### Basic
Expand Down
75 changes: 64 additions & 11 deletions crates/ty_python_semantic/src/types/infer/builder/class.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::{
semantic_index::{definition::Definition, scope::NodeWithScopeRef},
types::{
CallArguments, DataclassParams, KnownClass, KnownInstanceType, SpecialFormType,
StaticClassLiteral, Type, TypeContext,
CallArguments, ClassLiteral, DataclassParams, KnownClass, KnownInstanceType,
MemberLookupPolicy, SpecialFormType, StaticClassLiteral, Type, TypeContext,
call::CallError,
function::KnownFunction,
infer::{
Expand All @@ -14,6 +14,7 @@ use crate::{
},
};
use ruff_python_ast::{self as ast, helpers::any_over_expr};
use rustc_hash::FxHashMap;
use ty_module_resolver::{KnownModule, file_to_module};

impl<'db> TypeInferenceBuilder<'db, '_> {
Expand Down Expand Up @@ -214,15 +215,6 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
// if there are type parameters, then the keywords and bases are within that scope
// and we don't need to run inference here
if type_params.is_none() {
// In stub files, keyword values may reference names that are defined later in the file.
let in_stub = self.in_stub();
let previous_deferred_state =
std::mem::replace(&mut self.deferred_state, in_stub.into());
for keyword in class_node.keywords() {
self.infer_expression(&keyword.value, TypeContext::default());
}
self.deferred_state = previous_deferred_state;

// Inference of bases deferred in stubs, or if any are string literals.
if self.in_stub()
|| class_node
Expand All @@ -239,6 +231,8 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}
self.typevar_binding_context = previous_typevar_binding_context;
}

self.infer_class_keyword_arguments(inferred_ty, class_node);
}
}

Expand All @@ -261,4 +255,63 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}
self.typevar_binding_context = previous_typevar_binding_context;
}

/// Infer class keyword argument types with type context from `__init_subclass__` parameters.
fn infer_class_keyword_arguments(
&mut self,
class_ty: Type<'db>,
class_node: &ast::StmtClassDef,
) {
if class_node.keywords().is_empty() {
return;
}

let db = self.db();

// Build a map from keyword name to the parameter's annotated type.
let mut param_types: FxHashMap<String, Type<'db>> = FxHashMap::default();
if let Type::ClassLiteral(ClassLiteral::Static(class)) = class_ty {
let init_subclass_type = class
.class_member_from_mro(
db,
"__init_subclass__",
MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK,
// skip(1) to skip the current class and only consider base classes.
class.iter_mro(db, None).skip(1),
)
.ignore_possibly_undefined();

if let Some(init_subclass) = init_subclass_type {
let bindings = init_subclass.bindings(db);
for binding in bindings.iter_flat() {
for overload in binding.overloads() {
let parameters = overload.signature.parameters();
for param in parameters {
if let Some(name) = param.name() {
param_types
.entry(name.to_string())
.or_insert(param.annotated_type());
}
}
}
}
}
}

// In stub files, keyword values may reference names that are defined later in the file.
let in_stub = self.in_stub();
let previous_deferred_state = std::mem::replace(&mut self.deferred_state, in_stub.into());

for keyword in class_node.keywords() {
let tcx = keyword
.arg
.as_ref()
.and_then(|name| param_types.get(name.id.as_str()).copied())
.map(|param_ty| TypeContext::new(Some(param_ty)))
.unwrap_or_default();
self.infer_expression(&keyword.value, tcx);
}

self.deferred_state = previous_deferred_state;
}
}
Loading