Skip to content

updating dataclasses.replace and register_jax_tree#10

Closed
renaissancenerd wants to merge 1 commit into
ymahlau:mainfrom
renaissancenerd:method-private-attributes
Closed

updating dataclasses.replace and register_jax_tree#10
renaissancenerd wants to merge 1 commit into
ymahlau:mainfrom
renaissancenerd:method-private-attributes

Conversation

@renaissancenerd

@renaissancenerd renaissancenerd commented Apr 16, 2026

Copy link
Copy Markdown

Summary by CodeRabbit

  • Bug Fixes

    • Fixed functional update operations to correctly preserve and update all dataclass fields, including non-constructor fields that were previously lost.
    • Improved JAX integration to properly handle all dataclass field types during tracing and serialization.
  • Chores

    • Optimized visualization string construction.

@coderabbitai

coderabbitai Bot commented Apr 16, 2026

Copy link
Copy Markdown
📝 Walkthrough

Walkthrough

This PR extends functional update capabilities for frozen dataclasses by introducing an internal _dataclass_replace helper that preserves init=False fields, updates JAX pytree registration to classify fields by both static and init attributes, and makes a minor formatting adjustment in visualization code.

Changes

Cohort / File(s) Summary
Dataclass functional updates
src/drinx/base.py, src/drinx/transform.py
Added _dataclass_replace helper to support functional updates for both init=True and init=False dataclass fields. Updated DataClass.aset and DataClass.updated_copy to use the new helper. Modified JAX pytree registration to classify fields by both jax_static and init attributes, separating them into static-init and private auxiliary buckets. Updated flatten/unflatten logic to reconstruct objects and restore private fields post-construction.
Visualization formatting
src/drinx/visualize.py
Replaced map-based string join with generator expression in visualize_leaf, updating the corresponding type-checking annotation.

Sequence Diagram(s)

sequenceDiagram
    participant User as User Code
    participant DC as DataClass.aset/<br/>updated_copy
    participant Helper as _dataclass_replace
    participant Ctor as Constructor
    participant Attr as object.__setattr__

    User->>DC: Call with changes
    DC->>Helper: Pass obj and changes dict
    Helper->>Helper: Validate field names
    Helper->>Helper: Overlay changes onto<br/>current field values
    Helper->>Ctor: Construct with init=True<br/>fields only
    Ctor->>Ctor: Create new instance
    Ctor-->>Helper: Return new instance
    Helper->>Attr: Apply init=False<br/>fields to instance
    Attr->>Attr: Set private fields<br/>post-construction
    Attr-->>Helper: Complete
    Helper-->>DC: Return updated copy
    DC-->>User: Return result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

