Tiny-GEMM

Tiny-GEMM is a collection of fused Triton kernels that make decode-time transformer inference fast on resource-constrained GPUs by minimizing memory traffic and fusing sublayers.

TritonKernel FusionINT4Transformer InferenceProfiling
Tiny-GEMM hero visual

Overview

modern transformer inference is often bottlenecked not by flops, but by memory traffic, kernel launch overhead, and poor cache utilization--especially in the small-batch, low-latency regime (batch = 1-8).

tiny-gemm targets the two most dominant transformer compute paths: multi-head attention and feed-forward networks (mlps / ffns). the goal is to make decode-time inference fast by fusing operations, maximizing reuse in sram/cache, and exploiting packed int4 weights.

Why small-batch inference is hard

most optimized gpu kernels are tuned for training-like throughput: large batch sizes, long steady-state compute, and high arithmetic intensity. real deployment looks different: batch size ~ 1, decode steps are sequential, memory dominates compute, and launch overhead matters.

  • fusing whole transformer sublayers
  • io-aware tiling
  • weight-only quantization
  • cache-aligned layouts

Figure 1 — Transformer Inference Bottleneck Map

Small-batch inference is constrained less by compute and more by memory movement and kernel launch overhead.

Transformer inference bottleneck map

Diagram prompt

Show an attention + MLP block with arrows labeled 'HBM traffic dominates'. Emphasize memory movement and launch overhead over compute.

Fused multi-head attention kernel

transformer attention is conceptually: Attn(Q,K,V) = Softmax((QK^T) / sqrt(d_k)) V. naively, this pipeline allocates large intermediate matrices (QK^T, masked scores, softmax probabilities).

tiny-gemm computes attention in one fused triton kernel using a flashattention-style tiling approach. attention must be io-aware, minimizing reads/writes to hbm by keeping working tiles inside sram/registers.

  • block tiling for batch=1 decode workloads
  • fused causal masking (autoregressive safe)
  • locality-aware q/k/v access
  • optional dropout support

Figure 2 — Naive vs Fused Attention Pipeline

Tiny-GEMM computes attention in one fused Triton kernel, avoiding intermediate writes.

Naive vs fused attention pipeline

Diagram prompt

Left: QK^T -> mask -> softmax -> V with four kernel boxes. Right: single fused block. Use minimal arrows and labels.

Figure 3 — FlashAttention-Style Tiling in SRAM

IO-aware tiling keeps score computation and softmax normalization on-chip, reducing HBM reads/writes.

FlashAttention-style tiling in SRAM

Diagram prompt

Block matrix tiles inside GPU SRAM with arrows showing on-chip reuse. Emphasize 'on-chip' vs 'HBM'.

Fused feed-forward network (ffn)

the transformer mlp block is typically: Y = sigma(XW1 + B1) W2 + B2. standard implementations launch gemm, bias add, activation, gemm, bias add. tiny-gemm fuses the full pipeline to reduce kernel boundaries, intermediate writes, and memory bandwidth.

Figure 4 — FFN Fusion: GEMM -> Act -> GEMM

FFN fusion eliminates bandwidth-heavy intermediate activations.

FFN fusion diagram

Diagram prompt

Show two GEMMs with activation between, crossed-out intermediate buffers, and a single fused box on the right.

Packed INT4 quantization framework

for inference, weights dominate memory footprint. tiny-gemm implements per-channel int4 weight packing, custom dequantization in kernel, and packed int4 gemm primitives. int4 provides ~8x compression vs fp32 and boosts throughput in memory-bound regimes.

Figure 5 — Packed INT4 Weight Layout

Packed INT4 weights reduce memory footprint and improve cache residency, enabling faster weight-only inference.

Packed INT4 layout

Diagram prompt

Diagram showing two INT4 packed into one byte. Use a simple 8-bit box split into two 4-bit halves.

PyTorch operator integration

tiny-gemm registers fused attention + ffn as first-class pytorch ops using torch.library. this enables integration into torch.compile graphs, transformer backends, and higher-level inference runtimes.

import tiny_gemm.ops

out = torch.ops.tiny_gemm.fused_attention(q, k, v, causal=True)

Figure 7 — PyTorch Op Registration Stack

Custom operator registration makes fused kernels composable inside modern PyTorch inference graphs.

PyTorch op registration stack

Diagram prompt

Stacked diagram: torch.compile -> torch.library -> Triton kernel. Show flow arrows.

Profiling + bottleneck discovery

optimization work is only meaningful when guided by measurement. tiny-gemm includes pytorch profiler integration, tensorboard traces, and kernel-level bottleneck surfacing.

  • profile -> identify io wall -> fuse -> retile -> benchmark -> repeat

Benchmark highlights

benchmarks compare baseline pytorch attention/ffn, fused triton kernels, and int4 quantized weights. gains are largest for batch=1-4, sequence length <= 2k, decode-style inference workloads.

Figure 6 — Benchmark Plot

Fused kernels + INT4 quantization provide the largest speedups in batch=1 decode workloads.

Benchmark plot

Diagram prompt

Line chart: PyTorch FP16 baseline, Tiny-GEMM fused, Tiny-GEMM INT4. Emphasize batch=1 decode gains.

Project structure

  • triton_fused_transformer.py -- fused attention + ffn kernels
  • triton_gemm.py -- packed int4 gemm
  • quantize_utils.py -- quant/dequant utilities
  • benchmark_fused_transformer.py -- benchmarking harness
  • tiny_gemm/ops.py -- torch.library op registration
  • docker/ -- reproducible cuda runtime

Future work

  • flashattention-2 style scheduling improvements
  • additional fused blocks: layernorm + residual
  • broader int4 support across hidden dimension patterns
  • compiler-level integration into full transformer runtimes