Skip to content

Commit b4ad546

Browse files
committed
added cuBLASMp backend option to JAX unit tests for CollectiveGEMM
Signed-off-by: Alp Dener <[email protected]>
1 parent 69cf235 commit b4ad546

File tree

7 files changed

+68
-59
lines changed

7 files changed

+68
-59
lines changed

examples/jax/collective_gemm/common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def _initialize_distributed(args):
154154
num_devices_per_process=devices_per_process,
155155
process_id=args.process_id,
156156
tensor_parallel_size=args.tensor_parallel_size,
157+
use_cublasmp=args.use_cublasmp,
157158
)
158159

159160

@@ -241,5 +242,11 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para
241242
parser.add_argument(
242243
"--enable-result-check", action="store_true", default=True, help="Enable result checking"
243244
)
245+
parser.add_argument(
246+
"--use-cublasmp",
247+
action="store_true",
248+
default=False,
249+
help="Use the cuBLASMp backend for overlapping collective operations with GEMM computation",
250+
)
244251

245252
return parser

examples/jax/collective_gemm/run_test_cgemm.sh

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -65,50 +65,58 @@ for TEST_FILE in "${TEST_FILES[@]}"; do
6565
# Clear PIDs array for this test file
6666
PIDS=()
6767

68-
for i in $(seq 0 $(($NUM_GPUS - 1))); do
69-
# Define output file for logs
70-
LOG_FILE="${TEST_FILE}_gpu_${i}.log"
71-
72-
if [ $i -eq 0 ]; then
73-
# For process 0: show live output AND save to log file using tee
74-
echo "=== Live output from process 0 ==="
75-
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
76-
-vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \
77-
"$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
78-
--num-processes=$NUM_GPUS \
79-
--process-id=$i 2>&1 | tee "$LOG_FILE" &
80-
PID=$!
81-
PIDS+=($PID)
82-
else
83-
# For other processes: redirect to log files only
84-
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
85-
-vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
86-
--num-processes=$NUM_GPUS \
87-
--process-id=$i > "$LOG_FILE" 2>&1 &
88-
PID=$!
89-
PIDS+=($PID)
90-
fi
68+
PYTEST_ARGS=(
69+
"-vs"
70+
"-c $TE_PATH/tests/jax/pytest.ini"
71+
"$TE_PATH/examples/jax/collective_gemm/$TEST_FILE"
72+
"--num-processes=$NUM_GPUS"
73+
)
74+
75+
BACKENDS=("cublasmp" "userbuffers")
76+
for backend in "${BACKENDS[@]}"; do
77+
for i in $(seq 0 $(($NUM_GPUS - 1))); do
78+
# Define output file for logs
79+
LOG_FILE="${TEST_FILE}_gpu_${i}_${backend}.log"
80+
81+
if [ $i -eq 0 ]; then
82+
# For process 0: show live output AND save to log file using tee
83+
echo "=== Live output from process 0 with ${backend} ==="
84+
pytest --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \
85+
"${PYTEST_ARGS[@]}" \
86+
--process-id=$i 2>&1 | tee "$LOG_FILE" &
87+
PID=$!
88+
PIDS+=($PID)
89+
else
90+
# For other processes: redirect to log files only
91+
pytest "${PYTEST_ARGS[@]}" \
92+
--process-id=$i > "$LOG_FILE" 2>&1 &
93+
PID=$!
94+
PIDS+=($PID)
95+
fi
96+
done
9197
done
9298

9399
# Wait for all processes to finish
94100
wait
95101

96102
# Check and print the log content from process 0 (now has log file thanks to tee)
97-
if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then
98-
echo "... $TEST_FILE SKIPPED"
99-
elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then
100-
echo "... $TEST_FILE FAILED"
101-
HAS_FAILURE=1
102-
elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then
103-
echo "... $TEST_FILE PASSED"
104-
else
105-
echo "... $TEST_FILE INVALID"
106-
HAS_FAILURE=1
107-
fi
108-
109-
# Remove the log files after processing them
110-
wait
111-
rm ${TEST_FILE}_gpu_*.log
103+
for backend in "${BACKENDS[@]}"; do
104+
if grep -q "SKIPPED" "${TEST_FILE}_gpu_0_${backend}.log"; then
105+
echo "... $TEST_FILE SKIPPED for ${backend} backend"
106+
elif grep -q "FAILED" "${TEST_FILE}_gpu_0_${backend}.log"; then
107+
echo "... $TEST_FILE FAILED for ${backend} backend"
108+
HAS_FAILURE=1
109+
elif grep -q "PASSED" "${TEST_FILE}_gpu_0_${backend}.log"; then
110+
echo "... $TEST_FILE PASSED for ${backend} backend"
111+
else
112+
echo "... $TEST_FILE INVALID for ${backend} backend"
113+
HAS_FAILURE=1
114+
fi
115+
116+
# Remove the log files after processing them
117+
wait
118+
rm ${TEST_FILE}_gpu_*.log
119+
done
112120
done
113121

