Skip to content

Uddeshsingh/q4k fused kernels#20231

Open
uddeshsingh wants to merge 3 commits into
pytorch:mainfrom
uddeshsingh:uddeshsingh/q4k-fused-kernels
Open

Uddeshsingh/q4k fused kernels#20231
uddeshsingh wants to merge 3 commits into
pytorch:mainfrom
uddeshsingh:uddeshsingh/q4k-fused-kernels

Conversation

@uddeshsingh

Copy link
Copy Markdown

Fixes #20172

Summary

  • Add fused Q4_K Metal kernels (linear mat-vec/mat-mat, embedding gather) reading raw GGUF bytes
  • Guard legacy MLX-native repack path behind ET_MLX_EMIT_DIRECT_GGUF=0

Test plan

  • python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear run
  • python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_embedding run

Replace the export-time GGUF-to-MLX qparam repack path with fused Metal
kernels
Keep the legacy MLX-native repack path available when the env var is set to 0,
per maintainer request on pytorch#20172.
Copilot AI review requested due to automatic review settings June 12, 2026 05:38
@pytorch-bot

pytorch-bot Bot commented Jun 12, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20231

Note: Links to docs will display an error until the docs builds have been completed.

⚠️ 13 Awaiting Approval

As of commit 1f7be36 with merge base 7282106 (image):

AWAITING APPROVAL - The following workflows need approval before CI can run:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla

meta-cla Bot commented Jun 12, 2026

Copy link
Copy Markdown

Hi @uddeshsingh!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@linux-foundation-easycla

linux-foundation-easycla Bot commented Jun 12, 2026

Copy link
Copy Markdown

CLA Signed
The committers listed above are authorized under a signed CLA.

  • ✅ login: uddeshsingh / name: Uddesh Singh (1f7be36)
  • ✅ login: uddeshsingh / name: uddeshsingh (49ac1d2, eda40b8)

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds fused Q4_K Metal kernel support (linear + embedding) to the MLX GGUF lowering path, with an environment-variable switch to fall back to the legacy MLX-native repack implementation.

Changes:

  • Extend GGUF linear/embedding tests to cover both Q6_K and Q4_K.
  • Implement fused Q4_K Metal kernels for linear (mat-vec + mat-mat + dynamic IfNode) and embedding gather.
  • Add ET_MLX_EMIT_DIRECT_GGUF-controlled dispatch between fused-kernel and legacy repack paths.

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
backends/mlx/custom_kernel_ops/gguf/test/test_linear.py Updates linear tests to include Q4_K configs and adjusts reference path assumptions.
backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py Updates embedding tests to generate Q4_K blobs and run additional Q4_K cases.
backends/mlx/custom_kernel_ops/gguf/q4k/repack_mlx.py New helper to repack raw Q4_K GGUF blobs into MLX qparams for the legacy path.
backends/mlx/custom_kernel_ops/gguf/q4k/linear_mlx_native.py New legacy Q4_K lowering via MLX native quantized matmul using repacked qparams.
backends/mlx/custom_kernel_ops/gguf/q4k/linear.py Replaces prior approach with fused Metal mat-vec/mat-mat kernels + dynamic dispatch.
backends/mlx/custom_kernel_ops/gguf/q4k/embedding_mlx_native.py New legacy Q4_K lowering via MLX native quantized gather using repacked qparams.
backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py New fused Metal gather kernel reading raw Q4_K GGUF bytes directly.
backends/mlx/custom_kernel_ops/gguf/q4k/common.py Adds shared Q4_K constants + shared Metal header (block layout + dequant helpers).
backends/mlx/custom_kernel_ops/gguf/q4k/init.py Adds emit_direct_gguf() env-var gate and updated package documentation.
backends/mlx/custom_kernel_ops/gguf/patterns.py Dispatches Q4_K lowering between fused kernels and legacy repack path based on env var.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +154 to +156
constexpr short NL = 16; // Q4_K: QK_K / 32
constexpr short NL0 = NK / 16; // = 2 — dequant iterations per thread for weight
constexpr short NL1 = NK / 8; // = 4 — load iterations per thread for activation
Comment on lines +133 to +138
Both Q6_K and Q4_K kernels dequantize the raw GGUF blob in-kernel; use the
gguf-exact dequant as the reference oracle.
"""
lin = model.linear
weight = lin.weight
if getattr(weight, "ggml_type", None) == "q4_k":
# Q4_K is repacked into bf16 MLX affine qparams (S, Q, B); reconstruct
# exactly what the kernel dequantizes so the oracle isolates kernel
# accumulation (repack precision vs gguf is covered by test_gguf.py).
from executorch.backends.mlx.builder.op_helpers import to_mlx_qparams

intx = weight.to_intx_unpacked_to_int8_tensor()
gs = int(intx.block_size[-1])
Q, B = to_mlx_qparams(intx.qdata, intx.scale, intx.zero_point, 4)
qb = Q.view(torch.uint8)
nibbles = torch.stack([(qb & 0xF).float(), ((qb >> 4) & 0xF).float()], dim=-1)
q_unsigned = nibbles.reshape(intx.qdata.shape[0], -1)
scale = intx.scale.float().repeat_interleave(gs, dim=1)
bias_b = B.float().repeat_interleave(gs, dim=1)
w = scale * q_unsigned + bias_b
else:
w = weight.dequantize(torch.float32)
w = weight.dequantize(torch.float32)
K: int,
out: Slot,
) -> None:
in_dtype_int = torch_dtype_to_scalar_type(x_node.meta["val"].dtype)
output_names=["out"],
output_shapes_flat=out_shape_flat,
output_shape_lengths=[len(out_shape_flat)],
output_dtypes=[in_dtype_int],
out_shape_flat = leading + [IntOrVid.from_literal(K)]

# threadgroup.x must divide grid.x (= K, a multiple of 256).
tg_x = 256 if K % 256 == 0 else K
Comment on lines +45 to +50
const uint j = thread_position_in_grid.x; // 0..K-1
const uint r = thread_position_in_grid.y; // gathered row
const int row = (int) indices[r];
const int nb = K / QK_K;
device const block_q4_K * blk =
((device const block_q4_K *) weight) + (uint)row * nb + (j / QK_K);
Comment on lines +121 to +128
if emit_direct_gguf():
from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear import (
emit_linear,
)
else:
from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear_mlx_native import (
emit_linear,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Good first issue] Add Q4K support to MLX backend

2 participants