🐰 Hops through frozen dataclass fields,
init=False now safely shields,
JAX pytrees classify with care,
Static, dynamic, private—all there!
Functional updates dance anew!

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly references the two main areas of change: dataclasses.replace functionality and register_jax_tree logic, which align with the core modifications across base.py and transform.py.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/drinx/base.py`:
- Around line 45-68: _dataclass_replace currently allows changing private
(underscore-prefixed) fields and also drops existing non-init fields that aren't
in changes; fix it by forbidding updates to private fields and by restoring all
non-init fields from the original object. Specifically: in _dataclass_replace,
compute private_names = {f.name for f in all_fields if f.name.startswith("_")}
and if private_names & set(changes): raise TypeError (reject updates to private
fields), keep known_names logic but validate against public names only; build
current_values as you do, then after creating new_obj, stamp back every non-init
field from current_values (use all non-init f.name), and finally overwrite those
non-init public fields with values from changes if present (i.e., ensure
non-init values are preserved unless an allowed public non-init field was
explicitly changed). Use the symbols _dataclass_replace, known_names,
current_values, non_init_overrides (or replace it with the new logic) to locate
and update the code.

In `@src/drinx/transform.py`:
- Around line 29-63: The flatten/unflatten currently lumps all init=False fields
into aux; change this so private_fields are split by jax_static: compute
private_dynamic = [f.name for f in all_fields if not f.init and not
f.metadata.get("jax_static")] and private_static = [f.name for f in all_fields
if not f.init and f.metadata.get("jax_static")]; then extend dynamic_fields
(used in keyed_leaves) with private_dynamic so non-static private fields become
leaves in flatten_with_keys, and only include static_private and
static_init_fields in aux (e.g., aux = (tuple(getattr(obj, f) for f in
static_init_fields), tuple(getattr(obj, f) for f in private_static))); finally
update unflatten to accept aux = (static_init_values, private_static_values),
build init_kwargs by zipping static_init_fields and dynamic_fields (which now
includes private_dynamic), construct obj = cls_(**init_kwargs), and then
object.__setattr__ only for names in private_static to restore static private
fields.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a4c6c7b9-ed55-439e-9c70-96dc1171c4ba

📥 Commits

Reviewing files that changed from the base of the PR and between fa3bc32 and 2170142.

📒 Files selected for processing (3)
  • src/drinx/base.py
  • src/drinx/transform.py
  • src/drinx/visualize.py

Comment thread src/drinx/base.py
Comment on lines +45 to +68
all_fields = dataclasses.fields(obj)
known_names = {f.name for f in all_fields}
unknown = set(changes) - known_names
if unknown:
raise TypeError(f"_dataclass_replace() got unexpected field names: {unknown!r}")

# Collect current values for every field, then overlay changes
current_values: dict[str, Any] = {
f.name: object.__getattribute__(obj, f.name) for f in all_fields
}
current_values.update(changes)

init_kwargs = {f.name: current_values[f.name] for f in all_fields if f.init}
non_init_overrides = {
f.name: current_values[f.name]
for f in all_fields
if not f.init and f.name in changes
}

new_obj = type(obj)(**init_kwargs)

# Stamp non-init fields that were explicitly changed
for name, value in non_init_overrides.items():
object.__setattr__(new_obj, name, value)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

_dataclass_replace currently both exposes and drops private fields.

known_names lets init=False names through, so updated_copy(_cache=...) / aset("_cache", ...) now bypass the existing guard that tests/test_base.py:1302-1311 and tests/test_base.py:1400-1410 assert. At the same time, only f.name in changes gets stamped back, so updated_copy(x=...) still resets any existing private state instead of preserving it.

Suggested fix
     all_fields = dataclasses.fields(obj)
-    known_names = {f.name for f in all_fields}
-    unknown = set(changes) - known_names
+    fields_by_name = {f.name: f for f in all_fields}
+    unknown = set(changes) - fields_by_name.keys()
     if unknown:
         raise TypeError(f"_dataclass_replace() got unexpected field names: {unknown!r}")
+
+    non_init_updates = {name for name in changes if not fields_by_name[name].init}
+    if non_init_updates:
+        raise TypeError(
+            f"_dataclass_replace() cannot update init=False fields: {non_init_updates!r}"
+        )
 
     # Collect current values for every field, then overlay changes
     current_values: dict[str, Any] = {
         f.name: object.__getattribute__(obj, f.name) for f in all_fields
     }
     current_values.update(changes)
 
     init_kwargs = {f.name: current_values[f.name] for f in all_fields if f.init}
-    non_init_overrides = {
-        f.name: current_values[f.name]
-        for f in all_fields
-        if not f.init and f.name in changes
-    }
+    non_init_values = {f.name: current_values[f.name] for f in all_fields if not f.init}
 
     new_obj = type(obj)(**init_kwargs)
 
-    # Stamp non-init fields that were explicitly changed
-    for name, value in non_init_overrides.items():
+    # Preserve current non-init fields after construction
+    for name, value in non_init_values.items():
         object.__setattr__(new_obj, name, value)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
all_fields = dataclasses.fields(obj)
known_names = {f.name for f in all_fields}
unknown = set(changes) - known_names
if unknown:
raise TypeError(f"_dataclass_replace() got unexpected field names: {unknown!r}")
# Collect current values for every field, then overlay changes
current_values: dict[str, Any] = {
f.name: object.__getattribute__(obj, f.name) for f in all_fields
}
current_values.update(changes)
init_kwargs = {f.name: current_values[f.name] for f in all_fields if f.init}
non_init_overrides = {
f.name: current_values[f.name]
for f in all_fields
if not f.init and f.name in changes
}
new_obj = type(obj)(**init_kwargs)
# Stamp non-init fields that were explicitly changed
for name, value in non_init_overrides.items():
object.__setattr__(new_obj, name, value)
all_fields = dataclasses.fields(obj)
fields_by_name = {f.name: f for f in all_fields}
unknown = set(changes) - fields_by_name.keys()
if unknown:
raise TypeError(f"_dataclass_replace() got unexpected field names: {unknown!r}")
non_init_updates = {name for name in changes if not fields_by_name[name].init}
if non_init_updates:
raise TypeError(
f"_dataclass_replace() cannot update init=False fields: {non_init_updates!r}"
)
# Collect current values for every field, then overlay changes
current_values: dict[str, Any] = {
f.name: object.__getattribute__(obj, f.name) for f in all_fields
}
current_values.update(changes)
init_kwargs = {f.name: current_values[f.name] for f in all_fields if f.init}
non_init_values = {f.name: current_values[f.name] for f in all_fields if not f.init}
new_obj = type(obj)(**init_kwargs)
# Preserve current non-init fields after construction
for name, value in non_init_values.items():
object.__setattr__(new_obj, name, value)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/drinx/base.py` around lines 45 - 68, _dataclass_replace currently allows
changing private (underscore-prefixed) fields and also drops existing non-init
fields that aren't in changes; fix it by forbidding updates to private fields
and by restoring all non-init fields from the original object. Specifically: in
_dataclass_replace, compute private_names = {f.name for f in all_fields if
f.name.startswith("_")} and if private_names & set(changes): raise TypeError
(reject updates to private fields), keep known_names logic but validate against
public names only; build current_values as you do, then after creating new_obj,
stamp back every non-init field from current_values (use all non-init f.name),
and finally overwrite those non-init public fields with values from changes if
present (i.e., ensure non-init values are preserved unless an allowed public
non-init field was explicitly changed). Use the symbols _dataclass_replace,
known_names, current_values, non_init_overrides (or replace it with the new
logic) to locate and update the code.

