Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/MaxText/layers/engram.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from jax.sharding import Mesh
from flax import nnx

from MaxText.tokenizer import HFTokenizer
from maxtext.input_pipeline.tokenizer import HFTokenizer
from MaxText.common_types import MODEL_MODE_TRAIN, Array, Config
from MaxText.layers.embeddings import Embed
from MaxText.layers.initializers import nd_dense_init, NdInitializer
Expand Down
3 changes: 1 addition & 2 deletions src/MaxText/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
from MaxText.rl.evaluate_rl import evaluate
from MaxText.rl import utils_rl
from MaxText.input_pipeline.instruction_data_processing import load_template_from_file
from maxtext.input_pipeline.instruction_data_processing import load_template_from_file
from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils


Expand Down Expand Up @@ -370,7 +370,6 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
max_logging.log("Creating policy model with same config as reference model on trainer mesh")
actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices)


if trainer_config.debug.rl:
max_logging.log("Policy Model initialized successfully")
nnx.display(actor_model)
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from flax.training import train_state
import jax
from MaxText.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
from MaxText.multihost_dataloading import MultiHostDataLoadIterator, RemoteIterator
from MaxText.input_pipeline.input_pipeline_interface import PlaceHolderDataIterator
from maxtext.input_pipeline.multihost_dataloading import MultiHostDataLoadIterator, RemoteIterator
from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator
from maxtext.utils import exceptions
from maxtext.utils import max_logging
import numpy as np
Expand Down
7 changes: 3 additions & 4 deletions src/maxtext/examples/sft_train_and_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,12 @@

from flax import nnx

from MaxText.globals import MAXTEXT_REPO_ROOT
from MaxText import pyconfig
from MaxText.input_pipeline import instruction_data_processing
from MaxText.globals import MAXTEXT_REPO_ROOT
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
from maxtext.input_pipeline import instruction_data_processing
from maxtext.trainers.post_train.sft import train_sft
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils import max_logging, max_utils

# Suppress vLLM logging with a severity level below ERROR
os.environ["VLLM_LOGGING_LEVEL"] = "ERROR"
Expand Down
10 changes: 5 additions & 5 deletions src/maxtext/experimental/rl/grpo_input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@

import grain.python as grain

from MaxText.input_pipeline import input_pipeline_interface
from MaxText.input_pipeline import _input_pipeline_utils
from maxtext.input_pipeline import input_pipeline_interface
from maxtext.input_pipeline import input_pipeline_utils


class SingleHostDataLoader:
Expand Down Expand Up @@ -141,7 +141,7 @@ def preprocessing_pipeline(
)

dataset = dataset.map(
_input_pipeline_utils.tokenization,
input_pipeline_utils.tokenization,
batched=True,
fn_kwargs={
"hf_tokenizer": tokenizer,
Expand All @@ -151,7 +151,7 @@ def preprocessing_pipeline(
},
)
dataset = dataset.select_columns(data_column_names)
dataset = _input_pipeline_utils.HFDataSource(
dataset = input_pipeline_utils.HFDataSource(
dataset,
dataloading_host_index,
dataloading_host_count,
Expand All @@ -166,7 +166,7 @@ def lists2array(x):

operations = [
grain.MapOperation(lists2array),
_input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, add_true_length=True),
input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, add_true_length=True),
grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=drop_remainder),
]

Expand Down
9 changes: 3 additions & 6 deletions src/maxtext/inference/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,10 @@
from absl import app
from collections.abc import MutableMapping

from MaxText import maxengine
from MaxText import prefill_packing
from MaxText import pyconfig
from MaxText import maxengine, pyconfig
from maxtext.common import profiler
from maxtext.utils import gcs_utils
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils
from maxtext.input_pipeline.packing import prefill_packing
from maxtext.utils import gcs_utils, max_utils, maxtext_utils

import warnings

Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/inference/mlperf/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
# pylint: disable=no-name-in-module
from MaxText.maxengine import MaxEngine
from MaxText.maxengine import set_engine_vars_from_base_engine
from MaxText.prefill_packing import PrefillProcessor
from MaxText.prefill_packing import BatchedPrefillProcessor
from maxtext.input_pipeline.packing.prefill_packing import PrefillProcessor
from maxtext.input_pipeline.packing.prefill_packing import BatchedPrefillProcessor

DecodeState = Any
Params = Any
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/inference/offline_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from jax.experimental import mesh_utils

from MaxText.maxengine import MaxEngine
from MaxText.prefill_packing import PrefillProcessor, BatchedPrefillProcessor
from maxtext.input_pipeline.packing.prefill_packing import PrefillProcessor, BatchedPrefillProcessor
from maxtext.utils import max_logging
from maxtext.utils import max_utils

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023–2025 Google LLC
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from dataclasses import dataclass, field

