diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index c200aaa296..ced99492a4 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -459,6 +459,7 @@ def _maybe_shard_with_logical(self, inputs, logical_name): mesh=self.mesh, shard_mode=self.config.shard_mode, debug_sharding=self.config.debug_sharding, + extra_stack_level=1, ) def _logical_to_mesh_axes(self, logical_name): diff --git a/src/maxtext/layers/pipeline.py b/src/maxtext/layers/pipeline.py index e5045c88fd..79f71c78ff 100644 --- a/src/maxtext/layers/pipeline.py +++ b/src/maxtext/layers/pipeline.py @@ -133,6 +133,7 @@ def _maybe_shard_with_logical(self, inputs, logical_axes): mesh=self.mesh, rules=self.config.logical_axis_rules, debug_sharding=self.config.debug_sharding, + extra_stack_level=1, ) def _maybe_shard_with_name(self, inputs, sharding_name): diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index 4a00fd093a..6d502d92c4 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -184,16 +184,20 @@ def with_logical_constraint(self, x): mesh=self.mesh, shard_mode=self.config.shard_mode, debug_sharding=self.config.debug_sharding, + extra_stack_level=1, ) def dropout_op(self, x, deterministic): - return self.with_logical_constraint(self.dropout(x, deterministic=deterministic)) + dropout = self.dropout(x, deterministic=deterministic) + return self.with_logical_constraint(dropout) def pre_attention_norm_op(self, x): - return self.with_logical_constraint(self.pre_self_attention_layer_norm(x)) + pre_attention_norm = self.pre_self_attention_layer_norm(x) + return self.with_logical_constraint(pre_attention_norm) def post_attention_norm_op(self, x): - return self.with_logical_constraint(self.post_self_attention_layer_norm(x)) + post_attention_norm = self.post_self_attention_layer_norm(x) + return self.with_logical_constraint(post_attention_norm) def attention_op( self, @@ -332,9 +336,8 @@ def __init__( ) def mlp_op(self, x, deterministic): - return self.with_logical_constraint( - self.mlp(x, deterministic, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding) - ) + mlp = self.mlp(x, deterministic, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding) + return self.with_logical_constraint(mlp) def __call__( self, diff --git a/src/maxtext/models/llama2.py b/src/maxtext/models/llama2.py index 3ad5f38f96..252dadc768 100644 --- a/src/maxtext/models/llama2.py +++ b/src/maxtext/models/llama2.py @@ -133,6 +133,7 @@ def __init__( mesh=self.mesh, shard_mode=config.shard_mode, debug_sharding=config.debug_sharding, + extra_stack_level=1, ) def __call__( diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index b523de3cc2..b890e2f8b4 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -29,9 +29,11 @@ from maxtext.utils import max_logging from maxtext.utils import max_utils +import inspect # for debugging only +from pathlib import Path _LOGGED_ACTIVATION_SHARDINGS = set() -_LOGGED_LOGICAL_AXES = set() +_ACTIVATION_SHARDINGS_DUMP = [] def get_input_data_sharding(config, mesh): @@ -45,21 +47,62 @@ def get_input_data_sharding(config, mesh): return data_sharding -def maybe_shard_with_name(inputs, named_sharding, shard_mode, debug_sharding=False, extra_stack_level=0): +def _get_sharding_desc(inputs, extra_stack_level): + """Get the inputs sharding description using inspect module""" + frame = inspect.currentframe() + # Traverse back extra_stack_level times: + for _ in range(1 + extra_stack_level): + if frame is not None: + frame = frame.f_back + if frame is not None: + callers_local_vars = frame.f_locals.items() + + x = [var_name for var_name, var_val in callers_local_vars if var_val is inputs] + if len(x) > 0: + caller_path_full = inspect.stack()[1 + extra_stack_level].filename + # Use pathlib.Path to easily extract just the filename from the full path. + caller_filename = Path(caller_path_full).name + return f"{caller_filename[:-3]}/{x[0]}" + return "Unknown" + + +def maybe_shard_with_name( + inputs, named_sharding, shard_mode, debug_sharding=False, extra_stack_level=0, sharding_desc="", logical_axes=None +): """ In auto shardmode, this function hints inputs follow given named_sharding. In explicit shardmode, this function enforces inputs following named_sharding. + sharding_desc is description of inputs of upper layer(s) of caller (with the form of /). + It is used as key in log/dump files when debug_sharding==true """ if inputs is None: return None if ( debug_sharding and isinstance(inputs, Tracer) and isinstance(named_sharding, NamedSharding) ): # only print pspec for JitTracer + if not sharding_desc: + sharding_desc = _get_sharding_desc(inputs, extra_stack_level + 1) + + if not logical_axes: + logical_axes = "Unknown" + elif isinstance(logical_axes, list): + logical_axes = tuple(logical_axes) + pspec = remove_size_one_mesh_axis(getattr(named_sharding, "spec"), getattr(named_sharding, "mesh")) - log_key = (str(jax.typeof(inputs)), tuple(pspec), extra_stack_level) + log_key = (sharding_desc, str(jax.typeof(inputs)), tuple(pspec), extra_stack_level) if log_key not in _LOGGED_ACTIVATION_SHARDINGS: - max_logging.info(f"Physical: {log_key[0]:.<80} {log_key[1]}.", stacklevel=3 + extra_stack_level) + max_logging.info(f"{sharding_desc} Logical: {log_key[1]:.<60} {logical_axes}.", stacklevel=3 + extra_stack_level) + max_logging.info(f"{sharding_desc} Physical: {log_key[1]:.<60} {log_key[2]}.", stacklevel=3 + extra_stack_level) _LOGGED_ACTIVATION_SHARDINGS.add(log_key) + + _ACTIVATION_SHARDINGS_DUMP.append( + { + f"{sharding_desc}: {log_key[1]}": { + "logic_axes": f"{logical_axes}", + "PartitionSpec": f"P{log_key[2]}", + } + } + ) if shard_mode == ShardMode.EXPLICIT: return reshard(inputs, named_sharding) else: @@ -67,22 +110,20 @@ def maybe_shard_with_name(inputs, named_sharding, shard_mode, debug_sharding=Fal def maybe_shard_with_logical( - inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0 + inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0, sharding_desc="" ): """ A wrapper of maybe_shard_with_name when logical axes are inputs + sharding_desc is description of inputs of upper layer(s) of caller (with the form of /). + It is used as key in log/dump files when debug_sharding==true """ if inputs is None: return None - named_sharding = create_sharding(mesh, logical_axes, rules=rules) - - if debug_sharding and isinstance(inputs, Tracer): - log_key = (str(jax.typeof(inputs)), tuple(logical_axes), extra_stack_level) + if debug_sharding and not sharding_desc: + sharding_desc = _get_sharding_desc(inputs, extra_stack_level + 1) - if log_key not in _LOGGED_LOGICAL_AXES: - max_logging.info(f"Logical: {log_key[0]:.<60} {log_key[1]}", stacklevel=3 + extra_stack_level) - _LOGGED_LOGICAL_AXES.add(log_key) + named_sharding = create_sharding(mesh, logical_axes, rules=rules) return maybe_shard_with_name( inputs, @@ -90,6 +131,8 @@ def maybe_shard_with_logical( shard_mode, debug_sharding=debug_sharding, extra_stack_level=extra_stack_level + 1, + sharding_desc=sharding_desc, + logical_axes=logical_axes, ) diff --git a/src/maxtext/utils/vocabulary_tiling.py b/src/maxtext/utils/vocabulary_tiling.py index e6a4bdcc19..ece393cb06 100644 --- a/src/maxtext/utils/vocabulary_tiling.py +++ b/src/maxtext/utils/vocabulary_tiling.py @@ -89,7 +89,10 @@ def vocab_tiling_linen_loss( ) _maybe_shard_with_name = functools.partial( - maybe_shard_with_name, shard_mode=config.shard_mode, debug_sharding=config.debug_sharding + maybe_shard_with_name, + shard_mode=config.shard_mode, + debug_sharding=config.debug_sharding, + extra_stack_level=1, ) def _reshape(inputs, out_shape, out_sharding): diff --git a/tests/unit/sharding_desc_test.py b/tests/unit/sharding_desc_test.py new file mode 100755 index 0000000000..93979ddc12 --- /dev/null +++ b/tests/unit/sharding_desc_test.py @@ -0,0 +1,306 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for _get_sharding_desc and maybe_shard_with_name in MaxText sharding module.""" + +import os +import unittest +from unittest import mock + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P + +from maxtext.common.common_types import ShardMode +from maxtext.utils.sharding import _get_sharding_desc, maybe_shard_with_name + + +def _make_single_device_mesh(): + """Create a single-device mesh using one CPU device for testing.""" + devices = np.array(jax.devices()[:1]).reshape((1,)) + return Mesh(devices, axis_names=("data",)) + + +class TestGetShardingDesc(unittest.TestCase): + """Tests for the _get_sharding_desc function.""" + + # The filename prefix expected from the mock + expected_filename_prefix = os.path.basename(__file__)[:-3] + "/" + + # Each test method now accepts 'mock_pathlib' as an argument. + # Pytest will run the mock_pathlib fixture before each of these tests. + def test_direct_call_found(self): + """Tests when inputs matches a local variable in the direct calling frame.""" + + my_test_data = {"key": 123} + result = _get_sharding_desc(my_test_data, extra_stack_level=0) + assert result == self.expected_filename_prefix + "my_test_data" + + def test_direct_call_not_found_literal(self): + """Tests when inputs is a literal, which is not a named local variable.""" + result = _get_sharding_desc({"key": 456}, extra_stack_level=0) + assert result == "Unknown" + + def test_direct_call_not_found_expression(self): + """Tests when inputs is a result of an expression, not a variable.""" + data_a = {"a": 1} + data_b = {"b": 2} + result = _get_sharding_desc(data_a | data_b, extra_stack_level=0) + assert result == "Unknown" + + def test_nested_call_found_at_level_1(self): + """Tests when inputs matches a variable one level up the stack.""" + outer_var = ["a", "b"] + + def inner_func(): + return _get_sharding_desc(outer_var, extra_stack_level=1) + + result = inner_func() + assert result == self.expected_filename_prefix + "outer_var" + + def test_double_nested_call_found_at_level_2(self): + """Tests when inputs matches a variable two levels up the stack.""" + deep_var = 12345 + + def middle_func(): + def inner_func(): + return _get_sharding_desc(deep_var, extra_stack_level=2) + + return inner_func() + + result = middle_func() + assert result == self.expected_filename_prefix + "deep_var" + + def test_too_deep_extra_stack_level(self): + """Tests when extra_stack_level exceeds the actual stack depth.""" + some_inputs = {"c": 3} + result = _get_sharding_desc(some_inputs, extra_stack_level=100) + assert result == "Unknown" + + def test_multiple_matches_returns_first(self): + """Tests that if multiple local vars point to inputs, the first is returned.""" + data = {"test": 1} + data_alias = data + result = _get_sharding_desc(data_alias, extra_stack_level=0) + assert result == self.expected_filename_prefix + "data" + + def test_inputs_is_none_as_variable(self): + """Tests when inputs is None and assigned to a variable.""" + none_val = None + result = _get_sharding_desc(none_val, extra_stack_level=0) + assert result == self.expected_filename_prefix + "none_val" + + def test_inputs_is_none_literal(self): + """Tests when inputs is the None literal.""" + result = _get_sharding_desc(None, extra_stack_level=0) + assert result == "Unknown" + + +class TestMaybeShardWithNameNoneInput(unittest.TestCase): + """Tests for maybe_shard_with_name when inputs is None.""" + + def test_returns_none_auto_mode(self): + """When inputs is None in AUTO mode, should return None immediately.""" + result = maybe_shard_with_name(None, named_sharding=None, shard_mode=ShardMode.AUTO) + self.assertIsNone(result) + + def test_returns_none_explicit_mode(self): + """When inputs is None in EXPLICIT mode, should return None immediately.""" + result = maybe_shard_with_name(None, named_sharding=None, shard_mode=ShardMode.EXPLICIT) + self.assertIsNone(result) + + def test_returns_none_with_debug_sharding_enabled(self): + """When inputs is None with debug_sharding=True, should return None.""" + result = maybe_shard_with_name(None, named_sharding=None, shard_mode=ShardMode.AUTO, debug_sharding=True) + self.assertIsNone(result) + + def test_none_input_does_not_call_with_sharding_constraint(self): + """When inputs is None, should not call jax.lax.with_sharding_constraint.""" + with mock.patch("jax.lax.with_sharding_constraint") as mock_wsc: + maybe_shard_with_name(None, named_sharding=None, shard_mode=ShardMode.AUTO) + mock_wsc.assert_not_called() + + def test_none_input_does_not_call_reshard(self): + """When inputs is None, should not call reshard even in EXPLICIT mode.""" + with mock.patch("maxtext.utils.sharding.reshard") as mock_reshard: + maybe_shard_with_name(None, named_sharding=None, shard_mode=ShardMode.EXPLICIT) + mock_reshard.assert_not_called() + + +class TestMaybeShardWithNameAutoMode(unittest.TestCase): + """Tests for maybe_shard_with_name in AUTO shard mode.""" + + def setUp(self): + self.mesh = _make_single_device_mesh() + self.named_sharding = NamedSharding(self.mesh, P()) + self.inputs = jnp.ones((4, 4)) + + @mock.patch("jax.lax.with_sharding_constraint") + def test_auto_mode_calls_with_sharding_constraint(self, mock_wsc): + """AUTO mode should delegate to jax.lax.with_sharding_constraint.""" + mock_wsc.return_value = self.inputs + maybe_shard_with_name(self.inputs, self.named_sharding, ShardMode.AUTO) + mock_wsc.assert_called_once_with(self.inputs, self.named_sharding) + + @mock.patch("jax.lax.with_sharding_constraint") + def test_auto_mode_returns_wsc_result(self, mock_wsc): + """AUTO mode should return the value produced by with_sharding_constraint.""" + sentinel = object() + mock_wsc.return_value = sentinel + result = maybe_shard_with_name(self.inputs, self.named_sharding, ShardMode.AUTO) + self.assertIs(result, sentinel) + + @mock.patch("jax.lax.with_sharding_constraint") + def test_auto_mode_passes_inputs_unchanged(self, mock_wsc): + """AUTO mode should forward the original inputs to with_sharding_constraint.""" + mock_wsc.return_value = self.inputs + maybe_shard_with_name(self.inputs, self.named_sharding, ShardMode.AUTO) + call_inputs, _ = mock_wsc.call_args[0] + self.assertIs(call_inputs, self.inputs) + + @mock.patch("jax.lax.with_sharding_constraint") + def test_auto_mode_passes_sharding_unchanged(self, mock_wsc): + """AUTO mode should forward the original named_sharding to with_sharding_constraint.""" + mock_wsc.return_value = self.inputs + maybe_shard_with_name(self.inputs, self.named_sharding, ShardMode.AUTO) + _, call_sharding = mock_wsc.call_args[0] + self.assertIs(call_sharding, self.named_sharding) + + @mock.patch("maxtext.utils.sharding.reshard") + @mock.patch("jax.lax.with_sharding_constraint") + def test_auto_mode_does_not_call_reshard(self, mock_wsc, mock_reshard): + """AUTO mode should NOT call reshard.""" + mock_wsc.return_value = self.inputs + maybe_shard_with_name(self.inputs, self.named_sharding, ShardMode.AUTO) + mock_reshard.assert_not_called() + + +class TestMaybeShardWithNameExplicitMode(unittest.TestCase): + """Tests for maybe_shard_with_name in EXPLICIT shard mode.""" + + def setUp(self): + self.mesh = _make_single_device_mesh() + self.named_sharding = NamedSharding(self.mesh, P()) + self.inputs = jnp.ones((4, 4)) + + @mock.patch("maxtext.utils.sharding.reshard") + def test_explicit_mode_calls_reshard(self, mock_reshard): + """EXPLICIT mode should delegate to reshard.""" + mock_reshard.return_value = self.inputs + maybe_shard_with_name(self.inputs, self.named_sharding, ShardMode.EXPLICIT) + mock_reshard.assert_called_once_with(self.inputs, self.named_sharding) + + @mock.patch("maxtext.utils.sharding.reshard") + def test_explicit_mode_returns_reshard_result(self, mock_reshard): + """EXPLICIT mode should return the value produced by reshard.""" + sentinel = object() + mock_reshard.return_value = sentinel + result = maybe_shard_with_name(self.inputs, self.named_sharding, ShardMode.EXPLICIT) + self.assertIs(result, sentinel) + + @mock.patch("maxtext.utils.sharding.reshard") + def test_explicit_mode_passes_inputs_unchanged(self, mock_reshard): + """EXPLICIT mode should forward the original inputs to reshard.""" + mock_reshard.return_value = self.inputs + maybe_shard_with_name(self.inputs, self.named_sharding, ShardMode.EXPLICIT) + call_inputs, _ = mock_reshard.call_args[0] + self.assertIs(call_inputs, self.inputs) + + @mock.patch("maxtext.utils.sharding.reshard") + def test_explicit_mode_passes_sharding_unchanged(self, mock_reshard): + """EXPLICIT mode should forward the original named_sharding to reshard.""" + mock_reshard.return_value = self.inputs + maybe_shard_with_name(self.inputs, self.named_sharding, ShardMode.EXPLICIT) + _, call_sharding = mock_reshard.call_args[0] + self.assertIs(call_sharding, self.named_sharding) + + @mock.patch("maxtext.utils.sharding.reshard") + @mock.patch("jax.lax.with_sharding_constraint") + def test_explicit_mode_does_not_call_with_sharding_constraint(self, mock_wsc, mock_reshard): + """EXPLICIT mode should NOT call with_sharding_constraint.""" + mock_reshard.return_value = self.inputs + maybe_shard_with_name(self.inputs, self.named_sharding, ShardMode.EXPLICIT) + mock_wsc.assert_not_called() + + +class TestMaybeShardWithNameDebugSharding(unittest.TestCase): + """Tests for the debug_sharding logging behavior of maybe_shard_with_name.""" + + def setUp(self): + self.mesh = _make_single_device_mesh() + self.named_sharding = NamedSharding(self.mesh, P()) + self.inputs = jnp.ones((4, 4)) + # Reset the module-level log-deduplication cache before each test. + # sharding_module._LOGGED_ACTIVATION_SHARDINGS.clear() + + def tearDown(self): + # sharding_module._LOGGED_ACTIVATION_SHARDINGS.clear() + pass + + @mock.patch("jax.lax.with_sharding_constraint") + def test_no_log_when_debug_sharding_false(self, mock_wsc): + """When debug_sharding=False, max_logging.info should never be called.""" + mock_wsc.return_value = self.inputs + with mock.patch("maxtext.utils.sharding.max_logging") as mock_ml: + maybe_shard_with_name(self.inputs, self.named_sharding, ShardMode.AUTO, debug_sharding=False) + mock_ml.info.assert_not_called() + + @mock.patch("jax.lax.with_sharding_constraint") + def test_no_log_for_non_tracer_input(self, mock_wsc): + """When debug_sharding=True but input is a concrete array (not a Tracer), should not log.""" + mock_wsc.return_value = self.inputs + with mock.patch("maxtext.utils.sharding.max_logging") as mock_ml: + # A jnp array outside of jit is a concrete value, not a Tracer. + maybe_shard_with_name(self.inputs, self.named_sharding, ShardMode.AUTO, debug_sharding=True) + mock_ml.info.assert_not_called() + + def test_same_key_logged_only_once(self): + """The same (type, pspec, stack_level) combination should only produce one log entry.""" + with mock.patch("maxtext.utils.sharding.max_logging") as mock_ml: + with mock.patch("jax.lax.with_sharding_constraint", side_effect=lambda x, s: x): + + @jax.jit + def first_fn(x): + return maybe_shard_with_name(x, self.named_sharding, ShardMode.AUTO, debug_sharding=True) + + first_fn(self.inputs) + log_count_after_first = mock_ml.info.call_count + + # A distinct jit function forces re-tracing (Python body runs again), but the + # same log_key is already present in _LOGGED_ACTIVATION_SHARDINGS, so no new log. + @jax.jit + def second_fn(x): + return maybe_shard_with_name(x, self.named_sharding, ShardMode.AUTO, debug_sharding=True) + + second_fn(self.inputs) + self.assertEqual(mock_ml.info.call_count, log_count_after_first) + + def test_different_stack_levels_produce_separate_log_entries(self): + """Different extra_stack_level values create different log keys and each logs once.""" + with mock.patch("maxtext.utils.sharding.max_logging") as mock_ml: + with mock.patch("jax.lax.with_sharding_constraint", side_effect=lambda x, s: x): + + @jax.jit + def traced_fn(x): + maybe_shard_with_name(x, self.named_sharding, ShardMode.AUTO, debug_sharding=True, extra_stack_level=0) + return maybe_shard_with_name(x, self.named_sharding, ShardMode.AUTO, debug_sharding=True, extra_stack_level=1) + + traced_fn(self.inputs) + self.assertEqual(mock_ml.info.call_count, 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/run_sharding_dump.py b/tests/utils/run_sharding_dump.py index 8152640da0..886dd62efb 100644 --- a/tests/utils/run_sharding_dump.py +++ b/tests/utils/run_sharding_dump.py @@ -72,6 +72,8 @@ def run_single_dump(model_name: str, topology: str, num_slice: str) -> None: f"compile_topology_num_slices={num_slice}", f"model_name={model_name}", "weight_dtype=float32", + "log_config=false", + "debug_sharding=true", ], check=True, ) @@ -101,8 +103,7 @@ def main(argv: Sequence[str]) -> None: json_path_logical = base_path / "logical_shardings.json" if json_path_named.exists() and json_path_logical.exists(): - print(" -> Sharding files already exist. Skipping.") - continue + print(" -> Sharding files already exist. Regenerating to overwrite.") try: run_single_dump(model_name, topology, str(num_slice)) diff --git a/tests/utils/sharding_dump.py b/tests/utils/sharding_dump.py index 53559526e2..42b9e3ee41 100644 --- a/tests/utils/sharding_dump.py +++ b/tests/utils/sharding_dump.py @@ -28,12 +28,14 @@ from jax.tree_util import tree_flatten_with_path from MaxText import maxtext_utils from MaxText import pyconfig + from maxtext.utils.globals import MAXTEXT_REPO_ROOT +from maxtext.utils.sharding import _ACTIVATION_SHARDINGS_DUMP + from maxtext.models import models from maxtext.optimizers import optimizers from maxtext.trainers.pre_train.train_compile import get_shaped_inputs, get_topology_mesh, validate_config - Transformer = models.Transformer MODEL_NAMES = [ @@ -384,6 +386,20 @@ def partition_specs_to_json(logical_tree, shape_tree) -> dict[str, Any]: return logical_dict +def input_sharding_to_json() -> dict[str, Any]: + input_sharding = {} + input_sharding["Activation Sharding Dump"] = _ACTIVATION_SHARDINGS_DUMP + return input_sharding + + +def save_activation_shading_dict(output_path: str | Path, sharding_dict: dict) -> None: + """Save the activation sharding dict directly to a JSON file.""" + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(sharding_dict, f, indent=2) + + def save_json(output_path: str | Path, sharding_dict: dict) -> None: """Save dict to a JSON file.""" output_path = Path(output_path) @@ -409,6 +425,7 @@ def main(argv: Sequence[str]) -> None: config = pyconfig.initialize(argv) validate_config(config) + print(f"Sharding debug: {config.debug_sharding}") base_path = Path( f"{MAXTEXT_REPO_ROOT}/tests/utils/sharding_info/{config.model_name}/" @@ -416,6 +433,7 @@ def main(argv: Sequence[str]) -> None: ) json_path_named = base_path / "named_shardings.json" json_path_logical = base_path / "logical_shardings.json" + json_path_input = base_path / "input_shardings.json" try: topology_mesh = get_topology_mesh(config) @@ -436,12 +454,16 @@ def main(argv: Sequence[str]) -> None: # Logical: Tree of PartitionSpec (direct from get_shaped_inputs) logical_shardings = partition_specs_to_json(logical_annotations, shaped_train_args[0]) - print(f"Got {len(named_shardings)} Physical entries and" f" {len(logical_shardings)} Logical entries.") + # Input + input_shardings = input_sharding_to_json() + + print(f"Got {len(named_shardings)} Physical entries and {len(logical_shardings)} Logical entries.") # 2. Save New Output (Overwrite) print(f"\nSaving updated shardings to {base_path}...") save_json(json_path_named, named_shardings) save_json(json_path_logical, logical_shardings) + save_json(json_path_input, input_shardings) print(f"Finished: {config.model_name} {config.compile_topology}") diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json new file mode 100644 index 0000000000..2ca5429163 --- /dev/null +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json @@ -0,0 +1,148 @@ +{ + "Activation Sharding Dump": [ + { + "deepseek/inputs: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "deepseek/pre_attention_norm: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attention_mla/inputs_q: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attention_mla/inputs_kv: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attention_mla/q_nope: bfloat16[192,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/q_pe: bfloat16[192,2048,16,64]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/query: bfloat16[192,2048,16,192]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/key_nope: bfloat16[192,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/key_rope: bfloat16[192,2048,16,64]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/key: bfloat16[192,2048,16,192]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/value: bfloat16[192,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/query: bfloat16[192,16,2048,192]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/key: bfloat16[192,16,2048,192]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/value: bfloat16[192,16,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/out: bfloat16[192,2048,16,128]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "deepseek/attention_result: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "deepseek/post_attention_norm: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "linears/x: bfloat16[192,2048,10944]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "deepseek/mlp: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "deepseek/x: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "moe/inputs: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "moe/gate_logits: bfloat16[192,2048,64]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "linears/x: bfloat16[192,2048,2816]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "deepseek/mlp_lnx: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json new file mode 100644 index 0000000000..c3bec496eb --- /dev/null +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json @@ -0,0 +1,148 @@ +{ + "Activation Sharding Dump": [ + { + "deepseek/inputs: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "deepseek/pre_attention_norm: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attention_mla/inputs_q: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attention_mla/inputs_kv: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attention_mla/q_nope: bfloat16[768,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/q_pe: bfloat16[768,2048,16,64]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/query: bfloat16[768,2048,16,192]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/key_nope: bfloat16[768,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/key_rope: bfloat16[768,2048,16,64]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/key: bfloat16[768,2048,16,192]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/value: bfloat16[768,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/query: bfloat16[768,16,2048,192]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/key: bfloat16[768,16,2048,192]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/value: bfloat16[768,16,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/out: bfloat16[768,2048,16,128]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "deepseek/attention_result: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "deepseek/post_attention_norm: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "linears/x: bfloat16[768,2048,10944]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "deepseek/mlp: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "deepseek/x: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "moe/inputs: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "moe/gate_logits: bfloat16[768,2048,64]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "linears/x: bfloat16[768,2048,2816]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "deepseek/mlp_lnx: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json new file mode 100644 index 0000000000..6bdd341c12 --- /dev/null +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json @@ -0,0 +1,148 @@ +{ + "Activation Sharding Dump": [ + { + "deepseek/inputs: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "deepseek/pre_attention_norm: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attention_mla/inputs_q: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attention_mla/inputs_kv: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attention_mla/q_nope: bfloat16[96,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/q_pe: bfloat16[96,2048,16,64]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/query: bfloat16[96,2048,16,192]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/key_nope: bfloat16[96,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/key_rope: bfloat16[96,2048,16,64]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/key: bfloat16[96,2048,16,192]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/value: bfloat16[96,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/query: bfloat16[96,16,2048,192]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/key: bfloat16[96,16,2048,192]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/value: bfloat16[96,16,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/out: bfloat16[96,2048,16,128]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "deepseek/attention_result: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "deepseek/post_attention_norm: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "linears/x: bfloat16[96,2048,10944]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "deepseek/mlp: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "deepseek/x: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "moe/inputs: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "moe/gate_logits: bfloat16[96,2048,64]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "linears/x: bfloat16[96,2048,2816]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "deepseek/mlp_lnx: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json new file mode 100644 index 0000000000..a9b0c8c577 --- /dev/null +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json @@ -0,0 +1,148 @@ +{ + "Activation Sharding Dump": [ + { + "deepseek/inputs: bfloat16[384,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "deepseek/pre_attention_norm: bfloat16[384,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attention_mla/inputs_q: bfloat16[384,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attention_mla/inputs_kv: bfloat16[384,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attention_mla/q_nope: bfloat16[384,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/q_pe: bfloat16[384,2048,16,64]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/query: bfloat16[384,2048,16,192]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/key_nope: bfloat16[384,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/key_rope: bfloat16[384,2048,16,64]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/key: bfloat16[384,2048,16,192]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/value: bfloat16[384,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/query: bfloat16[384,16,2048,192]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/key: bfloat16[384,16,2048,192]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/value: bfloat16[384,16,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/out: bfloat16[384,2048,16,128]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "deepseek/attention_result: bfloat16[384,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "deepseek/post_attention_norm: bfloat16[384,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "linears/x: bfloat16[384,2048,10944]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "deepseek/mlp: bfloat16[384,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "deepseek/x: bfloat16[384,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "moe/inputs: bfloat16[384,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "moe/gate_logits: bfloat16[384,2048,64]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "linears/x: bfloat16[384,2048,2816]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "deepseek/mlp_lnx: bfloat16[384,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json new file mode 100644 index 0000000000..2ca5429163 --- /dev/null +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json @@ -0,0 +1,148 @@ +{ + "Activation Sharding Dump": [ + { + "deepseek/inputs: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "deepseek/pre_attention_norm: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attention_mla/inputs_q: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attention_mla/inputs_kv: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attention_mla/q_nope: bfloat16[192,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/q_pe: bfloat16[192,2048,16,64]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/query: bfloat16[192,2048,16,192]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/key_nope: bfloat16[192,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/key_rope: bfloat16[192,2048,16,64]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/key: bfloat16[192,2048,16,192]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/value: bfloat16[192,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/query: bfloat16[192,16,2048,192]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/key: bfloat16[192,16,2048,192]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/value: bfloat16[192,16,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/out: bfloat16[192,2048,16,128]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "deepseek/attention_result: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "deepseek/post_attention_norm: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "linears/x: bfloat16[192,2048,10944]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "deepseek/mlp: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "deepseek/x: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "moe/inputs: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "moe/gate_logits: bfloat16[192,2048,64]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "linears/x: bfloat16[192,2048,2816]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "deepseek/mlp_lnx: bfloat16[192,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json new file mode 100644 index 0000000000..c3bec496eb --- /dev/null +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json @@ -0,0 +1,148 @@ +{ + "Activation Sharding Dump": [ + { + "deepseek/inputs: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "deepseek/pre_attention_norm: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attention_mla/inputs_q: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attention_mla/inputs_kv: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attention_mla/q_nope: bfloat16[768,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/q_pe: bfloat16[768,2048,16,64]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/query: bfloat16[768,2048,16,192]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/key_nope: bfloat16[768,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/key_rope: bfloat16[768,2048,16,64]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/key: bfloat16[768,2048,16,192]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/value: bfloat16[768,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/query: bfloat16[768,16,2048,192]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/key: bfloat16[768,16,2048,192]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/value: bfloat16[768,16,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_mla/out: bfloat16[768,2048,16,128]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "deepseek/attention_result: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "deepseek/post_attention_norm: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "linears/x: bfloat16[768,2048,10944]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "deepseek/mlp: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "deepseek/x: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "moe/inputs: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "moe/gate_logits: bfloat16[768,2048,64]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "linears/x: bfloat16[768,2048,2816]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "deepseek/mlp_lnx: bfloat16[768,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/input_shardings.json new file mode 100644 index 0000000000..8409398a06 --- /dev/null +++ b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/input_shardings.json @@ -0,0 +1,70 @@ +{ + "Activation Sharding Dump": [ + { + "attentions/inputs_q: bfloat16[192,2048,2880]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attentions/inputs_kv: bfloat16[192,2048,2880]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attentions/query: bfloat16[192,2048,64,64]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attentions/key: bfloat16[192,2048,8,64]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attentions/value: bfloat16[192,2048,8,64]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/query: bfloat16[192,64,2048,64]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/key: bfloat16[192,8,2048,64]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/value: bfloat16[192,8,2048,64]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attentions/out: bfloat16[192,2048,64,64]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "moe/inputs: bfloat16[192,2048,2880]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "moe/gate_logits: bfloat16[192,2048,32]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P('fsdp', None, None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/input_shardings.json new file mode 100644 index 0000000000..37aeba83cc --- /dev/null +++ b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/input_shardings.json @@ -0,0 +1,70 @@ +{ + "Activation Sharding Dump": [ + { + "attentions/inputs_q: bfloat16[768,2048,2880]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attentions/inputs_kv: bfloat16[768,2048,2880]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attentions/query: bfloat16[768,2048,64,64]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attentions/key: bfloat16[768,2048,8,64]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attentions/value: bfloat16[768,2048,8,64]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/query: bfloat16[768,64,2048,64]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/key: bfloat16[768,8,2048,64]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/value: bfloat16[768,8,2048,64]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attentions/out: bfloat16[768,2048,64,64]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "moe/inputs: bfloat16[768,2048,2880]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "moe/gate_logits: bfloat16[768,2048,32]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/input_shardings.json new file mode 100644 index 0000000000..5c25c0f2a5 --- /dev/null +++ b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/input_shardings.json @@ -0,0 +1,70 @@ +{ + "Activation Sharding Dump": [ + { + "attentions/inputs_q: bfloat16[96,2048,2880]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attentions/inputs_kv: bfloat16[96,2048,2880]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attentions/query: bfloat16[96,2048,64,64]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attentions/key: bfloat16[96,2048,8,64]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attentions/value: bfloat16[96,2048,8,64]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/query: bfloat16[96,64,2048,64]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/key: bfloat16[96,8,2048,64]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/value: bfloat16[96,8,2048,64]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attentions/out: bfloat16[96,2048,64,64]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "moe/inputs: bfloat16[96,2048,2880]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "moe/gate_logits: bfloat16[96,2048,32]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P('fsdp', None, None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/input_shardings.json new file mode 100644 index 0000000000..26f3df46f2 --- /dev/null +++ b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/input_shardings.json @@ -0,0 +1,70 @@ +{ + "Activation Sharding Dump": [ + { + "attentions/inputs_q: bfloat16[384,2048,2880]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attentions/inputs_kv: bfloat16[384,2048,2880]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attentions/query: bfloat16[384,2048,64,64]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attentions/key: bfloat16[384,2048,8,64]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attentions/value: bfloat16[384,2048,8,64]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/query: bfloat16[384,64,2048,64]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/key: bfloat16[384,8,2048,64]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/value: bfloat16[384,8,2048,64]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attentions/out: bfloat16[384,2048,64,64]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "moe/inputs: bfloat16[384,2048,2880]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "moe/gate_logits: bfloat16[384,2048,32]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/input_shardings.json new file mode 100644 index 0000000000..8409398a06 --- /dev/null +++ b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/input_shardings.json @@ -0,0 +1,70 @@ +{ + "Activation Sharding Dump": [ + { + "attentions/inputs_q: bfloat16[192,2048,2880]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attentions/inputs_kv: bfloat16[192,2048,2880]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attentions/query: bfloat16[192,2048,64,64]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attentions/key: bfloat16[192,2048,8,64]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attentions/value: bfloat16[192,2048,8,64]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/query: bfloat16[192,64,2048,64]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/key: bfloat16[192,8,2048,64]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/value: bfloat16[192,8,2048,64]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attentions/out: bfloat16[192,2048,64,64]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "moe/inputs: bfloat16[192,2048,2880]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "moe/gate_logits: bfloat16[192,2048,32]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P('fsdp', None, None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/input_shardings.json new file mode 100644 index 0000000000..37aeba83cc --- /dev/null +++ b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/input_shardings.json @@ -0,0 +1,70 @@ +{ + "Activation Sharding Dump": [ + { + "attentions/inputs_q: bfloat16[768,2048,2880]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attentions/inputs_kv: bfloat16[768,2048,2880]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attentions/query: bfloat16[768,2048,64,64]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attentions/key: bfloat16[768,2048,8,64]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attentions/value: bfloat16[768,2048,8,64]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/query: bfloat16[768,64,2048,64]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/key: bfloat16[768,8,2048,64]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/value: bfloat16[768,8,2048,64]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attentions/out: bfloat16[768,2048,64,64]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "moe/inputs: bfloat16[768,2048,2880]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "moe/gate_logits: bfloat16[768,2048,32]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/input_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/input_shardings.json new file mode 100644 index 0000000000..0d5b2d8c24 --- /dev/null +++ b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/input_shardings.json @@ -0,0 +1,64 @@ +{ + "Activation Sharding Dump": [ + { + "attentions/inputs_q: bfloat16[192,2048,1024]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attentions/inputs_kv: bfloat16[192,2048,1024]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attentions/query: bfloat16[192,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attentions/key: bfloat16[192,2048,8,128]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attentions/value: bfloat16[192,2048,8,128]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/query: bfloat16[192,16,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/key: bfloat16[192,8,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/value: bfloat16[192,8,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attentions/out: bfloat16[192,2048,16,128]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "linears/x: bfloat16[192,2048,3072]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "PartitionSpec": "P('fsdp', None, None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/input_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/input_shardings.json new file mode 100644 index 0000000000..2146f74797 --- /dev/null +++ b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/input_shardings.json @@ -0,0 +1,64 @@ +{ + "Activation Sharding Dump": [ + { + "attentions/inputs_q: bfloat16[768,2048,1024]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attentions/inputs_kv: bfloat16[768,2048,1024]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attentions/query: bfloat16[768,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attentions/key: bfloat16[768,2048,8,128]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attentions/value: bfloat16[768,2048,8,128]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/query: bfloat16[768,16,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/key: bfloat16[768,8,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/value: bfloat16[768,8,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attentions/out: bfloat16[768,2048,16,128]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "linears/x: bfloat16[768,2048,3072]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/input_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/input_shardings.json new file mode 100644 index 0000000000..4a5224cd6d --- /dev/null +++ b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/input_shardings.json @@ -0,0 +1,64 @@ +{ + "Activation Sharding Dump": [ + { + "attentions/inputs_q: bfloat16[96,2048,1024]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attentions/inputs_kv: bfloat16[96,2048,1024]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attentions/query: bfloat16[96,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attentions/key: bfloat16[96,2048,8,128]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attentions/value: bfloat16[96,2048,8,128]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/query: bfloat16[96,16,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/key: bfloat16[96,8,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/value: bfloat16[96,8,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attentions/out: bfloat16[96,2048,16,128]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "linears/x: bfloat16[96,2048,3072]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "PartitionSpec": "P('fsdp', None, None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/input_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/input_shardings.json new file mode 100644 index 0000000000..6bb047297d --- /dev/null +++ b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/input_shardings.json @@ -0,0 +1,64 @@ +{ + "Activation Sharding Dump": [ + { + "attentions/inputs_q: bfloat16[384,2048,1024]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attentions/inputs_kv: bfloat16[384,2048,1024]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attentions/query: bfloat16[384,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attentions/key: bfloat16[384,2048,8,128]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attentions/value: bfloat16[384,2048,8,128]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/query: bfloat16[384,16,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/key: bfloat16[384,8,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/value: bfloat16[384,8,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attentions/out: bfloat16[384,2048,16,128]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "linears/x: bfloat16[384,2048,3072]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/input_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/input_shardings.json new file mode 100644 index 0000000000..0d5b2d8c24 --- /dev/null +++ b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/input_shardings.json @@ -0,0 +1,64 @@ +{ + "Activation Sharding Dump": [ + { + "attentions/inputs_q: bfloat16[192,2048,1024]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attentions/inputs_kv: bfloat16[192,2048,1024]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P('fsdp', None, None)" + } + }, + { + "attentions/query: bfloat16[192,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attentions/key: bfloat16[192,2048,8,128]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attentions/value: bfloat16[192,2048,8,128]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/query: bfloat16[192,16,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/key: bfloat16[192,8,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/value: bfloat16[192,8,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attentions/out: bfloat16[192,2048,16,128]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "linears/x: bfloat16[192,2048,3072]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "PartitionSpec": "P('fsdp', None, None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/input_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/input_shardings.json new file mode 100644 index 0000000000..2146f74797 --- /dev/null +++ b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/input_shardings.json @@ -0,0 +1,64 @@ +{ + "Activation Sharding Dump": [ + { + "attentions/inputs_q: bfloat16[768,2048,1024]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attentions/inputs_kv: bfloat16[768,2048,1024]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + }, + { + "attentions/query: bfloat16[768,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attentions/key: bfloat16[768,2048,8,128]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attentions/value: bfloat16[768,2048,8,128]": { + "logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/query: bfloat16[768,16,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/key: bfloat16[768,8,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attention_op/value: bfloat16[768,8,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "attentions/out: bfloat16[768,2048,16,128]": { + "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" + } + }, + { + "linears/x: bfloat16[768,2048,3072]": { + "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "PartitionSpec": "P(('data', 'fsdp'), None, None)" + } + } + ] +} \ No newline at end of file