Comment thread src/drinx/transform.py
Comment on lines +29 to +63
all_fields = fields(cls_)
# Traced leaves: init=True and not jax_static
dynamic_fields = [
f.name for f in all_fields if not f.metadata.get("jax_static") and f.init
]
# Aux bucket 1: explicitly static (jax_static=True) and init=True
static_init_fields = [
f.name for f in all_fields if f.metadata.get("jax_static") and f.init
]
# Aux bucket 2: private (init=False), regardless of jax_static — must ride in aux
private_fields_names = [f.name for f in all_fields if not f.init]

def flatten_with_keys(obj):
keyed_leaves = [
(jax.tree_util.GetAttrKey(f), getattr(obj, f)) for f in dynamic_fields
]
aux = tuple(getattr(obj, f) for f in static_fields)
aux = (
tuple(getattr(obj, f) for f in static_init_fields),
tuple(getattr(obj, f) for f in private_fields_names),
)
return keyed_leaves, aux

def unflatten(aux, leaves):
kwargs = {**dict(zip(static_fields, aux)), **dict(zip(dynamic_fields, leaves))}
return cls_(**kwargs)
static_init_values, private_values = aux
init_kwargs = {
**dict(zip(static_init_fields, static_init_values)),
**dict(zip(dynamic_fields, leaves)),
}
obj = cls_(**init_kwargs)
# Stamp private fields back after construction. object.__setattr__ bypasses
# DataClass.__setattr__'s frozen guard, which is safe here because we are
# restoring the exact values the object was flattened from.
for name, value in zip(private_fields_names, private_values):
object.__setattr__(obj, name, value)
return obj

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In JAX custom pytree registration (register_pytree_with_keys), is aux_datatreated as static tree metadata, and should array-valued data that must participate injit/vmap/tree_mapbe returned as children/leaves rather than inaux_data?

