Skip to content

Commit d26eb87

Browse files
committed
enable DSV3 manual bucketing
1 parent fcc5643 commit d26eb87

File tree

5 files changed

+38
-12
lines changed

5 files changed

+38
-12
lines changed

torchtitan/experiments/simple_fsdp/backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010
import torch._functorch.config as functorch_config
11+
from functorch.compile import min_cut_rematerialization_partition
1112
from torchtitan.tools.logging import logger
1213

1314
from .job_config import Compile as CompileConfig
@@ -60,6 +61,7 @@ def aot_eager_autobucketing_reordering_pass(
6061
backend = aot_autograd_backend(
6162
fw_compiler=aot_eager_autobucketing_reordering_pass,
6263
bw_compiler=aot_eager_autobucketing_reordering_pass,
64+
partition_fn=min_cut_rematerialization_partition,
6365
keep_inference_input_mutations=True,
6466
)
6567
elif compile_config.backend == "inductor":
@@ -108,6 +110,7 @@ def aot_eager_transformer_block_bucketing_reordering_pass(
108110
backend = aot_autograd_backend(
109111
fw_compiler=aot_eager_transformer_block_bucketing_reordering_pass,
110112
bw_compiler=aot_eager_transformer_block_bucketing_reordering_pass,
113+
partition_fn=min_cut_rematerialization_partition,
111114
keep_inference_input_mutations=True,
112115
)
113116
elif compile_config.backend == "inductor":

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,20 @@ def get_transformer_block_buckets(model) -> list[list[str] | str]:
3030
[model.norm, model.output],
3131
]
3232
for layer_id, transformer_block in model.layers.items():
33-
# [TODO](ruisizhang123) add EP support for transformer block bucketing
3433
module_list.append(transformer_block)
3534

3635
def convert_modules_to_fqns(modules, module_to_fqn_mapping):
3736
"""Convert a (possibly nested) list of modules to FQN strings."""
3837
result = []
3938
for m in modules:
4039
if isinstance(m, list):
41-
result.append(convert_modules_to_fqns(m, module_to_fqn_mapping))
40+
# check if fqn_list is valid. In PP, bucketed module may
41+
# not be in the current rank, and fqn_list is None.
42+
if fqn_list := convert_modules_to_fqns(m, module_to_fqn_mapping):
43+
result.append(fqn_list)
4244
else:
43-
result.append(module_to_fqn_mapping.get(m, None))
45+
if fqn := module_to_fqn_mapping.get(m):
46+
result.append(fqn)
4447
return result
4548

4649
module_to_name = {m: n for n, m in model.named_modules()}

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def convert_modules_to_fqns(modules, module_to_fqn_mapping):
5151
result = []
5252
for m in modules:
5353
if isinstance(m, list):
54+
# check if fqn_list is valid. In PP, bucketed module may
55+
# not be in the current rank, and fqn_list is None.
5456
if fqn_list := convert_modules_to_fqns(m, module_to_fqn_mapping):
5557
result.append(fqn_list)
5658
else:

torchtitan/experiments/simple_fsdp/reshard_after_forward.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,40 @@ def force_recompute_node(node):
6565
# nearby ac_graph_id values
6666
node.meta["ac_graph_id"] = 100000
6767

68+
def tag_view_recompute_nodes(start_node, seen=None):
69+
"""
70+
Recursively walk only *single‑user* paths from start_node,
71+
tagging view nodes for recompute.
72+
"""
73+
if seen is None:
74+
seen = set()
75+
if start_node in seen:
76+
return
77+
seen.add(start_node)
78+
79+
# If current node has multiple users, stop descending
80+
if len(start_node.users) != 1:
81+
return
82+
83+
user = next(iter(start_node.users))
84+
85+
# Only continue if this user is a view op
86+
if user.op == "call_function" and user.target.is_view:
87+
force_recompute_node(user)
88+
# Recurse deeper only on this one user
89+
tag_view_recompute_nodes(user, seen)
90+
6891
# Make all-gather nodes (and related nodes) recomputable, to circumvent
6992
# https://github.com/pytorch/pytorch/issues/136433
7093
for node in graph.nodes:
7194
if is_wait_tensor_from_fsdp(node):
7295
ag_node = node.args[0]
7396
force_recompute_node(ag_node) # all_gather
7497
force_recompute_node(node) # wait_tensor
75-
# Force-recompute slice that comes after wait
76-
for user in node.users:
77-
if (
78-
user.op == "call_function"
79-
and user.target == torch.ops.aten.slice.Tensor
80-
):
81-
force_recompute_node(user)
98+
99+
# Recursively tag view ops from all_gather_wait
100+
tag_view_recompute_nodes(node)
101+
82102
# Force-recompute potential dtype casts from all_gather
83103
if (
84104
ag_node.all_input_nodes[0].op == "call_function"

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,6 @@
9797
qk_rope_head_dim=64,
9898
v_head_dim=128,
9999
mscale=0.70,
100-
attn_type="flex",
101-
attn_mask_type="block_causal",
102100
),
103101
"236B": DeepSeekV3ModelArgs(
104102
vocab_size=102400,

0 commit comments

Comments
 (0)