Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies

## 🔥 Latest news 🔥

* \[February 19, 2026\] [Qwen3-Next](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md) is now supported.
* \[December 22, 2025\] [Muon optimizer](https://kellerjordan.github.io/posts/muon) is now supported.
* \[December 10, 2025\] DeepSeek V3.1 is now supported. Use existing configs for [DeepSeek V3 671B](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/deepseek3-671b.yml) and load in V3.1 checkpoint to use model.
* \[December 9, 2025\] [New RL and SFT Notebook tutorials](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/examples) are available.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#!/bin/bash

# This script validates a pre-converted MaxText checkpoint against its original
# HuggingFace counterpart to ensure numerical correctness.
# This file is documentation for how to get started with Qwen3 Next.

# This file runs Step 1 on CPU.
# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16):
# Scanned format is better for training; unscanned format is better for decoding.
# 2. Run logit check, pre-training, fine-tuning, and decoding.
# ---
# Example Usage:
#
Expand All @@ -17,43 +20,41 @@

set -ex

# --- Configuration & Input Validation ---
export MODEL_NAME='qwen3-next-80b-a3b'
export TOKENIZER_PATH='Qwen/Qwen3-Next-80B-A3B-Instruct'

if [ -z "${MAXTEXT_CHECKPOINT_PATH}" ]; then
echo "ERROR: The MAXTEXT_CHECKPOINT_PATH environment variable is not set."
echo "Please set it to the full GCS path of the pre-converted MaxText checkpoint weights."
exit 1
fi
# Installing torch for checkpoint conversion and forward_pass_logit_checker.py
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

# Set a default for the HF model path if it's not provided by the user
if [ -z "${HF_MODEL_PATH}" ]; then
export HF_MODEL_PATH="Qwen/Qwen3-Next-80B-A3B-Instruct"
echo "HF_MODEL_PATH is not set, using default: ${HF_MODEL_PATH}"
# Ensure HF_TOKEN is set
if [ -z "${HF_TOKEN}" ]; then
echo "Error: HF_TOKEN environment variable is not set. Please export your Hugging Face token."
echo "Example: export HF_TOKEN=hf_..."
exit 1
fi

# Install dependencies required for the logit checker.
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

# --- Run the Forward Pass Logit Checker ---

echo "Validating MaxText checkpoint at ${MAXTEXT_CHECKPOINT_PATH}"
echo "Against original HF model: ${HF_MODEL_PATH}"

# This command runs the core validation logic.
JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \
tokenizer_type=huggingface \
tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/qwen3-tokenizer \
megablox=False \
sparse_matmul=False \
load_parameters_path=${MAXTEXT_CHECKPOINT_PATH} \
model_name=qwen3-next-80b-a3b \
checkpoint_storage_concurrent_gb=1024 \
skip_jax_distributed_system=True \
dtype=float32 \
weight_dtype=float32 \
matmul_precision=highest \
--hf_model_path=${HF_MODEL_PATH} \
--max_kl_div=0.03 \
--run_hf_model=True

echo "Validation complete."
if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
# this bucket will store all the files generated by MaxText during a run
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
echo "BASE_OUTPUT_PATH is not set"
fi
BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/}
echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}

# 1.1 Convert checkpoint to `scanned` format, more suitable for training
JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \
model_name=qwen3-next-80b-a3b \
base_output_directory=${BASE_OUTPUT_PATH}/scanned \
hf_access_token=${HF_TOKEN} \
scan_layers=true \
use_multimodal=false

# 1.2 Convert checkpoint to `unscanned` format, more suitable for decoding
JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \
model_name=qwen3-next-80b-a3b \
base_output_directory=${BASE_OUTPUT_PATH}/unscanned \
hf_access_token=${HF_TOKEN} \
scan_layers=false \
use_multimodal=false

Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/bin/bash

# This file is documentation for how to get started with Qwen3 Next.

# This file runs Step 2 on v5p-128 on a daily basis.
# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16):
# Scanned format is better for training; unscanned format is better for decoding.
# 2. Run logit check, pretraining, finetuning, and decoding.

# The golden logit can be generated by:
# python3 -m tests.assets.logits_generation.generate_hf_golden_logits --model-id=Qwen/Qwen3-Next-80B-A3B-Instruct --output-path=golden_data_qwen3-next-80b-a3b.jsonl --prompts='I love to' --hf-model-path=$local_bf16_path --trust-remote-code=False --hf-load-dtype=bfloat16

set -ex

