Sparse Attention Kernel for DeepSeek V3.2 on B200
Implements DeepSeek Sparse Attention (DSA) on B200 Blackwell, fusing FP8 indexer scoring, histogram-based top-K selection, and sparse BF16 attention into minimal kernel launches using TMEM, tcgen05.mma, and page-sorted gather.
The Problem: O(L²) Attention at Long Context
Standard full attention scales as O(L²) in both compute and memory. At L=128K tokens, a single attention head requires 128K × 128K = 16B attention weights — 64 GB at FP16. That's not a compute problem, it's a physics problem.
DeepSeek V3.2 addresses this with DeepSeek Sparse Attention (DSA): instead of attending to all L tokens, each query identifies the k=2048 most relevant KV tokens and computes full BF16 attention only over that sparse subset. The selection uses a lightweight FP8 indexer — a compressed key cache of 132 bytes/token — to score all L tokens cheaply, then picks the top-2048 by score. The result is O(L·k) compute instead of O(L²), with attention quality close to full attention because attention weight distributions are naturally sparse.
| Metric | Full Attention | DSA Sparse Attention |
|---|---|---|
| Compute complexity | O(L²) | O(L·k), k=2048 |
| KV tokens attended | L (all) | 2048 (top-k) |
| Memory per head (L=128K) | ~64 GB at FP16 | ~16 MB at BF16 |
| Scoring dtype | BF16 | FP8 (132 bytes/token) |
| Top-k selection | N/A | Histogram scan O(256) |
B200 Blackwell: New Architecture Primitives
The kernel targets NVIDIA B200, which introduces a fundamentally different programming model compared to Hopper/Ampere. Two new primitives are central to this work:
Tensor Memory (TMEM)
TMEM is 256 KB of dedicated per-SM accumulator storage, separate from registers and shared memory. Estimated bandwidth: ~100 TB/s — 10× faster than SMEM and unconstrained by the register file. The new tcgen05.mma instruction writes accumulators directly into TMEM rather than registers, eliminating register pressure from accumulation entirely. This is a major win for attention kernels: accumulating Q·K and attn·V results previously competed with live register state for the finite register file.
tcgen05.mma vs Hopper wgmma
Hopper's wgmma (warp-group MMA) requires all 128 threads in a warp group to participate in a single synchronous MMA operation. Blackwell's tcgen05.mma has single-thread semantics — one thread issues the MMA, freeing the other 31 threads in the warp to do epilogue work (softmax, output writeback) concurrently. This enables clean warp specialization without the tight coupling that made Hopper kernels hard to overlap.
| Feature | Hopper (H100) | Blackwell (B200) |
|---|---|---|
| Peak FP8 compute | ~2.0 PFLOPS | ~4.5 PFLOPS |
| HBM bandwidth | 3.35 TB/s | 8 TB/s |
| L2 cache | 50 MB | 65 MB |
| MMA instruction | wgmma (128-thread) | tcgen05.mma (1-thread) |
| Accumulator storage | Registers | TMEM (256KB/SM, ~100 TB/s) |
End-to-End Kernel Design
The kernel fuses three logical stages into minimal CUDA launches:
Stage 1 — FP8 Indexer Scoring
Use tcgen05.mma.kind::f8f6f4 to compute Q·K scores over all L tokens using the compressed FP8 key cache (132 bytes/token). In the scoring kernel's epilogue, fuse a histogram accumulation pass: since FP8 scores map to only 256 discrete values, a per-bin histogram in shared memory captures the full score distribution. This enables exact top-K selection without sorting.
Stage 2 — Top-K Selection via Histogram Scan
The standard approach to top-K is radix sort: O(L log L). Here, FP8 quantization makes O(256) possible: scan the 256-bin histogram from the highest bin downward, accumulating counts until the running sum reaches k=2048. The threshold bin gives the exact top-K boundary. This is the most elegant property of FP8 scoring — the discrete score space converts an O(L log L) selection problem into an O(256) scan. Selected indices are then sorted by page_id for coalesced HBM access in stage 3.
Stage 3 — Sparse Gather + FlashAttention
Load 32 selected pages (2048 tokens at page_size=64) from HBM via cp.async with double-buffering — while computing attention over page i, page i+1 is loading. Run FlashAttention's online softmax across 64-token tiles using tcgen05.mma for Q·K and attn·V, accumulating in TMEM at FP32 precision. Output is written to HBM in BF16.
Key Optimizations
Page-sorted gather
The 2048 selected token indices from stage 2 are sorted by page_id before loading. This transforms random HBM accesses (one cache-line per arbitrary token) into sequential page-level reads (64 contiguous tokens per page). Within each page, 32 threads each load 4 bytes of a 128-byte FP8 key — a single perfectly coalesced cache-line transaction. Estimated efficiency improvement: 40–60% on gather bandwidth.
Online softmax with exp2
FlashAttention's online softmax maintains running (max, denominator, output) accumulators across tiles, avoiding materializing the full N×N attention weight matrix. Using exp2f() instead of expf() in the softmax maps directly to the GPU's MUFU.EX2 hardware instruction, avoiding the implicit multiply by log2(e) = 1.4427 that expf() requires internally. This is a ~1.4× throughput improvement on the softmax step at no accuracy cost.
L2 persistence
Total gathered KV data per query: 2048 tokens × 128 bytes/token (BF16 keys + values) = 512 KB. This fits in B200's 65 MB L2. Marking the index arrays and hot KV pages as cudaAccessPropertyPersisting prevents eviction across attention heads and layers, eliminating repeated HBM round-trips for the same KV pages.
Warp specialization
Four warp roles per thread block: one producer warp (issues cp.async loads for the next page while current page computes), two MMA warps (Q·K scoring and attn·V accumulation via tcgen05.mma), and one epilogue warp (online softmax correction and BF16 output writeback). Unlike Hopper where wgmma requires 128-thread synchrony, Blackwell's single-thread MMA semantics makes this clean warp specialization straightforward.
| Theoretical bound | Value | Assumption |
|---|---|---|
| Min gather latency (512KB @ 8TB/s) | ~64 ns | 100% HBM bandwidth utilization |
| Target end-to-end per step | 5–15 µs | Well-optimized kernel, 128K context |
| Dense attention (128K ctx) | ~300 µs | Full O(L²) FlashAttention-3 estimate |
| Theoretical speedup | ~20–60× | DSA vs dense at L=128K |