- Changes
- Upgrade Jax from 0.4.37 to 0.4.38.
- Removes all QRM (queued resource manager) codepaths from
axlearn.cloud.gcp. - Introduces
named_runner_configs. Seeaxlearn gcp launch --helpfor details. - Upgrade Grain from 0.2.3 to 0.2.7. This removes
input_grain.trim_and_pack_dataset.
- Changes
- Upgrade Jax from 0.4.33 to 0.4.37.
- Changes
- Upgrade Jax from 0.4.33 to 0.4.34.
- Updates the
input_base.InputAPI to support configuring input partitioning behavior. - The config fields
batch_axis_namesandseq_axis_namesincausal_lm.Modelare now deprecated. Please useinput_base.Input.input_partitionerinstead. - Updates the
causal_lm.ModelAPI to support configuring metrics without subclassing. This requires a golden config change.
- Changes
- Upgrade Jax from 0.4.30 to 0.4.33.
- Changes
- Upgrade Python to 3.10
- Fall back to triton backend for qkv in fp32 or with bias on gpu flash attention.
- Changes
- Upgrade Jax from 0.4.28 to 0.4.30.
- Changes
- Add changelog.