Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from MaxText.layers import normalizations
from MaxText.layers import quantizations
from MaxText.layers import pipeline
from MaxText.layers import mhc
from MaxText import sharding
from MaxText.layers.attentions import attention_as_linen
from MaxText.layers.normalizations import rms_norm
Expand Down Expand Up @@ -731,6 +732,11 @@ def __call__(
audio_masks,
)

mhc_expand, mhc_reduce = mhc.get_functions(cfg.mhc_expansion_rate)
if cfg.mhc_expansion_rate > 1:
# (batch, length, emb_dim) --> (batch, length, mhc_expansion_rate, emb_dim)
y = mhc_expand(y)

policy = self.get_remat_policy()
RemattedBlockLayers = self.set_remat_policy(self.decoder_layer, policy)
# scan does not support kwargs in layer call, passing broadcast_args as positional arg
Expand Down Expand Up @@ -927,7 +933,11 @@ def __call__(
assert isinstance(y, jax.Array)

# After the final transformer layer, `y` holds the raw, un-normalized hidden state.
hidden_state = y
if cfg.mhc_expansion_rate > 1:
# (batch, length, mhc_expansion_rate, emb_dim) --> (batch, length, emb_dim)
hidden_state = mhc_reduce(y)
else:
hidden_state = y

# When initializing with vLLM RPA attention, we need to run the output head to
# initialize any parameters associated with it.
Expand Down
105 changes: 64 additions & 41 deletions src/MaxText/layers/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
import jax.numpy as jnp
from jax.sharding import Mesh
from MaxText.common_types import Config
from MaxText.common_types import MODEL_MODE_PREFILL
from MaxText.common_types import MODEL_MODE_PREFILL, HyperConnectionType
from MaxText.layers import attention_mla
from MaxText.layers import deepseek_batchsplit
from MaxText.layers import initializers
from MaxText.layers import linears
from MaxText.layers import moe
from MaxText.layers import nnx_wrappers
from MaxText.layers import quantizations
from MaxText.layers import mhc
from MaxText.layers.linears import Dropout
from MaxText.layers.normalizations import RMSNorm
from MaxText.sharding import create_sharding
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(
self.mesh = mesh
self.quant = quant
self.rngs = rngs
self.is_mhc_enabled = config.mhc_expansion_rate > 1

batch_size, sequence_length = max_utils.get_batch_seq_len_for_mode(self.config, self.model_mode)
self.dummy_inputs_shape = (batch_size, sequence_length, self.config.emb_dim)
Expand Down Expand Up @@ -122,6 +124,9 @@ def __init__(
)

self.dropout = Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs)
if self.is_mhc_enabled:
self.mhc_attention = mhc.ManifoldConstrainedHyperConnections(self.config, self.config.emb_dim, self.mesh, self.rngs)
self.mhc_mlp = mhc.ManifoldConstrainedHyperConnections(self.config, self.config.emb_dim, self.mesh, self.rngs)

def mlp_op(self, x, deterministic, *args, **kwargs):
"""Executes the MLP operation. To be implemented by subclasses."""
Expand Down Expand Up @@ -172,31 +177,17 @@ def attention_op(

@property
def logical_axis_names(self):
if self.model_mode == MODEL_MODE_PREFILL:
return (
"activation_batch",
"prefill_activation_norm_length",
"activation_embed",
)
return (
"activation_batch",
"activation_norm_length",
"activation_embed",
)
"""Generate logical names for activations generally."""
length_name = "prefill_activation_norm_length" if self.model_mode == MODEL_MODE_PREFILL else "activation_norm_length"
axis_names = ["activation_batch", length_name, "activation_embed"]
return axis_names

@property
def mlp_logical_axis_names(self):
if self.model_mode == MODEL_MODE_PREFILL:
return (
"activation_batch",
"prefill_activation_norm_length",
"activation_mlp",
)
return (
"activation_batch",
"activation_norm_length",
"activation_mlp",
)
"""Generate logical names for activations in MLP."""
length_name = "prefill_activation_norm_length" if self.model_mode == MODEL_MODE_PREFILL else "activation_norm_length"
axis_names = ["activation_batch", length_name, "activation_mlp"]
return axis_names

def post_process(self, layer_output, load_balance_loss, moe_bias_updates, kv_cache=None):
"""postprocessing."""
Expand Down Expand Up @@ -231,18 +222,33 @@ def self_attention_with_norm_op(
slot: None | int = None,
):
"""self-attention with normalization"""
lnx = self.pre_attention_norm_op(inputs)

attention_lnx = self.attention_op(
lnx,
decoder_segment_ids,
decoder_positions,
deterministic,
previous_chunk,
page_state,
slot,
)
intermediate_inputs = inputs + attention_lnx
if self.is_mhc_enabled:
intermediate_inputs, _ = self.mhc_attention(
self.pre_attention_norm_op,
self.self_attention,
x=inputs,
mhc_type=HyperConnectionType.ATTENTION,
decoder_segment_ids=decoder_segment_ids,
inputs_positions=decoder_positions,
deterministic=deterministic,
model_mode=self.model_mode,
out_sharding=self.out_sharding,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
)
else:
lnx = self.pre_attention_norm_op(inputs)
attention_lnx = self.attention_op(
lnx,
decoder_segment_ids,
decoder_positions,
deterministic,
previous_chunk,
page_state,
slot,
)
intermediate_inputs = inputs + attention_lnx
# Normalization
hidden_states = self.post_attention_norm_op(intermediate_inputs)
return hidden_states, intermediate_inputs
Expand Down Expand Up @@ -308,9 +314,17 @@ def __call__(
slot,
)

mlp_lnx = self.mlp_op(hidden_states, deterministic)

layer_output = mlp_lnx + intermediate_inputs
if self.is_mhc_enabled:
layer_output, _ = self.mhc_mlp(
self.post_attention_norm_op,
self.mlp,
x=intermediate_inputs,
mhc_type=HyperConnectionType.MLP_DENSE,
deterministic=deterministic,
)
else:
mlp_lnx = self.mlp_op(hidden_states, deterministic)
layer_output = mlp_lnx + intermediate_inputs
layer_output = self.dropout_op(layer_output, deterministic=deterministic)

return self.post_process(layer_output, None, None, kv_cache)
Expand Down Expand Up @@ -394,9 +408,18 @@ def __call__(
slot,
)

mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp_op(hidden_states, deterministic)

layer_output = mlp_lnx + intermediate_inputs
if self.is_mhc_enabled:
layer_output, metadata = self.mhc_mlp(
self.post_attention_norm_op,
self.DeepSeekMoeBlock_0,
x=intermediate_inputs,
mhc_type=HyperConnectionType.MLP_MOE,
)
load_balance_loss = metadata["load_balance_loss"]
moe_bias_updates = metadata["moe_bias_updates"]
else:
mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp_op(hidden_states, deterministic)
layer_output = mlp_lnx + intermediate_inputs
layer_output = self.dropout_op(layer_output, deterministic=deterministic)

return self.post_process(layer_output, load_balance_loss, moe_bias_updates, kv_cache)
Expand Down
52 changes: 35 additions & 17 deletions src/MaxText/layers/mhc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def get_functions(expansion_rate: int):

def expand(x: Array):
# (batch, length, dim) -> (batch, length, streams, dim)
return jnp.repeat(jnp.expand_dims(x, axis=2), expansion_rate, axis=2)
return jnp.repeat(jnp.expand_dims(x, axis=2), expansion_rate, axis=2).astype(x.dtype)

def reduce(x: Array):
# (batch, length, streams, dim) -> (batch, length, dim)
return jnp.sum(x, axis=2)
return jnp.sum(x, axis=2, dtype=x.dtype)

return expand, reduce

Expand Down Expand Up @@ -93,7 +93,9 @@ def __init__(
self.dim = dim
self.rngs = rngs
self.mesh = mesh
self.dtype = self.config.dtype
self.weight_dtype = self.config.weight_dtype
self.matmul_precision = jax.lax.Precision(self.config.matmul_precision)

# Norm layer
self.mhc_norm = RMSNorm(
Expand Down Expand Up @@ -162,33 +164,42 @@ def __init__(
)
self.pre_beta = nnx.Param(
default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype),
sharding=(None, None),
sharding=(None,),
)
self.post_beta = nnx.Param(
default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype),
sharding=(None, None),
sharding=(None,),
)

def res_mapping(self, x: Array):
"""Helper function for residual mapping."""
# In MaxText, we match weight precision to activations before Matmul
res_alpha = jnp.asarray(self.res_alpha[...], self.dtype)
res_beta = jnp.asarray(self.res_beta[...], self.dtype)
res_alpha_scale = jnp.asarray(self.res_alpha_scale[...], self.dtype)
# Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k)
h_res = jnp.einsum("bsm,mn -> bsn", x, self.res_alpha[...], precision=self.config.matmul_precision)
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
b, s, _ = h_res.shape
h_res = jnp.reshape(h_res, (b, s, self.k, self.k))
intermediate = self.res_alpha_scale * h_res + self.res_beta[...][None, None, :, :]
intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :]
output = sinkhorn(intermediate, self.sinkhorn_iterations)
return output

def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int):
"""Helper function for both pre and post mappings."""
# In MaxText, we match weight precision to activations before Matmul
alpha = jnp.asarray(alpha, self.dtype)
beta = jnp.asarray(beta, self.dtype)
alpha_scale = jnp.asarray(alpha_scale, self.dtype)
# Apply projection: (b, s, k*d) @ (k*d, k) -> (b, s, k)
h = jnp.einsum("bsm,mk -> bsk", x, alpha, precision=self.config.matmul_precision)
h = jnp.einsum("bsm,mk -> bsk", x, alpha, precision=self.matmul_precision)
intermediate = alpha_scale * h + beta[None, None, :]
output = scale * jax.nn.sigmoid(intermediate)
return output