💡 Result:

Yes, in JAX custom pytree registration with register_pytree_with_keys (and similarly register_pytree_node), aux_data is treated as static tree metadata. It is stored in the treedef, must be hashable and immutable, and is used for structure reconstruction during unflatten. It does not participate in transformations like jit, vmap, or tree_map—those operate on the dynamic children/leaves. Array-valued data that must participate in jit (i.e., be traced and potentially compiled differently based on value), vmap (vectorized), or tree_map (mapped over) should be returned as children/leaves, not in aux_data. Including arrays in aux_data leads to unsupported, poorly-defined behavior, such as unexpected tracing under jit (e.g., jnp arrays may become Traced unexpectedly, while np arrays do not). Official JAX documentation examples explicitly state: "children must contain arrays & pytrees" and "aux_data must contain static, hashable data." StackOverflow answer citing JAX docs confirms: "aux_data in a pytree flattening must contain static elements... Neither np.ndarray nor jax.Array satisfy this, so they should not be included in aux_data."

Citations:


🏁 Script executed:

#!/bin/bash
# Find definitions of private_field and static_private_field
rg "def (private_field|static_private_field)" -A 5

Repository: ymahlau/drinx

Length of output: 668


🏁 Script executed:

#!/bin/bash
# Search for uses of jax_static metadata in the fields
rg "jax_static" -B 2 -A 2 | head -100

Repository: ymahlau/drinx

Length of output: 4940


🏁 Script executed:

#!/bin/bash
# Check the structure of fields and metadata in base.py
fd "base.py" | xargs cat -n | head -150

Repository: ymahlau/drinx

Length of output: 6448


Split private fields by jax_static in pytree flatten/unflatten logic.

private_field() and static_private_field() are distinct APIs: the former defaults to dynamic (non-static) and the latter is always static. The current implementation collapses all init=False fields into pytree aux regardless of jax_static, placing non-static array-valued private fields into JAX's static treedef metadata where they cannot participate in transformations like jit, vmap, or tree_map.

Keep init=False and not jax_static fields as pytree leaves (via keyed_leaves), and only put init=False and jax_static values into aux. This aligns with the API design and JAX's requirement that aux_data contain only static, hashable metadata—not arrays or mutable data.

🧰 Tools
🪛 Ruff (0.15.10)

[warning] 54-54: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


[warning] 55-55: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


[warning] 61-61: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/drinx/transform.py` around lines 29 - 63, The flatten/unflatten currently
lumps all init=False fields into aux; change this so private_fields are split by
jax_static: compute private_dynamic = [f.name for f in all_fields if not f.init
and not f.metadata.get("jax_static")] and private_static = [f.name for f in
all_fields if not f.init and f.metadata.get("jax_static")]; then extend
dynamic_fields (used in keyed_leaves) with private_dynamic so non-static private
fields become leaves in flatten_with_keys, and only include static_private and
static_init_fields in aux (e.g., aux = (tuple(getattr(obj, f) for f in
static_init_fields), tuple(getattr(obj, f) for f in private_static))); finally
update unflatten to accept aux = (static_init_values, private_static_values),
build init_kwargs by zipping static_init_fields and dynamic_fields (which now
includes private_dynamic), construct obj = cls_(**init_kwargs), and then
object.__setattr__ only for names in private_static to restore static private
fields.

@ymahlau

ymahlau commented Apr 23, 2026

Copy link
Copy Markdown
Owner

resolved by #11

@ymahlau ymahlau closed this Apr 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants