Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ set(WEBGPU_SRCS
runtime/ops/update_cache/UpdateCache.cpp
runtime/ops/sdpa/Sdpa.cpp
runtime/ops/select_as_symint/SelectAsSymint.cpp
runtime/ops/quantized_linear/QuantizedLinear.cpp
)

add_library(webgpu_backend ${WEBGPU_SRCS})
Expand Down
244 changes: 244 additions & 0 deletions backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h>

#include <webgpu/webgpu.h>

#include <cstdint>
#include <cstring>
#include <stdexcept>

namespace executorch::backends::webgpu {

namespace {

// Uniform layout matching the WGSL Params struct (16-byte aligned, 32 bytes).
struct Q4gswParams {
uint32_t M;
uint32_t N;
uint32_t K;
uint32_t K_packed;
uint32_t group_size;
uint32_t padded_N;
uint32_t has_bias;
uint32_t _pad;
};
static_assert(sizeof(Q4gswParams) == 32, "Q4gswParams must be 32 bytes");

// et_vk.linear_q4gsw args: [in, weight, scales, group_size, bias, out].
void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
const int in_id = args.at(0);
const int weight_id = args.at(1);
const int scales_id = args.at(2);
const int group_size_id = args.at(3);
const int bias_id = args.at(4);
const int out_id = args.at(5);

WGPUDevice device = graph.device();

const auto& in = graph.get_tensor(in_id);
const auto& weight = graph.get_tensor(weight_id);
const auto& scales = graph.get_tensor(scales_id);
const auto& out = graph.get_tensor(out_id);

if (in.dims.empty() || weight.dims.size() < 2 || scales.dims.size() < 2) {
throw std::runtime_error("WebGPU linear_q4gsw: malformed input dims");
}

// Shapes from the tensors' own dims (no dtype field at runtime).
const uint32_t K = static_cast<uint32_t>(in.dims.back());
if (K == 0) {
throw std::runtime_error("WebGPU linear_q4gsw: K == 0");
}
uint64_t in_numel = 1;
for (int64_t d : in.dims) {
in_numel *= static_cast<uint64_t>(d);
}
const uint32_t M = static_cast<uint32_t>(in_numel / K);
if (in_numel % K != 0) {
throw std::runtime_error(
"WebGPU linear_q4gsw: input numel not a multiple of K");
}
const uint32_t N = static_cast<uint32_t>(weight.dims[0]);
const uint32_t K_packed = static_cast<uint32_t>(weight.dims[1]);
const uint32_t num_groups = static_cast<uint32_t>(scales.dims[0]);
const uint32_t padded_N = static_cast<uint32_t>(scales.dims[1]);
if (M == 0 || N == 0) {
throw std::runtime_error("WebGPU linear_q4gsw: M or N == 0");
}
// int4 packing is 2 nibbles/byte, so K_packed must be ceil(K/2) (guards OOB).
if (K_packed != (K + 1) / 2) {
throw std::runtime_error("WebGPU linear_q4gsw: K_packed must be ceil(K/2)");
}
// Weight is read as array<u32>; a non-multiple-of-4 byte count over-reads.
if ((static_cast<uint64_t>(N) * K_packed) % 4u != 0u) {
throw std::runtime_error(
"WebGPU linear_q4gsw: N*K_packed must be a multiple of 4 (u32-packed)");
}

// One workgroup per output row (M); validate dispatch before any alloc.
const uint32_t workgroup_count =
utils::compute_1d_workgroup_count(device, M, 1, "linear_q4gsw");

// fp32-only byte-size guards (no runtime dtype); fp16 scales -> bail.
const uint64_t scales_numel =
static_cast<uint64_t>(num_groups) * static_cast<uint64_t>(padded_N);
const uint64_t weight_numel =
static_cast<uint64_t>(N) * static_cast<uint64_t>(K_packed);
if (in.nbytes != in_numel * sizeof(float) ||
out.nbytes != static_cast<uint64_t>(M) * N * sizeof(float) ||
scales.nbytes != scales_numel * sizeof(float) ||
weight.nbytes != weight_numel) {
throw std::runtime_error(
"WebGPU linear_q4gsw: fp32-only (byte-size mismatch)");
}

int64_t group_size = 0;
if (graph.get_value_type(group_size_id) == WebGPUGraph::ValueType::Int) {
group_size = graph.get_int(group_size_id);
}
if (group_size <= 0) {
throw std::runtime_error("WebGPU linear_q4gsw: group_size <= 0");
}
// scales is indexed [(k/group_size)*padded_N + n]; guard the table bounds.
const uint32_t gs = static_cast<uint32_t>(group_size);
if (num_groups < (K + gs - 1u) / gs || padded_N < N) {
throw std::runtime_error(
"WebGPU linear_q4gsw: scales dims too small for K/N");
}

// Optional bias: real buffer if present, else a dummy for the fixed layout.
uint32_t has_bias = 0;
WGPUBuffer bias_buffer = nullptr;
uint64_t bias_size = 4;
if (graph.get_value_type(bias_id) == WebGPUGraph::ValueType::Tensor) {
const auto& bias = graph.get_tensor(bias_id);
if (bias.buffer == nullptr || bias.nbytes < N * sizeof(float)) {
throw std::runtime_error(
"WebGPU linear_q4gsw: bias present but null/undersized");
}
has_bias = 1;
bias_buffer = bias.buffer;
bias_size = bias.nbytes;
}
if (bias_buffer == nullptr) {
bias_buffer = graph.create_scratch_buffer(4);
}

Q4gswParams params = {};
params.M = M;
params.N = N;
params.K = K;
params.K_packed = K_packed;
params.group_size = gs;
params.padded_N = padded_N;
params.has_bias = has_bias;

WGPUBufferDescriptor uniform_desc = {};
uniform_desc.size = sizeof(Q4gswParams);
uniform_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
uniform_desc.mappedAtCreation = true;
WGPUBuffer uniform_buffer = wgpuDeviceCreateBuffer(device, &uniform_desc);
void* mapped =
wgpuBufferGetMappedRange(uniform_buffer, 0, sizeof(Q4gswParams));
std::memcpy(mapped, &params, sizeof(Q4gswParams));
wgpuBufferUnmap(uniform_buffer);
graph.add_uniform_buffer_bytes(sizeof(Q4gswParams));

WGPUShaderSourceWGSL wgsl_desc = {};
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
wgsl_desc.code = {kQ4gswLinearWGSL, WGPU_STRLEN};
WGPUShaderModuleDescriptor shader_desc = {};
shader_desc.nextInChain = &wgsl_desc.chain;
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);

// Bind group layout: out (rw) + in/weight/scales/bias (ro storage) + uniform.
WGPUBindGroupLayoutEntry entries[6] = {};
entries[0].binding = 0;
entries[0].visibility = WGPUShaderStage_Compute;
entries[0].buffer.type = WGPUBufferBindingType_Storage;
for (uint32_t i = 1; i <= 4; i++) {
entries[i].binding = i;
entries[i].visibility = WGPUShaderStage_Compute;
entries[i].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
}
entries[5].binding = 5;
entries[5].visibility = WGPUShaderStage_Compute;
entries[5].buffer.type = WGPUBufferBindingType_Uniform;

WGPUBindGroupLayoutDescriptor bgl_desc = {};
bgl_desc.entryCount = 6;
bgl_desc.entries = entries;
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);

