Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Greptile SummaryThis PR is a well-motivated refactoring of the JAX GEMM stack: it removes dead code paths (fused GeLU, bias gradient, grad mode), consolidates the C++ FFI parameters into a Key verified findings:
The core logic is sound and previous review issues have been properly addressed. The identified issues are fixable improvements to API documentation and type safety. Confidence Score: 4/5
|
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci JAX L1 |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Overall LGTM, thanks for this PR! Left some small comments and questions
| std::vector<size_t> buffer_shape{1, 1}; | ||
| auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, DType::kFloat32, | ||
| JAXX_Collective_Op::ALL_GATHER); | ||
| [[maybe_unused]] auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor( |
There was a problem hiding this comment.
This wasn't introduced in this PR, but why do we need to have an auto _ =? Would it not work to just call CollectiveGemmPlanRegistry::getInstance().get_executor?
There was a problem hiding this comment.
It will work without an unused return in modern C++, but I think it is not a good practice.
| transformer_engine::jax::GemmConfig, | ||
| ::xla::ffi::StructMember<transformer_engine::jax::JAXX_Scaling_Mode>("scaling_mode"), | ||
| ::xla::ffi::StructMember<transformer_engine::jax::JAXX_Collective_Op>("collective_op"), | ||
| ::xla::ffi::StructMember<int64_t>("lhs_axis_boundary"), |
There was a problem hiding this comment.
Good idea, this struct-based approach lets us better set default values for newly added fields, right? So it's easier to be backwards compatible with older HLO but still have the flexibility to add new attributes as long as the default value keeps the same behavior as before. I recall doing something similar for attention
There was a problem hiding this comment.
The one that we have in attention is slightly different. There, we don't define a struct, and XLA automatically decodes the attribute, but rather dynamically queries the attribute at runtime, which is harder to catch debug
TransformerEngine/transformer_engine/jax/csrc/extensions/attention.cpp
Lines 338 to 347 in d40b9de
I think this struct approach is better, it's flexible enough so that we don't need to introduce a new API whenever we want to add an attribute, it should also allow an optional attribute, i.e. struct variable with a default value, and also less bug prone.
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
|
/te-ci JAX L1 |
Description
This PR removes unused, untested, and partially supported features from the public GEMM primitive: fused GeLU, bias gradient, and grad mode — these were dead code paths not exercised by any JAX-side caller.
Besides, the PR also removes all the
FP8_2X_ACC_XGRADfrom the QuantizeConfig as it is no longer inferable from the recipes. Users can set the precision via the new env variableTE_FP8_GEMM_HIGH_PRECISION_ACCUMULATIONinstead.Change details:
GemmV2FFIreplacing the old GemmFFI, remove untested/unused boolean flags (fuse_bias, fuse_gelu, grad), and consolidate other individual attributes into a GemmConfig struct.bias.size> 0 rather than a separate fuse_bias flagTE_FP8_GEMM_HIGH_PRECISION_ACCUMULATIONenv-var is introduced to set the precision of the accumulation in MatMul, i.e., whether to promote to high dtype for storing the intermediate accumulation result.assert_cublas_requirementschecks from lowering to abstract, providing earlier shape validationtest_distributed_dense.py, output sharding constraint is added to ensure the correct sharding pattern for the input gradients in the bprop.Type of change
Checklist: