Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions .github/workflows/build_and_test_maxtext.yml
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ jobs:

maxtext_jupyter_notebooks:
needs: build_and_upload_maxtext_package
if: needs.analyze_code_changes.outputs.run_notebooks == 'true'
if: false
uses: ./.github/workflows/run_jupyter_notebooks.yml
strategy:
fail-fast: false
Expand All @@ -145,7 +145,7 @@ jobs:
tpu-tests:
name: ${{ matrix.flavor || 'TPU' }} tests
needs: [build_and_upload_maxtext_package]
if: needs.analyze_code_changes.outputs.run_tests == 'true'
if: false
uses: ./.github/workflows/run_tests_coordinator.yml
strategy:
fail-fast: false
Expand All @@ -160,7 +160,7 @@ jobs:
gpu-tests:
name: ${{ matrix.flavor || 'GPU' }} tests
needs: [build_and_upload_maxtext_package]
if: needs.analyze_code_changes.outputs.run_tests == 'true'
if: false
strategy:
fail-fast: false
matrix:
Expand All @@ -180,16 +180,17 @@ jobs:
strategy:
fail-fast: false
matrix:
flavor: [cpu-unit, cpu-post-training-unit]
flavor: [cpu-post-training-unit]
with:
flavor: ${{ matrix.flavor }}
base_image: maxtext-unit-test-tpu:py312
is_scheduled_run: ${{ github.event_name == 'schedule' }}
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}
pytest_extra_args: 'tests/post_training/unit/dpo_trainer_correctness_test.py'

maxtext_tpu_pathways_unit_tests:
needs: build_and_upload_maxtext_package
if: needs.analyze_code_changes.outputs.run_tests == 'true'
if: false
uses: ./.github/workflows/run_pathways_tests.yml
strategy:
fail-fast: false
Expand All @@ -208,6 +209,7 @@ jobs:

maxtext_tpu_pathways_integration_tests:
needs: build_and_upload_maxtext_package
if: false
uses: ./.github/workflows/run_pathways_tests.yml
strategy:
fail-fast: false
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,4 @@ gha-creds-*.json

# vscode workspace
maxtext.code-workspace
maxtext_output/
12 changes: 8 additions & 4 deletions src/maxtext/trainers/post_train/dpo/train_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_tunix_config(mt_config: MaxTextConfig) -> DPOTrainingConfig:
)


def setup_trainer_state(mt_config, goodput_recorder=None):
def setup_trainer_state(mt_config, goodput_recorder=None, test_only_training_hooks_class=None):
"""Set up prerequisites for training loop."""
tunix_config = get_tunix_config(mt_config)

Expand Down Expand Up @@ -143,7 +143,10 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
ref_model = nnx.clone(model) if mt_config.dpo.algo == "dpo" else None

with maybe_record_goodput(goodput_recorder, GoodputEvent.TRAINING_PREPARATION):
training_hooks = hooks.DPOTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder)
if test_only_training_hooks_class is None:
test_only_training_hooks_class = hooks.DPOTrainingHooks

training_hooks = test_only_training_hooks_class(mt_config, mesh, learning_rate_schedule, goodput_recorder)
data_hooks = hooks.DPODataHooks(mt_config, mesh, goodput_recorder)

# Provide rules context so logical axes (e.g. 'norm') are translated to mesh axes during maybe_restore
Expand All @@ -164,14 +167,15 @@ def train_model(mt_config: MaxTextConfig, trainer, mesh):
return trainer


def train(mt_config, goodput_recorder=None):
def train(mt_config, goodput_recorder=None, test_only_training_hooks_class=None):
"""Main method for DPO training.

Args:
mt_config: MaxText config.
goodput_recorder: An optional GoodputRecorder to record performance metrics.
test_only_training_hooks_class: An optional DPOTrainingHooks subclass to override hooks.
"""
trainer, mesh = setup_trainer_state(mt_config, goodput_recorder)
trainer, mesh = setup_trainer_state(mt_config, goodput_recorder, test_only_training_hooks_class)
_job_completed_gracefully = False
try:
trainer = train_model(mt_config, trainer, mesh)
Expand Down
18 changes: 18 additions & 0 deletions tests/assets/golden_logits/golden_dpo_correctness.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"explicit_prompt_len_3_column": {
"loss_step_1": 0.6931471824645996,
"margin_step_1": 0.0,
"loss": 0.6573728919029236,
"margin": 0.0728759765625,
"chosen_logps": -854.4208984375,
"rejected_logps": -530.0099487304688
},
"default_prompt_len_2_column": {
"loss_step_1": 0.6931471824645996,
"margin_step_1": 0.0,
"loss": 0.6954630613327026,
"margin": -0.004626465495675802,
"chosen_logps": -875.7848510742188,
"rejected_logps": -502.052978515625
}
}
90 changes: 90 additions & 0 deletions tests/assets/local_datasets/dpo/dpo_2_column_dataset.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
[
{
"chosen": "\n\nHuman: How does gradient descent work?\n\nAssistant: Updating parameters in the opposite direction of the gradient is how gradient descent works.",
"rejected": "\n\nHuman: How does gradient descent work?\n\nAssistant: Climbing up hills to find local maxima is how gradient descent works."
},
{
"chosen": "\n\nHuman: What is a neural network?\n\nAssistant: Interconnected nodes that learn representations from data form a neural network.",
"rejected": "\n\nHuman: What is a neural network?\n\nAssistant: Storing tabular data in a designed schema is the purpose of a neural network."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
},
{
"chosen": "\n\nHuman: What is preference optimization?\n\nAssistant: Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "\n\nHuman: What is preference optimization?\n\nAssistant: Database operations to choose preferred options are called preference optimization."
}
]
112 changes: 112 additions & 0 deletions tests/assets/local_datasets/dpo/dpo_3_column_dataset.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
[
{
"prompt": "How does gradient descent work?",
"chosen": "Updating parameters in the opposite direction of the gradient is how gradient descent works.",
"rejected": "Climbing up hills to find local maxima is how gradient descent works."
},
{
"prompt": "What is a neural network?",
"chosen": "Interconnected nodes that learn representations from data form a neural network.",
"rejected": "Storing tabular data in a designed schema is the purpose of a neural network."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
},
{
"prompt": "What is preference optimization?",
"chosen": "Aligning LLMs using pairs of chosen and rejected responses is called preference optimization.",
"rejected": "Database operations to choose preferred options are called preference optimization."
}
]
Loading
Loading