export PYTHONPATH=$PYTHONPATH:$(pwd)/src

export MODEL_NAME='qwen3-next-80b-a3b'
export TOKENIZER_PATH='Qwen/Qwen3-Next-80B-A3B-Instruct'

# Installing torch for checkpoint conversion and forward_pass_logit_checker.py
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

# e.g., $HOME/maxtext/src/MaxText
export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}"

if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
# this bucket will store all the files generated by MaxText during a run
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
echo "BASE_OUTPUT_PATH is not set"
fi
BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/}
echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}

# Step 2:
# We define the checkpoint paths. This way it is easier to use these paths in the `train.py` and `decode.py` commands
# export SCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/scanned/0/items
# export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items
# Use a hard-coded golden checkpoint, rather than checkpoints generated by Step 1 as it is not in daily test.
SCANNED_CKPT_PATH=gs://maxtext-model-checkpoints/qwen3-next-80b-a3b/scanned/0/items
UNSCANNED_CKPT_PATH=gs://maxtext-model-checkpoints/qwen3-next-80b-a3b/unscanned/0/items
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
export DATASET_PATH=gs://maxtext-dataset

# Test whether the forward pass logits match the golden logits
# default golden_logits_path=/deps/tests/assets/golden_logits/golden_data_{MODEL_NAME}.jsonl, copied from gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl
GOLDEN_LOGITS_DISK_LOCATION="/deps/tests/assets/golden_logits/golden_data_${MODEL_NAME}.jsonl"
if [ ! -f "${GOLDEN_LOGITS_DISK_LOCATION}" ]; then
GOLDEN_LOGITS_PATH="gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl"
GOLDEN_LOGITS_DISK_LOCATION=/tmp/golden_data.jsonl
gcloud storage cp ${GOLDEN_LOGITS_PATH} ${GOLDEN_LOGITS_DISK_LOCATION}
fi

python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=true attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=True ici_fsdp_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --atol=1.5 --rtol=1.5 --max_kl_div=0.1

# Run pre-training - tokamax_gmm implementation
python3 -m maxtext.trainers.pre_train.train ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=tokamax_gmm_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 steps=5 max_target_length=1024

# Run fine-tuning - tokamax_gmm implementation
python3 -m maxtext.trainers.pre_train.train ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=tokamax_gmm_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 steps=5 max_target_length=1024 checkpoint_storage_concurrent_gb=1024


