Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions crates/ty_python_semantic/resources/mdtest/enums.md
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,182 @@ class WithInit(Enum):
reveal_type(WithInit.MERCURY.value) # revealed: Any
```

When `_generate_next_value_` is overridden, its return type is used for `auto()` value types, unless
overridden by an explicit `_value_` annotation or a custom construction hook:

```py
from enum import StrEnum, IntEnum, auto
from typing import Literal

class CustomNextValue(Enum):
@staticmethod
def _generate_next_value_(name, start, count, last_values): ...

A = auto()
B = auto()

reveal_type(CustomNextValue.A.value) # revealed: Unknown

class CustomNextValueNonAuto(Enum):
@staticmethod
def _generate_next_value_(name, start, count, last_values) -> Literal[3]:
return 3

A = 1
B = 2

reveal_type(CustomNextValueNonAuto.A.value) # revealed: Literal[1]

class CustomNextValueStr(Enum):
@staticmethod
def _generate_next_value_(name, start, count, last_values) -> str:
return ""

A = auto()
B = auto()

# Should not be `Literal['A']`
# revealed: str
reveal_type(CustomNextValueStr.A.value)

class CustomNextValuePrecedence(Enum):
_value_: str

@staticmethod
def _generate_next_value_(name, start, count, last_values) -> Literal["a"]:
return "a"

A = auto()
B = auto()

# `_value_` annotation takes precedence over `_generate_next_value_`'s return type
# revealed: str
reveal_type(CustomNextValuePrecedence.A.value)

def foo(a: CustomNextValuePrecedence):
# revealed: str
reveal_type(a.value)

class CustomNextValueInt(IntEnum):
@staticmethod
def _generate_next_value_(name, start, count, last_values) -> Literal[42]:
return 42

A = auto()
B = auto()

# `IntEnum` inherits `_value_: int`, which takes precedence over `_generate_next_value_`
# revealed: int
reveal_type(CustomNextValueInt.A.value)
```

When an enum defines both `_generate_next_value_` and a construction hook (`__new__`, `__init__`, or
a custom enum metaclass `__new__`), the hook can rewrite `_value_` to a different type than the
value returned by `_generate_next_value_`. The hook-based `Any` fallback should therefore take
precedence:

```py
from enum import Enum, EnumMeta, IntEnum, auto
from typing import Literal

class WithNewAndGenerateNextValue(Enum):
@staticmethod
def _generate_next_value_(name, start, count, last_values) -> str:
return ""

def __new__(cls, value: str) -> "WithNewAndGenerateNextValue":
obj = object.__new__(cls)
obj._value_ = len(value)
return obj

A = auto()
B = auto()

# `__new__` rewrites `_value_` to an `int`, so we can't trust `_generate_next_value_`'s return type
reveal_type(WithNewAndGenerateNextValue.A.value) # revealed: Any

def _instance_new(a: WithNewAndGenerateNextValue):
reveal_type(a.value) # revealed: Any

class WithInitAndGenerateNextValue(Enum):
@staticmethod
def _generate_next_value_(name, start, count, last_values) -> str:
return ""

def __init__(self, value: str) -> None: ...

A = auto()
B = auto()

reveal_type(WithInitAndGenerateNextValue.A.value) # revealed: Any

def _instance_init(a: WithInitAndGenerateNextValue):
reveal_type(a.value) # revealed: Any

class ChoicesType(EnumMeta):
def __new__(metacls, classname, bases, classdict, **kwds): ...

class IntegerChoices(IntEnum, metaclass=ChoicesType):
@staticmethod
def _generate_next_value_(name, start, count, last_values) -> Literal[42]:
return 42

class MyModelChoices(IntegerChoices):
A = auto()
B = auto()

# The metaclass `__new__` can rewrite member values before they reach `_value_`
reveal_type(MyModelChoices.A.value) # revealed: Any

def _instance_metaclass(a: MyModelChoices):
reveal_type(a.value) # revealed: Any
```

For non-`auto()` members in a mixed enum, `_generate_next_value_` does not apply at all, and the
inferred value type should be used (subject to the same hook-based `Any` fallback):

```py
from enum import Enum, auto
from ty_extensions import enum_members
from typing import Literal

class MixedAutoAndLiteral(Enum):
@staticmethod
def _generate_next_value_(name, start, count, last_values) -> str:
return ""

A = auto()
B = 99

reveal_type(MixedAutoAndLiteral.A.value) # revealed: str
reveal_type(MixedAutoAndLiteral.B.value) # revealed: Literal[99]