WGPUPipelineLayoutDescriptor pl_desc = {};
pl_desc.bindGroupLayoutCount = 1;
pl_desc.bindGroupLayouts = &bgl;
WGPUPipelineLayout pipeline_layout =
wgpuDeviceCreatePipelineLayout(device, &pl_desc);

const uint32_t wg_size =
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
WGPUConstantEntry wg_size_constant = {};
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
wg_size_constant.value = static_cast<double>(wg_size);

WGPUComputePipelineDescriptor pipeline_desc = {};
pipeline_desc.layout = pipeline_layout;
pipeline_desc.compute.module = shader;
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
pipeline_desc.compute.constantCount = 1;
pipeline_desc.compute.constants = &wg_size_constant;
WGPUComputePipeline pipeline =
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);

WGPUBindGroupEntry bg_entries[6] = {};
bg_entries[0].binding = 0;
bg_entries[0].buffer = out.buffer;
bg_entries[0].size = out.nbytes;
bg_entries[1].binding = 1;
bg_entries[1].buffer = in.buffer;
bg_entries[1].size = in.nbytes;
bg_entries[2].binding = 2;
bg_entries[2].buffer = weight.buffer;
bg_entries[2].size = weight.nbytes;
bg_entries[3].binding = 3;
bg_entries[3].buffer = scales.buffer;
bg_entries[3].size = scales.nbytes;
bg_entries[4].binding = 4;
bg_entries[4].buffer = bias_buffer;
bg_entries[4].size = bias_size;
bg_entries[5].binding = 5;
bg_entries[5].buffer = uniform_buffer;
bg_entries[5].size = sizeof(Q4gswParams);

WGPUBindGroupDescriptor bg_desc = {};
bg_desc.layout = bgl;
bg_desc.entryCount = 6;
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count, "linear_q4gsw"});

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
wgpuBufferRelease(uniform_buffer);
}

} // namespace

WEBGPU_REGISTER_OPERATORS {
WEBGPU_REGISTER_OP(et_vk.linear_q4gsw.default, q4gsw_linear_impl);
}

} // namespace executorch::backends::webgpu
64 changes: 64 additions & 0 deletions backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
@group(0) @binding(1) var<storage, read> t_input: array<f32>;
@group(0) @binding(2) var<storage, read> t_weight: array<u32>;
@group(0) @binding(3) var<storage, read> t_scales: array<f32>;
@group(0) @binding(4) var<storage, read> t_bias: array<f32>;

struct Params {
M: u32,
N: u32,
K: u32,
K_packed: u32,
group_size: u32,
padded_N: u32,
has_bias: u32,
_pad: u32,
}
@group(0) @binding(5) var<uniform> params: Params;

override wg_size: u32 = 64u;

// One workgroup per row m, threads stride N; loop logical K only (in-bounds).
@compute @workgroup_size(wg_size, 1, 1)
fn main(
@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let m = wid.x;
if (m >= params.M) {
return;
}
let in_base = m * params.K;

var n: u32 = lid.x;
loop {
if (n >= params.N) {
break;
}
var acc: f32 = 0.0;
var k: u32 = 0u;
loop {
if (k >= params.K) {
break;
}
// Packed weight byte for (n, k): row stride K_packed bytes, byte k/2.
let byte_idx = n * params.K_packed + (k >> 1u);
let word = t_weight[byte_idx >> 2u];
let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu;
var nib: u32;
if ((k & 1u) == 0u) {
nib = b & 0x0Fu; // even k -> low nibble
} else {
nib = (b >> 4u) & 0x0Fu; // odd k -> high nibble
}
let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7]
let scale = t_scales[(k / params.group_size) * params.padded_N + n];
acc = acc + t_input[in_base + k] * q * scale;
k = k + 1u;
}
if (params.has_bias != 0u) {
acc = acc + t_bias[n];
}
t_out[m * params.N + n] = acc;
n = n + wg_size;
}
}
Loading
Loading