updating dataclasses.replace and register_jax_tree#10
Conversation
📝 WalkthroughWalkthroughThis PR extends functional update capabilities for frozen dataclasses by introducing an internal Changes
Sequence Diagram(s)sequenceDiagram
participant User as User Code
participant DC as DataClass.aset/<br/>updated_copy
participant Helper as _dataclass_replace
participant Ctor as Constructor
participant Attr as object.__setattr__
User->>DC: Call with changes
DC->>Helper: Pass obj and changes dict
Helper->>Helper: Validate field names
Helper->>Helper: Overlay changes onto<br/>current field values
Helper->>Ctor: Construct with init=True<br/>fields only
Ctor->>Ctor: Create new instance
Ctor-->>Helper: Return new instance
Helper->>Attr: Apply init=False<br/>fields to instance
Attr->>Attr: Set private fields<br/>post-construction
Attr-->>Helper: Complete
Helper-->>DC: Return updated copy
DC-->>User: Return result
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/drinx/base.py`:
- Around line 45-68: _dataclass_replace currently allows changing private
(underscore-prefixed) fields and also drops existing non-init fields that aren't
in changes; fix it by forbidding updates to private fields and by restoring all
non-init fields from the original object. Specifically: in _dataclass_replace,
compute private_names = {f.name for f in all_fields if f.name.startswith("_")}
and if private_names & set(changes): raise TypeError (reject updates to private
fields), keep known_names logic but validate against public names only; build
current_values as you do, then after creating new_obj, stamp back every non-init
field from current_values (use all non-init f.name), and finally overwrite those
non-init public fields with values from changes if present (i.e., ensure
non-init values are preserved unless an allowed public non-init field was
explicitly changed). Use the symbols _dataclass_replace, known_names,
current_values, non_init_overrides (or replace it with the new logic) to locate
and update the code.
In `@src/drinx/transform.py`:
- Around line 29-63: The flatten/unflatten currently lumps all init=False fields
into aux; change this so private_fields are split by jax_static: compute
private_dynamic = [f.name for f in all_fields if not f.init and not
f.metadata.get("jax_static")] and private_static = [f.name for f in all_fields
if not f.init and f.metadata.get("jax_static")]; then extend dynamic_fields
(used in keyed_leaves) with private_dynamic so non-static private fields become
leaves in flatten_with_keys, and only include static_private and
static_init_fields in aux (e.g., aux = (tuple(getattr(obj, f) for f in
static_init_fields), tuple(getattr(obj, f) for f in private_static))); finally
update unflatten to accept aux = (static_init_values, private_static_values),
build init_kwargs by zipping static_init_fields and dynamic_fields (which now
includes private_dynamic), construct obj = cls_(**init_kwargs), and then
object.__setattr__ only for names in private_static to restore static private
fields.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a4c6c7b9-ed55-439e-9c70-96dc1171c4ba
📒 Files selected for processing (3)
src/drinx/base.pysrc/drinx/transform.pysrc/drinx/visualize.py
| all_fields = dataclasses.fields(obj) | ||
| known_names = {f.name for f in all_fields} | ||
| unknown = set(changes) - known_names | ||
| if unknown: | ||
| raise TypeError(f"_dataclass_replace() got unexpected field names: {unknown!r}") | ||
|
|
||
| # Collect current values for every field, then overlay changes | ||
| current_values: dict[str, Any] = { | ||
| f.name: object.__getattribute__(obj, f.name) for f in all_fields | ||
| } | ||
| current_values.update(changes) | ||
|
|
||
| init_kwargs = {f.name: current_values[f.name] for f in all_fields if f.init} | ||
| non_init_overrides = { | ||
| f.name: current_values[f.name] | ||
| for f in all_fields | ||
| if not f.init and f.name in changes | ||
| } | ||
|
|
||
| new_obj = type(obj)(**init_kwargs) | ||
|
|
||
| # Stamp non-init fields that were explicitly changed | ||
| for name, value in non_init_overrides.items(): | ||
| object.__setattr__(new_obj, name, value) |
There was a problem hiding this comment.
_dataclass_replace currently both exposes and drops private fields.
known_names lets init=False names through, so updated_copy(_cache=...) / aset("_cache", ...) now bypass the existing guard that tests/test_base.py:1302-1311 and tests/test_base.py:1400-1410 assert. At the same time, only f.name in changes gets stamped back, so updated_copy(x=...) still resets any existing private state instead of preserving it.
Suggested fix
all_fields = dataclasses.fields(obj)
- known_names = {f.name for f in all_fields}
- unknown = set(changes) - known_names
+ fields_by_name = {f.name: f for f in all_fields}
+ unknown = set(changes) - fields_by_name.keys()
if unknown:
raise TypeError(f"_dataclass_replace() got unexpected field names: {unknown!r}")
+
+ non_init_updates = {name for name in changes if not fields_by_name[name].init}
+ if non_init_updates:
+ raise TypeError(
+ f"_dataclass_replace() cannot update init=False fields: {non_init_updates!r}"
+ )
# Collect current values for every field, then overlay changes
current_values: dict[str, Any] = {
f.name: object.__getattribute__(obj, f.name) for f in all_fields
}
current_values.update(changes)
init_kwargs = {f.name: current_values[f.name] for f in all_fields if f.init}
- non_init_overrides = {
- f.name: current_values[f.name]
- for f in all_fields
- if not f.init and f.name in changes
- }
+ non_init_values = {f.name: current_values[f.name] for f in all_fields if not f.init}
new_obj = type(obj)(**init_kwargs)
- # Stamp non-init fields that were explicitly changed
- for name, value in non_init_overrides.items():
+ # Preserve current non-init fields after construction
+ for name, value in non_init_values.items():
object.__setattr__(new_obj, name, value)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| all_fields = dataclasses.fields(obj) | |
| known_names = {f.name for f in all_fields} | |
| unknown = set(changes) - known_names | |
| if unknown: | |
| raise TypeError(f"_dataclass_replace() got unexpected field names: {unknown!r}") | |
| # Collect current values for every field, then overlay changes | |
| current_values: dict[str, Any] = { | |
| f.name: object.__getattribute__(obj, f.name) for f in all_fields | |
| } | |
| current_values.update(changes) | |
| init_kwargs = {f.name: current_values[f.name] for f in all_fields if f.init} | |
| non_init_overrides = { | |
| f.name: current_values[f.name] | |
| for f in all_fields | |
| if not f.init and f.name in changes | |
| } | |
| new_obj = type(obj)(**init_kwargs) | |
| # Stamp non-init fields that were explicitly changed | |
| for name, value in non_init_overrides.items(): | |
| object.__setattr__(new_obj, name, value) | |
| all_fields = dataclasses.fields(obj) | |
| fields_by_name = {f.name: f for f in all_fields} | |
| unknown = set(changes) - fields_by_name.keys() | |
| if unknown: | |
| raise TypeError(f"_dataclass_replace() got unexpected field names: {unknown!r}") | |
| non_init_updates = {name for name in changes if not fields_by_name[name].init} | |
| if non_init_updates: | |
| raise TypeError( | |
| f"_dataclass_replace() cannot update init=False fields: {non_init_updates!r}" | |
| ) | |
| # Collect current values for every field, then overlay changes | |
| current_values: dict[str, Any] = { | |
| f.name: object.__getattribute__(obj, f.name) for f in all_fields | |
| } | |
| current_values.update(changes) | |
| init_kwargs = {f.name: current_values[f.name] for f in all_fields if f.init} | |
| non_init_values = {f.name: current_values[f.name] for f in all_fields if not f.init} | |
| new_obj = type(obj)(**init_kwargs) | |
| # Preserve current non-init fields after construction | |
| for name, value in non_init_values.items(): | |
| object.__setattr__(new_obj, name, value) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/drinx/base.py` around lines 45 - 68, _dataclass_replace currently allows
changing private (underscore-prefixed) fields and also drops existing non-init
fields that aren't in changes; fix it by forbidding updates to private fields
and by restoring all non-init fields from the original object. Specifically: in
_dataclass_replace, compute private_names = {f.name for f in all_fields if
f.name.startswith("_")} and if private_names & set(changes): raise TypeError
(reject updates to private fields), keep known_names logic but validate against
public names only; build current_values as you do, then after creating new_obj,
stamp back every non-init field from current_values (use all non-init f.name),
and finally overwrite those non-init public fields with values from changes if
present (i.e., ensure non-init values are preserved unless an allowed public
non-init field was explicitly changed). Use the symbols _dataclass_replace,
known_names, current_values, non_init_overrides (or replace it with the new
logic) to locate and update the code.
| all_fields = fields(cls_) | ||
| # Traced leaves: init=True and not jax_static | ||
| dynamic_fields = [ | ||
| f.name for f in all_fields if not f.metadata.get("jax_static") and f.init | ||
| ] | ||
| # Aux bucket 1: explicitly static (jax_static=True) and init=True | ||
| static_init_fields = [ | ||
| f.name for f in all_fields if f.metadata.get("jax_static") and f.init | ||
| ] | ||
| # Aux bucket 2: private (init=False), regardless of jax_static — must ride in aux | ||
| private_fields_names = [f.name for f in all_fields if not f.init] | ||
|
|
||
| def flatten_with_keys(obj): | ||
| keyed_leaves = [ | ||
| (jax.tree_util.GetAttrKey(f), getattr(obj, f)) for f in dynamic_fields | ||
| ] | ||
| aux = tuple(getattr(obj, f) for f in static_fields) | ||
| aux = ( | ||
| tuple(getattr(obj, f) for f in static_init_fields), | ||
| tuple(getattr(obj, f) for f in private_fields_names), | ||
| ) | ||
| return keyed_leaves, aux | ||
|
|
||
| def unflatten(aux, leaves): | ||
| kwargs = {**dict(zip(static_fields, aux)), **dict(zip(dynamic_fields, leaves))} | ||
| return cls_(**kwargs) | ||
| static_init_values, private_values = aux | ||
| init_kwargs = { | ||
| **dict(zip(static_init_fields, static_init_values)), | ||
| **dict(zip(dynamic_fields, leaves)), | ||
| } | ||
| obj = cls_(**init_kwargs) | ||
| # Stamp private fields back after construction. object.__setattr__ bypasses | ||
| # DataClass.__setattr__'s frozen guard, which is safe here because we are | ||
| # restoring the exact values the object was flattened from. | ||
| for name, value in zip(private_fields_names, private_values): | ||
| object.__setattr__(obj, name, value) | ||
| return obj |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
In JAX custom pytree registration (register_pytree_with_keys), is aux_datatreated as static tree metadata, and should array-valued data that must participate injit/vmap/tree_mapbe returned as children/leaves rather than inaux_data?
💡 Result:
Yes, in JAX custom pytree registration with register_pytree_with_keys (and similarly register_pytree_node), aux_data is treated as static tree metadata. It is stored in the treedef, must be hashable and immutable, and is used for structure reconstruction during unflatten. It does not participate in transformations like jit, vmap, or tree_map—those operate on the dynamic children/leaves. Array-valued data that must participate in jit (i.e., be traced and potentially compiled differently based on value), vmap (vectorized), or tree_map (mapped over) should be returned as children/leaves, not in aux_data. Including arrays in aux_data leads to unsupported, poorly-defined behavior, such as unexpected tracing under jit (e.g., jnp arrays may become Traced unexpectedly, while np arrays do not). Official JAX documentation examples explicitly state: "children must contain arrays & pytrees" and "aux_data must contain static, hashable data." StackOverflow answer citing JAX docs confirms: "aux_data in a pytree flattening must contain static elements... Neither np.ndarray nor jax.Array satisfy this, so they should not be included in aux_data."
Citations:
- 1: https://stackoverflow.com/questions/78485445/why-is-custom-pytree-aux-data-traced-after-jax-jit-for-jnp-array-but-not-for
- 2: http://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.register_pytree_with_keys.html
- 3: https://docs.jax.dev/en/latest/_autosummary/jax.tree_util.register_pytree_node.html
- 4: https://docs.jax.dev/en/latest/%5Fsources/custom%5Fpytrees.md
- 5: https://docs.jax.dev/en/latest/%5Fautosummary/jax.tree_util.register_dataclass.html
- 6: http://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.register_pytree_node.html
🏁 Script executed:
#!/bin/bash
# Find definitions of private_field and static_private_field
rg "def (private_field|static_private_field)" -A 5Repository: ymahlau/drinx
Length of output: 668
🏁 Script executed:
#!/bin/bash
# Search for uses of jax_static metadata in the fields
rg "jax_static" -B 2 -A 2 | head -100Repository: ymahlau/drinx
Length of output: 4940
🏁 Script executed:
#!/bin/bash
# Check the structure of fields and metadata in base.py
fd "base.py" | xargs cat -n | head -150Repository: ymahlau/drinx
Length of output: 6448
Split private fields by jax_static in pytree flatten/unflatten logic.
private_field() and static_private_field() are distinct APIs: the former defaults to dynamic (non-static) and the latter is always static. The current implementation collapses all init=False fields into pytree aux regardless of jax_static, placing non-static array-valued private fields into JAX's static treedef metadata where they cannot participate in transformations like jit, vmap, or tree_map.
Keep init=False and not jax_static fields as pytree leaves (via keyed_leaves), and only put init=False and jax_static values into aux. This aligns with the API design and JAX's requirement that aux_data contain only static, hashable metadata—not arrays or mutable data.
🧰 Tools
🪛 Ruff (0.15.10)
[warning] 54-54: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
[warning] 55-55: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
[warning] 61-61: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/drinx/transform.py` around lines 29 - 63, The flatten/unflatten currently
lumps all init=False fields into aux; change this so private_fields are split by
jax_static: compute private_dynamic = [f.name for f in all_fields if not f.init
and not f.metadata.get("jax_static")] and private_static = [f.name for f in
all_fields if not f.init and f.metadata.get("jax_static")]; then extend
dynamic_fields (used in keyed_leaves) with private_dynamic so non-static private
fields become leaves in flatten_with_keys, and only include static_private and
static_init_fields in aux (e.g., aux = (tuple(getattr(obj, f) for f in
static_init_fields), tuple(getattr(obj, f) for f in private_static))); finally
update unflatten to accept aux = (static_init_values, private_static_values),
build init_kwargs by zipping static_init_fields and dynamic_fields (which now
includes private_dynamic), construct obj = cls_(**init_kwargs), and then
object.__setattr__ only for names in private_static to restore static private
fields.
|
resolved by #11 |
Summary by CodeRabbit
Bug Fixes
Chores