Tiny-GEMM
Tiny-GEMM is a collection of fused Triton kernels that make decode-time transformer inference fast on resource-constrained GPUs by minimizing memory traffic and fusing sublayers.

Overview
modern transformer inference is often bottlenecked not by flops, but by memory traffic, kernel launch overhead, and poor cache utilization--especially in the small-batch, low-latency regime (batch = 1-8).
tiny-gemm targets the two most dominant transformer compute paths: multi-head attention and feed-forward networks (mlps / ffns). the goal is to make decode-time inference fast by fusing operations, maximizing reuse in sram/cache, and exploiting packed int4 weights.
Why small-batch inference is hard
most optimized gpu kernels are tuned for training-like throughput: large batch sizes, long steady-state compute, and high arithmetic intensity. real deployment looks different: batch size ~ 1, decode steps are sequential, memory dominates compute, and launch overhead matters.
- fusing whole transformer sublayers
- io-aware tiling
- weight-only quantization
- cache-aligned layouts
Figure 1 — Transformer Inference Bottleneck Map
Small-batch inference is constrained less by compute and more by memory movement and kernel launch overhead.
Diagram prompt
Show an attention + MLP block with arrows labeled 'HBM traffic dominates'. Emphasize memory movement and launch overhead over compute.
Fused multi-head attention kernel
transformer attention is conceptually: Attn(Q,K,V) = Softmax((QK^T) / sqrt(d_k)) V. naively, this pipeline allocates large intermediate matrices (QK^T, masked scores, softmax probabilities).
tiny-gemm computes attention in one fused triton kernel using a flashattention-style tiling approach. attention must be io-aware, minimizing reads/writes to hbm by keeping working tiles inside sram/registers.
- block tiling for batch=1 decode workloads
- fused causal masking (autoregressive safe)
- locality-aware q/k/v access
- optional dropout support
Figure 2 — Naive vs Fused Attention Pipeline
Tiny-GEMM computes attention in one fused Triton kernel, avoiding intermediate writes.
Diagram prompt
Left: QK^T -> mask -> softmax -> V with four kernel boxes. Right: single fused block. Use minimal arrows and labels.
Figure 3 — FlashAttention-Style Tiling in SRAM
IO-aware tiling keeps score computation and softmax normalization on-chip, reducing HBM reads/writes.
Diagram prompt
Block matrix tiles inside GPU SRAM with arrows showing on-chip reuse. Emphasize 'on-chip' vs 'HBM'.
Fused feed-forward network (ffn)
the transformer mlp block is typically: Y = sigma(XW1 + B1) W2 + B2. standard implementations launch gemm, bias add, activation, gemm, bias add. tiny-gemm fuses the full pipeline to reduce kernel boundaries, intermediate writes, and memory bandwidth.
Figure 4 — FFN Fusion: GEMM -> Act -> GEMM
FFN fusion eliminates bandwidth-heavy intermediate activations.
Diagram prompt
Show two GEMMs with activation between, crossed-out intermediate buffers, and a single fused box on the right.
Packed INT4 quantization framework
for inference, weights dominate memory footprint. tiny-gemm implements per-channel int4 weight packing, custom dequantization in kernel, and packed int4 gemm primitives. int4 provides ~8x compression vs fp32 and boosts throughput in memory-bound regimes.
Figure 5 — Packed INT4 Weight Layout
Packed INT4 weights reduce memory footprint and improve cache residency, enabling faster weight-only inference.
Diagram prompt
Diagram showing two INT4 packed into one byte. Use a simple 8-bit box split into two 4-bit halves.
PyTorch operator integration
tiny-gemm registers fused attention + ffn as first-class pytorch ops using torch.library. this enables integration into torch.compile graphs, transformer backends, and higher-level inference runtimes.
import tiny_gemm.ops
out = torch.ops.tiny_gemm.fused_attention(q, k, v, causal=True)Figure 7 — PyTorch Op Registration Stack
Custom operator registration makes fused kernels composable inside modern PyTorch inference graphs.
Diagram prompt
Stacked diagram: torch.compile -> torch.library -> Triton kernel. Show flow arrows.
Profiling + bottleneck discovery
optimization work is only meaningful when guided by measurement. tiny-gemm includes pytorch profiler integration, tensorboard traces, and kernel-level bottleneck surfacing.
- profile -> identify io wall -> fuse -> retile -> benchmark -> repeat
Benchmark highlights
benchmarks compare baseline pytorch attention/ffn, fused triton kernels, and int4 quantized weights. gains are largest for batch=1-4, sequence length <= 2k, decode-style inference workloads.
Figure 6 — Benchmark Plot
Fused kernels + INT4 quantization provide the largest speedups in batch=1 decode workloads.
Diagram prompt
Line chart: PyTorch FP16 baseline, Tiny-GEMM fused, Tiny-GEMM INT4. Emphasize batch=1 decode gains.
Project structure
- triton_fused_transformer.py -- fused attention + ffn kernels
- triton_gemm.py -- packed int4 gemm
- quantize_utils.py -- quant/dequant utilities
- benchmark_fused_transformer.py -- benchmarking harness
- tiny_gemm/ops.py -- torch.library op registration
- docker/ -- reproducible cuda runtime
Future work
- flashattention-2 style scheduling improvements
- additional fused blocks: layernorm + residual
- broader int4 support across hidden dimension patterns
- compiler-level integration into full transformer runtimes