From 2997fcfa4ec115d649e157e102e21c29d64b6b5e Mon Sep 17 00:00:00 2001 From: Igor Tsvetkov Date: Thu, 11 Jun 2026 12:28:17 -0700 Subject: [PATCH 1/2] Add logits correctness integration tests for DPO against HF TRL This commit introduces a JAX DPO correctness integration test that validates JAX DPO training step metrics (loss, margin, chosen/rejected logprobs) against stored golden outputs. It also includes the CPU-only JAX/PyTorch parallel validation script used to verify parity against Hugging Face TRL and generate the remote-canonical golden assets. --- .gitignore | 1 + .../trainers/post_train/dpo/train_dpo.py | 12 +- .../golden_logits/golden_dpo_correctness.json | 18 ++ .../dpo/dpo_2_column_dataset.json | 90 +++++++ .../dpo/dpo_3_column_dataset.json | 112 ++++++++ .../logits_generation/dpo_pytorch_helpers.py | 232 +++++++++++++++++ ..._golden_data_and_compare_pytorch_logits.py | 204 +++++++++++++++ .../integration/dpo_correctness_base.py | 239 ++++++++++++++++++ .../unit/dpo_trainer_correctness_test.py | 116 +++++++++ 9 files changed, 1020 insertions(+), 4 deletions(-) create mode 100644 tests/assets/golden_logits/golden_dpo_correctness.json create mode 100644 tests/assets/local_datasets/dpo/dpo_2_column_dataset.json create mode 100644 tests/assets/local_datasets/dpo/dpo_3_column_dataset.json create mode 100644 tests/assets/logits_generation/dpo_pytorch_helpers.py create mode 100644 tests/assets/logits_generation/generate_dpo_golden_data_and_compare_pytorch_logits.py create mode 100644 tests/post_training/integration/dpo_correctness_base.py create mode 100644 tests/post_training/unit/dpo_trainer_correctness_test.py diff --git a/.gitignore b/.gitignore index b67cba160d..a5d7706c65 100644 --- a/.gitignore +++ b/.gitignore @@ -155,3 +155,4 @@ gha-creds-*.json # vscode workspace maxtext.code-workspace +maxtext_output/ diff --git a/src/maxtext/trainers/post_train/dpo/train_dpo.py b/src/maxtext/trainers/post_train/dpo/train_dpo.py index 0fb9d49655..73f3a37202 100644 --- a/src/maxtext/trainers/post_train/dpo/train_dpo.py +++ b/src/maxtext/trainers/post_train/dpo/train_dpo.py @@ -108,7 +108,7 @@ def get_tunix_config(mt_config: MaxTextConfig) -> DPOTrainingConfig: ) -def setup_trainer_state(mt_config, goodput_recorder=None): +def setup_trainer_state(mt_config, goodput_recorder=None, test_only_training_hooks_class=None): """Set up prerequisites for training loop.""" tunix_config = get_tunix_config(mt_config) @@ -143,7 +143,10 @@ def setup_trainer_state(mt_config, goodput_recorder=None): ref_model = nnx.clone(model) if mt_config.dpo.algo == "dpo" else None with maybe_record_goodput(goodput_recorder, GoodputEvent.TRAINING_PREPARATION): - training_hooks = hooks.DPOTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder) + if test_only_training_hooks_class is None: + test_only_training_hooks_class = hooks.DPOTrainingHooks + + training_hooks = test_only_training_hooks_class(mt_config, mesh, learning_rate_schedule, goodput_recorder) data_hooks = hooks.DPODataHooks(mt_config, mesh, goodput_recorder) # Provide rules context so logical axes (e.g. 'norm') are translated to mesh axes during maybe_restore @@ -164,14 +167,15 @@ def train_model(mt_config: MaxTextConfig, trainer, mesh): return trainer -def train(mt_config, goodput_recorder=None): +def train(mt_config, goodput_recorder=None, test_only_training_hooks_class=None): """Main method for DPO training. Args: mt_config: MaxText config. goodput_recorder: An optional GoodputRecorder to record performance metrics. + test_only_training_hooks_class: An optional DPOTrainingHooks subclass to override hooks. """ - trainer, mesh = setup_trainer_state(mt_config, goodput_recorder) + trainer, mesh = setup_trainer_state(mt_config, goodput_recorder, test_only_training_hooks_class) _job_completed_gracefully = False try: trainer = train_model(mt_config, trainer, mesh) diff --git a/tests/assets/golden_logits/golden_dpo_correctness.json b/tests/assets/golden_logits/golden_dpo_correctness.json new file mode 100644 index 0000000000..5552ebde26 --- /dev/null +++ b/tests/assets/golden_logits/golden_dpo_correctness.json @@ -0,0 +1,18 @@ +{ + "explicit_prompt_len_3_column": { + "loss_step_1": 0.6931471824645996, + "margin_step_1": 0.0, + "loss": 0.6573728919029236, + "margin": 0.0728759765625, + "chosen_logps": -854.4208984375, + "rejected_logps": -530.0099487304688 + }, + "default_prompt_len_2_column": { + "loss_step_1": 0.6931471824645996, + "margin_step_1": 0.0, + "loss": 0.6954630613327026, + "margin": -0.004626465495675802, + "chosen_logps": -875.7848510742188, + "rejected_logps": -502.052978515625 + } +} \ No newline at end of file diff --git a/tests/assets/local_datasets/dpo/dpo_2_column_dataset.json b/tests/assets/local_datasets/dpo/dpo_2_column_dataset.json new file mode 100644 index 0000000000..ae284c9be1 --- /dev/null +++ b/tests/assets/local_datasets/dpo/dpo_2_column_dataset.json @@ -0,0 +1,90 @@ +[ + { + "chosen": "\n\nHuman: How does gradient descent work?\n\nAssistant: Updating parameters in the opposite direction of the gradient is how gradient descent works.", + "rejected": "\n\nHuman: How does gradient descent work?\n\nAssistant: Climbing up hills to find local maxima is how gradient descent works." + }, + { + "chosen": "\n\nHuman: What is a neural network?\n\nAssistant: Interconnected nodes that learn representations from data form a neural network.", + "rejected": "\n\nHuman: What is a neural network?\n\nAssistant: Storing tabular data in a designed schema is the purpose of a neural network." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + }, + { + "chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization." + } +] \ No newline at end of file diff --git a/tests/assets/local_datasets/dpo/dpo_3_column_dataset.json b/tests/assets/local_datasets/dpo/dpo_3_column_dataset.json new file mode 100644 index 0000000000..6c5c236cad --- /dev/null +++ b/tests/assets/local_datasets/dpo/dpo_3_column_dataset.json @@ -0,0 +1,112 @@ +[ + { + "prompt": "How does gradient descent work?", + "chosen": "Updating parameters in the opposite direction of the gradient is how gradient descent works.", + "rejected": "Climbing up hills to find local maxima is how gradient descent works." + }, + { + "prompt": "What is a neural network?", + "chosen": "Interconnected nodes that learn representations from data form a neural network.", + "rejected": "Storing tabular data in a designed schema is the purpose of a neural network." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + }, + { + "prompt": "What is preference optimization?", + "chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.", + "rejected": "Database operations to choose preferred options are called preference optimization." + } +] diff --git a/tests/assets/logits_generation/dpo_pytorch_helpers.py b/tests/assets/logits_generation/dpo_pytorch_helpers.py new file mode 100644 index 0000000000..cb22accd95 --- /dev/null +++ b/tests/assets/logits_generation/dpo_pytorch_helpers.py @@ -0,0 +1,232 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch reference functions and weight synchronization helpers for DPO/ORPO integration tests. + +Note: The integration tests validate parity by comparing JAX and PyTorch/TRL on an +identical, miniaturized 2-layer Qwen2 model architecture. +- The JAX model configuration is defined in `tests/post_training/integration/dpo_correctness_base.py` + (via `_build_jax_config` which overrides model dimensions to a tiny 2-layer shape). +- The PyTorch model configuration is defined here in `create_pytorch_config` and is + designed to be structurally identical to the JAX model to allow direct parameter + synchronization and logit comparison. +""" + +import tempfile +import numpy as np +from flax import nnx +import torch +from transformers import Qwen2Config +from trl import DPOConfig, DPOTrainer +from datasets import Dataset + + +def sync_jax_to_pytorch(jax_model, torch_model): + """Synchronizes JAX model parameters directly to a PyTorch model.""" + hidden_size = torch_model.config.hidden_size + torch_state_dict = torch_model.state_dict() + jax_flat = dict(nnx.state(jax_model).flat_state()) + + def sync_param(torch_key, jax_key, reshape=None, transpose=False): + val = np.array(jax_flat[jax_key][...]) + if reshape: + val = val.reshape(reshape) + if transpose: + val = val.T + torch_state_dict[torch_key].copy_(torch.from_numpy(val)) + + # 1. Token embedding + sync_param("model.embed_tokens.weight", ("base", "token_embedder", "embedding")) + + # 2. Final layer norm + sync_param("model.norm.weight", ("base", "decoder", "decoder_norm", "scale")) + + # 3. Causal layers (2 layers) + num_layers = 2 + for i in range(num_layers): + # Input and post-attention layer norms + sync_param( + f"model.layers.{i}.input_layernorm.weight", + ("base", "decoder", f"layers_{i}", "pre_self_attention_layer_norm", "scale"), + ) + sync_param( + f"model.layers.{i}.post_attention_layernorm.weight", + ("base", "decoder", f"layers_{i}", "post_self_attention_layer_norm", "scale"), + ) + + # Attention projection weights (transposed from JAX to match PyTorch) + sync_param( + f"model.layers.{i}.self_attn.q_proj.weight", + ("base", "decoder", f"layers_{i}", "self_attention", "query", "kernel"), + reshape=(hidden_size, hidden_size), + transpose=True, + ) + sync_param( + f"model.layers.{i}.self_attn.k_proj.weight", + ("base", "decoder", f"layers_{i}", "self_attention", "key", "kernel"), + reshape=(hidden_size, hidden_size), + transpose=True, + ) + sync_param( + f"model.layers.{i}.self_attn.v_proj.weight", + ("base", "decoder", f"layers_{i}", "self_attention", "value", "kernel"), + reshape=(hidden_size, hidden_size), + transpose=True, + ) + sync_param( + f"model.layers.{i}.self_attn.o_proj.weight", + ("base", "decoder", f"layers_{i}", "self_attention", "out", "kernel"), + reshape=(hidden_size, hidden_size), + transpose=True, + ) + + # Attention biases + sync_param( + f"model.layers.{i}.self_attn.q_proj.bias", + ("base", "decoder", f"layers_{i}", "self_attention", "query", "bias"), + reshape=(hidden_size,), + ) + sync_param( + f"model.layers.{i}.self_attn.k_proj.bias", + ("base", "decoder", f"layers_{i}", "self_attention", "key", "bias"), + reshape=(hidden_size,), + ) + sync_param( + f"model.layers.{i}.self_attn.v_proj.bias", + ("base", "decoder", f"layers_{i}", "self_attention", "value", "bias"), + reshape=(hidden_size,), + ) + + # MLP projection weights (wi_0, wi_1, wo) + sync_param( + f"model.layers.{i}.mlp.gate_proj.weight", + ("base", "decoder", f"layers_{i}", "mlp", "wi_0", "kernel"), + transpose=True, + ) + sync_param( + f"model.layers.{i}.mlp.up_proj.weight", + ("base", "decoder", f"layers_{i}", "mlp", "wi_1", "kernel"), + transpose=True, + ) + sync_param( + f"model.layers.{i}.mlp.down_proj.weight", + ("base", "decoder", f"layers_{i}", "mlp", "wo", "kernel"), + transpose=True, + ) + + # 4. LM Head weight (logits_via_embedding=True matches embeddings) + sync_param("lm_head.weight", ("base", "token_embedder", "embedding")) + + torch_model.load_state_dict(torch_state_dict) + + +def get_pytorch_reference( + policy_model, + ref_model, + tokenizer, + prompt_str, + chosen_str, + rejected_str, + beta=0.1, + tokenize_together=False, +): + # pylint: disable=too-many-positional-arguments + """Computes reference chosen/rejected logps and loss in PyTorch using TRL trainers.""" + policy_model.eval() + if ref_model is not None: + ref_model.eval() + + # Set up the tokenizer based on JAX's tokenize_together formatting. + # tokenize_together=True (2-column format) has no EOS in the middle of prompt. + # tokenize_together=False (3-column format) has EOS in the middle. + tokenizer.add_eos_token = False + if not tokenize_together: + prompt_str = prompt_str + tokenizer.eos_token + + # Build Dataset + dataset = Dataset.from_list( + [ + { + "prompt": prompt_str, + "chosen": chosen_str, + "rejected": rejected_str, + } + ] + ) + + with tempfile.TemporaryDirectory() as temp_dir: + training_args = DPOConfig( + output_dir=temp_dir, + beta=beta, + max_length=256, + use_cpu=True, + remove_unused_columns=False, + ) + trainer = DPOTrainer( + model=policy_model, + ref_model=ref_model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + ) + dataloader = trainer.get_train_dataloader() + batch = next(iter(dataloader)) + + with torch.no_grad(): + loss = trainer.compute_loss(policy_model, batch) + + # Extract logps and ref_logps from "eval" key in trainer._metrics + # pylint: disable=protected-access + metrics = trainer._metrics["eval"] + chosen_logps = metrics["logps/chosen"][0] + rejected_logps = metrics["logps/rejected"][0] + + # Reconstruct ref_chosen_logps and ref_rejected_logps: + ref_chosen_logps = chosen_logps - (metrics["rewards/chosen"][0] / beta) + ref_rejected_logps = rejected_logps - (metrics["rewards/rejected"][0] / beta) + + # Margin: + margin = metrics["rewards/margins"][0] / beta + + return { + "chosen_logps": chosen_logps, + "rejected_logps": rejected_logps, + "ref_chosen_logps": ref_chosen_logps, + "ref_rejected_logps": ref_rejected_logps, + "loss": loss.item(), + "margin": margin, + } + + +def create_pytorch_config(max_target_length: int) -> Qwen2Config: + """Helper to create a symmetrical PyTorch tiny Qwen2 model configuration. + + This configuration must be kept structurally identical to the JAX model configuration + defined in `tests/post_training/integration/dpo_correctness_base.py` (specifically + the tiny architecture overrides in `_build_jax_config`: base_emb_dim=64, + base_num_decoder_layers=2, etc.). + """ + return Qwen2Config( + vocab_size=151936, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + head_dim=32, + max_position_embeddings=max_target_length, + rms_norm_eps=1e-6, + rope_theta=1000000.0, + use_cache=False, + ) diff --git a/tests/assets/logits_generation/generate_dpo_golden_data_and_compare_pytorch_logits.py b/tests/assets/logits_generation/generate_dpo_golden_data_and_compare_pytorch_logits.py new file mode 100644 index 0000000000..76a60d9acb --- /dev/null +++ b/tests/assets/logits_generation/generate_dpo_golden_data_and_compare_pytorch_logits.py @@ -0,0 +1,204 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script to execute JAX and HuggingFace TRL DPO in parallel, verify DPO correctness parity, +and output golden JAX metrics to a JSON file. + +The parity is verified by running JAX and PyTorch/TRL on an identical, miniaturized 2-layer Qwen2 model. +- The JAX model is configured in `tests/post_training/integration/dpo_correctness_base.py`. +- The PyTorch model is configured in `tests/assets/logits_generation/dpo_pytorch_helpers.py`. + +The flow of this test is: +* Load the same Qwen2 model (Qwen/Qwen2.5-1.5B-Instruct) in both JAX and PyTorch/TRL. +* Run the MaxText DPO training loop for a few steps (train_steps=2 in this implementation). + This is needed to make sure that the policy model diverges from the reference model and that we can calculate + the DPO loss and margin. +* Use custom training hooks to intercept the model parameters after 2 training steps and copy them into the Pytorch model. +* Compare the model parameters between JAX and PyTorch/TRL. +* Assert that the model parameters are identical. +* Calculate the DPO loss and margin for the last batch. +* Assert that the DPO loss and margin between JAX and PyTorch/TRL are identical. +* Save the metrics in the golden data json file. + +Note: Both JAX and PyTorch/TRL are executed on CPU to ensure maximum reproducibility and +eliminate GPU/TPU floating point differences. + +How to run: + 1. Install required PyTorch and Hugging Face dependencies if they are not already present in your + virtual environment: + $ uv pip install torch transformers datasets trl + 2. Run the script (forcing JAX on CPU to ensure exact parity with PyTorch): + $ JAX_PLATFORMS=cpu python3 -m tests.assets.logits_generation.generate_dpo_golden_data_and_compare_pytorch_logits +""" + +import json +import os +import tempfile +import jax +from transformers import AutoTokenizer, Qwen2ForCausalLM + +from tests.post_training.integration.dpo_correctness_base import ( + DPOCorrectnessTestBase, + run_jax_training, + InterceptingTrainingHooks, +) +from tests.assets.logits_generation.dpo_pytorch_helpers import ( + create_pytorch_config, + get_pytorch_reference, + sync_jax_to_pytorch, +) + + +class PyTorchSyncTrainingHooks(InterceptingTrainingHooks): + """Training hooks subclass that synchronizes JAX model parameters to PyTorch.""" + + torch_policy_model = None + torch_ref_model = None + + def on_train_step_start(self, train_ctx): + super().on_train_step_start(train_ctx) + if train_ctx.train_steps == 0: + # Grab the reference model on step 0 to add extra validation that the reference model is not changing during + # training. + if train_ctx.ref_model is not None: + sync_jax_to_pytorch(train_ctx.ref_model, PyTorchSyncTrainingHooks.torch_ref_model) + else: + sync_jax_to_pytorch(train_ctx.model, PyTorchSyncTrainingHooks.torch_ref_model) + + elif train_ctx.train_steps == 2: + # Copy the trained policy model weights after 2 steps of training. By this time we expect the policy model to have + # diverged from the reference model. + sync_jax_to_pytorch(train_ctx.model, PyTorchSyncTrainingHooks.torch_policy_model) + + +def run_parity_and_generate_golden(): + """Runs DPO scenarios to verify parity and outputs golden JAX metrics.""" + if jax.default_backend() != "cpu": + raise RuntimeError( + "DPO golden data generation must run on CPU (please run with environment variable JAX_PLATFORMS=cpu), " + f"but JAX default backend is '{jax.default_backend()}'." + ) + + # Setup JAX CPU options to align with base test environment setup + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""): + os.environ["LIBTPU_INIT_ARGS"] = ( + os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + ) + + # Instantiate a dummy base test class to invoke config builders + base_test = DPOCorrectnessTestBase() + DPOCorrectnessTestBase.setUpClass() + + model_id = DPOCorrectnessTestBase.MODEL_ID + max_target_length = 256 + beta = 0.1 + init_weights_seed = 0 + + scenarios = { + "explicit_prompt_len_3_column": (144, "dpo_3_column_dataset.json", ["prompt", "chosen", "rejected"]), + "default_prompt_len_2_column": (None, "dpo_2_column_dataset.json", ["chosen", "rejected"]), + } + + # ============================================================================ + # DPO GENERATION + # ============================================================================ + dpo_results = {} + print("\n>>> Running DPO Parity & Golden Generation...") + for name, (max_prompt_len, dataset_filename, data_columns) in scenarios.items(): + print(f"\n--- Scenario: {name} ---") + InterceptingTrainingHooks.captured_metrics = [] + + # Initialize Pytorch structures + torch_config = create_pytorch_config(max_target_length) + torch_policy_model = Qwen2ForCausalLM(torch_config) + torch_ref_model = Qwen2ForCausalLM(torch_config) + + PyTorchSyncTrainingHooks.torch_policy_model = torch_policy_model + PyTorchSyncTrainingHooks.torch_ref_model = torch_ref_model + + with tempfile.TemporaryDirectory() as temp_dir: + config = base_test.build_tiny_qwen2_jax_config( + max_target_length=max_target_length, + temp_dir=temp_dir, + init_weights_seed=init_weights_seed, + dataset_filename=dataset_filename, + data_columns=data_columns, + max_prompt_len=max_prompt_len, + extra_args=["run_name=dpo_correctness_gen"], + ) + jax_ref = run_jax_training(config, test_only_training_hooks_class=PyTorchSyncTrainingHooks) + + if len(data_columns) == 2: + prompt_str = f"\n\nHuman: {base_test.COMMON_PROMPT}\n\nAssistant:" + # Add a space prefix to chosen/rejected to avoid BPE prefix mismatch. + chosen_str = " " + base_test.COMMON_CHOSEN + rejected_str = " " + base_test.COMMON_REJECTED + else: + prompt_str = base_test.COMMON_PROMPT + chosen_str = base_test.COMMON_CHOSEN + rejected_str = base_test.COMMON_REJECTED + + py_ref = get_pytorch_reference( + policy_model=torch_policy_model, + ref_model=torch_ref_model, + tokenizer=AutoTokenizer.from_pretrained(model_id), + prompt_str=prompt_str, + chosen_str=chosen_str, + rejected_str=rejected_str, + beta=beta, + tokenize_together=(len(data_columns) == 2), + ) + + # Parity verification before writing + chosen_diff = abs(jax_ref["chosen_logps"] - py_ref["chosen_logps"]) + rejected_diff = abs(jax_ref["rejected_logps"] - py_ref["rejected_logps"]) + loss_diff = abs(jax_ref["loss"] - py_ref["loss"]) + + print( + f"JAX Chosen: {jax_ref['chosen_logps']:.6f} | " + f"PyTorch Chosen: {py_ref['chosen_logps']:.6f} " + f"(diff: {chosen_diff:.6f})" + ) + print( + f"JAX Rejected: {jax_ref['rejected_logps']:.6f} | " + f"PyTorch Rejected: {py_ref['rejected_logps']:.6f} " + f"(diff: {rejected_diff:.6f})" + ) + print(f"JAX Loss: {jax_ref['loss']:.6f} | PyTorch Loss: {py_ref['loss']:.6f} (diff: {loss_diff:.6f})") + + assert chosen_diff < DPOCorrectnessTestBase.LOG_PROBS_TOLERANCE, f"Chosen logps diff {chosen_diff} exceeds tolerance!" + assert ( + rejected_diff < DPOCorrectnessTestBase.LOG_PROBS_TOLERANCE + ), f"Rejected logps diff {rejected_diff} exceeds tolerance!" + assert loss_diff < DPOCorrectnessTestBase.DPO_LOSS_TOLERANCE, f"Loss diff {loss_diff} exceeds tolerance!" + + dpo_results[name] = jax_ref + + # Write DPO Golden Logits + dpo_output_path = "tests/assets/golden_logits/golden_dpo_correctness.json" + with open(dpo_output_path, "w", encoding="utf-8") as f: + json.dump(dpo_results, f, indent=2) + print(f"\nWrote DPO golden metrics to: {dpo_output_path}") + + # Cleanup hooks + PyTorchSyncTrainingHooks.torch_policy_model = None + PyTorchSyncTrainingHooks.torch_ref_model = None + + # Clean up class environment setup + DPOCorrectnessTestBase.tearDownClass() + + +if __name__ == "__main__": + run_parity_and_generate_golden() diff --git a/tests/post_training/integration/dpo_correctness_base.py b/tests/post_training/integration/dpo_correctness_base.py new file mode 100644 index 0000000000..26cbe10e86 --- /dev/null +++ b/tests/post_training/integration/dpo_correctness_base.py @@ -0,0 +1,239 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utilities, weight synchronization logic, and base class for DPO +integration correctness tests. +""" + +import os + +# Force JAX/XLA to only create 1 virtual CPU device to prevent batch size scaling +# issues on large multi-core CI runners (which would cause the 20-example test +# dataset to be dropped completely due to drop_remainder=True). +os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + " --xla_force_host_platform_device_count=1" + +from typing import Any +from absl.testing import parameterized +import jax + +# MaxText / Tunix imports +from maxtext.configs import pyconfig +from maxtext.trainers.post_train.dpo import train_dpo +from maxtext.trainers.post_train.dpo import hooks as dpo_hooks + + +# ============================================================================== +# 1. LOW-LEVEL STATE INTERCEPTION & WEIGHT SYNCHRONIZATION +# ============================================================================== +class InterceptingTrainingHooks(dpo_hooks.DPOTrainingHooks): + """Custom training hooks class to intercept loss and rewards margin during real trainer step execution.""" + + captured_metrics = [] + + def on_train_step_end(self, train_ctx, train_step, train_loss, step_time=0.0): + super().on_train_step_end(train_ctx, train_step, train_loss, step_time) + + prefix = train_ctx.metrics_prefix + accuracy = float(train_ctx.metrics_logger.get_metric_history(prefix, "rewards/accuracy", "train")[-1]) + margin = float(train_ctx.metrics_logger.get_metric_history(prefix, "rewards/margin", "train")[-1]) + chosen_logps = float(train_ctx.metrics_logger.get_metric_history(prefix, "log_probs/chosen", "train")[-1]) + rejected_logps = float(train_ctx.metrics_logger.get_metric_history(prefix, "log_probs/rejected", "train")[-1]) + + InterceptingTrainingHooks.captured_metrics.append( + { + "loss": float(train_loss), + "accuracy": accuracy, + "margin": margin, + "chosen_logps": chosen_logps, + "rejected_logps": rejected_logps, + } + ) + + +# ============================================================================== +# 2. HIGH-LEVEL MODEL EXECUTION RUNNERS +# ============================================================================== +def run_jax_training(config, test_only_training_hooks_class=InterceptingTrainingHooks): + """Runs JAX DPO training and returns a flat dict of captured step metrics.""" + print("Executing JAX DPO train_dpo.train()...") + train_dpo.train( + config, test_only_training_hooks_class=test_only_training_hooks_class + ) # Perform 3 steps of DPO training. + + captured = test_only_training_hooks_class.captured_metrics + assert len(captured) == 3, f"Expected 3 steps metrics, got {len(captured)}" + + step_1 = captured[0] # We will print the stats for this step, but won't use them for comparison. + step_3 = captured[2] # These are the values we want to compare against PyTorch. + + return { + "loss_step_1": step_1["loss"], + "margin_step_1": step_1["margin"], + "loss": step_3["loss"], + "margin": step_3["margin"], + "chosen_logps": step_3["chosen_logps"], + "rejected_logps": step_3["rejected_logps"], + } + + +# ============================================================================== +# 3. SHARED BASE CLASS DEFINITION +# ============================================================================== +class DPOCorrectnessTestBase(parameterized.TestCase): + """Shared base class establishing environment setup and configuration helpers for DPO parity tests.""" + + MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct" + + # These prompt and response strings must match the content of the pre-generated local + # dataset JSON files (e.g., dpo_2_column_dataset.json and dpo_3_column_dataset.json) + # loaded in the integration tests. If these strings are modified, the dataset files + # must be regenerated using the parity generator script. + COMMON_PROMPT = "What is preference optimization?" + COMMON_CHOSEN = "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization." + COMMON_REJECTED = "Database operations to choose preferred options are called preference optimization." + + # Constants for sanity checking & parity. + # These tolerances and platform constraints were validated by running extensive + # sensitivity and hardware studies: + # + # 1. JAX vs PyTorch CPU Parity (Same Host): + # - Loss Difference: stable across seeds, ranging from 0.001 to 0.057. + # - Log Probs Difference: max observed difference ~0.60. + # + # 2. JAX CPU Cross-Platform and Version Drift: + # - Running JAX CPU on different host environments (e.g., local workstation vs remote GCE VM) + # or upgrading JAX versions (0.9.2 -> 0.10.0) introduces minor float32 compiler drift. + # - Max observed cross-platform logprobs noise: ~1.02. + # - Max observed cross-platform loss noise: ~0.113. + # + # 3. JAX CPU vs. TPU Backend Divergence (Enforcing CPU constraint): + # - Running JAX on TPU introduces massive compiler and float32 rounding shifts relative to CPU: + # - Log Probs Shift: ~42.68 (TPU v4) and ~14.83 (TPU v5p) on the exact same inputs. + # - Loss Shift: ~0.039 (TPU v4) and ~0.219 (TPU v5p). + # - Because DPO loss is based on log-ratios, the hardware shift can cancel out on some chips + # (like TPU v4), but diverges significantly on others (e.g., TPU v5p loss shift of 0.219 + # exceeds safe tolerances). + # - Therefore, the test strictly enforces CPU execution to ensure cross-platform reproducibility. + # + # 4. Sensitivity to Semantic Mutations vs. Tolerances: + # - A 1-character dataset mutation ("responses" -> "response") shifts logprobs by 4.73 to 10.84. + # - A minor 1-word mutation ("Aligning" -> "Training") shifts logprobs by 18.02 to 42.86. + # - Since the CPU version/platform noise (~1.02) is far below the smallest semantic mutation (4.73), + # the tolerances are calibrated to maximize robustness to compiler drift while remaining highly + # sensitive to real regressions. + LOG_PROBS_TOLERANCE = 3.0 + DPO_LOSS_TOLERANCE = 0.20 + + @classmethod + def setUpClass(cls): + """Set up the test class by setting JAX default platform to CPU.""" + # Assert that the JAX CPU device count flag was successfully respected. + # If JAX was imported/initialized elsewhere before our os.environ["XLA_FLAGS"] + # statement, this assertion will fail, preventing silent test suite failures in CI. + assert jax.local_device_count() == 1, ( + f"Expected exactly 1 local JAX device (CPU), but got {jax.local_device_count()}. " + "This indicates that JAX was initialized before the XLA_FLAGS environment variable " + "could be set in dpo_correctness_base.py." + ) + + # Set JAX to CPU/TPU PRNG and SPMD defaults + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""): + os.environ["LIBTPU_INIT_ARGS"] = ( + os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + ) + + def setUp(self): + super().setUp() + if jax.default_backend() != "cpu": + raise RuntimeError( + "DPO correctness tests must run on CPU (please run with environment variable JAX_PLATFORMS=cpu), " + f"but JAX default backend is '{jax.default_backend()}'." + ) + InterceptingTrainingHooks.captured_metrics = [] + + def tearDown(self): + super().tearDown() + InterceptingTrainingHooks.captured_metrics = [] + + # ---------------------------------------------------------------------------- + # Private Helpers for test configuration and input generation + # ---------------------------------------------------------------------------- + + def build_tiny_qwen2_jax_config( + self, + max_target_length: int, + temp_dir: str, + init_weights_seed: int, + dataset_filename: str, + data_columns: list[str], + max_prompt_len: int | None = None, + extra_args: list[str] | None = None, + ) -> Any: + """Helper to build a tiny Qwen2 MaxText JAX config object for DPO correctness testing.""" + dataset_path = os.path.abspath(f"tests/assets/local_datasets/dpo/{dataset_filename}") + + # Hermetically resolve the tokenizer from the local pre-packaged assets + assets_root = os.environ.get("MAXTEXT_ASSETS_ROOT", "src/maxtext/assets") + tokenizer_path = os.path.join(assets_root, "tokenizers", "qwen3-tokenizer") + + argv = [ + "src/maxtext/configs/base.yml", + "model_name=qwen2.5-1.5b", + f"tokenizer_path={tokenizer_path}", + "scan_layers=False", + "attention=dot_product", + "per_device_batch_size=1", + f"max_target_length={max_target_length}", + "skip_jax_distributed_system=True", + "enable_nnx=True", + "pure_nnx=True", + "pure_nnx_decoder=False", + "remat_policy=full", + "log_config=0", + # Tiny architecture specifications. + # This JAX model configuration must be kept structurally identical to the PyTorch + # model configuration created in `create_pytorch_config` inside + # `tests/assets/logits_generation/dpo_pytorch_helpers.py` to allow direct parameter + # synchronization and logit comparison. + "base_emb_dim=64", + "head_dim=32", + "base_num_query_heads=2", + "base_num_kv_heads=2", + "base_mlp_dim=128", + "base_num_decoder_layers=2", + "override_model_config=True", + # Native input pipeline dataset specifications + "use_dpo=True", + "packing=False", + "dataset_type=hf", + "hf_path=json", + f"hf_train_files={dataset_path}", + "tokenize_train_data=True", + f"train_data_columns={data_columns}", + f"eval_data_columns={data_columns}", + "enable_data_shuffling=False", + "steps=3", + f"base_output_directory={temp_dir}", + # Set rope_interleave=False to match Hugging Face Qwen2's concatenated [x_i, x_(i+d/2)] RoPE layout, + # rather than MaxText's default adjacent [x_(2i), x_(2i+1)] interleaved RoPE layout. + "rope_interleave=False", + # Explicitly hard-code init_weights_seed to ensure model initialization is reproducible and self-contained + f"init_weights_seed={init_weights_seed}", + ] + if max_prompt_len is not None: + argv.append(f"dpo.max_prompt_length={max_prompt_len}") + if extra_args: + argv.extend(extra_args) + return pyconfig.initialize_pydantic(argv) diff --git a/tests/post_training/unit/dpo_trainer_correctness_test.py b/tests/post_training/unit/dpo_trainer_correctness_test.py new file mode 100644 index 0000000000..8263e9b7b6 --- /dev/null +++ b/tests/post_training/unit/dpo_trainer_correctness_test.py @@ -0,0 +1,116 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test validating DPO training step metrics against stored golden outputs. +This runs JAX-only training and does not require PyTorch at test execution time. +The test runs on CPU to ensure maximum reproducibility and eliminate GPU/TPU floating point differences. + +How to regenerate the golden data: + If the model implementation or training logic changes and you need to regenerate + the golden logits, please follow the instructions in the parity generation script: + tests/assets/logits_generation/generate_dpo_golden_data_and_compare_pytorch_logits.py +""" + +import json +import tempfile +import unittest +import pytest +from absl.testing import parameterized + +# Import shared base class and helper functions from shared correctness base +from tests.post_training.integration.dpo_correctness_base import ( + DPOCorrectnessTestBase, + run_jax_training, +) + +# Force this test to run only on CPU to avoid GPU/TPU floating point differences. +# It will be dynamically bypassed on active accelerator hardware testbeds. +pytestmark = [pytest.mark.post_training, pytest.mark.cpu_only] + + +class DPOTRLCorrectnessTest(DPOCorrectnessTestBase): + + @parameterized.named_parameters( + ( + "explicit_prompt_len_3_column", + "explicit_prompt_len_3_column", + 144, + "dpo_3_column_dataset.json", + ["prompt", "chosen", "rejected"], + ), + ( + "default_prompt_len_2_column", + "default_prompt_len_2_column", + None, + "dpo_2_column_dataset.json", + ["chosen", "rejected"], + ), + ) + def test_maxtext_dpo_correctness(self, name, max_prompt_len, dataset_filename, data_columns): + max_target_length = 256 + init_weights_seed = 0 + + # Load golden JAX metrics + golden_path = "tests/assets/golden_logits/golden_dpo_correctness.json" + with open(golden_path, "r", encoding="utf-8") as f: + golden_metrics = json.load(f) + + self.assertIn(name, golden_metrics, msg=f"Scenario {name} not found in golden metrics!") + golden = golden_metrics[name] + + # Configure JAX MaxText Config + with tempfile.TemporaryDirectory() as temp_dir: + config = self.build_tiny_qwen2_jax_config( + max_target_length=max_target_length, + temp_dir=temp_dir, + init_weights_seed=init_weights_seed, + dataset_filename=dataset_filename, + data_columns=data_columns, + max_prompt_len=max_prompt_len, + extra_args=["run_name=dpo_correctness_test"], + ) + + # Run JAX DPO Native Training Loop and get flat metrics + jax_metrics = run_jax_training(config) + + print(f"\n=== Parity Check against Golden Assets for scenario: {name} ===") + for key in ["loss_step_1", "margin_step_1", "loss", "margin", "chosen_logps", "rejected_logps"]: + print(f"Metric: {key:15s} | JAX: {jax_metrics[key]:.6f} | Golden: {golden[key]:.6f}") + + # Verify JAX policy and reference models did mutate and diverge after training steps + self.assertNotEqual( + jax_metrics["margin"], 0.0, msg="JAX policy model did not mutate and diverge after training steps!" + ) + + # Assert parity between JAX run and golden reference within safe thresholds (Option B). + # This allows the test to pass on both local CPU (CloudTop) and remote TPU VM/TPU + # environments by accommodating hardware-specific float32 divergence. + for key in ["loss_step_1", "margin_step_1", "loss", "margin"]: + self.assertAlmostEqual( + jax_metrics[key], + golden[key], + delta=self.DPO_LOSS_TOLERANCE, + msg=f"Metric {key} diverges from golden: JAX {jax_metrics[key]:.6f} vs Golden {golden[key]:.6f}", + ) + for key in ["chosen_logps", "rejected_logps"]: + self.assertAlmostEqual( + jax_metrics[key], + golden[key], + delta=self.LOG_PROBS_TOLERANCE, + msg=f"Metric {key} diverges from golden: JAX {jax_metrics[key]:.6f} vs Golden {golden[key]:.6f}", + ) + + +if __name__ == "__main__": + unittest.main() From 70dddd6283efd2a9a25229df90b2098eec6080e8 Mon Sep 17 00:00:00 2001 From: Igor Tsvetkov Date: Fri, 12 Jun 2026 15:43:42 -0700 Subject: [PATCH 2/2] Temp: Disable XLA_FLAGS forced CPU device count and isolate CPU test suite for fast assertion debugging in CI --- .github/workflows/build_and_test_maxtext.yml | 12 +++++++----- .../integration/dpo_correctness_base.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/build_and_test_maxtext.yml b/.github/workflows/build_and_test_maxtext.yml index d1e5248f52..1c33ec4992 100644 --- a/.github/workflows/build_and_test_maxtext.yml +++ b/.github/workflows/build_and_test_maxtext.yml @@ -129,7 +129,7 @@ jobs: maxtext_jupyter_notebooks: needs: build_and_upload_maxtext_package - if: needs.analyze_code_changes.outputs.run_notebooks == 'true' + if: false uses: ./.github/workflows/run_jupyter_notebooks.yml strategy: fail-fast: false @@ -145,7 +145,7 @@ jobs: tpu-tests: name: ${{ matrix.flavor || 'TPU' }} tests needs: [build_and_upload_maxtext_package] - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: false uses: ./.github/workflows/run_tests_coordinator.yml strategy: fail-fast: false @@ -160,7 +160,7 @@ jobs: gpu-tests: name: ${{ matrix.flavor || 'GPU' }} tests needs: [build_and_upload_maxtext_package] - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: false strategy: fail-fast: false matrix: @@ -180,16 +180,17 @@ jobs: strategy: fail-fast: false matrix: - flavor: [cpu-unit, cpu-post-training-unit] + flavor: [cpu-post-training-unit] with: flavor: ${{ matrix.flavor }} base_image: maxtext-unit-test-tpu:py312 is_scheduled_run: ${{ github.event_name == 'schedule' }} maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + pytest_extra_args: 'tests/post_training/unit/dpo_trainer_correctness_test.py' maxtext_tpu_pathways_unit_tests: needs: build_and_upload_maxtext_package - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: false uses: ./.github/workflows/run_pathways_tests.yml strategy: fail-fast: false @@ -208,6 +209,7 @@ jobs: maxtext_tpu_pathways_integration_tests: needs: build_and_upload_maxtext_package + if: false uses: ./.github/workflows/run_pathways_tests.yml strategy: fail-fast: false diff --git a/tests/post_training/integration/dpo_correctness_base.py b/tests/post_training/integration/dpo_correctness_base.py index 26cbe10e86..f61a3d4b3b 100644 --- a/tests/post_training/integration/dpo_correctness_base.py +++ b/tests/post_training/integration/dpo_correctness_base.py @@ -21,7 +21,7 @@ # Force JAX/XLA to only create 1 virtual CPU device to prevent batch size scaling # issues on large multi-core CI runners (which would cause the 20-example test # dataset to be dropped completely due to drop_remainder=True). -os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + " --xla_force_host_platform_device_count=1" +# os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + " --xla_force_host_platform_device_count=1" from typing import Any from absl.testing import parameterized