from MaxText.input_pipeline import _input_pipeline_utils
from maxtext.input_pipeline import input_pipeline_utils
from maxtext.utils import max_logging


Expand Down Expand Up @@ -83,7 +83,7 @@ def process_dataset(config, dataset): # pylint: disable=redefined-outer-name
assert any(
set(data_column_names) == set(supported) for supported in supported_columns
), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_column_names}"
assert _input_pipeline_utils.is_conversational(
assert input_pipeline_utils.is_conversational(
dataset.features, data_column_names
), "Dataset is not in conversational format."

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
from grain.experimental import BestFitPackIterDataset, pick_performance_config
import grain.python as grain

from MaxText.input_pipeline import _input_pipeline_utils
from MaxText.input_pipeline import _grain_tokenizer
from MaxText import multihost_dataloading
from MaxText import tokenizer
from maxtext.input_pipeline import input_pipeline_utils
from maxtext.input_pipeline import grain_tokenizer
from maxtext.input_pipeline import multihost_dataloading
from maxtext.input_pipeline import tokenizer
from maxtext.utils import gcs_utils
from maxtext.utils import max_logging

Expand Down Expand Up @@ -199,10 +199,10 @@ def pretrain_preprocessing_pipeline(
):
"""Use grain pipeline to pre-process the dataset and return iterators for pretrain"""
if config.grain_file_type == "arrayrecord":
dataset = dataset.map(_input_pipeline_utils.ParseFeatures(data_columns, tokenize))
dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize))
dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
else:
dataset = dataset.map(_input_pipeline_utils.KeepFeatures(feature_names=data_columns))
dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=data_columns))

assert len(data_columns) == 1
text_column = data_columns[0]
Expand All @@ -224,13 +224,13 @@ def pretrain_preprocessing_pipeline(

if tokenize:
if config.use_truncation:
dataset = dataset.map(_grain_tokenizer.TokenizeAndTrim(text_column, config.max_target_length, tokenizer_model))
dataset = dataset.map(grain_tokenizer.TokenizeAndTrim(text_column, config.max_target_length, tokenizer_model))
else:
dataset = dataset.apply(_grain_tokenizer.TokenizeAndChunk(text_column, config.max_target_length, tokenizer_model))
dataset = dataset.apply(grain_tokenizer.TokenizeAndChunk(text_column, config.max_target_length, tokenizer_model))

data_columns = ("inputs", "targets")
rekey_dict = {col: text_column for col in data_columns}
dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict))
dataset = dataset.map(input_pipeline_utils.Rekey(rekey_dict))

# Pack and Batch examples.
batch_size = config.global_batch_size_to_load // jax.process_count()
Expand Down Expand Up @@ -273,15 +273,15 @@ def pretrain_preprocessing_pipeline(
"targets_position": "targets_positions",
"inputs_position": "inputs_positions",
}
dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict))
dataset = dataset.map(input_pipeline_utils.Rekey(rekey_dict))
else:
dataset = dataset.map(_input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id))
dataset = dataset.map(input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id))
batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id)
dataset = dataset.batch(batch_size, batch_fn=batch_fn)

# Shift inputs for teacher-forced training
dataset = dataset.map(
_input_pipeline_utils.ShiftData(
input_pipeline_utils.ShiftData(
ignored_ids=[pad_id],
axis=1,
)
Expand Down Expand Up @@ -313,8 +313,8 @@ def dpo_preprocessing_pipeline(
):
"""Use grain to pre-process the dataset and return iterators for dpo fine-tuning"""
if config.grain_file_type == "arrayrecord":
dataset = dataset.map(_input_pipeline_utils.ParseFeatures(data_columns, tokenize))
dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize))
dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
tokenizer_model = tokenizer.build_tokenizer(
config.tokenizer_path,
config.tokenizer_type,
Expand All @@ -331,9 +331,9 @@ def dpo_preprocessing_pipeline(
pad_id = -1

if tokenize:
dataset = dataset.map(_grain_tokenizer.TokenizeAndTrim(data_columns, config.max_target_length, tokenizer_model))
dataset = dataset.map(grain_tokenizer.TokenizeAndTrim(data_columns, config.max_target_length, tokenizer_model))

dataset = dataset.map(_input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id))
dataset = dataset.map(input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id))
batch_size = config.global_batch_size_to_load // jax.process_count()
batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id)
dataset = dataset.batch(batch_size, batch_fn=batch_fn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Any
import grain.python as grain
import numpy as np
from MaxText import tokenizer
from maxtext.input_pipeline import tokenizer


@dataclasses.dataclass
Expand Down
Loading
Loading