Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions crates/ty_python_semantic/resources/mdtest/enums.md
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,7 @@ precedence:

```py
from enum import Enum, EnumMeta, IntEnum, auto
from ty_extensions import enum_members
from typing import Literal

class WithNewAndGenerateNextValue(Enum):
Expand All @@ -1190,6 +1191,27 @@ reveal_type(WithNewAndGenerateNextValue.A.value) # revealed: Any
def _instance_new(a: WithNewAndGenerateNextValue):
reveal_type(a.value) # revealed: Any

class WithNewAndLiteralGenerateNextValue(Enum):
@staticmethod
def _generate_next_value_(name, start, count, last_values) -> Literal["x"]:
return "x"

def __new__(cls, value: str) -> "WithNewAndLiteralGenerateNextValue":
obj = object.__new__(cls)
obj._value_ = object()
return obj

A = auto()
B = auto()

# `__new__` can rewrite duplicate generated values to distinct values, so `B` is not an alias of `A`.
# revealed: tuple[Literal["A"], Literal["B"]]
reveal_type(enum_members(WithNewAndLiteralGenerateNextValue))
reveal_type(WithNewAndLiteralGenerateNextValue.A) # revealed: Literal[WithNewAndLiteralGenerateNextValue.A]
reveal_type(WithNewAndLiteralGenerateNextValue.B) # revealed: Literal[WithNewAndLiteralGenerateNextValue.B]
reveal_type(WithNewAndLiteralGenerateNextValue.A.value) # revealed: Any
reveal_type(WithNewAndLiteralGenerateNextValue.B.value) # revealed: Any

class WithInitAndGenerateNextValue(Enum):
@staticmethod
def _generate_next_value_(name, start, count, last_values) -> str:
Expand All @@ -1205,6 +1227,20 @@ reveal_type(WithInitAndGenerateNextValue.A.value) # revealed: Any
def _instance_init(a: WithInitAndGenerateNextValue):
reveal_type(a.value) # revealed: Any

class WithInitAndLiteralGenerateNextValue(Enum):
@staticmethod
def _generate_next_value_(name, start, count, last_values) -> Literal["x"]:
return "x"

def __init__(self, value: str) -> None: ...

A = auto()
B = auto()

# `__init__` runs after duplicate generated values are resolved to aliases.
# revealed: tuple[Literal["A"]]
reveal_type(enum_members(WithInitAndLiteralGenerateNextValue))

class ChoicesType(EnumMeta):
def __new__(metacls, classname, bases, classdict, **kwds): ...

Expand All @@ -1222,6 +1258,18 @@ reveal_type(MyModelChoices.A.value) # revealed: Any

def _instance_metaclass(a: MyModelChoices):
reveal_type(a.value) # revealed: Any

class IntEnumDuplicateAutoAliases(IntEnum):
@staticmethod
def _generate_next_value_(name, start, count, last_values) -> Literal[42]:
return 42

A = auto()
B = auto()

# The stdlib `IntEnum.__new__` preserves duplicate generated values as aliases.
# revealed: tuple[Literal["A"]]
reveal_type(enum_members(IntEnumDuplicateAutoAliases))
```

For non-`auto()` members in a mixed enum, `_generate_next_value_` does not apply at all, and the
Expand Down
45 changes: 36 additions & 9 deletions crates/ty_python_semantic/src/types/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,16 +218,26 @@ fn try_register_alias<'db>(
/// with a custom `_generate_next_value_`, aliasing is based on the generated
/// value instead of the pre-generator placeholder used while collecting
/// members.
///
/// Returns `None` for `auto()` members when `__new__` or a custom metaclass can
/// rewrite `_value_` before alias registration, because neither the generated
/// value nor the placeholder is reliable alias evidence in that case.
fn alias_detection_value<'db>(
db: &'db dyn Db,
value_ty: Type<'db>,
is_auto: bool,
generate_next_value_function: Option<FunctionType<'db>>,
) -> Type<'db> {
if is_auto && let Some(func_ty) = generate_next_value_function {
func_ty.signature(db).overload_return_type_or_unknown(db)
user_defined_new_function: Option<FunctionType<'db>>,
custom_enum_metaclass_new: bool,
) -> Option<Type<'db>> {
if !is_auto {
Some(value_ty)
} else if user_defined_new_function.is_some() || custom_enum_metaclass_new {
None
} else if let Some(func_ty) = generate_next_value_function {
Some(func_ty.signature(db).overload_return_type_or_unknown(db))
} else {
value_ty
Some(value_ty)
}
}

Expand Down Expand Up @@ -301,6 +311,13 @@ pub(crate) fn enum_metadata<'db>(
let mut prev_value_was_non_literal_int = false;
let mut prev_bool_literal = None;
let ignored_names = enum_ignored_names(db, scope_id);

// Look up custom construction hooks, falling back to parent enum classes.
let init_function = custom_init(db, scope_id).or_else(|| inherited_init(db, class));
let user_defined_new_function =
custom_new(db, scope_id).or_else(|| inherited_user_defined_new(db, class));
let new_function = user_defined_new_function.or_else(|| inherited_new(db, class));
let custom_enum_metaclass_new = custom_enum_metaclass_new(db, class);
let generate_next_value_function = custom_generate_next_value(db, scope_id)
.or_else(|| inherited_generate_next_value(db, class));

Expand Down Expand Up @@ -444,8 +461,12 @@ pub(crate) fn enum_metadata<'db>(
value_ty,
auto_members.contains(name),
generate_next_value_function,
user_defined_new_function,
custom_enum_metaclass_new,
);
if try_register_alias(alias_value_ty, name, &mut enum_values, &mut aliases) {
if let Some(alias_value_ty) = alias_value_ty
&& try_register_alias(alias_value_ty, name, &mut enum_values, &mut aliases)
{
return None;
}

Expand Down Expand Up @@ -488,10 +509,6 @@ pub(crate) fn enum_metadata<'db>(
return None;
}

// Look up custom construction hooks, falling back to parent enum classes.
let init_function = custom_init(db, scope_id).or_else(|| inherited_init(db, class));
let new_function = custom_new(db, scope_id).or_else(|| inherited_new(db, class));
let custom_enum_metaclass_new = custom_enum_metaclass_new(db, class);
let custom_value_annotation = custom_value_annotation(db, scope_id);
let value_annotation = custom_value_annotation.or_else(|| {
if custom_enum_metaclass_new {
Expand Down Expand Up @@ -599,6 +616,16 @@ fn inherited_new<'db>(
iter_parent_enum_classes(db, class).find_map(|base| custom_new(db, base.body_scope(db)))
}

/// Looks up an inherited `__new__` from user-defined parent enum classes in the MRO.
fn inherited_user_defined_new<'db>(
db: &'db dyn Db,
class: StaticClassLiteral<'db>,
) -> Option<FunctionType<'db>> {
iter_parent_enum_classes(db, class)
.filter(|base| base.known(db).is_none())
.find_map(|base| custom_new(db, base.body_scope(db)))
}

/// Looks up an inherited `_generate_next_value_` from parent enum classes in the MRO.
fn inherited_generate_next_value<'db>(
db: &'db dyn Db,
Expand Down
Loading