Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions examples/models/gemma4_31b/tests/test_cuda_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,13 @@ def _forward(self):
return self.model(tok, pos, temp)

def test_int4_weights_preserved(self):
"""Packing passes Int4Tensor through without conversion."""
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
"""Packing converts Int4Tensor to CudaCoalescedInt4Tensor."""
from executorch.backends.cuda.coalesced_int4_tensor import (
CudaCoalescedInt4Tensor,
)

w = self.model.layers[0].mlp.gate_proj.weight.data
self.assertIsInstance(w, Int4Tensor)
self.assertIsInstance(w, CudaCoalescedInt4Tensor)

def test_inference_produces_valid_output(self):
out = self._forward()
Expand Down Expand Up @@ -243,14 +245,19 @@ def _load(self, tmp):
return load_gguf_model(path, backend="cuda", config=GGUF_CONFIG)

def test_load_converts_weights(self):
"""GGUF -> CUDA: Q4_K -> Int4Tensor, Q6_K -> IntxUnpacked, embedding bf16."""
"""GGUF -> CUDA: Q4_K -> CudaCoalescedInt4Tensor, Q6_K -> IntxUnpacked,
embedding bf16."""
from executorch.backends.cuda.coalesced_int4_tensor import (
CudaCoalescedInt4Tensor,
)
from torchao.quantization import IntxUnpackedToInt8Tensor
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor

with tempfile.TemporaryDirectory() as tmp:
model, _ = self._load(tmp)

self.assertIsInstance(model.layers[0].self_attn.q_proj.weight.data, Int4Tensor)
self.assertIsInstance(
model.layers[0].self_attn.q_proj.weight.data, CudaCoalescedInt4Tensor
)
self.assertIsInstance(
model.layers[0].mlp.down_proj.weight.data, IntxUnpackedToInt8Tensor
)
Expand Down
Loading