Skip to content

Temp testing automation, please ignore#145

Open
RerRayne wants to merge 55 commits intotest_638428288from
master
Open

Temp testing automation, please ignore#145
RerRayne wants to merge 55 commits intotest_638428288from
master

Conversation

@RerRayne
Copy link
Collaborator

@RerRayne RerRayne commented Jul 5, 2024

No description provided.

CLRSDev and others added 20 commits May 29, 2024 17:10
…tion_matrix_to_predecessor_pointers` methods to probing.

PiperOrigin-RevId: 638460408
…s specifications, and add a `track_max_steps` flag to the sampler. This flag will enable or disable length tracking for padding.

PiperOrigin-RevId: 638815148
… methods to convert CLRS sample into text.

PiperOrigin-RevId: 638980228
…sv with accuracy metric for Gemma 2B, Gemma 2B + RPE and Gemini Flash.

PiperOrigin-RevId: 639058968
…y handle Nones.

Due to a bug in JAX, JAX previously permitted `jax.tree.map(f, None, x)` where `x` is not `None`, effectively treating `None` as if it were pytree-prefix of any value. But `None` is a pytree container, and it is only a prefix of `None` itself.

Fix user code that was relying on this bug. Most commonly, the fix is to write
`jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None)`.

PiperOrigin-RevId: 641687779
PiperOrigin-RevId: 643364870
Co-authored-by: Olga Kozlova <grenlaykk@gmail.com>
PiperOrigin-RevId: 647327862
… instead of exactly equal sampling.

Added a use_hints field so data with and without hints can be identified per sample.
Adding huggingface generators for clrs text
PiperOrigin-RevId: 649677897
@google-cla
Copy link

google-cla bot commented Jul 5, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

RerRayne and others added 28 commits July 17, 2024 15:59
PiperOrigin-RevId: 653667715
…/setup-python@v4`.

Also, remove the check for consistency between the package version and release tag, as this is no longer necessary with `actions/checkout@v4`.

PiperOrigin-RevId: 653679763
PiperOrigin-RevId: 653690381
This eradicates the protobuf version conflict.

PiperOrigin-RevId: 671801866
PiperOrigin-RevId: 696161473
Introduce a truncation option in CLRS samplers, enabling users to specify the number of significant digits in float values for improved control.

PiperOrigin-RevId: 717477558
…o prevent generating samples that admit multiple valid paths.

PiperOrigin-RevId: 717531004
PiperOrigin-RevId: 720584481
The numpy cross product is vectorized, resulting in an extra dimension for the length 1 case.

PiperOrigin-RevId: 738834349
PiperOrigin-RevId: 753034846
The `train_lengths` variable was being modified in place for string matching algorithms, affecting subsequent algorithms in the loop. This change introduces `current_algo_train_lengths` to hold algorithm-specific length configurations, preventing unintended side effects. Additionally, unused return values from `make_multi_sampler` are no longer assigned.

Was reported in #169 by [@abhitopia
](https://github.com/abhitopia)

PiperOrigin-RevId: 819789548
This change prepares for the new `jax.pmap` by implementing the recommended mechanism for accessing the first shard in a sharded array. A common pattern used with `jax.pmap` is to shard an array that is semantically replicated and grabbing the first shard is meant to "unreplicate". However, JAX does not know that a sharded array is actually replicated, so we must now explicitly grab the first shard.

The change is under the `jax_pmap_shmap_merge` configuration flag. If `True`, the new `jax.pmap` implementation based on `jax.jit(jax.shard_map)` is used and requires the new explicit shard access. If `False`, the old `jax.pmap` implementation is used and there is a special case in how `x[0]` works.

Please see details here: https://docs.jax.dev/en/latest/migrate_pmap.html#int-array-indexing-into-sharded-arrays

PiperOrigin-RevId: 850021421
CL/853389459 added `optax.tree.cast_like` which uses `jax.tree.map`
internally. This fails when gradients contain None values because
JAX no longer considers None as a tree prefix of non-None values.

The fix replaces None gradients with actual gradient values before
calling optax, then masks the results back to None afterwards.

PUBLIC: Fix JAX tree mismatch error in filter_null_grads.
PiperOrigin-RevId: 854216449
Increase tolerance in `test_hint_loss` from `rtol=1e-4` to `rtol=1e-3`
to fix flaky test failures.

PiperOrigin-RevId: 864401872
…merge

With jax_pmap_shmap_merge=True, the code in nets.py's NetChunked.__call__
was calling addressable_shards on inputs and hints without checking if they
are JAX arrays first. When numpy arrays are passed (e.g., during init), this
causes AttributeError.

This fix adds a robust _get_first helper that:
1. Checks if input is a JAX Array before accessing addressable_shards
2. For numpy arrays, uses simple x[0] indexing
3. For JAX arrays, handles 0-d arrays, SingleDeviceSharding, and properly
   extracts data from replicated shards

PiperOrigin-RevId: 864890502
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.