diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index d82bc065ca..24debf12cf 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -35,6 +35,7 @@ from MaxText.layers import normalizations from MaxText.layers import quantizations from MaxText.layers import pipeline +from MaxText.layers import mhc from MaxText import sharding from MaxText.layers.attentions import attention_as_linen from MaxText.layers.normalizations import rms_norm @@ -731,6 +732,11 @@ def __call__( audio_masks, ) + mhc_expand, mhc_reduce = mhc.get_functions(cfg.mhc_expansion_rate) + if cfg.mhc_expansion_rate > 1: + # (batch, length, emb_dim) --> (batch, length, mhc_expansion_rate, emb_dim) + y = mhc_expand(y) + policy = self.get_remat_policy() RemattedBlockLayers = self.set_remat_policy(self.decoder_layer, policy) # scan does not support kwargs in layer call, passing broadcast_args as positional arg @@ -927,7 +933,11 @@ def __call__( assert isinstance(y, jax.Array) # After the final transformer layer, `y` holds the raw, un-normalized hidden state. - hidden_state = y + if cfg.mhc_expansion_rate > 1: + # (batch, length, mhc_expansion_rate, emb_dim) --> (batch, length, emb_dim) + hidden_state = mhc_reduce(y) + else: + hidden_state = y # When initializing with vLLM RPA attention, we need to run the output head to # initialize any parameters associated with it. diff --git a/src/MaxText/layers/deepseek.py b/src/MaxText/layers/deepseek.py index cb473e445e..9aa4c08331 100644 --- a/src/MaxText/layers/deepseek.py +++ b/src/MaxText/layers/deepseek.py @@ -23,7 +23,7 @@ import jax.numpy as jnp from jax.sharding import Mesh from MaxText.common_types import Config -from MaxText.common_types import MODEL_MODE_PREFILL +from MaxText.common_types import MODEL_MODE_PREFILL, HyperConnectionType from MaxText.layers import attention_mla from MaxText.layers import deepseek_batchsplit from MaxText.layers import initializers @@ -31,6 +31,7 @@ from MaxText.layers import moe from MaxText.layers import nnx_wrappers from MaxText.layers import quantizations +from MaxText.layers import mhc from MaxText.layers.linears import Dropout from MaxText.layers.normalizations import RMSNorm from MaxText.sharding import create_sharding @@ -64,6 +65,7 @@ def __init__( self.mesh = mesh self.quant = quant self.rngs = rngs + self.is_mhc_enabled = config.mhc_expansion_rate > 1 batch_size, sequence_length = max_utils.get_batch_seq_len_for_mode(self.config, self.model_mode) self.dummy_inputs_shape = (batch_size, sequence_length, self.config.emb_dim) @@ -122,6 +124,9 @@ def __init__( ) self.dropout = Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) + if self.is_mhc_enabled: + self.mhc_attention = mhc.ManifoldConstrainedHyperConnections(self.config, self.config.emb_dim, self.mesh, self.rngs) + self.mhc_mlp = mhc.ManifoldConstrainedHyperConnections(self.config, self.config.emb_dim, self.mesh, self.rngs) def mlp_op(self, x, deterministic, *args, **kwargs): """Executes the MLP operation. To be implemented by subclasses.""" @@ -172,31 +177,17 @@ def attention_op( @property def logical_axis_names(self): - if self.model_mode == MODEL_MODE_PREFILL: - return ( - "activation_batch", - "prefill_activation_norm_length", - "activation_embed", - ) - return ( - "activation_batch", - "activation_norm_length", - "activation_embed", - ) + """Generate logical names for activations generally.""" + length_name = "prefill_activation_norm_length" if self.model_mode == MODEL_MODE_PREFILL else "activation_norm_length" + axis_names = ["activation_batch", length_name, "activation_embed"] + return axis_names @property def mlp_logical_axis_names(self): - if self.model_mode == MODEL_MODE_PREFILL: - return ( - "activation_batch", - "prefill_activation_norm_length", - "activation_mlp", - ) - return ( - "activation_batch", - "activation_norm_length", - "activation_mlp", - ) + """Generate logical names for activations in MLP.""" + length_name = "prefill_activation_norm_length" if self.model_mode == MODEL_MODE_PREFILL else "activation_norm_length" + axis_names = ["activation_batch", length_name, "activation_mlp"] + return axis_names def post_process(self, layer_output, load_balance_loss, moe_bias_updates, kv_cache=None): """postprocessing.""" @@ -231,18 +222,33 @@ def self_attention_with_norm_op( slot: None | int = None, ): """self-attention with normalization""" - lnx = self.pre_attention_norm_op(inputs) - - attention_lnx = self.attention_op( - lnx, - decoder_segment_ids, - decoder_positions, - deterministic, - previous_chunk, - page_state, - slot, - ) - intermediate_inputs = inputs + attention_lnx + if self.is_mhc_enabled: + intermediate_inputs, _ = self.mhc_attention( + self.pre_attention_norm_op, + self.self_attention, + x=inputs, + mhc_type=HyperConnectionType.ATTENTION, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=deterministic, + model_mode=self.model_mode, + out_sharding=self.out_sharding, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + ) + else: + lnx = self.pre_attention_norm_op(inputs) + attention_lnx = self.attention_op( + lnx, + decoder_segment_ids, + decoder_positions, + deterministic, + previous_chunk, + page_state, + slot, + ) + intermediate_inputs = inputs + attention_lnx # Normalization hidden_states = self.post_attention_norm_op(intermediate_inputs) return hidden_states, intermediate_inputs @@ -308,9 +314,17 @@ def __call__( slot, ) - mlp_lnx = self.mlp_op(hidden_states, deterministic) - - layer_output = mlp_lnx + intermediate_inputs + if self.is_mhc_enabled: + layer_output, _ = self.mhc_mlp( + self.post_attention_norm_op, + self.mlp, + x=intermediate_inputs, + mhc_type=HyperConnectionType.MLP_DENSE, + deterministic=deterministic, + ) + else: + mlp_lnx = self.mlp_op(hidden_states, deterministic) + layer_output = mlp_lnx + intermediate_inputs layer_output = self.dropout_op(layer_output, deterministic=deterministic) return self.post_process(layer_output, None, None, kv_cache) @@ -394,9 +408,18 @@ def __call__( slot, ) - mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp_op(hidden_states, deterministic) - - layer_output = mlp_lnx + intermediate_inputs + if self.is_mhc_enabled: + layer_output, metadata = self.mhc_mlp( + self.post_attention_norm_op, + self.DeepSeekMoeBlock_0, + x=intermediate_inputs, + mhc_type=HyperConnectionType.MLP_MOE, + ) + load_balance_loss = metadata["load_balance_loss"] + moe_bias_updates = metadata["moe_bias_updates"] + else: + mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp_op(hidden_states, deterministic) + layer_output = mlp_lnx + intermediate_inputs layer_output = self.dropout_op(layer_output, deterministic=deterministic) return self.post_process(layer_output, load_balance_loss, moe_bias_updates, kv_cache) diff --git a/src/MaxText/layers/mhc.py b/src/MaxText/layers/mhc.py index f1a2da1c8c..30c584db0d 100644 --- a/src/MaxText/layers/mhc.py +++ b/src/MaxText/layers/mhc.py @@ -34,11 +34,11 @@ def get_functions(expansion_rate: int): def expand(x: Array): # (batch, length, dim) -> (batch, length, streams, dim) - return jnp.repeat(jnp.expand_dims(x, axis=2), expansion_rate, axis=2) + return jnp.repeat(jnp.expand_dims(x, axis=2), expansion_rate, axis=2).astype(x.dtype) def reduce(x: Array): # (batch, length, streams, dim) -> (batch, length, dim) - return jnp.sum(x, axis=2) + return jnp.sum(x, axis=2, dtype=x.dtype) return expand, reduce @@ -93,7 +93,9 @@ def __init__( self.dim = dim self.rngs = rngs self.mesh = mesh + self.dtype = self.config.dtype self.weight_dtype = self.config.weight_dtype + self.matmul_precision = jax.lax.Precision(self.config.matmul_precision) # Norm layer self.mhc_norm = RMSNorm( @@ -162,33 +164,42 @@ def __init__( ) self.pre_beta = nnx.Param( default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype), - sharding=(None, None), + sharding=(None,), ) self.post_beta = nnx.Param( default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype), - sharding=(None, None), + sharding=(None,), ) def res_mapping(self, x: Array): """Helper function for residual mapping.""" + # In MaxText, we match weight precision to activations before Matmul + res_alpha = jnp.asarray(self.res_alpha[...], self.dtype) + res_beta = jnp.asarray(self.res_beta[...], self.dtype) + res_alpha_scale = jnp.asarray(self.res_alpha_scale[...], self.dtype) # Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k) - h_res = jnp.einsum("bsm,mn -> bsn", x, self.res_alpha[...], precision=self.config.matmul_precision) + h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision) b, s, _ = h_res.shape h_res = jnp.reshape(h_res, (b, s, self.k, self.k)) - intermediate = self.res_alpha_scale * h_res + self.res_beta[...][None, None, :, :] + intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :] output = sinkhorn(intermediate, self.sinkhorn_iterations) return output def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int): """Helper function for both pre and post mappings.""" + # In MaxText, we match weight precision to activations before Matmul + alpha = jnp.asarray(alpha, self.dtype) + beta = jnp.asarray(beta, self.dtype) + alpha_scale = jnp.asarray(alpha_scale, self.dtype) # Apply projection: (b, s, k*d) @ (k*d, k) -> (b, s, k) - h = jnp.einsum("bsm,mk -> bsk", x, alpha, precision=self.config.matmul_precision) + h = jnp.einsum("bsm,mk -> bsk", x, alpha, precision=self.matmul_precision) intermediate = alpha_scale * h + beta[None, None, :] output = scale * jax.nn.sigmoid(intermediate) return output def __call__( self, + norm_fn: Callable, branch_fn: Callable, x: Array, mhc_type: HyperConnectionType, @@ -197,6 +208,7 @@ def __call__( """Applying manifold-constrained hyper connection based on callable function. Args: + norm_fn: The pre-normalization function to be applied. branch_fn: The function to be wrapped by the hyper-connection. x: Input tensor of shape `(batch..., dim)`. mhc_type: The variant of the connection to apply. @@ -212,24 +224,30 @@ def __call__( norm_x = self.mhc_norm(jnp.reshape(x, (b, s, k * d))) # 2. Pre mapping - pre_mapping = self.mapping(norm_x, self.pre_alpha_scale, self.pre_alpha[...], self.pre_beta[...], 1.0) - layer_input = jnp.einsum("bskd,bsk -> bsd", x, pre_mapping, precision=self.config.matmul_precision) + pre_mapping = self.mapping(norm_x, self.pre_alpha_scale[...], self.pre_alpha[...], self.pre_beta[...], 1.0) + layer_input = jnp.einsum("bskd,bsk -> bsd", x, pre_mapping, precision=self.matmul_precision) + + # 3. Pre-norm + layer_input = norm_fn(layer_input) - # 3. Attention or MLP + # 4. Attention or MLP + metadata = {} if mhc_type == HyperConnectionType.ATTENTION: layer_out, _ = branch_fn(inputs_q=layer_input, inputs_kv=layer_input, **kwargs) elif mhc_type == HyperConnectionType.MLP_DENSE: layer_out = branch_fn(inputs=layer_input, **kwargs) elif mhc_type == HyperConnectionType.MLP_MOE: - layer_out, _, _ = branch_fn(inputs=layer_input, **kwargs) + layer_out, load_balance_loss, moe_bias_updates = branch_fn(inputs=layer_input, **kwargs) + metadata["load_balance_loss"] = load_balance_loss + metadata["moe_bias_updates"] = moe_bias_updates else: raise ValueError(f"Unsupported type: {mhc_type}") - # 4. Post mapping - post_mapping = self.mapping(norm_x, self.post_alpha_scale, self.post_alpha[...], self.post_beta[...], 2.0) - post_out = jnp.einsum("bsd,bsk -> bskd", layer_out, post_mapping, precision=self.config.matmul_precision) + # 5. Post mapping + post_mapping = self.mapping(norm_x, self.post_alpha_scale[...], self.post_alpha[...], self.post_beta[...], 2.0) + post_out = jnp.einsum("bsd,bsk -> bskd", layer_out, post_mapping, precision=self.matmul_precision) - # 5. Residual mapping, res_out shape as [batch, seq, expansion_rate, emb] + # 6. Residual mapping, res_out shape as [batch, seq, expansion_rate, emb] res_mapping = self.res_mapping(norm_x) - res_out = jnp.einsum("bskd,bskm -> bsmd", x, res_mapping, precision=self.config.matmul_precision) - return res_out + post_out + res_out = jnp.einsum("bskd,bskm -> bsmd", x, res_mapping, precision=self.matmul_precision) + return res_out + post_out, metadata diff --git a/src/MaxText/train.py b/src/MaxText/train.py index c66472e685..29895d9178 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -197,8 +197,23 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): # get MoE load balance loss moe_lb_loss = 0.0 if config.num_experts > 1: - nested_key = ("intermediates", "decoder", "layers", "moe_lb_loss") - total_moe_lb_loss = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, 0.0) + # Note: the key is affected by the model implementation + possible_keys = [ + ("intermediates", "decoder", "layers", "moe_lb_loss"), + ("intermediates", "decoder", "moe_layers", "moe_lb_loss"), + ] + + total_moe_lb_loss = 0.0 + found_loss = False + for nested_key in possible_keys: + total_moe_lb_loss = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, 0.0) + if total_moe_lb_loss != 0.0: + found_loss = True + break + + if not found_loss: + max_logging.debug("\nNo MoE load balance loss found. Defaulting to 0.0.") + moe_lb_loss = jnp.mean(jnp.array(total_moe_lb_loss)) loss += moe_lb_loss diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index b152f6d081..1ec58ef5a7 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -1087,6 +1087,6 @@ force_q_layout: false ################################## DeepSeek Manifold-Constrained Hyper Connections (mHC) ################################## # The number of parallel streams in Hyper Connection. -mhc_expansion_rate: 0 +mhc_expansion_rate: 1 # The number of iterations for the Sinkhorn-Knopp algorithm. sinkhorn_iterations: 20 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 6457b49041..0c28192925 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -248,7 +248,7 @@ class ProfilerType(str, Enum): "llama4-17b-16e", "llama4-17b-128e", "olmo3-7b", - 'olmo3-7b-pt', + "olmo3-7b-pt", "olmo3-32b", ] @@ -1082,7 +1082,7 @@ class TrainingLoop(BaseModel): class ManifoldConstrainedHyperConnections(BaseModel): """Configuration for DeepSeek Manifold-Constrained Hyper Connections (mHC).""" - mhc_expansion_rate: int = Field(0, description="The number of parallel streams in Hyper Connection.") + mhc_expansion_rate: PositiveInt = Field(1, description="The number of parallel streams in Hyper Connection.") sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.") diff --git a/tests/unit/mhc_test.py b/tests/unit/mhc_test.py index f1ada24dfa..ac21e249ca 100644 --- a/tests/unit/mhc_test.py +++ b/tests/unit/mhc_test.py @@ -30,6 +30,7 @@ from MaxText.globals import MAXTEXT_PKG_DIR from MaxText.layers import attention_mla, linears, mhc, moe from MaxText.layers.initializers import nd_dense_init +from MaxText.layers.normalizations import RMSNorm from maxtext.utils import maxtext_utils @@ -104,6 +105,8 @@ def setUp(self): num_experts=4, num_experts_per_tok=2, attention="dot_product", + routed_bias_update_rate=0.01, + load_balance_loss_weight=0.02, ) devices_array = maxtext_utils.create_device_mesh(self.config) self.mesh = Mesh(devices_array, self.config.mesh_axes) @@ -119,6 +122,15 @@ def setUp(self): ), ) + self.pre_norm = RMSNorm( + num_features=self.dim, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + kernel_axes=("norm",), + epsilon=self.config.normalization_layer_epsilon, + rngs=self.rngs, + ) + # Skip GPU due to NotImplementedError: dynamic grid bounds not supported in the Triton backend @pytest.mark.tpu_only def test_moe_layer_output_shape(self): @@ -138,7 +150,11 @@ def test_moe_layer_output_shape(self): ) b, s, k, d = self.x.shape - output = module(layer, x=self.x, mhc_type=HyperConnectionType.MLP_MOE) + output, metadata = module(self.pre_norm, layer, x=self.x, mhc_type=HyperConnectionType.MLP_MOE) + # metadata includes load_balance_loss & moe_bias_updates + self.assertEqual(len(metadata), 2) + for key, value in metadata.items(): + self.assertIsNotNone(value, f"Key '{key}' has a value of None") self.assertEqual(output.shape, (b, s, k, d)) def test_dense_layer_output_shape(self): @@ -158,7 +174,8 @@ def test_dense_layer_output_shape(self): ) b, s, k, d = self.x.shape - output = module(layer, x=self.x, mhc_type=HyperConnectionType.MLP_DENSE) + output, metadata = module(self.pre_norm, layer, x=self.x, mhc_type=HyperConnectionType.MLP_DENSE) + self.assertDictEqual(metadata, {}) self.assertEqual(output.shape, (b, s, k, d)) def test_attention_layer_output_shape(self): @@ -196,7 +213,8 @@ def test_attention_layer_output_shape(self): ) b, s, k, d = self.x.shape - output = module(layer, x=self.x, mhc_type=HyperConnectionType.ATTENTION) + output, metadata = module(self.pre_norm, layer, x=self.x, mhc_type=HyperConnectionType.ATTENTION) + self.assertDictEqual(metadata, {}) self.assertEqual(output.shape, (b, s, k, d)) diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index e6c44533d6..58d23e9427 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -794,3 +794,24 @@ def test_olmo3_7b(self): "max_target_length=1024", ) ) + + @pytest.mark.cpu_only + def test_mhc_integration(self): + """AOT test for Manifold-onstrained Hyper Connection implementation""" + compiled_trainstep_file = "/tmp/test_mhc_integration" + train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-8", + "compile_topology_num_slices=1", + "model_name=deepseek-custom", + "per_device_batch_size=4", + "scan_layers=True", + "max_target_length=1024", + "mhc_expansion_rate=4", + "attention=flash", + "use_tokamax_splash=True", + ) + )