Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,6 +1376,12 @@ class Distillation(BaseModel):
"The other parameters will be frozen if this attribute is non empty)",
)

# --- Experimental features ----
blockwise_distill: bool = Field(
False,
description="Enables layer-wise parallel distillaion mode.",
)


class TrainingLoop(BaseModel):
"""Configuration for the main training loop, evaluation, and reproducibility."""
Expand Down
5 changes: 5 additions & 0 deletions src/maxtext/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,11 @@ def __call__(
else:
input_axis_names = self.decode_input_axis_names

# ---- For Disillation pipeline only ----
# Sow the attention module inputs for the teacher forward pass
if self.config.blockwise_distill:
self.sow(nnx.Intermediate, "attention_inputs", inputs_q)

inputs_q = self._maybe_shard_with_logical(inputs_q, input_axis_names)
inputs_kv = self._maybe_shard_with_logical(inputs_kv, input_axis_names)
qkv_sharding = create_sharding(self.mesh, input_axis_names)
Expand Down
11 changes: 9 additions & 2 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ def __call__(
kv_caches: list[jax.Array] | None = None,
attention_metadata=None,
deepstack_visual_embeds: None | list[jnp.ndarray] = None,
injected_attention_inputs: jax.Array | None = None,
):
cfg = self.config
mesh = self.mesh
Expand Down Expand Up @@ -1047,8 +1048,14 @@ def __call__(
kv_caches[i] = returned_kv_cache[i]
else:
# Fallback to old behavior if kv_caches is None (not vLLM RPA)
current_broadcast_args.append(None)
current_in_axes_tuple.append(nn.broadcast)
if injected_attention_inputs is not None:
# ---- For Disillation pipeline only ----
# Append injected_attention_inputs as the 10th positional argument
current_broadcast_args.extend([None, None, None, None, injected_attention_inputs])
current_in_axes_tuple.extend([nn.broadcast] * 4 + [0])
else:
current_broadcast_args.append(None) # previous_chunk decoder's call() argument
current_in_axes_tuple.append(nn.broadcast)

y, _ = self.scan_decoder_layers(
cfg,
Expand Down
9 changes: 7 additions & 2 deletions src/maxtext/models/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def __call__(
slot: None | int = None,
kv_cache=None,
attention_metadata=None,
injected_attention_inputs: jax.Array | None = None,
):
cfg = self.config

Expand All @@ -166,10 +167,14 @@ def __call__(
lnx = self.pre_self_attention_layer_norm(inputs, out_sharding=lnx_sharding)
lnx = self._maybe_shard_with_logical(lnx, self.activation_axis_names)

# Override attention module inputs if teacher activations are injected
attn_q = injected_attention_inputs if injected_attention_inputs is not None else lnx
attn_kv = injected_attention_inputs if injected_attention_inputs is not None else lnx

# Self-attention block
attention_lnx, kv_cache = self.self_attention(
lnx,
lnx,
attn_q,
attn_kv,
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
Expand Down
3 changes: 3 additions & 0 deletions src/maxtext/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ def __call__(
decoder_target_mask: jax.Array | None = None,
kv_caches: list[jax.Array] | None = None,
attention_metadata: dict[str, Any] | None = None,
injected_attention_inputs: jax.Array | None = None,
):
"""Applies the Zero-1 FSDP wrapped Transformer model.

Expand Down Expand Up @@ -549,6 +550,7 @@ def __call__(
kv_caches=kv_caches,
attention_metadata=attention_metadata,
deepstack_visual_embeds=deepstack_visual_embeds,
injected_attention_inputs=injected_attention_inputs,
) # pytype: disable=wrong-keyword-args
else:
logits, hidden_state, kv_caches = self.decoder(
Expand All @@ -565,6 +567,7 @@ def __call__(
attention_metadata=attention_metadata,
deepstack_visual_embeds=deepstack_visual_embeds,
mutable=mutable_collections,
injected_attention_inputs=injected_attention_inputs,
) # pytype: disable=wrong-keyword-args

# Materialize hidden state when vocab tiling is enabled
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class DistillationForwardOutput:
#: top-k indices for sparse offline distillation
top_k_indices: jax.Array | None = None

#: Experimental: field to carry teacher attention inputs
attention_inputs: jax.Array | None = None


@flax.struct.dataclass(frozen=True)
class MaxTextTrainingInput(peft_trainer.TrainingInput):
Expand Down Expand Up @@ -340,6 +343,7 @@ def __init__(
beta_end: float | None = None,
beta_schedule: Literal["constant", "linear", "cosine"] = "constant",
max_steps: int = 1,
blockwise_distill: bool = False,
):
"""Initializes the Combined distillation strategy.

Expand Down Expand Up @@ -377,6 +381,7 @@ def __init__(
self.alpha = alpha
self.beta_feature = beta_feature
self.layer_indices = jnp.array(layer_indices) if layer_indices is not None else None
self.blockwise_distill = blockwise_distill

# Schedule parameters
self.alpha_end = alpha_end
Expand Down Expand Up @@ -560,7 +565,7 @@ def compute_loss(

feature_loss = beta_feature * self.feature_loss_fn(s_features_sliced, t_features_sliced, mask)

total_loss = base_logit_loss + feature_loss
total_loss = feature_loss if self.blockwise_distill else base_logit_loss + feature_loss

moe_lb_loss = jnp.array(0.0)
if student_output.moe_lb_loss is not None:
Expand Down
39 changes: 28 additions & 11 deletions src/maxtext/trainers/post_train/distillation/train_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,17 @@ def model_forward_fn(
enable_dropout=config.enable_dropout,
decoder_target_tokens=kwargs.get("decoder_target_tokens", None),
decoder_target_mask=kwargs.get("decoder_target_mask", None),
# Pass injected teacher's attention inputs down to the model
injected_attention_inputs=kwargs.get("injected_attention_inputs", None),
)
out_projection_activations = None
if config.distill_beta > 0.0:
out_projection_activations = maxtext_utils.get_intermediate_value(model, "out_projection_activations", clear=True)

attention_inputs = None
if config.blockwise_distill:
attention_inputs = maxtext_utils.get_intermediate_value(model, "attention_inputs", clear=True)

moe_lb_loss = None
if config.num_experts > 1 and config.load_balance_loss_weight > 0.0:
intermediate_outputs = nnx.pop(model, nnx.Intermediate)
Expand All @@ -163,7 +169,10 @@ def model_forward_fn(
moe_lb_loss = jnp.mean(jnp.concatenate(total_moe_lb_losses))

retval = distillation_utils.DistillationForwardOutput(
logits=logits, out_projection_activations=out_projection_activations, moe_lb_loss=moe_lb_loss
logits=logits,
out_projection_activations=out_projection_activations,
moe_lb_loss=moe_lb_loss,
attention_inputs=attention_inputs,
)
return retval

Expand Down Expand Up @@ -314,16 +323,23 @@ def _train_step(self, model, optimizer, inputs):

def loss_wrapper_pure(diff_params, rest):
local_student = nnx.merge(student_graphdef, diff_params, rest, copy=True)
student_output = self.strategy.student_forward_fn(
model=local_student,
input_tokens=batch["input_tokens"],
positions=batch["positions"],
attention_mask=batch.get("attention_mask"),
decoder_segment_ids=batch.get("decoder_segment_ids"),
decoder_target_tokens=batch.get("targets", None),
decoder_target_mask=batch.get("targets_segmentation", None),
cache=None,
)
student_kwargs = {
"model": local_student,
"input_tokens": batch["input_tokens"],
"positions": batch["positions"],
"attention_mask": batch.get("attention_mask"),
"decoder_segment_ids": batch.get("decoder_segment_ids"),
"decoder_target_tokens": batch.get("targets", None),
"decoder_target_mask": batch.get("targets_segmentation", None),
"cache": None,
}
if (
isinstance(teacher_output, distillation_utils.DistillationForwardOutput)
and teacher_output.attention_inputs is not None
):
student_kwargs["injected_attention_inputs"] = teacher_output.attention_inputs

student_output = self.strategy.student_forward_fn(**student_kwargs)
labels = self.strategy.create_labels(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None))
loss, aux = self.strategy.compute_loss(student_output, teacher_output, labels, step=current_step)
# Capture updated non-param state (e.g. RNG counters) from local_student.
Expand Down Expand Up @@ -613,6 +629,7 @@ def build_training_components(
beta_end=student_config.distill_beta_end,
beta_schedule=student_config.distill_beta_schedule,
max_steps=student_config.steps,
blockwise_distill=teacher_config.blockwise_distill,
)

# Prepare optimizer
Expand Down
6 changes: 6 additions & 0 deletions src/maxtext/utils/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,6 +1302,12 @@ def get_intermediate_value(model, nested_key, default=None, clear=False):
intermediate_value = model.decoder.layers["self_attention"][nested_key].get_value()[-1]
if clear:
del model.decoder.layers["self_attention"][nested_key]
# TODO: unitfy with above AND make it compatible with non-scan mode
case "attention_inputs": # for re-architectured distillaion
if nested_key in model.decoder.layers["self_attention"]:
intermediate_value = model.decoder.layers["self_attention"][nested_key].get_value()[-1]
if clear:
del model.decoder.layers["self_attention"][nested_key]
case _:
# Default case to handle any unknown nested keys
raise ValueError(f"Incorrect nested_key: {nested_key}")
Expand Down
Loading
Loading