[CUDA] fp16/32 x int8 quantized matmul#3137
Conversation
2924ca2 to
2fa4bca
Compare
|
That is fantastic! I think we should merge this (I 'll take a closer look and approve a bit later). We should also measure on an A100 and H100 where I think the lack of pipelining is gonna hurt us a bit but I am very keen to start having some proper QMMs on CUDA! |
018da53 to
eeb0bf4
Compare
|
I have ran the benchmark on A100 and H100 and the performance is quite bad there, it seems that the low memory bandwidth of DGX has hid a lot of problems. Also updated the numbers of DGX after fixing a out-of-bounds write bug. DGX Sparkactivation: float16
activation: float32
A100activation: float16
activation: float32
H100activation: float16
activation: float32
|
|
@zcbenz do you want to hold out until we have an implementation that does shared memory pipelining or shall we merge that as a first step? |
|
I'm working on a sm80 optimized implementation, let's hold out for a while. |
Refs #2536, #3128.
This PR implements QMM for float16/float32 activations and int8 quantizations, the kernel is optimized for small M (batch size) and large N/K, it works with arbitrary M and requires N/K to be aligned with tile size.
The kernel uses the
mma.sync.aligned.m16n8k16tensor op, so the GEMM TILE_SIZE_M is set to 16 and a lot of threads would be wasted for small batch size, but the performance is still close to ideal: 2x for FP16xINT8 and 4x for FP32xINT8.The kernel is written in CuTe which I'm still learning, the code follows CuTe's coding style and I have turned off code formatting for it, otherwise it would be harder to read and maintain.
Note that this kernel only works well for group size 32 and 64 for now, it performs quite bad for group size 128 (0.5x of cuBLAS) and I haven't found out the root cause.
Performance numbers profiled on a DGX Spark:
activation: float16
bits: 8
group_size: 64
activation: float32
bits: 8
group_size: 64
An independent C++ file profiling the kernel can be found below:
Details
It is not hard to extend the kernel to support more types, and I'll add support for bfloat16 activations and sub-byte integer quants in later PRs.
There are also many optimization opportunities:
Also since the features this PR implemented are limited, I did not enable qmm tests, but if you actually run theTestQuantized.test_qmmtest with this PR you would notice that many tests are flaky: that is not because this kernel outputs wrong results (y_q), it is because the expected results (y_hat) is 0 sometimes, I don't think it is caused by this PR and I will investigate later.