def _mixed_instance(x: MixedAutoAndLiteral):
# Union of all member value types, not just `_generate_next_value_`'s return type
reveal_type(x.value) # revealed: str | Literal[99]

class InheritedCustomNextValue(Enum):
@staticmethod
def _generate_next_value_(name, start, count, last_values) -> str:
return ""

class InheritedCustomNextValueChild(InheritedCustomNextValue):
A = auto()
B = 1
C = 1

# `A` uses the inherited `_generate_next_value_`, so `B` is not an alias of `A`.
# revealed: tuple[Literal["A"], Literal["B"]]
reveal_type(enum_members(InheritedCustomNextValueChild))
reveal_type(InheritedCustomNextValueChild.A.value) # revealed: str
reveal_type(InheritedCustomNextValueChild.B) # revealed: Literal[InheritedCustomNextValueChild.B]
reveal_type(InheritedCustomNextValueChild.B.value) # revealed: Literal[1]
reveal_type(InheritedCustomNextValueChild.C) # revealed: Literal[InheritedCustomNextValueChild.B]

def _inherited_mixed_instance(x: InheritedCustomNextValueChild):
reveal_type(x.value) # revealed: str | Literal[1]
```

### `member` and `nonmember`

```toml
Expand Down
6 changes: 4 additions & 2 deletions crates/ty_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3443,8 +3443,10 @@ impl<'db> Type<'db> {
.and_then(|metadata| match name_str {
"name" if is_enum_subclass => metadata.name_type(db, enum_literal.name(db)),
"_name_" => metadata.name_type(db, enum_literal.name(db)),
"value" if is_enum_subclass => metadata.value_type(enum_literal.name(db)),
"_value_" => metadata.value_type(enum_literal.name(db)),
"value" if is_enum_subclass => {
metadata.value_type(db, enum_literal.name(db))
}
"_value_" => metadata.value_type(db, enum_literal.name(db)),
_ => None,
})
.map_or_else(|| Place::Undefined, Place::bound)
Expand Down
85 changes: 77 additions & 8 deletions crates/ty_python_semantic/src/types/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ pub(crate) struct EnumMetadata<'db> {
/// independently.
pub(crate) new_function: Option<FunctionType<'db>>,

/// The custom `_generate_next_value_` function, if defined on this enum.
///
/// When present, defines the value returned by calls to `auto()`
pub(crate) generate_next_value_function: Option<FunctionType<'db>>,

/// Whether the enum metaclass may transform member values before they are
/// passed to enum construction hooks.
pub(crate) custom_enum_metaclass_new: bool,
Expand All @@ -55,16 +60,17 @@ impl<'db> EnumMetadata<'db> {
value_annotation: None,
init_function: None,
new_function: None,
generate_next_value_function: None,
custom_enum_metaclass_new: false,
}
}

/// Returns the type of `.value`/`._value_` for a given enum member.
///
/// Priority: explicit `_value_` annotation, then custom construction hooks
/// or metaclass value transformation → `Any`, then the inferred member
/// value type.
pub(crate) fn value_type(&self, member_name: &Name) -> Option<Type<'db>> {
/// or metaclass value transformation → `Any`, then `_generate_next_value_`
/// return type for `auto()` members, then the inferred member value type.
pub(crate) fn value_type(&self, db: &'db dyn Db, member_name: &Name) -> Option<Type<'db>> {
if !self.members.contains_key(member_name) {
return None;
}
Expand All @@ -75,6 +81,10 @@ impl<'db> EnumMetadata<'db> {
|| self.custom_enum_metaclass_new
{
Some(Type::Dynamic(DynamicType::Any))
} else if let Some(func_ty) = self.generate_next_value_function
&& self.auto_members.contains(member_name)
{
Some(func_ty.signature(db).overload_return_type_or_unknown(db))
} else {
self.members.get(member_name).copied()
}
Expand All @@ -95,7 +105,8 @@ impl<'db> EnumMetadata<'db> {
/// If there is an explicit `_value_` annotation, returns that.
/// If there is a custom `__init__` or `__new__` or a custom enum
/// metaclass may transform member values, returns `Any`.
/// Otherwise, returns the union of all member value types.
/// Otherwise, returns the union of each member's `value_type`, which
/// applies `_generate_next_value_`'s return type to `auto()` members.
pub(crate) fn instance_value_type(&self, db: &'db dyn Db) -> Option<Type<'db>> {
if self.members.is_empty() {
return None;
Expand All @@ -110,8 +121,8 @@ impl<'db> EnumMetadata<'db> {
} else {
let union = self
.members
.values()
.copied()
.keys()
.filter_map(|name| self.value_type(db, name))
.fold(UnionBuilder::new(db), UnionBuilder::add)
.build();
Some(union)
Expand Down Expand Up @@ -201,6 +212,25 @@ fn try_register_alias<'db>(
false
}

/// Returns the value to use when checking whether an enum member is an alias.
///
/// For ordinary members, this is the inferred value type. For `auto()` members
/// with a custom `_generate_next_value_`, aliasing is based on the generated
/// value instead of the pre-generator placeholder used while collecting
/// members.
fn alias_detection_value<'db>(
db: &'db dyn Db,
value_ty: Type<'db>,
is_auto: bool,
generate_next_value_function: Option<FunctionType<'db>>,
) -> Type<'db> {
if is_auto && let Some(func_ty) = generate_next_value_function {
func_ty.signature(db).overload_return_type_or_unknown(db)
} else {
value_ty
}
}

/// List all members of an enum.
#[salsa::tracked(returns(as_ref), cycle_initial=|_, _, _| Some(EnumMetadata::empty()), heap_size=ruff_memory_usage::heap_size)]
pub(crate) fn enum_metadata<'db>(
Expand Down Expand Up @@ -243,6 +273,7 @@ pub(crate) fn enum_metadata<'db>(
value_annotation: None,
init_function: None,
new_function: None,
generate_next_value_function: None,
custom_enum_metaclass_new: false,
});
}
Expand Down Expand Up @@ -270,6 +301,8 @@ pub(crate) fn enum_metadata<'db>(
let mut prev_value_was_non_literal_int = false;
let mut prev_bool_literal = None;
let ignored_names = enum_ignored_names(db, scope_id);
let generate_next_value_function = custom_generate_next_value(db, scope_id)
.or_else(|| inherited_generate_next_value(db, class));

let mut aliases = FxHashMap::default();

Expand Down Expand Up @@ -406,7 +439,13 @@ pub(crate) fn enum_metadata<'db>(
}
};

