diff --git a/.gitignore b/.gitignore index e732069..2430ba6 100644 --- a/.gitignore +++ b/.gitignore @@ -4,5 +4,6 @@ outputs/ *.pyc *.npz docs/build +docs/jupyter_execute docs/source/api scripts_dev/ \ No newline at end of file diff --git a/docs/requirements.txt b/docs/requirements.txt index ce79f43..d6abe36 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -15,4 +15,6 @@ sphinx-autodoc-typehints>=3.2.0 myst_nb>=1.3.0 sphinx-copybutton>=0.5.2 -jax>=0.9.0 \ No newline at end of file +jax>=0.9.0 +jupyterlab +ipywidgets \ No newline at end of file diff --git a/docs/source/examples/basic_usage.ipynb b/docs/source/examples/basic_usage.ipynb new file mode 100644 index 0000000..cba7a66 --- /dev/null +++ b/docs/source/examples/basic_usage.ipynb @@ -0,0 +1,509 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Basic Usage\n", + "\n", + "**Drinx** (Dataclass Registry in JAX) makes Python dataclasses work as JAX pytree nodes, so they pass through `jit`, `grad`, `vmap`, and other JAX transforms seamlessly.\n", + "\n", + "Two usage styles are available:\n", + "- **Decorator**: `@drinx.dataclass`\n", + "- **Inheritance**: `class Foo(drinx.DataClass)`" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "import drinx" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Decorator style\n", + "\n", + "`@drinx.dataclass` wraps `dataclasses.dataclass` and registers the class as a JAX pytree. All fields are dynamic (traced) by default." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Params(weights=Array([1., 1., 1.], dtype=float32), bias=Array([0., 0., 0.], dtype=float32))\n" + ] + } + ], + "source": [ + "@drinx.dataclass\n", + "class Params:\n", + " weights: jax.Array\n", + " bias: jax.Array\n", + "\n", + "params = Params(weights=jnp.ones((3,)), bias=jnp.zeros((3,)))\n", + "print(params)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Params(weights=Array([2., 2., 2.], dtype=float32), bias=Array([0., 0., 0.], dtype=float32))\n", + "\n" + ] + } + ], + "source": [ + "# jax.tree_util.tree_map works out of the box\n", + "doubled = jax.tree_util.tree_map(lambda x: x * 2, params)\n", + "print(doubled)\n", + "print(type(doubled)) # still a Params instance" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Inheritance style\n", + "\n", + "Subclassing `drinx.DataClass` applies the transform automatically — no decorator needed." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model(weights=Array([[1., 1., 1., 1.],\n", + " [1., 1., 1., 1.],\n", + " [1., 1., 1., 1.],\n", + " [1., 1., 1., 1.]], dtype=float32), bias=Array([0., 0., 0., 0.], dtype=float32))\n" + ] + } + ], + "source": [ + "class Model(drinx.DataClass):\n", + " weights: jax.Array\n", + " bias: jax.Array\n", + "\n", + "model = Model(weights=jnp.ones((4, 4)), bias=jnp.zeros((4,)))\n", + "print(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Static fields\n", + "\n", + "Mark a field as **static** to exclude it from JAX tracing. Static values are treated as compile-time constants by `jit` — changing a static field triggers recompilation." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(128,)\n" + ] + } + ], + "source": [ + "@drinx.dataclass\n", + "class LinearLayer:\n", + " weights: jax.Array\n", + " # hidden_size is a compile-time constant — not traced by JAX\n", + " hidden_size: int = drinx.static_field(default=128)\n", + "\n", + "layer = LinearLayer(weights=jnp.ones((128, 32)))\n", + "\n", + "@jax.jit\n", + "def forward(layer, x):\n", + " return layer.weights[:layer.hidden_size] @ x\n", + "\n", + "x = jnp.ones((32,))\n", + "result = forward(layer, x)\n", + "print(result.shape) # (128,)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> **Note:** Changing a static field causes `jit` to recompile for the new value." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(64,)\n" + ] + } + ], + "source": [ + "# Static field can vary per-instance at Python level\n", + "small_layer = LinearLayer(weights=jnp.ones((64, 32)), hidden_size=64)\n", + "result_small = forward(small_layer, x)\n", + "print(result_small.shape) # (64,) — recompiled for hidden_size=64" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. JAX transforms\n", + "\n", + "### 4a. `jax.grad`\n", + "\n", + "Gradients have the same pytree structure as the input." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "[2. 4. 6.]\n" + ] + } + ], + "source": [ + "class State(drinx.DataClass):\n", + " x: jax.Array\n", + " step_size: float = drinx.static_field(default=0.1)\n", + "\n", + "def loss(state):\n", + " return jnp.sum(state.x ** 2)\n", + "\n", + "state = State(x=jnp.array([1.0, 2.0, 3.0]))\n", + "grads = jax.grad(loss)(state)\n", + "\n", + "print(type(grads)) # State — same type as input\n", + "print(grads.x) # [2. 4. 6.] (gradient of sum(x^2) = 2x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4b. `jax.vmap`\n", + "\n", + "Batch over dynamic fields by stacking arrays along a new leading dimension." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[2. 4.]\n", + " [6. 8.]]\n" + ] + } + ], + "source": [ + "@jax.vmap\n", + "def scale(state):\n", + " return state.x * 2\n", + "\n", + "# Each row is one element of the batch\n", + "batched = State(x=jnp.array([[1.0, 2.0], [3.0, 4.0]]))\n", + "result = scale(batched)\n", + "print(result) # [[2. 4.] [6. 8.]]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4c. `jax.lax.scan`\n", + "\n", + "Dataclasses work as the `carry` in `jax.lax.scan`, enabling stateful loops without Python-level iteration." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "history: [0.9 0.79999995 0.6999999 0.5999999 0.4999999 ]\n", + "final x: 0.4999999\n" + ] + } + ], + "source": [ + "class ScanState(drinx.DataClass):\n", + " x: jax.Array\n", + " step_size: float = drinx.static_field(default=0.1)\n", + "\n", + "def step(carry, _):\n", + " new_x = carry.x - carry.step_size # gradient descent step\n", + " return ScanState(x=new_x, step_size=carry.step_size), new_x\n", + "\n", + "init = ScanState(x=jnp.array(1.0))\n", + "final, history = jax.lax.scan(step, init, None, length=5)\n", + "\n", + "print(\"history:\", history) # [0.9, 0.8, 0.7, 0.6, 0.5]\n", + "print(\"final x:\", final.x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Nested dataclasses\n", + "\n", + "Drinx dataclasses compose naturally — nest them to represent hierarchical model parameters." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2. 3. 4.]\n" + ] + } + ], + "source": [ + "class Inner(drinx.DataClass):\n", + " w: jax.Array\n", + "\n", + "class Outer(drinx.DataClass):\n", + " inner: Inner\n", + " bias: jax.Array\n", + "\n", + "@jax.jit\n", + "def apply(outer, x):\n", + " return outer.inner.w @ x + outer.bias\n", + "\n", + "outer = Outer(inner=Inner(w=jnp.eye(3)), bias=jnp.ones((3,)))\n", + "x = jnp.array([1.0, 2.0, 3.0])\n", + "print(apply(outer, x)) # [2. 3. 4.]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Functional updates with `aset`\n", + "\n", + "Since drinx dataclasses are frozen, use `.aset(path, value)` to get an updated copy. Nested fields use `->` as a separator. Note that this function is only available when using the inheritance style to create the dataclass." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "original bias: [0. 0. 0.]\n", + "updated bias: [1. 1. 1.]\n", + "nested update: [99. 99. 99.]\n" + ] + } + ], + "source": [ + "outer = Outer(inner=Inner(w=jnp.ones((3,))), bias=jnp.zeros((3,)))\n", + "\n", + "# Update a top-level field\n", + "outer2 = outer.aset(\"bias\", jnp.ones((3,)))\n", + "print(\"original bias:\", outer.bias) # [0. 0. 0.]\n", + "print(\"updated bias: \", outer2.bias) # [1. 1. 1.]\n", + "\n", + "# Update a nested field using '->' path syntax\n", + "outer3 = outer.aset(\"inner->w\", jnp.full((3,), 99.0))\n", + "print(\"nested update:\", outer3.inner.w) # [99. 99. 99.]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Fluent updates with `.at[].set()`\n", + "\n", + "An alternative to `aset` is the `.at[key].set(value)` API, which mirrors JAX array indexing and supports chaining." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Simple(a=Array(99., dtype=float32, weak_type=True), b=Array(2., dtype=float32, weak_type=True))\n", + "Simple(a=Array(10., dtype=float32, weak_type=True), b=Array(20., dtype=float32, weak_type=True))\n" + ] + } + ], + "source": [ + "class Simple(drinx.DataClass):\n", + " a: jax.Array\n", + " b: jax.Array\n", + "\n", + "tree = Simple(a=jnp.array(1.0), b=jnp.array(2.0))\n", + "\n", + "# Single field update\n", + "result = tree.at[\"a\"].set(jnp.array(99.0))\n", + "print(result) # Simple(a=99.0, b=2.0)\n", + "\n", + "# Chained updates\n", + "result2 = tree.at[\"a\"].set(jnp.array(10.0)).at[\"b\"].set(jnp.array(20.0))\n", + "print(result2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Mask-based update: zero out all elements greater than 5\n", + "tree = Simple(a=jnp.array([1.0, 6.0, 3.0]), b=jnp.array([7.0, 2.0, 8.0]))\n", + "mask = jax.tree.map(lambda x: x > 5, tree)\n", + "result = tree.at[mask].set(0.0)\n", + "print(result.a) # [1. 0. 3.]\n", + "print(result.b) # [0. 2. 0.]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. Combining `jit` + `grad` for a training step\n", + "\n", + "A minimal example of a JAX training loop using drinx to hold model parameters." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Final loss: 6.751089676981792e-05\n", + "Learned w: [ 1.9897425 -1.0014862 0.49940124 1.4957309 ]\n", + "True w: [ 2. -1. 0.5 1.5]\n" + ] + } + ], + "source": [ + "class LinearModel(drinx.DataClass):\n", + " w: jax.Array # weight matrix\n", + " b: jax.Array # bias vector\n", + " lr: float = drinx.static_field(default=0.01)\n", + "\n", + "def mse_loss(model, x, y):\n", + " pred = x @ model.w + model.b\n", + " return jnp.mean((pred - y) ** 2)\n", + "\n", + "@jax.jit\n", + "def train_step(model, x, y):\n", + " loss, grads = jax.value_and_grad(mse_loss)(model, x, y)\n", + " # Gradient descent: subtract lr * grad for each dynamic field\n", + " new_model = jax.tree_util.tree_map(lambda p, g: p - model.lr * g, model, grads)\n", + " return new_model, loss\n", + "\n", + "# Toy dataset: y = 2x\n", + "key = jax.random.PRNGKey(0)\n", + "x_data = jax.random.normal(key, (32, 4))\n", + "y_data = x_data @ jnp.array([2.0, -1.0, 0.5, 1.5]) + 0.1\n", + "\n", + "model = LinearModel(w=jnp.zeros((4,)), b=jnp.zeros(()))\n", + "\n", + "for i in range(500):\n", + " model, loss = train_step(model, x_data, y_data)\n", + "\n", + "print(\"Final loss:\", float(loss))\n", + "print(\"Learned w: \", model.w)\n", + "print(\"True w: \", jnp.array([2.0, -1.0, 0.5, 1.5]))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "drinx", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/source/index.rst b/docs/source/index.rst index 5a985a1..c2537bb 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -157,4 +157,5 @@ Drinx dataclasses work with all JAX transforms out of the box: :hidden: self + examples/basic_usage.ipynb api \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index cfa4bdf..79d8e76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,8 @@ dev = [ "sphinx-autodoc-typehints>=3.2.0", "myst_nb>=1.3.0", "sphinx-copybutton>=0.5.2", + "jupyterlab", + "ipywidgets", ] [build-system] @@ -41,3 +43,6 @@ allow-direct-references = true [tool.ruff] fix = true exclude = ["docs", "scripts_dev"] + +[tool.ty.src] +exclude = ["docs/source/examples"]