diff --git a/torchtitan/experiments/simple_fsdp/backend.py b/torchtitan/experiments/simple_fsdp/backend.py index 7fc9d13bf4..0244b182b6 100644 --- a/torchtitan/experiments/simple_fsdp/backend.py +++ b/torchtitan/experiments/simple_fsdp/backend.py @@ -8,6 +8,7 @@ import torch import torch._functorch.config as functorch_config +from functorch.compile import min_cut_rematerialization_partition from torchtitan.tools.logging import logger from .job_config import Compile as CompileConfig @@ -60,6 +61,7 @@ def aot_eager_autobucketing_reordering_pass( backend = aot_autograd_backend( fw_compiler=aot_eager_autobucketing_reordering_pass, bw_compiler=aot_eager_autobucketing_reordering_pass, + partition_fn=min_cut_rematerialization_partition, keep_inference_input_mutations=True, ) elif compile_config.backend == "inductor": @@ -108,6 +110,7 @@ def aot_eager_transformer_block_bucketing_reordering_pass( backend = aot_autograd_backend( fw_compiler=aot_eager_transformer_block_bucketing_reordering_pass, bw_compiler=aot_eager_transformer_block_bucketing_reordering_pass, + partition_fn=min_cut_rematerialization_partition, keep_inference_input_mutations=True, ) elif compile_config.backend == "inductor": diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index 83e24d7dc1..98101bdf12 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -30,7 +30,6 @@ def get_transformer_block_buckets(model) -> list[list[str] | str]: [model.norm, model.output], ] for layer_id, transformer_block in model.layers.items(): - # [TODO](ruisizhang123) add EP support for transformer block bucketing module_list.append(transformer_block) def convert_modules_to_fqns(modules, module_to_fqn_mapping): @@ -38,9 +37,13 @@ def convert_modules_to_fqns(modules, module_to_fqn_mapping): result = [] for m in modules: if isinstance(m, list): - result.append(convert_modules_to_fqns(m, module_to_fqn_mapping)) + # check if fqn_list is valid. In PP, bucketed module may + # not be in the current rank, and fqn_list is None. + if fqn_list := convert_modules_to_fqns(m, module_to_fqn_mapping): + result.append(fqn_list) else: - result.append(module_to_fqn_mapping.get(m, None)) + if fqn := module_to_fqn_mapping.get(m): + result.append(fqn) return result module_to_name = {m: n for n, m in model.named_modules()} diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index 484d3d4747..41d4c88a4a 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -51,6 +51,8 @@ def convert_modules_to_fqns(modules, module_to_fqn_mapping): result = [] for m in modules: if isinstance(m, list): + # check if fqn_list is valid. In PP, bucketed module may + # not be in the current rank, and fqn_list is None. if fqn_list := convert_modules_to_fqns(m, module_to_fqn_mapping): result.append(fqn_list) else: diff --git a/torchtitan/experiments/simple_fsdp/reshard_after_forward.py b/torchtitan/experiments/simple_fsdp/reshard_after_forward.py index dac010bfcd..a6bb24c9a9 100644 --- a/torchtitan/experiments/simple_fsdp/reshard_after_forward.py +++ b/torchtitan/experiments/simple_fsdp/reshard_after_forward.py @@ -65,6 +65,29 @@ def force_recompute_node(node): # nearby ac_graph_id values node.meta["ac_graph_id"] = 100000 + def tag_view_recompute_nodes(start_node, seen=None): + """ + Recursively walk only *single‑user* paths from start_node, + tagging view nodes for recompute. + """ + if seen is None: + seen = set() + if start_node in seen: + return + seen.add(start_node) + + # If current node has multiple users, stop descending + if len(start_node.users) != 1: + return + + user = next(iter(start_node.users)) + + # Only continue if this user is a view op + if user.op == "call_function" and user.target.is_view: + force_recompute_node(user) + # Recurse deeper only on this one user + tag_view_recompute_nodes(user, seen) + # Make all-gather nodes (and related nodes) recomputable, to circumvent # https://github.com/pytorch/pytorch/issues/136433 for node in graph.nodes: @@ -72,13 +95,10 @@ def force_recompute_node(node): ag_node = node.args[0] force_recompute_node(ag_node) # all_gather force_recompute_node(node) # wait_tensor - # Force-recompute slice that comes after wait - for user in node.users: - if ( - user.op == "call_function" - and user.target == torch.ops.aten.slice.Tensor - ): - force_recompute_node(user) + + # Recursively tag view ops from all_gather_wait + tag_view_recompute_nodes(node) + # Force-recompute potential dtype casts from all_gather if ( ag_node.all_input_nodes[0].op == "call_function" diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 7e2d35a5d9..fdc414346d 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -97,8 +97,6 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - attn_type="flex", - attn_mask_type="block_causal", ), "236B": DeepSeekV3ModelArgs( vocab_size=102400,