if try_register_alias(value_ty, name, &mut enum_values, &mut aliases) {
let alias_value_ty = alias_detection_value(
db,
value_ty,
auto_members.contains(name),
generate_next_value_function,
);
if try_register_alias(alias_value_ty, name, &mut enum_values, &mut aliases) {
return None;
}

Expand All @@ -428,7 +467,7 @@ pub(crate) fn enum_metadata<'db>(
return None;
}

//Ttrack whether this member's value is a non-literal `int`, so a
// Track whether this member's value is a non-literal `int`, so a
// following `auto()` knows to widen its result to `int`.
prev_value_was_non_literal_int = value_ty.as_int_like_literal().is_none()
&& value_ty.is_assignable_to(db, KnownClass::Int.to_instance(db));
Expand Down Expand Up @@ -469,6 +508,7 @@ pub(crate) fn enum_metadata<'db>(
value_annotation,
init_function,
new_function,
generate_next_value_function,
custom_enum_metaclass_new,
})
}
Expand Down Expand Up @@ -559,6 +599,15 @@ fn inherited_new<'db>(
iter_parent_enum_classes(db, class).find_map(|base| custom_new(db, base.body_scope(db)))
}

/// Looks up an inherited `_generate_next_value_` from parent enum classes in the MRO.
fn inherited_generate_next_value<'db>(
db: &'db dyn Db,
class: StaticClassLiteral<'db>,
) -> Option<FunctionType<'db>> {
iter_parent_enum_classes(db, class)
.find_map(|base| custom_generate_next_value(db, base.body_scope(db)))
}

/// Returns the custom `__init__` function type if one is defined on the enum.
fn custom_init<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Option<FunctionType<'db>> {
let init_symbol_id = place_table(db, scope).symbol_id("__init__")?;
Expand Down Expand Up @@ -591,6 +640,26 @@ fn custom_new<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Option<FunctionType<
}
}

/// Returns the custom `_generate_next_value_` function type if one is defined on the enum.
fn custom_generate_next_value<'db>(
db: &'db dyn Db,
scope: ScopeId<'db>,
) -> Option<FunctionType<'db>> {
let symbol_id_opt = place_table(db, scope).symbol_id("_generate_next_value_");
let new_symbol_id = symbol_id_opt?;
let new_type = place_from_declarations(
db,
use_def_map(db, scope).end_of_scope_symbol_declarations(new_symbol_id),
)
.ignore_conflicting_declarations()
.ignore_possibly_undefined();
let new_type = new_type?;
match new_type {
Type::FunctionLiteral(f) => Some(f),
_ => None,
}
}

pub(crate) fn enum_member_literals<'a, 'db: 'a>(
db: &'db dyn Db,
class: ClassLiteral<'db>,
Expand Down
Loading