Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion benchmarks/recipes/user_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class UserConfig:
zone: str = "us-east5-b"
device_type: str = "v6e-256"
priority: str = "medium"
base_output_directory: str = None

# Images for env
server_image: str = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server"
Expand Down Expand Up @@ -97,7 +98,7 @@ def __post_init__(self):
self.worker_flags,
)
self.headless_workload_name = f"{self.user[:3]}-headless"
self.base_output_directory = f"gs://{self.user}-{self.region}/{self.user}-"
self.base_output_directory = self.base_output_directory or f"gs://{self.user}-{self.region}/{self.user}-"

device_base_type = self.device_type.split("-", maxsplit=1)[0]
self.models = build_user_models(
Expand All @@ -124,4 +125,7 @@ def __post_init__(self):
selected_model_framework=["pathways"],
selected_model_names=["llama3_1_8b_8192"],
priority="medium",
base_output_directory=None, # GCS Bucket path
# Optional parameters, useful for single controller data loading optimizations
# proxy_flags="--sidecar_name=external",
)
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ tensorflow
tiktoken
tokamax
transformers
uvloop
qwix
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ opt-einsum>=3.4.0
optax>=0.2.6
optree>=0.18.0
optype>=0.14.0
orbax-checkpoint>=0.11.28
orbax-checkpoint>=0.11.33
packaging>=25.0
pandas>=2.3.3
parameterized>=0.9.0
Expand Down Expand Up @@ -245,6 +245,7 @@ tzdata>=2025.2
uritemplate>=4.2.0
urllib3>=2.5.0
uvicorn>=0.38.0
uvloop>=0.19.0
virtualenv>=20.35.4
wadler-lindig>=0.1.7
websockets>=15.0.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ opt-einsum>=3.4.0
optax>=0.2.6
optree>=0.18.0
optype>=0.14.0
orbax-checkpoint>=0.11.28
orbax-checkpoint>=0.11.33
packaging>=25.0
pandas>=2.3.3
parameterized>=0.9.0
Expand Down Expand Up @@ -237,6 +237,7 @@ tzdata>=2025.2
uritemplate>=4.2.0
urllib3>=2.5.0
uvicorn>=0.38.0
uvloop>=0.19.0
virtualenv>=20.35.4
wadler-lindig>=0.1.7
websockets>=15.0.1
Expand Down
14 changes: 14 additions & 0 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ def create_orbax_checkpoint_manager(
enable_continuous_checkpointing: bool = False,
max_num_checkpoints_to_keep: int = 10,
checkpoint_storage_concurrent_gb: int = 96,
enable_single_controller: bool = False,
colocated_python_checkpointing: bool = False,
enable_single_replica_ckpt_restoring: bool = False,
):
"""Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
if not enable_checkpointing:
Expand Down Expand Up @@ -269,6 +272,17 @@ def create_orbax_checkpoint_manager(
logger=orbax_logger,
)

# Use Colocated Python checkpointing optimization (Single Controller only).
if enable_single_controller and colocated_python_checkpointing:
max_logging.log("Registering colocated python array handler")
checkpointing_impl = ocp.pathways.CheckpointingImpl.from_options(
use_colocated_python=True,
)
ocp.pathways.register_type_handlers(
use_single_replica_array_handler=enable_single_replica_ckpt_restoring,
checkpointing_impl=checkpointing_impl,
)

max_logging.log("Checkpoint manager created!")
return manager

Expand Down
3 changes: 3 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ enable_orbax_v1: False
checkpoint_conversion_fn: none
# optional checkpoint context to use for loading. options: "orbax", "safetensors"
source_checkpoint_layout: "orbax"

# Only applicable to Single Controller/Pathways on Cloud. Experimental feature, under testing
colocated_python_checkpointing: False
############################### end checkpointing ##################################


Expand Down
66 changes: 49 additions & 17 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ class Checkpointing(BaseModel):
True, description="If True, saves a final checkpoint upon training completion."
)
enable_continuous_checkpointing: bool = Field(False, description="If True, enables continuous checkpointing.")
colocated_python_checkpointing: bool = Field(
False,
description="If True, enables checkpointing from remote TPU VMs instead of head node on pathways.",
)


class OrbaxStorage(BaseModel):
Expand Down Expand Up @@ -599,7 +603,8 @@ class MoEGeneral(BaseModel):
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
use_custom_sort_vjp: bool = Field(
True, description="Whether to use a custom VJP sort for efficient backward pass processing in sparse matmul."
True,
description="Whether to use a custom VJP sort for efficient backward pass processing in sparse matmul.",
)
use_ring_of_experts: bool = Field(
False,
Expand Down Expand Up @@ -1003,7 +1008,8 @@ class GrainDataset(BaseModel):
grain_train_files: PathStr = Field("", description="Path to Grain training files.")
grain_eval_files: PathStr = Field("", description="Path to Grain evaluation files.")
grain_train_mixture_config_path: PathStr = Field(
"", description="Path to a JSON file specifying the mixture weights for Grain training data."
"",
description="Path to a JSON file specifying the mixture weights for Grain training data.",
)
grain_file_type: str = Field("arrayrecord", description="File type for Grain data.")
grain_worker_count: int = Field(1, description="Number of workers for Grain data loading.")
Expand Down Expand Up @@ -1049,10 +1055,12 @@ class Distillation(BaseModel):
# These dictionaries allow flexible configuration injection for Student/Teacher
# without needing to duplicate the entire MaxText schema here.
student_overrides: dict[str, Any] = Field(
default_factory=dict, description="Overrides specific to the Student model (e.g., {'num_query_heads': 16})."
default_factory=dict,
description="Overrides specific to the Student model (e.g., {'num_query_heads': 16}).",
)
teacher_overrides: dict[str, Any] = Field(
default_factory=dict, description="Overrides specific to the Teacher model (e.g., {'num_query_heads': 64})."
default_factory=dict,
description="Overrides specific to the Teacher model (e.g., {'num_query_heads': 64}).",
)

# --- Loss Params ---
Expand Down Expand Up @@ -1122,16 +1130,22 @@ class Optimizer(BaseModel):
)
learning_rate: NonNegativeFloat = Field(3.0e-5, description="The peak learning rate.")
lr_schedule_type: LearningRateScheduleType = Field(
LearningRateScheduleType.COSINE, description="The type of learning rate schedule to use."
LearningRateScheduleType.COSINE,
description="The type of learning rate schedule to use.",
)
learning_rate_final_fraction: float = Field(
0.1, description="Final LR as a fraction of peak LR (applies to both cosine and WSD schedules)."
0.1,
description="Final LR as a fraction of peak LR (applies to both cosine and WSD schedules).",
)
wsd_decay_steps_fraction: float = Field(
0.1, ge=0.0, le=1.0, description="Fraction of total steps for decay phase in WSD schedule."
0.1,
ge=0.0,
le=1.0,
description="Fraction of total steps for decay phase in WSD schedule.",
)
wsd_decay_style: WsdDecayStyle = Field(
WsdDecayStyle.LINEAR, description="The decay style for WSD schedule ('linear' or 'cosine')."
WsdDecayStyle.LINEAR,
description="The decay style for WSD schedule ('linear' or 'cosine').",
)
warmup_steps_fraction: float = Field(0.1, ge=0.0, le=1.0, description="Fraction of total steps for LR warmup.")
learning_rate_schedule_steps: int = Field(
Expand Down Expand Up @@ -1172,10 +1186,12 @@ class Muon(BaseModel):

muon_beta: float = Field(0.95, description="Decay rate for the exponentially weighted average of grads.")
muon_weight_decay: float = Field(
0, description="Strength of the weight decay regularization. This is multiplied with the learning rate."
0,
description="Strength of the weight decay regularization. This is multiplied with the learning rate.",
)
muon_consistent_rms: None | float = Field(
None, description="If None, apply width scaling to updates. If float, apply consistent rms scaling (recommend 0.2)."
None,
description="If None, apply width scaling to updates. If float, apply consistent rms scaling (recommend 0.2).",
)


Expand Down Expand Up @@ -1552,7 +1568,8 @@ class RLHardware(BaseModel):
"than one model replica in rollout.",
)
rollout_tensor_parallelism: int = Field(
-1, description="Tensor parallelism per replica for rollout. If not specified, it will be auto-determined."
-1,
description="Tensor parallelism per replica for rollout. If not specified, it will be auto-determined.",
)


Expand All @@ -1567,7 +1584,8 @@ class VLLM(BaseModel):
max_num_seqs: Optional[int] = Field(None, description="Max number of sequences in vLLM.")
vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.")
vllm_hf_overrides: dict[str, Any] = Field(
default_factory=dict, description="Overrides for HuggingFace model config for MaxText model."
default_factory=dict,
description="Overrides for HuggingFace model config for MaxText model.",
)
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")

Expand Down Expand Up @@ -1646,7 +1664,8 @@ class Engram(BaseModel):
engram_num_heads: int = Field(8, description="Number of heads dedicated to the Engram.")
engram_head_dim: int = Field(1280, description="Head dimension for heads.")
engram_vocab_bases: list[int] = Field(
default_factory=list, description="List of minimum head vocab sizes for each n-gram order."
default_factory=list,
description="List of minimum head vocab sizes for each n-gram order.",
)
engram_max_ngram_size: int = Field(3, description="The max 'n' in N-gram.")
engram_kernel_size: int = Field(4, description="Temporal window size for Engram convolution.")
Expand Down Expand Up @@ -1892,7 +1911,8 @@ class MaxTextConfig(

debug: Debug = Field(default_factory=Debug, description="Configuration for debugging options.")
rl: RL = Field(
default_factory=RL, description="Configuration for RL algorithms like Group Relative Policy Optimization (GRPO)."
default_factory=RL,
description="Configuration for RL algorithms like Group Relative Policy Optimization (GRPO).",
)
model_config = ConfigDict(extra="forbid", protected_namespaces=())

Expand Down Expand Up @@ -1941,7 +1961,11 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig":
filter(
os.path.exists,
(
os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", os.path.basename(tokenizer_path)),
os.path.join(
MAXTEXT_ASSETS_ROOT,
"tokenizers",
os.path.basename(tokenizer_path),
),
os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", tokenizer_path),
),
),
Expand Down Expand Up @@ -2093,7 +2117,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
self.global_batch_size_to_eval_on,
self.micro_batch_size_to_eval_on,
) = calculate_global_batch_sizes(
self.eval_per_device_batch_size, self.expansion_factor_real_data, self.num_target_devices, 1
self.eval_per_device_batch_size,
self.expansion_factor_real_data,
self.num_target_devices,
1,
)

# Calculate ramp-up batch size parameters if enabled.
Expand Down Expand Up @@ -2262,6 +2289,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
raise ValueError("`local_checkpoint_period` must be > 0 for multi-tier checkpointing.")
if self.multi_tier_checkpointing_backup_interval_minutes <= 0:
raise ValueError("`multi_tier_checkpointing_backup_interval_minutes` must be > 0.")
if self.colocated_python_checkpointing and not self.enable_single_controller:
raise ValueError("`colocated_python_checkpointing` is only supported with `enable_single_controller` set to True.")
if self.enable_emergency_checkpoint:
if not self.local_checkpoint_directory:
raise ValueError("`local_checkpoint_directory` must be set for emergency checkpointing.")
Expand Down Expand Up @@ -2423,7 +2452,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
raise ValueError("When dataset_type=grain, please set grain_train_files or grain_train_mixture_config_path")
if self.eval_interval > 0 and not self.grain_eval_files:
raise ValueError("Please specify grain_eval_files or set eval_interval to <=0.")
if self.tokenizer_type not in (TokenizerType.SENTENCEPIECE, TokenizerType.HUGGINGFACE):
if self.tokenizer_type not in (
TokenizerType.SENTENCEPIECE,
TokenizerType.HUGGINGFACE,
):
raise ValueError(
f"grain pipeline only supports tokenizer_type: sentencepiece, huggingface, but got {self.tokenizer_type}"
)
Expand Down
5 changes: 4 additions & 1 deletion src/maxtext/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def create_training_tools(config, model, mesh):
# TODO(b/368121306): Remove this once zarr3 support is plumbed on the backend
use_ocdbt = config.checkpoint_storage_use_ocdbt
use_zarr3 = config.checkpoint_storage_use_zarr3
if config.enable_single_controller:
if config.enable_single_controller and not config.colocated_python_checkpointing:
use_ocdbt, use_zarr3 = False, False

checkpoint_dir = ""
Expand All @@ -79,6 +79,9 @@ def create_training_tools(config, model, mesh):
config.enable_continuous_checkpointing,
config.max_num_checkpoints_to_keep,
config.checkpoint_storage_concurrent_gb,
config.enable_single_controller,
config.colocated_python_checkpointing,
config.enable_single_replica_ckpt_restoring,
)

return init_rng, checkpoint_manager, learning_rate_schedule, tx
Expand Down
Loading