Skip to content

[Draft] Engram integration#3125

Draft
RissyRan wants to merge 6 commits intomainfrom
engram_integration
Draft

[Draft] Engram integration#3125
RissyRan wants to merge 6 commits intomainfrom
engram_integration

Conversation

@RissyRan
Copy link
Collaborator

@RissyRan RissyRan commented Feb 12, 2026

==== Try 3: sum(vocab_sizes) trouble commit ====

Trouble with num_embeddings=sum(vocab_sizes) in MultiHeadEmbedding module.

  File "/home/ranran_google_com/maxtext/src/MaxText/layers/engram.py", line 363, in __init__
    self.embedding = Embed(
                     ^^^^^^
  File "/home/ranran_google_com/venv-maxtext/lib/python3.12/site-packages/flax/nnx/pytreelib.py", line 400, in __call__
    return _graph_node_meta_call(cls, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ranran_google_com/venv-maxtext/lib/python3.12/site-packages/flax/nnx/pytreelib.py", line 412, in _graph_node_meta_call
    cls._pytree_meta_construct(node, *args, **kwargs)
  File "/home/ranran_google_com/venv-maxtext/lib/python3.12/site-packages/flax/nnx/pytreelib.py", line 403, in _pytree_meta_construct
    self.__init__(*args, **kwargs)
  File "/home/ranran_google_com/maxtext/src/MaxText/layers/embeddings.py", line 130, in __init__
    embedding_init(
  File "/home/ranran_google_com/venv-maxtext/lib/python3.12/site-packages/jax/_src/nn/initializers.py", line 337, in init
    shape = core.canonicalize_shape(shape)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (JitTracer<~int32[]>, 512).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

==== Try 2: Not working - pass ngram_layer_map to deepseek layer commit ====

I moved generate_engram_map from data_loader to decoders.py, and pass layer_id and ngram_layer_map to the deepseek.py decoder layer. However, I am not able to pass the layer_id directly, and met error bellow.

  File "/home/ranran_google_com/maxtext/src/MaxText/layers/deepseek.py", line 369, in __call__
    engram_output = self.engram_op(x, ngram_layer_map, layer_id)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ranran_google_com/maxtext/src/MaxText/layers/deepseek.py", line 328, in engram_op
    layer_id = core.concrete_or_error(
               ^^^^^^^^^^^^^^^^^^^^^^^
TypeError: 'DynamicJaxprTracer' object is not callable

Noticed for other models like llama4 and gpt-oss, we passed layer_id to a function, and get the attention_type. It seems in JIT, it's tricky to pass this index of layer back to decoder layer.

==== Try 1 - Not working - Integrate Engram with DeepSeek custom model commit ====

I noticed an issue when putting NgramHashMapping into data_loader.py, and it seems not easily to initialize the self.engram = engram.Engram inside of deepseek.py file.

The tricky part is in the current implementation, the Engram needs engram_vocab_sizes to initialize inside of DeepSeekGenericLayer NNX module, which is data dependent based on each data batch here.

      engram_vocab_sizes = ngram_map[layer_id]["vocab_sizes"]
      self.engram_input_ids = ngram_map[layer_id]["input_ids"]
      self.engram = engram.Engram(
        config=self.config,
        mesh=mesh,
        vocab_sizes=engram_vocab_sizes,
        engram_num_heads=self.config.engram_num_heads,
        engram_head_dim=self.config.engram_head_dim,
        engram_max_ngram_size=self.config.engram_max_ngram_size,
        engram_kernel_size=self.config.engram_kernel_size,
        mhc_expansion_rate=self.config.mhc_expansion_rate,
        rngs=rngs,
      )

I think all dynamical data inputs should be passed via call method, instead of init method. If so, we cannot easily initialize the Engram module in this way.

Alternatively, we could only put NgramHashMapping in models.py or decoders.py.

@codecov
Copy link

codecov bot commented Feb 12, 2026

@RissyRan RissyRan force-pushed the engram_integration branch 3 times, most recently from 0bfd32f to 6ae6275 Compare February 13, 2026 00:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant

Comments