# Run decoding - tokamax_gmm implementation
# Note decode requires the access token for huggingface tokenizer even if the model is not gated
python3 -m maxtext.decode ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=64 max_target_length=512 ici_fsdp_parallelism=1 ici_tensor_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is "
119 changes: 98 additions & 21 deletions tests/end_to_end/tpu/qwen/next/run_qwen3_next.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,31 @@ For more details on the architecture, see the [Qwen3 Technical Blog](https://qwe

* * * * *

Pre-Training
---------------------
You can train from scratch to generate a new checkpoint. One example command to run pretraining with Qwen3-Next on v5p-64.

```sh
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
run_name=q3_next_pre_training \
per_device_batch_size=1 \
enable_checkpointing=false \
model_name=qwen3-next-80b-a3b \
ici_fsdp_parallelism=-1 \
steps=5 \
max_target_length=1024 \
async_checkpointing=false \
tokenizer_type=huggingface \
tokenizer_path=src/maxtext/assets/tokenizers/qwen3-tokenizer \
attention=flash \
dtype=bfloat16 \
weight_dtype=bfloat16 \
megablox=False \
sparse_matmul=False \
dataset_type=synthetic
```

Checkpoint Conversion
---------------------

Expand All @@ -22,18 +47,20 @@ To get started, you first need a MaxText-compatible checkpoint.
2. **Convert the Checkpoint**: Run the `convert_qwen3_next_scanned.py` script to convert the downloaded Hugging Face weights into the Orbax format required by MaxText.

```
python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_qwen3_next_scanned \
--base_model_path /path/to/qwen3_next_hf_checkpoint \
--maxtext_model_path gs://your-gcs-bucket/qwen3_next_maxtext_ckpt \
--model_size qwen3-next-80b-a3b
JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \
model_name=qwen3-next-80b-a3b \
base_output_directory=gs://your-gcs-bucket/qwen3_next_maxtext_ckpt \
hf_access_token=${HF_TOKEN} \
scan_layers=true \ # Set to false for unscanned checkpoint
use_multimodal=false
```

* * * * *

Pre-training and Fine-tuning
Fine-tuning
----------------------------

After converting the checkpoint, you can use it for fine-tuning or start a pre-training run from scratch. The command below is an example for fine-tuning on a v5p-512 slice. To pre-train, simply remove the `load_parameters_path` argument.
After converting the checkpoint, you can use it for fine-tuning. The command below is an example for fine-tuning on a v5p-64 slice.

```
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
Expand All @@ -43,40 +70,90 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
run_name=qwen3_next_finetuning \
per_device_batch_size=1 \
model_name=qwen3-next-80b-a3b \
steps=500 \
max_target_length=8192 \
ici_fsdp_parallelism=256 \
steps=30 \
max_target_length=4096 \
ici_fsdp_parallelism=-1 \
tokenizer_type=huggingface \
tokenizer_path=src/maxtext/assets/tokenizers/qwen3-tokenizer
```

## Decoding
One example command to run decoding with Qwen3-Next on v5p-64 with unscanned checkpoint for fast decoding.

```sh
python3 -m maxtext.decode src/maxtext/configs/base.yml \
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
load_parameters_path=${CONVERTED_CHECKPOINT} \
run_name=q3-next-decode \
per_device_batch_size=1 \
enable_checkpointing=false \
model_name=qwen3-next-80b-a3b \
max_prefill_predict_length=64 \
max_target_length=1024 \
tokenizer_type=huggingface \
tokenizer_path=src/maxtext/assets/tokenizers/qwen3-tokenizer \
attention=dot_product \
dtype=bfloat16 \
weight_dtype=bfloat16 \
megablox=False \
sparse_matmul=False \
ici_tensor_parallelism=1 \
ici_fsdp_parallelism=1 \
ici_expert_parallelism=-1 \
prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is " \
scan_layers=False
```

* * * * *

Correctness Validation
----------------------

To verify that the MaxText implementation is numerically equivalent to the original Hugging Face model, you can run the end-to-end test scripts. These scripts automate the logit comparison test for each model.
we perform two primary checks:

Before running, you must set the `MAXTEXT_CHECKPOINT_PATH` environment variable. You can also optionally set `HF_MODEL_PATH` to point to a local copy of the Hugging Face model.
* **Logit Comparison**: We compare the logits generated by our implementation against those from a HuggingFace implementation for a set of given prompts.
* **MMLU Score Validation**: We validate the MMLU score against established benchmarks.

### Qwen3-Next-80B-A3B

Bash
One example command to generate golden logits from HuggingFace for Qwen3-Next:

```sh
python3 -m tests.assets.logits_generation.generate_hf_golden_logits \
--model-id=Qwen/Qwen3-Next-80B-A3B-Instruct \
--output-path=golden_Qwen3_Next.jsonl \
--prompts='I love to;Today is a;What is the'
```
# Set the required path to your converted MaxText checkpoint
export MAXTEXT_CHECKPOINT_PATH=gs://your-gcs-bucket/qwen3-next-80b-a3b_maxtext_ckpt/0/items/

# (Optional) Set the path to your local Hugging Face checkpoint
# export HF_MODEL_PATH=/path/to/local/qwen3-next-80b-a3b_hf_checkpoint
You should be able to see logs like below:

```
...
File is stored locally at golden_Qwen3_Next.jsonl.
```

# Execute the validation script
bash tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh
Run command below to compare logits between HuggingFace and MaxText.

```sh
python3 -m tests.utils.forward_pass_logit_checker \
src/maxtext/configs/base.yml \
tokenizer_type=huggingface \
tokenizer_path=Qwen/Qwen3-Next-80B-A3B-Instruct \
load_parameters_path=${CONVERTED_CHECKPOINT} \
run_name=forward_pass_test_qwen3_next \
per_device_batch_size=1 \
model_name=qwen3-next-80b-a3b \
max_prefill_predict_length=4 \
max_target_length=4 \
scan_layers=false \
sparse_matmul=False \
dtype=float32 \
activations_in_float32=true \
matmul_precision=high \
--max_kl_div=2e-4 \
--golden_logits_path=${PWD}/golden_Qwen3_Next.jsonl
```

To run MMLU benchmarks and validate the model's performance, follow the instructions provided [here](../../../benchmarks/api_server/README.md).

## Supported MoE Strategies

This model implementation supports both **Token Dropping** and **Dropless** strategies for Mixture of Experts routing. Take a look at the MaxText [documentation](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/reference/core_concepts/moe_configuration.md) on MoE configs and flags to set based on desired strategy.

Loading