Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
c7c1a76
Implemented the kernel with split dbias
Oleg-Goncharov Feb 11, 2026
7abbc7b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2026
f820b21
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 12, 2026
0c05632
Relaxed constraints on the last dimension
Oleg-Goncharov Feb 13, 2026
4a85dea
Added notes on group tensor restrictions into documentation
Oleg-Goncharov Feb 13, 2026
aedd53d
Fixes per the review
Oleg-Goncharov Feb 27, 2026
38288b1
Fixed pointer
Oleg-Goncharov Feb 27, 2026
ce3a137
More fixes
Oleg-Goncharov Feb 27, 2026
bddd804
Fixed kernel grid size
Oleg-Goncharov Mar 2, 2026
a894d1a
Merge branch 'main' into pr_split_dbias
Oleg-Goncharov Mar 2, 2026
87352bd
Enabled persistency with WorkID Query feature
Oleg-Goncharov Mar 4, 2026
e23f553
Added a struct with tunable parameters
Oleg-Goncharov Mar 4, 2026
d185299
Added persistency with static scheduling
Oleg-Goncharov Mar 4, 2026
5e15f57
Fixed test cases
Oleg-Goncharov Mar 4, 2026
98e9558
Ready for benchmarking
Oleg-Goncharov Mar 4, 2026
ab816cb
Fixed out-of-boundary error
Oleg-Goncharov Mar 4, 2026
8a429ad
Tuned kernel parameters
Oleg-Goncharov Mar 4, 2026
ab3f911
Refactoring
Oleg-Goncharov Mar 4, 2026
92720ac
Refactoring 2
Oleg-Goncharov Mar 4, 2026
46d9811
Refactoring 3
Oleg-Goncharov Mar 4, 2026
7172400
Removed the dynamic (WorkID Query) persistency
Oleg-Goncharov Mar 5, 2026
4344627
Ready for PR
Oleg-Goncharov Mar 5, 2026
ede33b4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2026
219e925
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 5, 2026
325181b
Fixes per the review
Oleg-Goncharov Mar 6, 2026
04609b1
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 6, 2026
5815335
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2026
0bd837c
Added the test suite
Oleg-Goncharov Mar 6, 2026
0c5849c
Initial kernel draft
Oleg-Goncharov Mar 6, 2026
178a7c4
Refactoring
Oleg-Goncharov Mar 6, 2026
b035b43
Added the kernel to the quantization dispatcher
Oleg-Goncharov Mar 6, 2026
9d72757
Isolated only the Group Quantize NVFP4 for compilation
Oleg-Goncharov Mar 6, 2026
da8da89
Fixed test suite and bug in scaling factors padding
Oleg-Goncharov Mar 6, 2026
1ab4adb
Uncommented unit tests
Oleg-Goncharov Mar 9, 2026
a439454
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 9, 2026
c7fba6c
Merge branch 'main' into pr_persistent_grouped_nvfp4_kernel
Oleg-Goncharov Mar 9, 2026
f7a00ce
Conditionally print the detailed unit tests summary
Oleg-Goncharov Mar 9, 2026
1e926d5
Fix
Oleg-Goncharov Mar 9, 2026
47be9b2
Cache scales base offsets
Oleg-Goncharov Mar 9, 2026
fef9220
Fixes per the review
Oleg-Goncharov Mar 9, 2026
9e37b4c
Fixed NVFP4 numerics
Oleg-Goncharov Mar 9, 2026
6a7409d
Fixed the number of launch threads per block
Oleg-Goncharov Mar 9, 2026
1f1ab92
Numerics fix 2
Oleg-Goncharov Mar 9, 2026
6c5cc7f
Added Quantize Configs to grouped Qauntization
Oleg-Goncharov Mar 9, 2026
f5e2ba0
Uncommented code
Oleg-Goncharov Mar 9, 2026
ab45c1c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 9, 2026
273b358
Fix of logic
Oleg-Goncharov Mar 9, 2026
eace4a6
Test suite fix
Oleg-Goncharov Mar 9, 2026
47b350c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_executable(test_operator
test_cast_mxfp8.cu
test_cast_mxfp8_grouped.cu
test_cast_nvfp4_transpose.cu
test_cast_nvfp4_transpose_grouped.cu
test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu
test_transpose.cu
Expand Down
162 changes: 88 additions & 74 deletions tests/cpp/operator/test_cast_mxfp8_grouped.cu

Large diffs are not rendered by default.

16 changes: 6 additions & 10 deletions tests/cpp/operator/test_cast_nvfp4_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,12 @@ std::vector<InputType> create_transpose(const InputType* const input, const size
}

// Compute the global encode scale factor for a given global amax
float compute_global_encode_scaling_factor_FP4(const float global_amax, const bool use_fast_math) {
float compute_global_encode_scaling_factor_FP4(const float global_amax) {
constexpr float fp8_max = 448.0f; // 448.0f;
constexpr float fp4_max = 6.0f; // 6.0f;
float global_encode_scale = fp8_max * fp4_max / global_amax;
// If scale is infinity, return the max normalized value
const float max_norm_clamp = use_fast_math
? Numeric_Traits<bf16>::maxNorm
: Numeric_Traits<float>::maxNorm;

global_encode_scale = fminf(global_encode_scale, max_norm_clamp);
global_encode_scale = fminf(global_encode_scale, Numeric_Traits<float>::maxNorm);
// If global amax is 0 or infinity, return 1
if (global_amax == 0.0f || global_encode_scale == 0.0f) {
return 1.0f;
Expand All @@ -84,7 +80,7 @@ void quantize_nvfp4_1d(float (*OP)(const float),
const bool use_fast_math) {

// Compute a global encoding/decoding scaling factor for all S_dec_b
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math);
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax);

constexpr size_t block_size_X = 16;
const size_t blocks_X = divide_round_up(cols, block_size_X);
Expand Down Expand Up @@ -163,7 +159,7 @@ void compute_2d_mathematical_scales(float (*OP)(const float),
std::vector<std::vector<fp8e4m3>>& math_scales,
const bool use_fast_math) {

const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math);
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax);
constexpr size_t block_size_Y = 16;
constexpr size_t block_size_X = 16;
const size_t blocks_Y = divide_round_up(rows, block_size_Y);
Expand Down Expand Up @@ -214,7 +210,7 @@ void quantize_nvfp4_2d(float (*OP)(const float),
std::vector<std::vector<fp8e4m3>> math_scales;
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math);

const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math);
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax);
constexpr size_t block_size_Y = 16;
constexpr size_t block_size_X = 16;
const size_t blocks_Y = divide_round_up(rows, block_size_Y);
Expand Down Expand Up @@ -738,7 +734,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(Activation_types),
::testing::ValuesIn(tensor_dims),
::testing::Values(DType::kBFloat16),
::testing::Values(false)),
::testing::Values(true, false)),
[](const testing::TestParamInfo<FusedCastTransposeNVFP4TestSuite::ParamType>& info) {
std::string name = to_string(std::get<0>(info.param));
const auto& shape = std::get<1>(info.param);
Expand Down
Loading
Loading