diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 921726b536..a76a3fc7b1 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1376,6 +1376,12 @@ class Distillation(BaseModel): "The other parameters will be frozen if this attribute is non empty)", ) + # --- Experimental features ---- + blockwise_distill: bool = Field( + False, + description="Enables layer-wise parallel distillaion mode.", + ) + class TrainingLoop(BaseModel): """Configuration for the main training loop, evaluation, and reproducibility.""" diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 93c54e25a6..1609e59d35 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -1120,6 +1120,11 @@ def __call__( else: input_axis_names = self.decode_input_axis_names + # ---- For Disillation pipeline only ---- + # Sow the attention module inputs for the teacher forward pass + if self.config.blockwise_distill: + self.sow(nnx.Intermediate, "attention_inputs", inputs_q) + inputs_q = self._maybe_shard_with_logical(inputs_q, input_axis_names) inputs_kv = self._maybe_shard_with_logical(inputs_kv, input_axis_names) qkv_sharding = create_sharding(self.mesh, input_axis_names) diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 46363dbf70..cfdc2cc2b9 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -794,6 +794,7 @@ def __call__( kv_caches: list[jax.Array] | None = None, attention_metadata=None, deepstack_visual_embeds: None | list[jnp.ndarray] = None, + injected_attention_inputs: jax.Array | None = None, ): cfg = self.config mesh = self.mesh @@ -1047,8 +1048,14 @@ def __call__( kv_caches[i] = returned_kv_cache[i] else: # Fallback to old behavior if kv_caches is None (not vLLM RPA) - current_broadcast_args.append(None) - current_in_axes_tuple.append(nn.broadcast) + if injected_attention_inputs is not None: + # ---- For Disillation pipeline only ---- + # Append injected_attention_inputs as the 10th positional argument + current_broadcast_args.extend([None, None, None, None, injected_attention_inputs]) + current_in_axes_tuple.extend([nn.broadcast] * 4 + [0]) + else: + current_broadcast_args.append(None) # previous_chunk decoder's call() argument + current_in_axes_tuple.append(nn.broadcast) y, _ = self.scan_decoder_layers( cfg, diff --git a/src/maxtext/models/llama2.py b/src/maxtext/models/llama2.py index 0c3e0cca7c..8f1a674da2 100644 --- a/src/maxtext/models/llama2.py +++ b/src/maxtext/models/llama2.py @@ -148,6 +148,7 @@ def __call__( slot: None | int = None, kv_cache=None, attention_metadata=None, + injected_attention_inputs: jax.Array | None = None, ): cfg = self.config @@ -166,10 +167,14 @@ def __call__( lnx = self.pre_self_attention_layer_norm(inputs, out_sharding=lnx_sharding) lnx = self._maybe_shard_with_logical(lnx, self.activation_axis_names) + # Override attention module inputs if teacher activations are injected + attn_q = injected_attention_inputs if injected_attention_inputs is not None else lnx + attn_kv = injected_attention_inputs if injected_attention_inputs is not None else lnx + # Self-attention block attention_lnx, kv_cache = self.self_attention( - lnx, - lnx, + attn_q, + attn_kv, decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index 2af0d560da..da52970281 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -453,6 +453,7 @@ def __call__( decoder_target_mask: jax.Array | None = None, kv_caches: list[jax.Array] | None = None, attention_metadata: dict[str, Any] | None = None, + injected_attention_inputs: jax.Array | None = None, ): """Applies the Zero-1 FSDP wrapped Transformer model. @@ -549,6 +550,7 @@ def __call__( kv_caches=kv_caches, attention_metadata=attention_metadata, deepstack_visual_embeds=deepstack_visual_embeds, + injected_attention_inputs=injected_attention_inputs, ) # pytype: disable=wrong-keyword-args else: logits, hidden_state, kv_caches = self.decoder( @@ -565,6 +567,7 @@ def __call__( attention_metadata=attention_metadata, deepstack_visual_embeds=deepstack_visual_embeds, mutable=mutable_collections, + injected_attention_inputs=injected_attention_inputs, ) # pytype: disable=wrong-keyword-args # Materialize hidden state when vocab tiling is enabled diff --git a/src/maxtext/trainers/post_train/distillation/distillation_utils.py b/src/maxtext/trainers/post_train/distillation/distillation_utils.py index f063cdb23a..068535c9ff 100644 --- a/src/maxtext/trainers/post_train/distillation/distillation_utils.py +++ b/src/maxtext/trainers/post_train/distillation/distillation_utils.py @@ -55,6 +55,9 @@ class DistillationForwardOutput: #: top-k indices for sparse offline distillation top_k_indices: jax.Array | None = None + #: Experimental: field to carry teacher attention inputs + attention_inputs: jax.Array | None = None + @flax.struct.dataclass(frozen=True) class MaxTextTrainingInput(peft_trainer.TrainingInput): @@ -340,6 +343,7 @@ def __init__( beta_end: float | None = None, beta_schedule: Literal["constant", "linear", "cosine"] = "constant", max_steps: int = 1, + blockwise_distill: bool = False, ): """Initializes the Combined distillation strategy. @@ -377,6 +381,7 @@ def __init__( self.alpha = alpha self.beta_feature = beta_feature self.layer_indices = jnp.array(layer_indices) if layer_indices is not None else None + self.blockwise_distill = blockwise_distill # Schedule parameters self.alpha_end = alpha_end @@ -560,7 +565,7 @@ def compute_loss( feature_loss = beta_feature * self.feature_loss_fn(s_features_sliced, t_features_sliced, mask) - total_loss = base_logit_loss + feature_loss + total_loss = feature_loss if self.blockwise_distill else base_logit_loss + feature_loss moe_lb_loss = jnp.array(0.0) if student_output.moe_lb_loss is not None: diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 15197d1af4..f88e419863 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -150,11 +150,17 @@ def model_forward_fn( enable_dropout=config.enable_dropout, decoder_target_tokens=kwargs.get("decoder_target_tokens", None), decoder_target_mask=kwargs.get("decoder_target_mask", None), + # Pass injected teacher's attention inputs down to the model + injected_attention_inputs=kwargs.get("injected_attention_inputs", None), ) out_projection_activations = None if config.distill_beta > 0.0: out_projection_activations = maxtext_utils.get_intermediate_value(model, "out_projection_activations", clear=True) + attention_inputs = None + if config.blockwise_distill: + attention_inputs = maxtext_utils.get_intermediate_value(model, "attention_inputs", clear=True) + moe_lb_loss = None if config.num_experts > 1 and config.load_balance_loss_weight > 0.0: intermediate_outputs = nnx.pop(model, nnx.Intermediate) @@ -163,7 +169,10 @@ def model_forward_fn( moe_lb_loss = jnp.mean(jnp.concatenate(total_moe_lb_losses)) retval = distillation_utils.DistillationForwardOutput( - logits=logits, out_projection_activations=out_projection_activations, moe_lb_loss=moe_lb_loss + logits=logits, + out_projection_activations=out_projection_activations, + moe_lb_loss=moe_lb_loss, + attention_inputs=attention_inputs, ) return retval @@ -314,16 +323,23 @@ def _train_step(self, model, optimizer, inputs): def loss_wrapper_pure(diff_params, rest): local_student = nnx.merge(student_graphdef, diff_params, rest, copy=True) - student_output = self.strategy.student_forward_fn( - model=local_student, - input_tokens=batch["input_tokens"], - positions=batch["positions"], - attention_mask=batch.get("attention_mask"), - decoder_segment_ids=batch.get("decoder_segment_ids"), - decoder_target_tokens=batch.get("targets", None), - decoder_target_mask=batch.get("targets_segmentation", None), - cache=None, - ) + student_kwargs = { + "model": local_student, + "input_tokens": batch["input_tokens"], + "positions": batch["positions"], + "attention_mask": batch.get("attention_mask"), + "decoder_segment_ids": batch.get("decoder_segment_ids"), + "decoder_target_tokens": batch.get("targets", None), + "decoder_target_mask": batch.get("targets_segmentation", None), + "cache": None, + } + if ( + isinstance(teacher_output, distillation_utils.DistillationForwardOutput) + and teacher_output.attention_inputs is not None + ): + student_kwargs["injected_attention_inputs"] = teacher_output.attention_inputs + + student_output = self.strategy.student_forward_fn(**student_kwargs) labels = self.strategy.create_labels(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None)) loss, aux = self.strategy.compute_loss(student_output, teacher_output, labels, step=current_step) # Capture updated non-param state (e.g. RNG counters) from local_student. @@ -613,6 +629,7 @@ def build_training_components( beta_end=student_config.distill_beta_end, beta_schedule=student_config.distill_beta_schedule, max_steps=student_config.steps, + blockwise_distill=teacher_config.blockwise_distill, ) # Prepare optimizer diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 446ea5d0ba..2455d17bcb 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -1302,6 +1302,12 @@ def get_intermediate_value(model, nested_key, default=None, clear=False): intermediate_value = model.decoder.layers["self_attention"][nested_key].get_value()[-1] if clear: del model.decoder.layers["self_attention"][nested_key] + # TODO: unitfy with above AND make it compatible with non-scan mode + case "attention_inputs": # for re-architectured distillaion + if nested_key in model.decoder.layers["self_attention"]: + intermediate_value = model.decoder.layers["self_attention"][nested_key].get_value()[-1] + if clear: + del model.decoder.layers["self_attention"][nested_key] case _: # Default case to handle any unknown nested keys raise ValueError(f"Incorrect nested_key: {nested_key}") diff --git a/tests/post_training/unit/train_distill_test.py b/tests/post_training/unit/train_distill_test.py index d30a6b135b..41b6a8c2b4 100644 --- a/tests/post_training/unit/train_distill_test.py +++ b/tests/post_training/unit/train_distill_test.py @@ -175,6 +175,7 @@ def test_train_step_skips_teacher_forward_when_output_present( trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer) trainer.strategy = mock.Mock() trainer.wrt_filter = lambda path, x: True # type: ignore + mock_tree_map.side_effect = lambda f, x: x # 2. Setup Batch WITH teacher_output mock_batch = { @@ -238,6 +239,7 @@ def test_train_step_calls_teacher_forward_when_output_missing( trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer) trainer.strategy = mock.Mock() trainer.wrt_filter = lambda path, x: True # type: ignore + mock_tree_map.side_effect = lambda f, x: x # 2. Setup Batch WITHOUT teacher_output mock_batch = { @@ -326,6 +328,7 @@ def test_train_step_passes_targets_segmentation( trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer) trainer.strategy = mock.Mock() trainer.wrt_filter = lambda path, x: True # type: ignore + mock_tree_map.side_effect = lambda f, x: x # 2. Setup Batch WITH targets_segmentation mock_targets_segmentation = jnp.array([[1, 1, 0]]) @@ -388,6 +391,73 @@ def test_train_step_passes_targets_segmentation( cache=None, ) + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") + def test_train_step_passes_injected_attention_inputs_under_blockwise_distill( + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_tree_map, mock_global_norm + ): + """Verifies that teacher's attention_inputs are passed to student_forward_fn as injected_attention_inputs.""" + # 1. Initialize Trainer + # pylint: disable=no-value-for-parameter + trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer) + trainer.strategy = mock.Mock() + trainer.wrt_filter = lambda path, x: True # type: ignore + mock_tree_map.side_effect = lambda f, x: x + + # 2. Setup Batch WITH teacher_output containing attention_inputs + mock_attention_inputs = jnp.ones((1, 2, 8)) + fake_teacher_output = distillation_utils.DistillationForwardOutput( + logits=jnp.zeros((1, 2, 4)), + out_projection_activations=None, + attention_inputs=mock_attention_inputs, + ) + mock_batch = { + "input_tokens": mock.Mock(), + "positions": mock.Mock(), + "attention_mask": mock.Mock(), + "decoder_segment_ids": mock.Mock(), + "targets": mock.Mock(), + "teacher_output": fake_teacher_output, + } + trainer.gen_model_input_fn = mock.Mock(return_value=mock_batch) + + # 3. Setup Models & Inputs + teacher_model, student_model = mock.Mock(), mock.Mock() + model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) + optimizer, inputs = mock.Mock(), mock.Mock() + + # 4. Configure nnx.split/merge/update mocks + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + # 5. Configure mocked jax.value_and_grad + mock_grad_fn = mock.Mock(return_value=((mock.Mock(), ({}, mock.Mock())), mock.Mock())) + mock_value_and_grad.return_value = mock_grad_fn + mock_global_norm.return_value = mock.Mock() + trainer.strategy.compute_loss.return_value = (mock.Mock(), {}) + + # 6. Execute outer function & trigger inner loss_wrapper_pure + trainer._train_step(model_bundle, optimizer, inputs) + loss_wrapper = mock_value_and_grad.call_args[0][0] + loss_wrapper(mock_diff_params, mock_rest) + + # 7. Assertions + trainer.strategy.student_forward_fn.assert_called_once_with( + model=mock.ANY, # local_student from nnx.merge, not the original student_model + input_tokens=mock_batch["input_tokens"], + positions=mock_batch["positions"], + attention_mask=mock_batch["attention_mask"], + decoder_segment_ids=mock_batch["decoder_segment_ids"], + decoder_target_tokens=mock_batch["targets"], + decoder_target_mask=None, + cache=None, + injected_attention_inputs=mock_attention_inputs, + ) + def test_optimizer_factory(self): """Verifies the optimizer factory injects hyperparams and handles configs.""" # Mock config @@ -575,6 +645,85 @@ def _mean(pair): self.assertTrue(19.0 < _mean(metrics["distill/kl_div_at_T"]) < 21.0) self.assertTrue(_mean(metrics["distill/teacher_loss"]) == 0.0) + def test_monitored_strategy_blockwise_distill(self): + """Verifies the strategy ignores base_logit_loss when blockwise_distill is enabled.""" + strategy = distillation_utils.CombinedDistillationStrategy( + student_forward_fn=lambda m, **k: None, + teacher_forward_fn=lambda m, **k: None, + vocab_size=4, + temperature=1.0, + alpha=0.5, + beta_feature=1.5, + feature_loss_type="cosine", + layer_indices=None, + blockwise_distill=True, + ) + student_output = distillation_utils.DistillationForwardOutput( + logits=jnp.array([[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]]) * 10, + out_projection_activations=jnp.ones((32, 1, 1, 8)), + ) + teacher_output = distillation_utils.DistillationForwardOutput( + logits=jnp.array([[[0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]]) * 10, + out_projection_activations=jnp.ones((32, 1, 1, 8)) * 1.5, + ) + labels_indices = jnp.array([[0, 1]]) + labels = jax.nn.one_hot(labels_indices, 4) + total_loss, _ = strategy.compute_loss(student_output, teacher_output, labels) + self.assertAlmostEqual(float(total_loss), 0.0, places=5) + + def test_build_training_components_blockwise_distill(self): + """Verifies build_training_components gets blockwise_distill from teacher_config.""" + student_config = mock.Mock() + student_config.tokenizer_path = "" + student_config.tokenizer_type = "huggingface" + student_config.add_bos = True + student_config.add_eos = True + student_config.hf_access_token = None + student_config.distill_temperature = 1.0 + student_config.distill_alpha = 0.5 + student_config.distill_beta = 0.0 + student_config.distill_layer_indices = None + student_config.distill_feature_loss_type = "cosine" + student_config.vocab_size = 4 + student_config.distill_alpha_end = None + student_config.distill_alpha_schedule = "constant" + student_config.distill_temperature_end = None + student_config.distill_temperature_schedule = "constant" + student_config.distill_beta_end = None + student_config.distill_beta_schedule = "constant" + student_config.steps = 100 + student_config.checkpoint_period = 10 + student_config.max_num_checkpoints_to_keep = 1 + student_config.async_checkpointing = False + student_config.profiler = "none" + student_config.tensorboard_dir = "" + student_config.log_period = 10 + student_config.gradient_accumulation_steps = 1 + student_config.data_sharding = ["data"] + student_config.learning_rate = 1e-4 + student_config.opt_type = "adamw" + student_config.adam_b1 = 0.9 + student_config.adam_b2 = 0.99 + student_config.adam_eps = 1e-8 + student_config.adam_eps_root = 0.0 + student_config.adam_weight_decay = 0.0 + student_config.mu_dtype = "float32" + student_config.gradient_clipping_threshold = 1.0 + student_config.warmup_steps_fraction = 0.1 + student_config.learning_rate_final_fraction = 0.1 + student_config.blockwise_distill = False + + teacher_config = mock.Mock() + teacher_config.blockwise_distill = True + + with mock.patch("maxtext.trainers.post_train.distillation.train_distill.tokenizer.build_tokenizer") as mock_build: + mock_tok = mock.Mock() + mock_tok.pad_id = 0 + mock_build.return_value = mock_tok + + strategy, _, _ = train_distill.build_training_components(student_config, teacher_config) + self.assertTrue(strategy.blockwise_distill) + def test_setup_pipeline_grain_enabled(self): """Covers setup_checkpoint_manager_and_restore when Grain IS detected.""" mock_trainer = mock.Mock() @@ -1136,6 +1285,7 @@ def test_main_offline_mode_skips_teacher_loading( mock_teacher_cfg.per_device_batch_size = 1 mock_teacher_cfg.max_target_length = 16 mock_teacher_cfg.gradient_accumulation_steps = 1 + mock_teacher_cfg.blockwise_distill = False mock_pyconfig_init.side_effect = [mock_global, mock_student_cfg, mock_teacher_cfg] # 2. Model Loading @@ -1246,6 +1396,7 @@ def test_main_online_mode_loads_teacher( mock_teacher_cfg.per_device_batch_size = 1 mock_teacher_cfg.max_target_length = 16 mock_teacher_cfg.gradient_accumulation_steps = 1 + mock_teacher_cfg.blockwise_distill = False mock_pyconfig_init.side_effect = [mock_global, mock_student_cfg, mock_teacher_cfg] mock_student_model = mock.Mock()