Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ add_executable(test_operator
test_cast_nvfp4_transpose.cu
test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu
test_dequantize_mxfp8_grouped.cu
test_transpose.cu
test_cast_transpose.cu
test_cast_transpose_current_scaling.cu
Expand Down
487 changes: 487 additions & 0 deletions tests/cpp/operator/test_dequantize_mxfp8_grouped.cu

Large diffs are not rendered by default.

76 changes: 76 additions & 0 deletions tests/pytorch/test_grouped_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,82 @@ def test_group_quantize_cudagraph_capturable(self) -> None:
assert torch.equal(static_output.data, expected.data)
assert torch.equal(static_output.scale_inv, expected.scale_inv)

@pytest.mark.parametrize(
"shape",
[[(512, 1024), (512, 1024)], [(256, 512), (512, 512), (768, 512)]],
)
@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
def test_group_dequantize(self, shape: List[Tuple[int, int]]) -> None:
"""Test grouped dequantization for MXFP8 back to BF16."""
num_tensors = len(shape)

# Create BF16 input tensors and quantize them with MXFP8.
input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape]
grouped_input = torch.cat(input_tensors, dim=0)

quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
quantizer.set_usage(rowwise=True, columnwise=False)
first_dims = torch.tensor([s[0] for s in shape], dtype=torch.int64, device="cuda")

# Quantize.
quantized = tex.group_quantize(grouped_input, quantizer, num_tensors, first_dims)

# Dequantize.
dequantized = tex.group_dequantize(quantized, tex.DType.kBFloat16)

# Verify output metadata.
assert dequantized.num_tensors == num_tensors
assert dequantized.logical_shape == quantized.logical_shape
assert torch.equal(dequantized.first_dims, quantized.first_dims)
assert torch.equal(dequantized.tensor_offsets, quantized.tensor_offsets)

# Verify dequantized values are close to original.
dequantized_bf16 = dequantized.data.reshape(grouped_input.shape)
torch.testing.assert_close(dequantized_bf16, grouped_input, atol=0.125, rtol=0.1)

@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
def test_group_dequantize_cudagraph_capturable(self) -> None:
"""Ensure group_dequantize is CUDA graph capturable."""
num_tensors = 2
shape = [(512, 1024) for _ in range(num_tensors)]
input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape]
grouped_input = torch.cat(input_tensors, dim=0)

quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
quantizer.set_usage(rowwise=True, columnwise=False)
first_dims = torch.tensor(
[shape[0][0] for _ in range(num_tensors)],
dtype=torch.int64,
device="cuda",
)

# Quantize to get MXFP8 grouped tensor.
quantized = tex.group_quantize(grouped_input, quantizer, num_tensors, first_dims)

# Warmup dequantize.
torch.cuda.synchronize()
_ = tex.group_dequantize(quantized, tex.DType.kBFloat16)
torch.cuda.synchronize()

graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
static_output = tex.group_dequantize(quantized, tex.DType.kBFloat16)

# Replay with different input data.
fresh_input = torch.cat(
[torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape],
dim=0,
)
fresh_quantized = tex.group_quantize(fresh_input, quantizer, num_tensors, first_dims)
quantized.data.copy_(fresh_quantized.data)
quantized.scale_inv.copy_(fresh_quantized.scale_inv)

graph.replay()
torch.cuda.synchronize()

expected = tex.group_dequantize(quantized, tex.DType.kBFloat16)
assert torch.equal(static_output.data, expected.data)

def test_clear(self) -> None:
"""Test clear method"""
num_tensors = 3
Expand Down
8 changes: 8 additions & 0 deletions transformer_engine/common/cast/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t str
stream);
}

void nvte_group_dequantize(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dequantize);
using namespace transformer_engine;
dispatch::group_dequantize_helper(*convertNVTEGroupedTensorCheck(input),
convertNVTEGroupedTensorCheck(output), stream);
}

void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
const NVTEQuantizationConfig quant_configs,
const size_t num_tensors, cudaStream_t stream) {
Expand Down
21 changes: 21 additions & 0 deletions transformer_engine/common/cast/dispatch/dequantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "../../common.h"
#include "../fp8/dequantize_fp8.cuh"
#include "../mxfp8/dequantize_mxfp8.cuh"
#include "../mxfp8/group_dequantize_mxfp8.cuh"
#include "../nvfp4/dequantize_nvfp4.cuh"

namespace transformer_engine {
Expand Down Expand Up @@ -50,6 +51,26 @@ inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t
}
}

inline void group_dequantize_helper(const GroupedTensor &input, GroupedTensor *output,
cudaStream_t stream) {
CheckInputGroupedTensor(input, "group_dequantize_input");
CheckOutputGroupedTensor(*output, "group_dequantize_output");

switch (input.scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
if (is_supported_by_CC_100()) {
mxfp8::group_dequantize(&input, output, stream);
} else {
NVTE_ERROR("MXFP8 Grouped Dequantization is NOT supported by architectures < 10.0");
}
break;
}
default:
NVTE_ERROR("Grouped dequantize not implemented for scaling mode: " +
to_string(input.scaling_mode) + ".");
}
}

} // namespace dispatch
} // namespace transformer_engine

Expand Down
Loading
Loading