Tiny-GEMM: Packed INT4 Triton GEMM for Decode-Heavy LLM Inference
Small-batch LLM decoding is dominated by narrow GEMMs that stress memory bandwidth and launch overhead rather than peak FLOPs. Tiny-GEMM is a packed INT4 GEMM kernel in Triton for decode-heavy shapes, with measurement-driven analysis backed by hardware counters of when weight-only INT4 helps or hurts.
Motivation
LLM inference and training are fundamentally different workloads. Training runs large, square-ish matrices that saturate tensor cores. Decoding runs one token at a time — batch size 1 to 8, skinny weight projections, tight latency budgets. In this regime the bottleneck shifts entirely: it's memory bandwidth and kernel launch overhead, not peak FLOPs.
This creates a real problem for quantization. The naive story is "INT4 halves your weight size so you get 2× bandwidth and 2× speed." But in the decode regime that math breaks down — you also have to unpack those 4-bit values back to float inside the kernel, and that dequantization cost is fixed per launch regardless of how much work you do. For narrow projections like KV, that overhead dominates and INT4 ends up slower than FP16.
Tiny-GEMM is my attempt to pin this down precisely: build the fused kernel, run it against FP16 and dequantized-FP16 baselines across the actual decode shapes that matter, use Nsight Compute to see what's really happening in hardware, and derive a concrete rule for when INT4 is worth using.
The Kernel
The kernel is written in Triton with per-tensor quantization and bit-packed weight tensors. INT4 values are stored two-per-byte and unpacked into FP32 accumulators inside the kernel. Tile configurations are static and keyed by shape family and batch bucket — decode shapes cluster tightly enough that a small lookup table beats dynamic autotuning at runtime.
Three baselines are compared across every shape: FP16 (torch.matmul / cuBLAS, vendor-optimized), dequantized FP16 (INT4 quantize → dequant → FP16 GEMM as a two-step pipeline), and the fused INT4 kernel. The dequantized baseline is important — it isolates whether the problem is the quantization format or the fused computation, and it's what most deployed systems actually do before switching to a fused kernel.
One important caveat: the kernel accumulates in FP32, not using INT4 tensor core MMA instructions. Exploiting Ampere/Hopper INT4 MMA is future work — the current bottleneck on decode shapes is memory, not compute, so the MMA throughput gap doesn't matter yet.
Setup
All experiments run on an NVIDIA A10G. Shapes are derived from Llama-style models — Q/K/V projections (K=N=4096), KV projections (K=4096, N=1024), FFN up-projections (K=4096, N=14336), and FFN down-projections (K=14336, N=4096). Batch sizes M ∈ {1…8} cover the decode regime. Each latency is the median of 50 runs after 10 warmup iterations; profiling uses Nsight Compute for hardware counters and Nsight Systems for kernel time breakdowns.
| Layer | M | K | N |
|---|---|---|---|
| Q/K/V proj | 1–8 | 4096 | 4096 |
| KV proj | 1–8 | 4096 | 1024 |
| FFN up | 1–8 | 4096 | 14336 |
| FFN down | 1–8 | 14336 | 4096 |
Results
The headline numbers at M=1. The split is immediate — FFN up gets 3.58×, KV proj gets 0.62× (it's slower with INT4). This isn't a subtle effect or a tuning artifact; it's a structural consequence of shape geometry.
| Shape | FP16 (ms) | INT4 (ms) | Speedup | Bottleneck |
|---|---|---|---|---|
| KV proj (K=4096, N=1024) | 0.027 | 0.043 | 0.62× | Dequant overhead |
| Q proj (K=4096, N=4096) | 0.075 | 0.047 | 1.58× | Mixed |
| FFN up (K=4096, N=14336) | 0.239 | 0.067 | 3.58× | Memory bandwidth |
| FFN down (K=14336, N=4096) | 0.258 | 0.152 | 1.69× | Memory bandwidth |
Prefill vs. Decode
The same kernel, same weights, different batch size — the story changes completely. In prefill you're running M in the hundreds or thousands, so dequantization overhead gets amortized across a huge amount of output work. INT4 helps across nearly all shapes. In decode, M is 1–8 and the fixed overhead per launch is a much larger fraction of total runtime.
This means a blanket quantization policy that's tuned for prefill throughput can actively hurt decode latency on the same model. Deployment decisions need to be mode-aware, not just shape-aware.
The Regime Model
To make the pattern precise, I decompose kernel runtime into four additive costs:
INT4 reduces T_mem by roughly 2× (half the bits to move). INT4 also adds T_dequant. The kernel wins when the bandwidth savings exceed the unpack cost:
This inequality is equivalent to an arithmetic intensity threshold. Below a certain α (FLOPs/byte), dequantization dominates and INT4 loses. Above it, bandwidth savings dominate and INT4 wins. In the sweep that boundary falls at roughly α ≈ 8 FLOPs/byte.
Hardware Counter Attribution
The regime model is clean but abstract — Nsight Compute lets me check it against actual hardware behavior. FP16 decode GEMMs on the A10G reach ~75–77% of peak DRAM bandwidth while compute utilization stays low. This is the textbook memory-bound regime: the GPU is waiting on DRAM, not doing arithmetic.
INT4 ends up at ~23% peak SM throughput vs ~32% for FP16. That's 28% less compute pressure — not because INT4 is more efficient, but because it's doing less useful work per cycle (more of the SM time goes to the unpack path). The memory traffic numbers confirm it: INT4 halves weight reads, consistently, across all shapes. The variable is whether you can convert that into latency savings.
Systems View
Decode latency and serving throughput are different objectives that sometimes point in opposite directions. Interactive serving wants minimum single-token latency (M=1). Batch serving wants maximum tokens/second (larger M). INT4 behaves differently in each.
The batch-size stability of the speedup profile is actually good news for deployment: it means the INT4/FP16 decision is static per layer, not dynamic per request. You don't need to re-evaluate at runtime — just apply the α > 8 FLOPs/byte rule at model load time.
Takeaways
The practical upshot: don't apply INT4 uniformly. The arithmetic intensity threshold (α ≈ 8 FLOPs/byte) is a reliable decision boundary. Above it — wide FFN projections — INT4 wins by 1.5–3.7× in decode. Below it — narrow KV projections — keep FP16. The layers you most want to quantize (FFN, because they're the largest) are also the ones where INT4 actually helps.
There's a broader systems lesson here too: reducing bandwidth pressure doesn't automatically improve latency if the freed capacity gets consumed by something else. Quantization is only effective when arithmetic intensity is high enough to amortize the dequantization overhead — and that threshold is measurable. The model isn't hard to derive; you just have to actually measure it instead of assuming.
What's next
- INT4 tensor core MMA: the kernel currently accumulates in FP32, skipping Ampere/Hopper INT4 MMA instructions. On compute-bound shapes this matters.
- Split-K for M=1 to improve SM occupancy on the narrowest projections by splitting the K dimension across thread blocks.
- FP8 on Blackwell. tcgen05.mma.kind::f8f6f4 changes the roofline substantially; re-evaluating the regime boundary on B200 is the next step.
- Multi-GPU and serving stack integration, connecting kernel-level gains to end-to-end serving latency under concurrent requests.