diff --git a/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md b/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md index 876e858be7fea..440179e59c17d 100644 --- a/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md +++ b/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md @@ -1745,6 +1745,36 @@ Model(x=1) reveal_type(Model.__init__) # revealed: (self: Model, x: int) -> None ``` +### Decorator return types are still metadata-only in decorator position + +When a `@dataclass_transform()`-decorated function is used as a class decorator, we currently use it +to shape the class like a dataclass but do not yet let an explicit non-class return annotation +replace the public class binding. + +```py +from typing import Protocol, TypeVar +from typing_extensions import dataclass_transform + +class Wrapped(Protocol): + def f(self) -> int: ... + +T = TypeVar("T", bound=object) + +@dataclass_transform() +def model(cls: type[T]) -> Wrapped: + raise NotImplementedError + +@model +class C: + x: int + +reveal_type(C) # revealed: +reveal_type(C.__init__) # revealed: (self: C, x: int) -> None + +# TODO: Decide whether the explicit `Wrapped` return type should replace the public binding here. +C.f() # error: [unresolved-attribute] +``` + ## `__dataclass_transform__` compatibility For backwards compatibility with pre-3.11 Python, ty recognizes any function named diff --git a/crates/ty_python_semantic/resources/mdtest/decorators.md b/crates/ty_python_semantic/resources/mdtest/decorators.md index 2adf9b7be3740..0a47fdb56af79 100644 --- a/crates/ty_python_semantic/resources/mdtest/decorators.md +++ b/crates/ty_python_semantic/resources/mdtest/decorators.md @@ -307,12 +307,10 @@ class AcceptsType: def __init__(self, cls: type) -> None: self.cls = cls -# Decorator call is validated, but the type transformation isn't applied yet. -# TODO: Class decorator return types should transform the class binding type. @AcceptsType class MyClass: ... -reveal_type(MyClass) # revealed: +reveal_type(MyClass) # revealed: AcceptsType ``` ### Generic class, used as a decorator @@ -378,6 +376,340 @@ def decorator(cls: type[int]) -> type[int]: @decorator class Baz: ... -# TODO: the revealed type should ideally be `type[int]` (the decorator's return type) -reveal_type(Baz) # revealed: +reveal_type(Baz) # revealed: type[int] +``` + +Class decorators can also replace the class object with an instance: + +```py +from dataclasses import dataclass +from typing import Callable, Generic, Protocol, TypeVar, overload +from typing_extensions import Self + +T = TypeVar("T") + +class Backend(Protocol): + def get(self, key: str) -> bytes | None: ... + +class WrapBackend: + def __init__(self, cls: type[object]) -> None: + self.cls = cls + + def get(self, key: str) -> bytes | None: + return None + +@WrapBackend +class CacheClient: + def clone(self) -> Self: + reveal_type(self) # revealed: Self@clone + return self + + @classmethod + def make(cls) -> Self: + reveal_type(cls) # revealed: type[Self@make] + return cls() + +reveal_type(CacheClient) # revealed: WrapBackend +reveal_type(CacheClient.get("x")) # revealed: bytes | None + +@WrapBackend +@dataclass +class DataclassThenWrapped: + value: int + +reveal_type(DataclassThenWrapped) # revealed: WrapBackend + +# error: [no-matching-overload] +@dataclass +@WrapBackend +class WrappedThenDataclass: + value: int + +reveal_type(WrappedThenDataclass) # revealed: Unknown + +def int_decorator_factory() -> Callable[[type[object]], int]: + def decorator(cls: type[object]) -> int: + return 1 + return decorator + +# error: [no-matching-overload] +@dataclass +@int_decorator_factory() +class IntThenDataclass: + value: int + +reveal_type(IntThenDataclass) # revealed: Unknown + +@WrapBackend +class InvalidWrappedBase(1): ... # error: [invalid-base] + +reveal_type(InvalidWrappedBase) # revealed: WrapBackend + +@WrapBackend +class GenericCacheClient(Generic[T]): + value: T + + def get_value(self) -> T: + return self.value + +reveal_type(GenericCacheClient) # revealed: WrapBackend + +@WrapBackend +class OverloadedCacheClient: + @overload + def get(self, key: str) -> bytes: ... + @overload + def get(self, key: bytes) -> bytes: ... + def get(self, key: str | bytes) -> bytes: + return b"" +``` + +Unannotated class decorators are assumed to preserve the class binding. We do not infer returned +classes from decorator bodies: + +```py +def personify(cls): + class Wrapped(cls): + full_name: str + + def set_full_name(self, full_name: str) -> None: + self.full_name = full_name + + return Wrapped + +@personify +class Animal: ... + +reveal_type(Animal) # revealed: +reveal_type(Animal()) # revealed: Animal + +Animal().set_full_name("John") # error: [unresolved-attribute] +``` + +This also applies to unannotated callables that are not function definitions: + +```py +lambda_decorator = lambda cls: cls + +@lambda_decorator +class LambdaDecorated: ... + +reveal_type(LambdaDecorated) # revealed: + +class DecoratorFactory: + def decorator(self, cls): + return cls + +decorator_factory = DecoratorFactory() + +@decorator_factory.decorator +class BoundMethodDecorated: ... + +reveal_type(BoundMethodDecorated) # revealed: + +class CallableDecorator: + def __call__(self, cls): + return cls + +callable_decorator = CallableDecorator() + +@callable_decorator +class CallableInstanceDecorated: ... + +reveal_type(CallableInstanceDecorated) # revealed: + +class ExplicitReturnDecorator(Generic[T]): + def __call__(self, cls) -> T: + raise NotImplementedError + +explicit_return_decorator = ExplicitReturnDecorator() + +@explicit_return_decorator +class ExplicitReturnCallableInstanceDecorated: ... + +reveal_type(ExplicitReturnCallableInstanceDecorated) # revealed: Unknown + +specialized_explicit_return_decorator = ExplicitReturnDecorator[int]() + +@specialized_explicit_return_decorator +class SpecializedExplicitReturnCallableInstanceDecorated: ... + +reveal_type(SpecializedExplicitReturnCallableInstanceDecorated) # revealed: int +``` + +An unknown class decorator still makes the class binding unknown: + +```py +# error: [unresolved-reference] "Name `unknown_class_decorator` used when not defined" +@unknown_class_decorator +class UnknownDecorated: ... + +reveal_type(UnknownDecorated) # revealed: Unknown +``` + +An unannotated class decorator preserves the result of earlier decorators: + +```py +def unannotated_identity(cls): + return cls + +@unannotated_identity +@WrapBackend +class WrappedThenUnannotated: ... + +reveal_type(WrappedThenUnannotated) # revealed: WrapBackend +``` + +Metadata decorators still apply above an unannotated class-preserving decorator: + +```py +from typing_extensions import deprecated + +def unannotated_identity(cls): + return cls + +@deprecated("use OtherClass") +@unannotated_identity +class DeprecatedThenUnannotated: ... + +DeprecatedThenUnannotated() # error: [deprecated] "use OtherClass" +``` + +If a class decorator returns the original class object, we preserve the class binding so it can +still be used in annotations and as a base class: + +```py +from typing import TypeVar + +T = TypeVar("T", bound=object) + +def identity_class_decorator(cls: type[T]) -> type[T]: + return cls + +@identity_class_decorator +class PreservedClass: ... + +reveal_type(PreservedClass) # revealed: + +class DerivedPreservedClass(PreservedClass): + value: PreservedClass +``` + +Class decorator factories that preserve the original class object also preserve the class binding: + +```py +from collections.abc import Callable +from typing import Any, TypeVar, overload + +DecoratorT = TypeVar("DecoratorT", bound=object) +DecoratedClass = type[DecoratorT] + +@overload +def identity_class_decorator_factory(cls: DecoratedClass, **kwargs: Any) -> DecoratedClass: ... +@overload +def identity_class_decorator_factory( + **kwargs: Any, +) -> Callable[[DecoratedClass], DecoratedClass]: ... +def identity_class_decorator_factory( + cls: DecoratedClass | None = None, **kwargs: Any +) -> DecoratedClass | Callable[[DecoratedClass], DecoratedClass]: + def decorator(inner_cls: DecoratedClass) -> DecoratedClass: + return inner_cls + + if cls is not None: + return decorator(cls) + return decorator + +@identity_class_decorator_factory(frozen=True) +class FactoryPreservedClass: ... + +reveal_type(FactoryPreservedClass) # revealed: + +class DerivedFactoryPreservedClass(FactoryPreservedClass): + value: FactoryPreservedClass +``` + +Class decorators can return intersections that expose attributes added to the decorated class +object: + +```py +from ty_extensions import Intersection +from typing import Protocol, TypeVar + +class Resource: + def fetch(self) -> str: + return "data" + +class ResourceEnabled(Protocol): + resource: Resource + +SchemaT = TypeVar("SchemaT") + +def register(cls: type[SchemaT]) -> Intersection[type[SchemaT], ResourceEnabled]: + return cls + +@register +class UserSchema: + id: int + +reveal_type(UserSchema.resource.fetch()) # revealed: str +``` + +Metadata decorators stacked above an intersection-returning class decorator still apply to the +original class object, while preserving the extra intersection members: + +```py +from dataclasses import dataclass +from ty_extensions import Intersection +from typing import Protocol, TypeVar + +class Resource: + def fetch(self) -> str: + return "data" + +class ResourceEnabled(Protocol): + resource: Resource + +SchemaT = TypeVar("SchemaT") + +def register(cls: type[SchemaT]) -> Intersection[type[SchemaT], ResourceEnabled]: + return cls + +@dataclass +@register +class RegisteredDataclass: + id: int + +reveal_type(RegisteredDataclass.resource.fetch()) # revealed: str +reveal_type(RegisteredDataclass(1)) # revealed: RegisteredDataclass +``` + +Class-preserving decorators stacked above an intersection-returning class decorator preserve the +existing intersection members: + +```py +from ty_extensions import Intersection +from typing import Protocol, TypeVar + +class Resource: + def fetch(self) -> str: + return "data" + +class ResourceEnabled(Protocol): + resource: Resource + +SchemaT = TypeVar("SchemaT") + +def register(cls: type[SchemaT]) -> Intersection[type[SchemaT], ResourceEnabled]: + return cls + +def identity(cls: type[SchemaT]) -> type[SchemaT]: + return cls + +@identity +@register +class RegisteredIdentity: + id: int + +reveal_type(RegisteredIdentity.resource.fetch()) # revealed: str ``` diff --git a/crates/ty_python_semantic/resources/mdtest/decorators/total_ordering.md b/crates/ty_python_semantic/resources/mdtest/decorators/total_ordering.md index 108c4b15ea3ad..d95a953e30290 100644 --- a/crates/ty_python_semantic/resources/mdtest/decorators/total_ordering.md +++ b/crates/ty_python_semantic/resources/mdtest/decorators/total_ordering.md @@ -38,6 +38,35 @@ reveal_type(s1 > s2) # revealed: bool reveal_type(s1 >= s2) # revealed: bool ``` +## Stacked with class-preserving decorator + +Metadata from `@total_ordering` still applies above an explicitly typed class-preserving decorator: + +```py +from functools import total_ordering +from typing import TypeVar + +T = TypeVar("T", bound=object) + +def identity(cls: type[T]) -> type[T]: + return cls + +@total_ordering +@identity +class OrderedIdentity: + def __eq__(self, other: object) -> bool: + return isinstance(other, OrderedIdentity) + + def __lt__(self, other: "OrderedIdentity") -> bool: + return True + +left = OrderedIdentity() +right = OrderedIdentity() + +reveal_type(left <= right) # revealed: bool +reveal_type(left >= right) # revealed: bool +``` + ## Signature derived from source ordering method When the source ordering method accepts a broader type (like `object`) for its `other` parameter, diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md index d7fd9ad2850b7..3932aa094a574 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md @@ -37,6 +37,22 @@ reveal_type(generic_context(SingleTypeVarTuple)) reveal_type(generic_context(TypeVarAndTypeVarTuple)) ``` +Decorated generic classes still use the original class for their class-body generic context: + +```py +class Wrap: + def __init__(self, cls: type[object]) -> None: ... + +@Wrap +class DecoratedGeneric[T]: + value: T + + def get_value(self) -> T: + return self.value + +reveal_type(DecoratedGeneric) # revealed: Wrap +``` + You cannot use the same typevar more than once. ```py diff --git a/crates/ty_python_semantic/resources/mdtest/typed_dict.md b/crates/ty_python_semantic/resources/mdtest/typed_dict.md index b33d786a00406..86580f9145833 100644 --- a/crates/ty_python_semantic/resources/mdtest/typed_dict.md +++ b/crates/ty_python_semantic/resources/mdtest/typed_dict.md @@ -4979,6 +4979,21 @@ class Foo(TypedDict("Foo", {"x": int, "y": str})): pass ``` +Other class decorators can replace the public TypedDict binding: + +```py +from typing import TypedDict + +class ReplacesClass: + def __init__(self, cls: type[object]) -> None: ... + +@ReplacesClass +class Decorated(TypedDict): + name: str + +reveal_type(Decorated) # revealed: ReplacesClass +``` + ## Class header validation diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 46c1ebb1023fc..09d9c1828be74 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1411,6 +1411,7 @@ impl<'db> Type<'db> { } } + #[cfg(test)] #[track_caller] pub(crate) const fn expect_class_literal(self) -> ClassLiteral<'db> { self.as_class_literal() diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 73dcdd3d4c28d..dca3357b5c201 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -76,7 +76,7 @@ use crate::types::diagnostic::{ }; use crate::types::display::DisplaySettings; use crate::types::generics::{ApplySpecialization, GenericContext, typing_self}; -use crate::types::infer::nearest_enclosing_class; +use crate::types::infer::{nearest_enclosing_class, original_class_type}; use crate::types::known_instance::DeprecatedInstance; use crate::types::list_members::all_members; use crate::types::narrow::ClassInfoConstraintFunction; @@ -89,7 +89,7 @@ use crate::types::{ ClassLiteral, ClassType, DynamicType, FindLegacyTypeVarsVisitor, IntersectionBuilder, KnownClass, KnownInstanceType, SpecialFormType, SubclassOfInner, SubclassOfType, Truthiness, Type, TypeContext, TypeMapping, TypeVarBoundOrConstraints, UnionBuilder, UnionType, - binding_type, definition_expression_type, infer_definition_types, walk_signature, + definition_expression_type, walk_signature, }; use crate::{Db, FxOrderSet}; use ty_python_core::ast_ids::HasScopedUseId; @@ -273,6 +273,9 @@ pub struct OverloadLiteral<'db> { /// The arguments to `dataclass_transformer`, if this function was annotated /// with `@dataclass_transformer(...)`. pub(crate) dataclass_transformer_params: Option>, + + /// Whether this overload or implementation has an explicit return annotation. + pub(crate) has_explicit_return_annotation: bool, } // The Salsa heap is tracked separately. @@ -293,6 +296,7 @@ impl<'db> OverloadLiteral<'db> { self.decorators(db), self.deprecated(db), Some(params), + self.has_explicit_return_annotation(db), ) } @@ -544,9 +548,8 @@ impl<'db> OverloadLiteral<'db> { // or there is but it isn't using the PEP-484 convention, // then `self`/`cls` are only implicitly positional-only if // it is a protocol class. - let class_type = binding_type(db, class_definition); - class_type - .to_class_type(db) + original_class_type(db, class_definition) + .map(|class_literal| class_literal.default_specialization(db)) .is_some_and(|class| class.is_protocol(db)) } @@ -593,12 +596,7 @@ impl<'db> OverloadLiteral<'db> { let class_scope = index.scope(class_scope_id.file_scope_id(db)); let class_node = class_scope.node().as_class()?; let class_def = index.expect_single_definition(class_node); - let Type::ClassLiteral(class_literal) = infer_definition_types(db, class_def) - .declaration_type(class_def) - .inner_type() - else { - return None; - }; + let class_literal = original_class_type(db, class_def)?; let class_is_generic = class_literal.generic_context(db).is_some(); let class_is_fallback = class_literal .known(db) @@ -1221,6 +1219,22 @@ impl<'db> FunctionType<'db> { self.literal(db).has_trivial_body(db) } + /// Returns `true` if any overload or implementation has an explicit return annotation. + /// + /// This distinguishes untyped decorators that infer an unknown return from decorators that + /// explicitly promise a replacement type: + /// ```python + /// def identity(cls): + /// return cls + /// + /// def replace(cls) -> object: + /// return object() + /// ``` + pub(crate) fn has_explicit_return_annotation(self, db: &'db dyn Db) -> bool { + self.iter_overloads_and_implementation(db) + .any(|overload| overload.has_explicit_return_annotation(db)) + } + /// Returns all of the overload signatures and the implementation definition, if any, of this /// function. The overload signatures will be in source order. pub(crate) fn overloads_and_implementation( diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index bdff2dc25f10d..2fb755ed9ba89 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -14,6 +14,7 @@ use crate::types::constraints::{ ConstraintSet, ConstraintSetBuilder, IteratorConstraintsExtension, OwnedConstraintSet, Solutions, }; +use crate::types::infer::original_class_type; use crate::types::relation::{ DisjointnessChecker, HasRelationToVisitor, IsDisjointVisitor, TypeRelation, TypeRelationChecker, }; @@ -346,9 +347,7 @@ impl<'db> GenericContext<'db> { match node { NodeWithScopeKind::Class(class) => { let definition = index.expect_single_definition(class); - binding_type(db, definition) - .as_class_literal()? - .generic_context(db) + original_class_type(db, definition)?.generic_context(db) } NodeWithScopeKind::Function(function) => { let definition = index.expect_single_definition(function); diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 6d773bf3c15d9..c9ff1684d4004 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -56,7 +56,6 @@ use crate::types::generics::Specialization; use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ ClassLiteral, KnownClass, StaticClassLiteral, Type, TypeAndQualifiers, TypeQualifiers, - declaration_type, }; use builder::TypeInferenceBuilder; pub(super) use comparisons::UnsupportedComparisonError; @@ -664,13 +663,37 @@ pub(crate) fn nearest_enclosing_class<'db>( .find_map(|(_, ancestor_scope)| { let class = ancestor_scope.node().as_class()?; let definition = semantic.expect_single_definition(class); - declaration_type(db, definition) - .inner_type() - .as_class_literal() - .and_then(ClassLiteral::as_static) + original_class_type(db, definition).and_then(ClassLiteral::as_static) }) } +/// Return the original class literal for a class definition. +/// +/// For decorated classes, this is the class object before applying decorators. The public +/// binding may be replaced by a class decorator's return type, but class-body inference still +/// needs the original class object for implicit `self`/`cls`, `Self`, and dataclass logic. +/// +/// For example, the public binding for `C` may be the `int` returned by `replace`, but the class +/// body and nested definitions still need the original class object: +/// ```python +/// def replace(cls: type[object]) -> int: +/// return 1 +/// +/// @replace +/// class C: +/// def method(self) -> None: ... +/// ``` +pub(crate) fn original_class_type<'db>( + db: &'db dyn Db, + definition: Definition<'db>, +) -> Option> { + let inference = infer_definition_types(db, definition); + inference + .undecorated_type() + .unwrap_or_else(|| inference.binding_type(definition)) + .as_class_literal() +} + /// Returns the type of the nearest enclosing function for the given scope. /// /// This function walks up the ancestor scopes starting from the given scope, @@ -876,7 +899,7 @@ struct DefinitionInferenceExtra<'db> { /// The diagnostics for this region. diagnostics: TypeCheckDiagnostics, - /// For function definitions, the undecorated type of the function. + /// For decorated function or class definitions, the type before applying decorators. undecorated_type: Option>, /// Type qualifiers (`Required`, `NotRequired`, etc.) for annotation expressions. diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 233d126d5a285..b13eecaedddac 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -79,6 +79,7 @@ use crate::types::infer::builder::typed_dict::TypedDictConstructorForm; use crate::types::infer::{ StatementInference, StatementInferenceInner, StatementInferenceInnerExtra, TypeExpressionFlags, infer_statement_types, nearest_enclosing_class, nearest_enclosing_function, + original_class_type, }; use crate::types::narrow::NarrowingEvaluatorExtension; use crate::types::newtype::NewType; @@ -315,7 +316,7 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> { /// is a stub file but we're still in a non-deferred region. deferred_state: DeferredExpressionState, - /// For function definitions, the undecorated type of the function. + /// For decorated function or class definitions, the type before applying decorators. undecorated_type: Option>, /// The fallback type for missing expressions/bindings/declarations or recursive type inference. @@ -539,11 +540,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let enclosing_scope = index.scope(scope.file_scope_id(db)); let class_node = enclosing_scope.node().as_class()?; let class_definition = index.expect_single_definition(class_node); - let class_literal = infer_definition_types(db, class_definition) - .declaration_type(class_definition) - .inner_type() - .as_class_literal()? - .as_static()?; + let class_literal = original_class_type(db, class_definition)?.as_static()?; class_literal .dataclass_params(db) @@ -821,6 +818,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ); } DefinitionKind::Class(class_node) => { + let original_ty = match self.region { + InferenceRegion::Definition(current) if current == definition => { + self.undecorated_type + } + _ => original_class_type(self.db(), definition).map(Type::ClassLiteral), + }; + let ty = original_ty.unwrap_or(ty); post_inference::static_class::check_static_class_definitions( &self.context, ty, @@ -1504,7 +1508,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { fn class_context_of_current_method(&self) -> Option> { let current_scope_id = self.scope().file_scope_id(self.db()); let class_definition = self.index.class_definition_of_method(current_scope_id)?; - binding_type(self.db(), class_definition).to_class_type(self.db()) + original_class_type(self.db(), class_definition) + .map(|class_literal| class_literal.default_specialization(self.db())) } /// If the current scope is a (non-lambda) function, return that function's AST node. diff --git a/crates/ty_python_semantic/src/types/infer/builder/class.rs b/crates/ty_python_semantic/src/types/infer/builder/class.rs index cbf783b6297f1..4970a3f163824 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/class.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/class.rs @@ -1,16 +1,18 @@ +use crate::place::Place; use crate::types::{ - CallArguments, DataclassParams, KnownClass, KnownInstanceType, SpecialFormType, - StaticClassLiteral, Type, TypeContext, + CallArguments, DataclassParams, KnownClass, KnownInstanceType, MemberLookupPolicy, + SpecialFormType, StaticClassLiteral, SubclassOfType, Type, TypeContext, call::CallError, function::KnownFunction, infer::{ TypeInferenceBuilder, builder::{DeclaredAndInferredType, DeferredExpressionState}, + original_class_type, }, - infer_definition_types, signatures::ParameterForm, special_form::TypeQualifier, }; +use ruff_python_ast::name::Name; use ruff_python_ast::{self as ast, helpers::any_over_expr}; use ty_module_resolver::{KnownModule, file_to_module}; use ty_python_core::{definition::Definition, scope::NodeWithScopeRef}; @@ -76,13 +78,75 @@ impl<'db> TypeInferenceBuilder<'db, '_> { let mut decorator_types_and_nodes: Vec<(Type<'db>, &ast::Decorator)> = Vec::with_capacity(decorator_list.len()); + for decorator in decorator_list { + let decorator_ty = self.infer_decorator(decorator); + decorator_types_and_nodes.push((decorator_ty, decorator)); + } + + let body_scope = self + .index + .node_scope(NodeWithScopeRef::Class(class_node)) + .to_scope_id(db, self.file()); + + let maybe_known_class = KnownClass::try_from_file_and_name(db, self.file(), name); + + let known_module = || file_to_module(db, self.file()).and_then(|module| module.known(db)); + let in_typing_module = || { + matches!( + known_module(), + Some(KnownModule::Typing | KnownModule::TypingExtensions) + ) + }; + + let mut decorators_to_apply = Vec::with_capacity(decorator_types_and_nodes.len()); + let mut metadata_applies_to_original_class = true; let mut deprecated = None; let mut type_check_only = false; let mut dataclass_params = None; let mut dataclass_transformer_params = None; let mut total_ordering = false; - for decorator in decorator_list { - let decorator_ty = self.infer_decorator(decorator); + let infer_original_class_ty = |deprecated, + type_check_only, + dataclass_params, + dataclass_transformer_params, + total_ordering| { + match (maybe_known_class, &*name.id) { + (None, "NamedTuple") if in_typing_module() => { + Type::SpecialForm(SpecialFormType::NamedTuple) + } + (None, "Any") if in_typing_module() => Type::SpecialForm(SpecialFormType::Any), + (None, "InitVar") if known_module() == Some(KnownModule::Dataclasses) => { + Type::SpecialForm(SpecialFormType::TypeQualifier(TypeQualifier::InitVar)) + } + _ => Type::from(StaticClassLiteral::new( + db, + name.id.clone(), + body_scope, + maybe_known_class, + deprecated, + type_check_only, + dataclass_params, + dataclass_transformer_params, + total_ordering, + )), + } + }; + let decorator_call_ty = |decorator: &ast::Decorator| match &decorator.expression { + ast::Expr::Call(call) => Some(self.expression_type(&call.func)), + _ => None, + }; + + // In the first pass, collect metadata decorators that shape the original class object. + // Once an inner decorator replaces the public binding, outer decorators are ordinary + // runtime applications only: they cannot retroactively add metadata to the original class. + // For ordinary decorators that still apply to the original class, precompute the call so + // the second pass can reuse it if no inner decorator has changed the binding. + for &(decorator_ty, decorator) in decorator_types_and_nodes.iter().rev() { + if !metadata_applies_to_original_class { + decorators_to_apply.push((decorator_ty, decorator, None)); + continue; + } + if decorator_ty .as_function_literal() .is_some_and(|function| function.is_known(db, KnownFunction::Dataclass)) @@ -104,6 +168,16 @@ impl<'db> TypeInferenceBuilder<'db, '_> { continue; } + if decorator_ty.is_unknown() + && let ast::Expr::Call(call) = &decorator.expression + && self + .expression_type(&call.func) + .as_function_literal() + .is_some_and(|function| function.is_known(db, KnownFunction::Dataclass)) + { + continue; + } + if let Type::KnownInstance(KnownInstanceType::Deprecated(deprecated_inst)) = decorator_ty { @@ -137,7 +211,13 @@ impl<'db> TypeInferenceBuilder<'db, '_> { // We do not yet detect or flag `@dataclass_transform` applied to more than one // overload, or an overload and the implementation both. Nevertheless, this is not // allowed. We do not try to treat the offenders intelligently -- just use the - // params of the last seen usage of `@dataclass_transform` + // params of the last seen usage of `@dataclass_transform`. + // + // In class-decorator position, dataclass-transform metadata shapes the + // original class object. We keep it metadata-only here because the call path + // uses synthetic dataclass-transform return types to model decorator factories; + // treating this as an ordinary replacement-returning class decorator would + // conflate those two cases. let transformer_params = f .iter_overloads_and_implementation(db) .rev() @@ -156,54 +236,102 @@ impl<'db> TypeInferenceBuilder<'db, '_> { continue; } - decorator_types_and_nodes.push((decorator_ty, decorator)); - } - - let body_scope = self - .index - .node_scope(NodeWithScopeRef::Class(class_node)) - .to_scope_id(db, self.file()); - - let maybe_known_class = KnownClass::try_from_file_and_name(db, self.file(), name); - - let known_module = || file_to_module(db, self.file()).and_then(|module| module.known(db)); - let in_typing_module = || { - matches!( - known_module(), - Some(KnownModule::Typing | KnownModule::TypingExtensions) - ) - }; - - let inferred_ty = match (maybe_known_class, &*name.id) { - (None, "NamedTuple") if in_typing_module() => { - Type::SpecialForm(SpecialFormType::NamedTuple) - } - (None, "Any") if in_typing_module() => Type::SpecialForm(SpecialFormType::Any), - (None, "InitVar") if known_module() == Some(KnownModule::Dataclasses) => { - Type::SpecialForm(SpecialFormType::TypeQualifier(TypeQualifier::InitVar)) - } - _ => Type::from(StaticClassLiteral::new( - db, - name.id.clone(), - body_scope, - maybe_known_class, + let original_class_ty = infer_original_class_ty( deprecated, type_check_only, dataclass_params, dataclass_transformer_params, total_ordering, - )), - }; - - // Validate decorator calls (but don't use return types yet). - for (decorator_ty, decorator_node) in decorator_types_and_nodes.iter().rev() { - if let Err(CallError(_, bindings)) = - decorator_ty.try_call(db, &CallArguments::positional([inferred_ty])) - { - bindings.report_diagnostics(&self.context, (*decorator_node).into()); + ); + let decorator_result = apply_class_decorator(db, decorator_ty, original_class_ty); + let decorated_ty = match &decorator_result { + Ok(return_ty) => *return_ty, + Err(error) => error.return_type(db), + }; + if is_unknown_decorator_result(db, decorated_ty) { + if !preserve_binding_for_unknown_result( + db, + decorator_ty, + decorator_call_ty(decorator), + ) { + metadata_applies_to_original_class = false; + } + } else if !type_retains_original_class(db, original_class_ty, decorated_ty) { + metadata_applies_to_original_class = false; } + + decorators_to_apply.push(( + decorator_ty, + decorator, + Some((original_class_ty, decorator_result)), + )); + } + + let mut inferred_ty = infer_original_class_ty( + deprecated, + type_check_only, + dataclass_params, + dataclass_transformer_params, + total_ordering, + ); + + let original_class_ty = inferred_ty; + let mut undecorated_ty = None; + + // In the second pass, apply class decorators from inner to outer and use their return types + // to update the public binding. `original_class_ty` remains the class object whose body and + // metadata were inferred above. + for (decorator_ty, decorator_node, precomputed_result) in decorators_to_apply { + let decorator_result = match precomputed_result { + // The metadata pass already called this decorator with the same input. If an inner + // decorator changed the binding, apply this decorator to the new public binding. + Some((precomputed_input_ty, decorator_result)) + if precomputed_input_ty == inferred_ty => + { + decorator_result + } + _ => apply_class_decorator(db, decorator_ty, inferred_ty), + }; + let decorated_ty = match decorator_result { + Ok(return_ty) => return_ty, + Err(CallError(_, bindings)) => { + bindings.report_diagnostics(&self.context, decorator_node.into()); + bindings.return_type(db) + } + }; + let decorated_ty = match decorated_ty { + Type::DataclassDecorator(_) | Type::DataclassTransformer(_) => Type::unknown(), + decorated_ty => decorated_ty, + }; + // If a class decorator application loses all precision, preserve the original class + // binding for decorators known to preserve unknown results. + let decorated_ty_is_unknown = is_unknown_decorator_result(db, decorated_ty); + let should_preserve_binding = decorated_ty_is_unknown + && preserve_binding_for_unknown_result( + db, + decorator_ty, + decorator_call_ty(decorator_node), + ); + inferred_ty = if should_preserve_binding { + inferred_ty + } else if class_decorator_preserves_class_binding(db, original_class_ty, decorated_ty) { + merge_class_preserving_decorator_result( + db, + original_class_ty, + inferred_ty, + decorated_ty, + ) + } else { + // Only record an undecorated type once a decorator actually replaces the public + // binding. If all decorators preserve the class, there is no alternate class type + // to expose. + undecorated_ty.get_or_insert(inferred_ty); + decorated_ty + }; } + self.undecorated_type = undecorated_ty; + self.add_declaration_with_binding( class_node.into(), definition, @@ -270,9 +398,8 @@ impl<'db> TypeInferenceBuilder<'db, '_> { if let Some(arguments) = class.arguments.as_deref() && let Some(extra_items_keyword) = arguments.find_keyword("extra_items") { - let class_type = infer_definition_types(self.db(), definition).binding_type(definition); - if let Type::ClassLiteral(class_literal) = class_type - && class_literal.is_typed_dict(self.db()) + if original_class_type(self.db(), definition) + .is_some_and(|class_literal| class_literal.is_typed_dict(self.db())) { self.infer_extra_items_kwarg(&extra_items_keyword.value); } else if self.in_stub() { @@ -287,3 +414,265 @@ impl<'db> TypeInferenceBuilder<'db, '_> { } } } + +fn apply_class_decorator<'db>( + db: &'db dyn crate::Db, + decorator_ty: Type<'db>, + decorated_ty: Type<'db>, +) -> Result, CallError<'db>> { + let call_arguments = CallArguments::positional([decorated_ty]); + decorator_ty + .try_call(db, &call_arguments) + .map(|bindings| bindings.return_type(db)) +} + +/// Return true if a decorator result still binds the name to the original class. +/// +/// For example, an identity decorator keeps the public name bound to the same class: +/// ```python +/// def identity[T](cls: type[T]) -> type[T]: +/// return cls +/// +/// @identity +/// class C: ... +/// ``` +/// +/// This also accepts metaclass-shaped results such as `type[C]`, because those still describe the +/// original class object even if the decorator call produced a `SubclassOf` type internally. +fn class_decorator_preserves_class_binding<'db>( + db: &'db dyn crate::Db, + original_class: Type<'db>, + decorated_class: Type<'db>, +) -> bool { + let Type::ClassLiteral(original_literal) = original_class else { + return false; + }; + + match decorated_class { + Type::ClassLiteral(decorated_literal) => { + let decorated_definition = decorated_literal.definition(db); + decorated_literal == original_literal + || decorated_definition.is_some() + && decorated_definition == original_literal.definition(db) + } + Type::SubclassOf(subclass_of) => subclass_of + .subclass_of() + .into_class(db) + .is_some_and(|class| class == original_literal.default_specialization(db)), + Type::Divergent(_) => true, + Type::Union(union) => union + .elements(db) + .iter() + .all(|element| class_decorator_preserves_class_binding(db, original_class, *element)), + Type::TypeAlias(alias) => { + class_decorator_preserves_class_binding(db, original_class, alias.value_type(db)) + } + _ => SubclassOfType::try_from_type(db, original_class).is_some_and(|original_meta_type| { + decorated_class.is_equivalent_to(db, original_meta_type) + }), + } +} + +/// Return true if a type still contains the original class object, even if it also carries extra +/// intersection members. +fn type_retains_original_class<'db>( + db: &'db dyn crate::Db, + original_class: Type<'db>, + decorated_class: Type<'db>, +) -> bool { + match decorated_class { + Type::Intersection(intersection) => intersection + .positive(db) + .iter() + .any(|element| type_retains_original_class(db, original_class, *element)), + Type::Union(union) => union + .elements(db) + .iter() + .all(|element| type_retains_original_class(db, original_class, *element)), + Type::TypeAlias(alias) => { + type_retains_original_class(db, original_class, alias.value_type(db)) + } + _ => class_decorator_preserves_class_binding(db, original_class, decorated_class), + } +} + +/// Return true if an unknown class-decorator result should leave the current class type in place. +/// +/// This handles both direct decorators and decorator factories: +/// ```python +/// def decorator(cls): +/// return cls +/// +/// def decorator_factory(): +/// return decorator +/// +/// @decorator_factory() +/// class C: ... +/// ``` +/// +/// The factory case needs the type of the call target, because the type of +/// `@decorator_factory()` is the returned decorator, while the expression type of +/// `decorator_factory` carries the static information that tells us whether an unknown result can +/// be preserved. +fn preserve_binding_for_unknown_result<'db>( + db: &'db dyn crate::Db, + decorator_ty: Type<'db>, + decorator_call_ty: Option>, +) -> bool { + ClassDecoratorUnknownResultPolicy::from_decorator(db, decorator_ty) + == ClassDecoratorUnknownResultPolicy::PreserveBinding + || decorator_call_ty.is_some_and(|ty| { + ClassDecoratorUnknownResultPolicy::from_decorator(db, ty) + == ClassDecoratorUnknownResultPolicy::PreserveBinding + }) +} + +/// Return true if applying a class decorator produced no useful replacement type. +/// +/// Besides plain `Unknown`, class decorators can produce unknown class-object types such as +/// `type[Any]`. Those are represented as a `SubclassOf` dynamic type, but they should trigger the +/// same preservation fallback as an unknown result: +/// ```python +/// from typing import Any +/// +/// def decorator(cls) -> type[Any]: ... +/// +/// @decorator +/// class C: ... +/// ``` +fn is_unknown_decorator_result<'db>(db: &'db dyn crate::Db, ty: Type<'db>) -> bool { + if ty.is_unknown() { + return true; + } + + let Type::SubclassOf(subclass_of) = ty.resolve_type_alias(db) else { + return false; + }; + + subclass_of + .subclass_of() + .into_dynamic() + .is_some_and(|dynamic| Type::Dynamic(dynamic).is_unknown()) +} + +/// Policy for class decorators whose application result is unknown. +/// +/// This is only consulted after applying the decorator produced no useful replacement type. If the +/// decorator itself statically suggests an unannotated identity-preserving shape, we keep the +/// current class binding; if it explicitly promises a replacement type, or if the decorator is +/// unknown, we let the unknown result replace the binding. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +enum ClassDecoratorUnknownResultPolicy { + /// Preserve the current class binding when the decorator result is unknown. + PreserveBinding, + /// Use the unknown decorator result as the public binding. + ReplaceBinding, +} + +impl ClassDecoratorUnknownResultPolicy { + /// Infer the unknown-result policy from the decorator's own type. + /// + /// Unannotated function and method decorators are treated as class-preserving when their + /// application result is unknown. Explicit return annotations are trusted as replacement + /// intent. + fn from_decorator<'db>(db: &'db dyn crate::Db, decorator_ty: Type<'db>) -> Self { + if decorator_ty.is_unknown() { + return Self::ReplaceBinding; + } + + Self::known_from_decorator(db, decorator_ty).unwrap_or(Self::ReplaceBinding) + } + + /// Return the known preservation policy for a class decorator, if one can be read statically. + /// + /// For unknown decorator results, unannotated functions are treated as likely + /// identity-preserving: + /// ```python + /// def decorator(cls): + /// return cls + /// ``` + /// + /// Explicit return annotations are trusted instead: + /// ```python + /// def decorator(cls) -> object: + /// return object() + /// ``` + /// + /// Callable instances and protocols delegate the decision to their `__call__` member, because + /// the decorator value itself is not the function that receives the class. + fn known_from_decorator<'db>(db: &'db dyn crate::Db, decorator_ty: Type<'db>) -> Option { + match decorator_ty { + Type::FunctionLiteral(function) => { + Some(if function.has_explicit_return_annotation(db) { + Self::ReplaceBinding + } else { + Self::PreserveBinding + }) + } + Type::BoundMethod(method) => { + Some(if method.function(db).has_explicit_return_annotation(db) { + Self::ReplaceBinding + } else { + Self::PreserveBinding + }) + } + Type::NominalInstance(_) | Type::ProtocolInstance(_) => { + let call_symbol = decorator_ty + .member_lookup_with_policy( + db, + Name::new_static("__call__"), + MemberLookupPolicy::NO_INSTANCE_FALLBACK, + ) + .place; + + if let Place::Defined(place) = call_symbol + && place.is_definitely_defined() + { + Some(Self::known_from_decorator(db, place.ty).unwrap_or(Self::ReplaceBinding)) + } else { + Some(Self::ReplaceBinding) + } + } + Type::Union(union) => Some( + if union.elements(db).iter().all(|element| { + Self::known_from_decorator(db, *element) == Some(Self::PreserveBinding) + }) { + Self::PreserveBinding + } else { + Self::ReplaceBinding + }, + ), + Type::TypeAlias(alias) => Some( + Self::known_from_decorator(db, alias.value_type(db)) + .unwrap_or(Self::ReplaceBinding), + ), + // TODO: We preserve the class binding for every `Callable` decorator today. Figure out + // which of these cases should instead let an unknown decorator result replace it. + Type::Callable(_) => Some(Self::PreserveBinding), + _ => None, + } + } +} + +/// Merge a class-preserving decorator result into the public binding. +/// +/// If earlier decorators already exposed extra members through an intersection, keep those +/// members instead of collapsing back to the undecorated class when a later decorator simply +/// returns the original class object again. +fn merge_class_preserving_decorator_result<'db>( + db: &'db dyn crate::Db, + original_class: Type<'db>, + current_binding: Type<'db>, + decorated_binding: Type<'db>, +) -> Type<'db> { + if current_binding == original_class + || type_retains_original_class(db, original_class, current_binding) + { + current_binding + } else { + decorated_binding + .as_class_literal() + .map(Type::ClassLiteral) + .unwrap_or(original_class) + } +} diff --git a/crates/ty_python_semantic/src/types/infer/builder/function.rs b/crates/ty_python_semantic/src/types/infer/builder/function.rs index 3ea5622c22475..6947f41eb015d 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/function.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/function.rs @@ -23,6 +23,7 @@ use crate::{ validate_paramspec_components, }, function_known_decorators, infer_statement_types, nearest_enclosing_function, + original_class_type, }, infer_definition_types, infer_scope_types, signatures::ReturnCallableTypeVarScope, @@ -414,6 +415,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { function_decorators, deprecated, dataclass_transformer_params, + function.returns.is_some(), ); let function_literal = FunctionLiteral { last_definition: overload_literal, @@ -990,10 +992,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } let class_definition = self.index.expect_single_definition(class); - let class_literal = infer_definition_types(db, class_definition) - .declaration_type(class_definition) - .inner_type() - .as_class_literal()?; + let class_literal = original_class_type(db, class_definition)?; let typing_self = typing_self(db, self.scope(), Some(method_definition), class_literal); if is_classmethod || function_name == "__new__" { diff --git a/crates/ty_python_semantic/src/types/infer/builder/post_inference/overloaded_function.rs b/crates/ty_python_semantic/src/types/infer/builder/post_inference/overloaded_function.rs index f48c2f37ee672..e100b4fbb186f 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/post_inference/overloaded_function.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/post_inference/overloaded_function.rs @@ -9,10 +9,11 @@ use crate::{ Db, place::{DefinedPlace, Definedness, Place, place_from_bindings}, types::{ - KnownClass, Type, binding_type, + KnownClass, Type, context::InferContext, diagnostic::INVALID_OVERLOAD, function::{FunctionDecorators, FunctionType, KnownFunction, OverloadLiteral}, + infer::original_class_type, signatures::{ParameterConsistency, ReturnTypeConsistency}, }, }; @@ -134,13 +135,12 @@ pub(crate) fn check_overloaded_function<'db>( ) }) { implementation_required = false; - } else if let NodeWithScopeKind::Class(class_node_ref) = scope { - let class = binding_type( + } else if let NodeWithScopeKind::Class(class_node_ref) = scope + && let Some(class) = original_class_type( db, index.expect_single_definition(class_node_ref.node(context.module())), ) - .expect_class_literal(); - + { if class.is_protocol(db) || (Type::ClassLiteral(class) .is_subtype_of(db, KnownClass::ABCMeta.to_instance(db))