Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,8 @@ position_id_per_seconds: 25
subslice_shape: ""

# NNX
enable_nnx: false
enable_nnx: True
pure_nnx_decoder: True

################################## Qwen3-Next Specific Configs ##################################
# Kernel size for the 1D convolution in the Gated Delta Net
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,7 @@ class HardwareAndMesh(BaseModel):
enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.")
optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.")
shardy: bool = Field(True, description="Whether to use shardy XLA backend.")
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")


class LayoutAndSharding(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,14 +533,14 @@ def __init__(
elif self.is_qwen3_next:
self.query_norm = Qwen3NextRMSNorm(
num_features=self.config.head_dim,
eps=self.config.normalization_layer_epsilon,
epsilon=self.config.normalization_layer_epsilon,
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
rngs=self.rngs,
)
self.key_norm = Qwen3NextRMSNorm(
num_features=self.config.head_dim,
eps=self.config.normalization_layer_epsilon,
epsilon=self.config.normalization_layer_epsilon,
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
rngs=self.rngs,
Expand Down
20 changes: 5 additions & 15 deletions src/maxtext/layers/multi_token_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import jax.numpy as jnp
from jax.sharding import Mesh
from maxtext.common.common_types import Config, MODEL_MODE_TRAIN
from maxtext.layers.nnx_decoders import NNXDecoderLayer
from maxtext.utils.globals import EPS
from maxtext.layers import nnx_wrappers
from maxtext.layers.decoders import DecoderLayer
from maxtext.layers.initializers import variable_to_logically_partitioned
from maxtext.layers.linears import DenseGeneral
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(
config: Config,
mesh: Mesh,
layer_number: int,
transformer_layer_module: Type[DecoderLayer],
transformer_layer_module: Type[NNXDecoderLayer],
*,
rngs: nnx.Rngs,
):
Expand Down Expand Up @@ -108,22 +108,12 @@ def __init__(
rngs=rngs,
)
# Use MODEL_MODE_TRAIN for initialization; runtime model_mode is passed dynamically.
mtp_transformer_layer = transformer_layer_module(
self.transformer_layer = transformer_layer_module(
config=cfg,
mesh=mesh,
model_mode=MODEL_MODE_TRAIN,
name=f"mtp_{k}_transformer_layer",
)
self.transformer_layer = nnx_wrappers.ToNNX(mtp_transformer_layer, rngs=rngs)

# ToNNX requires explicit initialization with sample inputs for proper parameter setup.
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=MODEL_MODE_TRAIN)
self.transformer_layer.lazy_init(
inputs=jnp.zeros((batch_size, seq_len, self.config.emb_dim), dtype=self.config.dtype),
decoder_segment_ids=None,
decoder_positions=jnp.zeros((batch_size, seq_len), dtype=jnp.int32),
deterministic=True,
model_mode=MODEL_MODE_TRAIN,
rngs=rngs,
)

@property
Expand Down Expand Up @@ -212,7 +202,7 @@ def __init__(
self,
config: Config,
mesh: Mesh,
transformer_layer_module: Type[DecoderLayer],
transformer_layer_module: Type[NNXDecoderLayer],
decoder: nnx.Module,
rngs: nnx.Rngs,
):
Expand Down
Loading
Loading