def __call__(
self,
norm_fn: Callable,
branch_fn: Callable,
x: Array,
mhc_type: HyperConnectionType,
Expand All @@ -197,6 +208,7 @@ def __call__(
"""Applying manifold-constrained hyper connection based on callable function.

Args:
norm_fn: The pre-normalization function to be applied.
branch_fn: The function to be wrapped by the hyper-connection.
x: Input tensor of shape `(batch..., dim)`.
mhc_type: The variant of the connection to apply.
Expand All @@ -212,24 +224,30 @@ def __call__(
norm_x = self.mhc_norm(jnp.reshape(x, (b, s, k * d)))

# 2. Pre mapping
pre_mapping = self.mapping(norm_x, self.pre_alpha_scale, self.pre_alpha[...], self.pre_beta[...], 1.0)
layer_input = jnp.einsum("bskd,bsk -> bsd", x, pre_mapping, precision=self.config.matmul_precision)
pre_mapping = self.mapping(norm_x, self.pre_alpha_scale[...], self.pre_alpha[...], self.pre_beta[...], 1.0)
layer_input = jnp.einsum("bskd,bsk -> bsd", x, pre_mapping, precision=self.matmul_precision)

# 3. Pre-norm
layer_input = norm_fn(layer_input)

# 3. Attention or MLP
# 4. Attention or MLP
metadata = {}
if mhc_type == HyperConnectionType.ATTENTION:
layer_out, _ = branch_fn(inputs_q=layer_input, inputs_kv=layer_input, **kwargs)
elif mhc_type == HyperConnectionType.MLP_DENSE:
layer_out = branch_fn(inputs=layer_input, **kwargs)
elif mhc_type == HyperConnectionType.MLP_MOE:
layer_out, _, _ = branch_fn(inputs=layer_input, **kwargs)
layer_out, load_balance_loss, moe_bias_updates = branch_fn(inputs=layer_input, **kwargs)
metadata["load_balance_loss"] = load_balance_loss
metadata["moe_bias_updates"] = moe_bias_updates
else:
raise ValueError(f"Unsupported type: {mhc_type}")

# 4. Post mapping
post_mapping = self.mapping(norm_x, self.post_alpha_scale, self.post_alpha[...], self.post_beta[...], 2.0)
post_out = jnp.einsum("bsd,bsk -> bskd", layer_out, post_mapping, precision=self.config.matmul_precision)
# 5. Post mapping
post_mapping = self.mapping(norm_x, self.post_alpha_scale[...], self.post_alpha[...], self.post_beta[...], 2.0)
post_out = jnp.einsum("bsd,bsk -> bskd", layer_out, post_mapping, precision=self.matmul_precision)

# 5. Residual mapping, res_out shape as [batch, seq, expansion_rate, emb]
# 6. Residual mapping, res_out shape as [batch, seq, expansion_rate, emb]
res_mapping = self.res_mapping(norm_x)
res_out = jnp.einsum("bskd,bskm -> bsmd", x, res_mapping, precision=self.config.matmul_precision)
return res_out + post_out
res_out = jnp.einsum("bskd,bskm -> bsmd", x, res_mapping, precision=self.matmul_precision)
return res_out + post_out, metadata
19 changes: 17 additions & 2 deletions src/MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,23 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
# get MoE load balance loss
moe_lb_loss = 0.0
if config.num_experts > 1:
nested_key = ("intermediates", "decoder", "layers", "moe_lb_loss")
total_moe_lb_loss = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, 0.0)
# Note: the key is affected by the model implementation
possible_keys = [
("intermediates", "decoder", "layers", "moe_lb_loss"),
("intermediates", "decoder", "moe_layers", "moe_lb_loss"),
]

total_moe_lb_loss = 0.0
found_loss = False
for nested_key in possible_keys:
total_moe_lb_loss = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, 0.0)
if total_moe_lb_loss != 0.0:
found_loss = True
break

if not found_loss:
max_logging.debug("\nNo MoE load balance loss found. Defaulting to 0.0.")

moe_lb_loss = jnp.mean(jnp.array(total_moe_lb_loss))
loss += moe_lb_loss

Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,6 @@ force_q_layout: false

################################## DeepSeek Manifold-Constrained Hyper Connections (mHC) ##################################
# The number of parallel streams in Hyper Connection.
mhc_expansion_rate: 0
mhc_expansion_rate: 1
# The number of iterations for the Sinkhorn-Knopp algorithm.
sinkhorn_iterations: 20
4 changes: 2 additions & 2 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ class ProfilerType(str, Enum):
"llama4-17b-16e",
"llama4-17b-128e",
"olmo3-7b",
'olmo3-7b-pt',
"olmo3-7b-pt",
"olmo3-32b",
]

Expand Down Expand Up @@ -1082,7 +1082,7 @@ class TrainingLoop(BaseModel):
class ManifoldConstrainedHyperConnections(BaseModel):
"""Configuration for DeepSeek Manifold-Constrained Hyper Connections (mHC)."""

mhc_expansion_rate: int = Field(0, description="The number of parallel streams in Hyper Connection.")
mhc_expansion_rate: PositiveInt = Field(1, description="The number of parallel streams in Hyper Connection.")
sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.")


Expand Down
Loading
Loading