114122
wait

transformer_engine/common/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ if (NVTE_WITH_CUBLASMP)
270270
target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP)
271271
target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include)
272272
find_library(CUBLASMP_LIB
273-
NAMES cublasmp libcublasmp
273+
NAMES cublasmp libcublasmp.so.0
274274
PATHS ${CUBLASMP_DIR}
275275
PATH_SUFFIXES lib
276276
REQUIRED)

transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
7979

8080
CommOverlapCore::CommOverlapCore(int64_t nccl_comm_ptr, int tp_rank, int tp_size,
8181
int num_comm_sm, bool is_p2p, bool atomic_gemm) {
82+
NVTE_CHECK(nvte_built_with_cublasmp(),
83+
"Comm+GEMM overlap with cuBLASMp backend requires TE to be built with NVTE_WITH_CUBLASMP=1.");
8284
_with_cublasmp = true;
8385

8486
nvte_comm_gemm_ctx_create(reinterpret_cast<ncclComm_t>(nccl_comm_ptr), tp_size, tp_rank);

transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
#define NVTE_COMM_OVERLAP_MAX_STREAMS 3
2020

21-
/* \brief Check if TE is built with cuBlasMp.
21+
/* \brief Check if TE is built with cuBLASMp.
2222
*
2323
* \return True if TE is built with cuBlasMp.
2424
*/

transformer_engine/jax/cpp_extensions/gemm.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070

7171

7272
num_cublas_streams = get_num_compute_streams()
73+
collective_gemm_with_cublasmp = False
7374

7475

7576
def get_cublas_workspace_size_bytes() -> None:
@@ -198,6 +199,7 @@ def collective_gemm_bootstrap(
198199
num_sm_for_communication=2,
199200
use_ce=True,
200201
aggregate_all_gather=False,
202+
use_cublasmp=False,
201203
):
202204
"""Initialize NCCL communicators for Collective GEMM operations.
203205
@@ -281,6 +283,8 @@ def collective_gemm_bootstrap(
281283
f" num_devices_per_process={num_devices_per_process}"
282284
)
283285
assert 0 <= process_id < num_total_devices, f"Invalid process_id={process_id}"
286+
global collective_gemm_with_cublasmp
287+
collective_gemm_with_cublasmp = use_cublasmp
284288
initialize_cgemm_communicator(
285289
num_total_devices,
286290
num_devices_per_process,
@@ -292,6 +296,7 @@ def collective_gemm_bootstrap(
292296
num_sm_for_communication,
293297
use_ce,
294298
aggregate_all_gather,
299+
use_cublasmp,
295300
)
296301

297302

@@ -386,7 +391,7 @@ class GemmPrimitive(BasePrimitive):
386391

387392
name = "te_gemm_ffi"
388393
multiple_results = True
389-
impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19)
394+
impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)
390395
inner_primitive = None
391396
outer_primitive = None
392397

@@ -411,7 +416,6 @@ def abstract(
411416
sequence_dim,
412417
is_outer,
413418
collective_op,
414-
use_cublasmp,
415419
):
416420
del use_split_accumulator, transpose_batch_sequence
417421

@@ -539,7 +543,7 @@ def _dims_are_consecutive(dims):
539543
if scaling_mode.is_nvfp4_scaling:
540544
workspace_size += lhs_scale_inv.size + rhs_scale_inv.size
541545
if not collective_op.is_none:
542-
if use_cublasmp:
546+
if collective_gemm_with_cublasmp:
543547
# cuBlasMp manages its own cuBlasLt workspaces per stream
544548
workspace_size = 0
545549
else:
@@ -578,7 +582,6 @@ def lowering(
578582
sequence_dim,
579583
is_outer,
580584
collective_op,
581-
use_cublasmp,
582585
):
583586
del out_dtype, transpose_batch_sequence, sequence_dim, is_outer
584587

@@ -623,7 +626,7 @@ def lowering(
623626
"grad": grad,
624627
"use_split_accumulator": use_split_accumulator,
625628
"collective_op": int(collective_op.value),
626-
"use_cublasmp": use_cublasmp,
629+
"use_cublasmp": collective_gemm_with_cublasmp,
627630
}
628631

629632
operand_output_aliases = {}
@@ -658,7 +661,6 @@ def impl(
658661
sequence_dim,
659662
is_outer,
660663
collective_op,
661-
use_cublasmp,
662664
):
663665
if scaling_mode.is_1d_block_scaling():
664666
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
@@ -726,7 +728,6 @@ def impl(
726728
transpose_batch_sequence=transpose_batch_sequence,
727729
sequence_dim=sequence_dim,
728730
is_outer=is_outer,
729-
use_cublasmp=use_cublasmp,
730731
)
731732
# Alter output blocks for CGEMM AG
732733
if (
@@ -778,7 +779,6 @@ def outer_impl(
778779
sequence_dim,
779780
is_outer,
780781
collective_op,
781-
use_cublasmp,
782782
):
783783
return GemmPrimitive.impl(
784784
lhs,
@@ -800,7 +800,6 @@ def outer_impl(
800800
sequence_dim,
801801
is_outer,
802802
collective_op,
803-
use_cublasmp,
804803
)
805804

806805
@staticmethod
@@ -818,7 +817,6 @@ def batcher(
818817
sequence_dim,
819818
is_outer,
820819
collective_op,
821-
use_cublasmp,
822820
):
823821
del transpose_batch_sequence, sequence_dim, is_outer
824822
assert GemmPrimitive.outer_primitive is not None
@@ -852,7 +850,6 @@ def batcher(
852850
transpose_batch_sequence=transpose_batch_sequence,
853851
sequence_dim=sequence_dim,
854852
is_outer=is_outer,
855-
use_cublasmp=use_cublasmp,
856853
),
857854
(out_bdims, bias_bdims, pre_gelu_bdims),
858855
)
@@ -1015,7 +1012,6 @@ def infer_sharding_from_operands(
10151012
sequence_dim,
10161013
is_outer,
10171014
collective_op,
1018-
use_cublasmp,
10191015
mesh,
10201016
arg_infos,
10211017
result_infos,
@@ -1027,7 +1023,6 @@ def infer_sharding_from_operands(
10271023
result_infos,
10281024
is_outer,
10291025
sequence_dim,
1030-
use_cublasmp,
10311026
)
10321027

10331028
(_, (out_specs, dbias_specs, pre_gelu_specs), *_) = (
@@ -1062,7 +1057,6 @@ def partition(
10621057
sequence_dim,
10631058
is_outer,
10641059
collective_op,
1065-
use_cublasmp,
10661060
mesh,
10671061
arg_infos,
10681062
result_infos,
@@ -1141,7 +1135,6 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alph
11411135
sequence_dim=inferred_sequence_dim,
11421136
is_outer=False,
11431137
collective_op=collective_op,
1144-
use_cublasmp=use_cublasmp,
11451138
)
11461139

11471140
if reduce_spec is not None:
@@ -1173,7 +1166,6 @@ def shardy_sharding_rule(
11731166
sequence_dim,
11741167
is_outer,
11751168
collective_op,
1176-
use_cublasmp,
11771169
mesh,
11781170
operand_types,
11791171
result_types,
@@ -1268,7 +1260,6 @@ def _te_gemm(
12681260
use_split_accumulator: bool = None,
12691261
transpose_batch_sequence: bool = False,
12701262
collective_op: CollectiveOp = CollectiveOp.NONE,
1271-
use_cublasmp: bool = False,
12721263
) -> Tuple[jax.Array, ...]:
12731264

12741265
if grad or fuse_gelu:
@@ -1372,7 +1363,6 @@ def _te_gemm(
13721363
sequence_dim=-1, # Dummy value and will be set in the primitive
13731364
is_outer=True,
13741365
collective_op=collective_op,
1375-
use_cublasmp=use_cublasmp,
13761366
)
13771367

13781368

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,9 @@ class CommOverlapHelper : torch::CustomClassHolder {
516516
void ub_barrier(ExtComm comm);
517517

518518
int64_t get_nccl_comm_ptr(std::string comm_name) {
519-
NVTE_CHECK(backend_is_nccl, "Cannot get nccComm_t ptr if backend is not NCCL.");
519+
NVTE_CHECK(backend_is_nccl,
520+
"Comm+GEMM overlap with cuBLASMp backend requires a tensor-parallel process ",
521+
"group with NCCL backend.");
520522
return reinterpret_cast<c10d::ProcessGroupNCCL *>(pgs[comm_name])->getCommPtr();
521523
}
522524
};

0 commit comments

Comments
 (0)