back to journal

How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance

going from naive to 94% of cuBLAS, one kernel at a time

SGEMM is probably the single most important computational kernel in modern deep learning — if you could only profile one operation in a transformer, it would be this one. I wanted to understand from first principles why GPUs are so good at it, and what it actually takes to close the gap with cuBLAS. This is my attempt to work through that iteratively, going from a naive kernel to something that hits ~94% of cuBLAS on an A6000.

The thing I find fascinating about this exercise is that the gap between the naive implementation and cuBLAS is about 75x in raw throughput. Almost none of that gap comes from algorithmic cleverness — it's all about understanding the memory hierarchy and feeding the compute units correctly. The math is simple; the hard part is data movement.

Before diving in: CUDA organizes computation into a three-level hierarchy. A kernel launch creates a grid of blocks, each block contains up to 1024 threads, and threads within the same block share a fast on-chip scratchpad called shared memory (SMEM). This hierarchy exists primarily to map cleanly onto GPU hardware — blocks map to streaming multiprocessors (SMs), and threads within a block share SMEM.

The blockDim variable is a 3D integer vector specifying how many threads live in each block dimension. Combined with gridDim (the number of blocks), this determines the full thread layout:

CUDA thread hierarchy diagram showing blockDim and threadIdx relationships
The CUDA thread hierarchy. blockDim.x × blockDim.y gives threads per block; gridDim.x × gridDim.y gives the total block count.

One important mental model shift: the thread hierarchy is primarily a correctness tool, not a performance one. For performance, you have to think in terms of warps — groups of 32 threads that execute in lockstep on the hardware. The block/grid structure tells you who can communicate; the warp structure tells you what the hardware actually schedules.

Kernel 1: Naive Implementation

The simplest possible kernel: assign each thread one output element in C, then loop over the K dimension accumulating the dot product. Dead simple, and predictably slow. Each thread independently reads a full row of A and full column of B, with essentially no data reuse.

Naive kernel visualization showing thread layout and computation assignments
Each thread computes exactly one output entry of C.
__global__ void sgemm_naive(int M, int N, int K, float alpha, const float *A,
                            const float *B, float beta, float *C) {
  const uint x = blockIdx.x * blockDim.x + threadIdx.x;
  const uint y = blockIdx.y * blockDim.y + threadIdx.y;

  if (x < M && y < N) {
    float tmp = 0.0;
    for (int i = 0; i < K; ++i) {
      tmp += A[x * K + i] * B[i * N + y];
    }
    C[x * N + y] = alpha * tmp + beta * C[x * N + y];
  }
}

CUDA kernels are written from a single-thread perspective — you write what one thread does, and the runtime stamps out N copies of that logic across the grid. The blockIdx and threadIdx builtins tell each thread where it sits. The math for your global index is always the same pattern: blockIdx * blockDim + threadIdx.

One subtlety worth flagging: if the matrix dimension isn't cleanly divisible by BLOCKSIZE, you need to launch extra blocks to cover the remainder. Those blocks will have some inactive threads (tile quantization). It's a small overhead for large matrices, but matters at small sizes — which is part of why cuBLAS switches kernels depending on matrix dimensions.

Tile quantization illustration showing partial block utilization
Tile quantization: extra blocks are launched for the remainder, not all threads are active.

Napkin Math: How Fast Can This Be?

Before profiling, let's bound the problem. For two 4092² matrices, the GEMM requires 2×4092³ ≈ 137 GFLOPs. The minimum GMEM transfer is 268MB (3 matrices × 4092² × 4B). On an A6000 with 30 TFLOPs/s compute and 768 GB/s bandwidth, compute takes ~4.5ms and memory takes ~0.34ms. The kernel is ~13x more compute-intensive than memory-intensive — so it should be compute-bound once we stop wasting memory bandwidth. cuBLAS itself loads about 500MB during the computation (not the theoretical minimum), which is the target to beat.

I find this ratio really useful to keep in mind as we go through each optimization. The question to ask at each step is: are we closer to being compute-bound, or are we still hemorrhaging bandwidth? The roofline model makes this explicit.

Why the Naive Kernel is Terrible

Two threads in the same block with threadIds (0,0) and (0,1) load the same column of B but different rows of A. With no caching, each thread loads 2×4092 floats, so 4092² threads produce 548GB of memory traffic for a 268MB problem — a 2× overshoot even in theory, and in practice much worse because there's no data reuse at all across threads.

Two threads' data access patterns from matrices A and B
Memory access pattern of the naive kernel for two example threads (red and green).

Result: ~300 GFLOPs on an A6000. For context, that's about what a well-tuned 2015 Haswell CPU achieves. A GPU with 100× the memory bandwidth is performing at CPU level because we're completely ignoring its access pattern requirements.

Kernel 2: Global Memory Coalescing

The key GPU concept here is the warp. Threads within a block are grouped into warps of 32, and a warp is the actual unit of execution on the hardware — all 32 threads in a warp execute the same instruction simultaneously (SIMT). The warp scheduler is what actually dispatches instructions to the CUDA cores.

Warps are formed from consecutive threadIds: threadId = threadIdx.x + blockDim.x*(threadIdx.y + blockDim.y*threadIdx.z). The x dimension is the fast-varying one. Think of it as column-major in 'warp space' — threads with adjacent threadIdx.x values end up in the same warp.

Illustration of how threadIds map to warps
Thread-to-warp mapping illustrated using an 8-thread warp example (real warps have 32 threads).

Global memory coalescing is the single most important GMEM optimization on GPU. When threads in the same warp issue memory requests to consecutive addresses, the hardware can combine them into one transaction. 32 threads × 4 bytes = 128 bytes, which fits exactly in one L2 cache line — perfect coalescing means 1 transaction per warp instead of 32.

Consecutive memory accesses grouped into single transactions
Global memory coalescing groups consecutive warp accesses into fewer, larger transactions.

An interesting nuance: threads within a warp don't have to access memory in threadIdx order for coalescing to work — they just need to collectively touch a consecutive, aligned 128B region. The hardware figures out which transaction to issue based on the union of all the addresses. Non-consecutive within-warp access patterns can still coalesce as long as the addresses themselves are contiguous.

Non-consecutive within-warp accesses that still coalesce
Non-consecutive within-warp accesses can still coalesce as long as they target consecutive addresses.

In the naive kernel we mapped threadIdx.x to the row of A. Threads with consecutive threadIdx.x (i.e., in the same warp) therefore load consecutive rows of A — but A is row-major, so consecutive rows are strided by K floats. That's 32 × K × 4B of cache-line fetches for K dot-product steps. No coalescing at all.

Non-consecutive row loading from matrix A in naive kernel
The naive kernel accesses A non-consecutively, preventing coalescing.

The fix: swap how we assign output elements to threads. Instead of threadIdx.x → row, use threadIdx.x → column. Threads in the same warp now compute the same row of C but consecutive columns — meaning they load the same row of A (broadcast-friendly) and consecutive columns of B (coalesced).

Reorganized thread-to-result mapping for coalescing
Reorganizing the thread-to-result mapping enables coalesced global memory access.

To implement this, we only need to change the first two lines of the index computation:

const int x = blockIdx.x * BLOCKSIZE + (threadIdx.x / BLOCKSIZE);
const int y = blockIdx.y * BLOCKSIZE + (threadIdx.x % BLOCKSIZE);

if (x < M && y < N) {
  float tmp = 0.0;
  for (int i = 0; i < K; ++i) {
    tmp += A[x * K + i] * B[i * N + y];
  }
  C[x * N + y] = alpha * tmp + beta * C[x * N + y];
}

What I found surprising: enabling coalescing changes zero assembly instructions. Coalescing is handled entirely by the hardware memory system at runtime — the PTX and SASS look identical. That makes sense: the compiler can't know at compile time whether the base pointers will be aligned, so it can't emit different instructions. The hardware figures it out dynamically per transaction.

The payoff is huge: memory throughput jumps from 15 GB/s to 110 GB/s, and FLOP/s goes from ~300 to ~2000 GFLOPS. We haven't changed any math, any shared memory usage, or any arithmetic — just the index assignment. It's a pure access pattern win.

Kernel 3: Shared Memory Cache-Blocking

The GPU memory hierarchy has a crucial middle tier: shared memory (SMEM). It sits on-chip, physically next to the CUDA cores, and is partitioned among blocks — every thread in a block can read and write the same SMEM region. On Volta-era hardware, SMEM bandwidth is measured at ~12 TB/s vs. ~750 GB/s for DRAM — roughly a 16× difference. On an A6000, each block gets up to 48KB of SMEM.

GPU memory hierarchy showing cache structure for A100
The GPU memory hierarchy. Shared memory is on-chip and orders of magnitude faster than DRAM.

The cache-blocking idea: instead of every thread independently fetching from GMEM, a group of threads cooperatively loads a tile of A and a tile of B into SMEM. Then everyone computes on the fast local copy. We slide the tile along the K dimension, accumulating partial sums. Each float in SMEM gets used by multiple threads, so the GMEM traffic per FLOP drops significantly.

Cache-blocking visualization showing chunk-based loading and computation
Cache-blocking: load tiles of A and B into shared memory, compute partial sums, then advance.
A += cRow * BLOCKSIZE * K;
B += cCol * BLOCKSIZE;
C += cRow * BLOCKSIZE * N + cCol * BLOCKSIZE;

float tmp = 0.0;
for (int bkIdx = 0; bkIdx < K; bkIdx += BLOCKSIZE) {
  As[threadRow * BLOCKSIZE + threadCol] = A[threadRow * K + threadCol];
  Bs[threadRow * BLOCKSIZE + threadCol] = B[threadRow * N + threadCol];

  __syncthreads();

  A += BLOCKSIZE;
  B += BLOCKSIZE * N;

  for (int dotIdx = 0; dotIdx < BLOCKSIZE; ++dotIdx) {
    tmp += As[threadRow * BLOCKSIZE + dotIdx] *
           Bs[dotIdx * BLOCKSIZE + threadCol];
  }
  __syncthreads();
}
C[threadRow * N + threadCol] =
    alpha * tmp + beta * C[threadRow * N + threadCol];

The result is ~2200 GFLOPS — only a 50% improvement. That might seem small given the effort. The reason: Kernel 2 already got decent L1 hit rates due to the access pattern change, so the explicit SMEM tiling doesn't buy as much as you'd hope. More importantly, the roofline reveals the real problem.

Roofline analysis showing performance vs arithmetic intensity for Kernel 3
Roofline plot for Kernel 3. We're far from the compute roofline — arithmetic intensity is the bottleneck.

The Roofline and the Arithmetic Intensity Problem

The roofline model plots achieved FLOPs/s vs. arithmetic intensity (FLOPs per byte of memory traffic). It has two limits: a horizontal line at peak compute (30 TFLOPs/s), and a diagonal line at peak memory bandwidth (768 GB/s × FLOPs/byte). If you're below the diagonal, you're memory-bound. If you're at the horizontal ceiling, you're compute-bound. Kernel 3 sits far below both — its arithmetic intensity is too low to be compute-bound, but it's also not saturating bandwidth. It's stalling on SMEM.

Occupancy at CHUNKSIZE=32 is ~66% (limited by thread count, not SMEM or registers). That's not terrible. But the profiler tells the real story: the instruction mix is dominated by LDS (shared memory loads) rather than FMA. We're spending cycles fetching from SMEM, not doing math. The fix: each thread needs to compute more output elements per SMEM access — reduce the ratio of loads to FMAs by doing more work in registers.

Kernel 4: 1D Blocktiling for Calculating Multiple Results per Thread

Instead of each thread computing one output element, we assign it a column of TM output elements. Each thread now keeps an array of TM partial sums in registers across the K-dimension loop. The critical loop reordering: put dotIdx (the position within the BK tile) as the outer inner loop, and resIdx (which of the TM outputs) as the inner loop. This way, for each dotIdx we load one value of Bs into a register (Btmp) and reuse it across all TM multiply-accumulates. One SMEM load, TM FMAs.

1D blocktiling showing multiple results per thread computation pattern
1D blocktiling: each thread computes a column of TM output elements, improving register reuse.
float threadResults[TM] = {0.0};

for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
  As[innerRowA * BK + innerColA] = A[innerRowA * K + innerColA];
  Bs[innerRowB * BN + innerColB] = B[innerRowB * N + innerColB];
  __syncthreads();

  A += BK;
  B += BK * N;

  for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
    float Btmp = Bs[dotIdx * BN + threadCol];
    for (uint resIdx = 0; resIdx < TM; ++resIdx) {
      threadResults[resIdx] +=
          As[(threadRow * TM + resIdx) * BK + dotIdx] * Btmp;
    }
  }
  __syncthreads();
}

What the Compiler Does for You

With TM known at compile time (it's a template parameter), the compiler can unroll the inner resIdx loop and allocate the threadResults array entirely in registers. The Btmp caching happens automatically too — the compiler sees that Bs[dotIdx * BN + threadCol] doesn't change across the resIdx iterations and hoists it into a register. This is why templating the tile sizes matters: without compile-time knowledge, the compiler has to be conservative.

1D warp tiling benefits showing shared input advantage
The benefit of 1D blocktiling: threads in the same warp share loads from Bs, amortizing SMEM traffic.

Result: ~8,600 GFLOPs, a 4× jump from Kernel 3. The profiler confirms the story — MIO stall cycles (shared memory contention) drop dramatically. We've shifted the bottleneck from 'waiting for SMEM' to actually doing FMAs. But we're still only doing 1D output tiling; we can push further.

Kernel 5: Increasing Arithmetic Intensity via 2D Blocktiling

Extend the idea to 2D: each thread now computes a TM×TN output tile. The arithmetic intensity improvement is multiplicative — instead of reusing a Bs value across TM rows, we now reuse a regM value across TN columns and a regN value across TM rows. The inner loop becomes an outer product: load TM values from As into regM, load TN values from Bs into regN, then compute the full TM×TN outer product and accumulate into threadResults.

Arithmetic intensity explanation showing compute-to-bandwidth ratio improvement
2D tiling raises arithmetic intensity by reusing loaded values across both row and column dimensions.
2D blocktiling diagram with three loop levels
2D blocktiling: each thread computes an 8×8 result tile using three nested loops.
float threadResults[TM * TN] = {0.0};
float regM[TM] = {0.0};
float regN[TN] = {0.0};

for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
  for (uint loadOffset = 0; loadOffset < BM; loadOffset += strideA) {
    As[(innerRowA + loadOffset) * BK + innerColA] =
        A[(innerRowA + loadOffset) * K + innerColA];
  }
  for (uint loadOffset = 0; loadOffset < BK; loadOffset += strideB) {
    Bs[(innerRowB + loadOffset) * BN + innerColB] =
        B[(innerRowB + loadOffset) * N + innerColB];
  }
  __syncthreads();

  A += BK;
  B += BK * N;

  for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
    for (uint i = 0; i < TM; ++i) {
      regM[i] = As[(threadRow * TM + i) * BK + dotIdx];
    }
    for (uint i = 0; i < TN; ++i) {
      regN[i] = Bs[dotIdx * BN + threadCol * TN + i];
    }
    for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
      for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
        threadResults[resIdxM * TN + resIdxN] +=
            regM[resIdxM] * regN[resIdxN];
      }
    }
  }
  __syncthreads();
}
Multiple-element loading per thread into SMEM
GMEM loading pattern for Kernel 5: each thread loads multiple elements into shared memory.
Register blocking showing dotIdx loop across time
Register blocking: regM and regN cache SMEM values in registers, then compute outer products.

Result: ~16 TFLOPs — another 2× improvement. The GMEM accesses per output element drop to K/64, and SMEM accesses to K/4. We're now genuinely close to the compute roofline. The remaining gap to cuBLAS is about vectorized loads (memory efficiency) and warp-level data locality, which Kernels 6 and 10 address.

Kernel 6: Vectorize SMEM and GMEM Accesses

Two tricks in one kernel. First, vectorized GMEM loads: use float4 to load 128 bits (4 floats) per instruction, turning LDG.E.32 into LDG.E.128. You have to explicitly promise the compiler the pointer is 128-bit aligned via reinterpret_cast — it can't figure that out from a generic float* argument. Second, transpose As during the SMEM load so that the subsequent SMEM reads along the column dimension become sequential (enabling LDS.128 instead of strided scalar loads).

Memory layout changes enabling vectorized SMEM loads for As
Transposing As during the GMEM→SMEM transfer enables vectorized 128-bit SMEM loads.
float4 tmp =
    reinterpret_cast<float4 *>(&A[innerRowA * K + innerColA * 4])[0];
As[(innerColA * 4 + 0) * BM + innerRowA] = tmp.x;
As[(innerColA * 4 + 1) * BM + innerRowA] = tmp.y;
As[(innerColA * 4 + 2) * BM + innerRowA] = tmp.z;
As[(innerColA * 4 + 3) * BM + innerRowA] = tmp.w;

reinterpret_cast<float4 *>(&Bs[innerRowB * BN + innerColB * 4])[0] =
    reinterpret_cast<float4 *>(&B[innerRowB * N + innerColB * 4])[0];
__syncthreads();

The As transpose is the subtle part. In Kernel 5, loading a column of As during the inner computation required strided SMEM reads (bad for bank conflicts). By storing As transposed (column-major) during the loading phase, the inner reads become row-sequential, which both avoids bank conflicts and enables vectorized LDS.128. The extra work during loading is cheap compared to the savings during compute.

Performance: ~18,200 GFLOPs, about 500 GFLOP/s over Kernel 5. Meaningful but not dramatic — the low-hanging fruit is mostly picked. The remaining gap to cuBLAS at this point is bank conflicts in SMEM, no double buffering (the GPU stalls waiting for SMEM loads to complete instead of overlapping compute and fetch), and the warp-level register locality that warptiling exploits.

Kernel 9: Autotuning

By this point, the kernel has accumulated five template parameters: BM and BN for the shared memory tile dimensions, BK for the K-dimension tile, and TM and TN for the per-thread register tile. The initial guess — BM=BN=128, BK=8, TM=TN=8 — is reasonable but almost certainly not optimal. Autotuning is just a grid search with validation: write a bash script, sweep sensible combinations, benchmark each, pick the winner.

The tricky part is keeping the search space honest. Not all combinations are valid — vectorized SMEM loads require BM*BK to be divisible by 4*NUM_THREADS, for instance. Out of ~400 configurations, maybe 200 compile and produce correct results. Validating each against a reference prevents accepting fast-but-wrong kernels.

On the A6000, the winner was BM=BN=128, BK=16, TM=TN=8 — only the K-tile changed. That tweak alone pushed throughput from ~19 to ~20 TFLOPs, a ~5% gain. What's humbling is that we can't fully explain *why* BK=16 beats BK=8 on this GPU. Larger BK means more data loaded per SMEM phase, which reduces the total number of GMEM loads — but it also increases register pressure and affects occupancy. The optimal balance is hardware-specific and analytically difficult to predict. This is why production libraries like CUTLASS and cuDNN literally ship hundreds of kernel variants selected at runtime by a dispatcher — the hardware landscape is too fragmented for any single 'best' configuration.

Kernel 10: Warptiling

Kernel 10 introduces a third level of tiling between blocktiling and threadtiling: warptiling. This is the sneaky level that CUDA hides from you — warps don't appear as an explicit concept in your code, but they're very real in the hardware. Every 32 threads are grouped into a warp that executes in lockstep, and the warp ID is just threadIdx.x / 32. The hardware scheduler thinks in warps, not threads.

Block, thread, and warp tiling nesting levels
The three-level loop structure: block tiles → warp tiles → thread tiles.
Four warp schedulers per multiprocessor diagram
Each SM has four warp schedulers. Warptiling lets different warps run on different schedulers concurrently.

Warps matter for three distinct reasons. First, they're the unit of scheduling — the A6000 has four warp schedulers per SM, so four warps can issue instructions concurrently each cycle. If your block only has one warp's worth of useful work at a time, you're leaving 3/4 of the scheduler capacity idle. Second, SMEM bank conflicts happen at the warp level — when 32 threads in a warp all access the same bank, those accesses serialize. Third, recent GPUs have register file caches that provide faster access to recently used registers; warptiling ensures that the threads in a warp all operate on adjacent data, maximizing cache locality at the register level. Warptiling is what makes the three levels of the memory hierarchy map cleanly to the three levels of the GPU compute hierarchy: GMEM → block, SMEM → warp, registers → thread.

Three-level tiling visualization for Kernel 10
Kernel 10 warptiling: each warp computes a (WSUBN*WNITER) × (WSUBM*WMITER) chunk of C.
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
  for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
    for (uint i = 0; i < TM; ++i) {
      regM[wSubRowIdx * TM + i] =
          As[(dotIdx * BM) + warpRow * WM + wSubRowIdx * WSUBM +
             threadRowInWarp * TM + i];
    }
  }
  for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) {
    for (uint i = 0; i < TN; ++i) {
      regN[wSubColIdx * TN + i] =
          Bs[(dotIdx * BN) + warpCol * WN + wSubColIdx * WSUBN +
             threadColInWarp * TN + i];
    }
  }

  for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
    for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) {
      for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
        for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
          threadResults[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) +
                        (wSubColIdx * TN) + resIdxN] +=
              regM[wSubRowIdx * TM + resIdxM] *
              regN[wSubColIdx * TN + resIdxN];
        }
      }
    }
  }
}

After autotuning kernel 10, throughput climbed from ~19.7 to ~21.7 TFLOPs — a 10% jump from warptiling alone. The gap to cuBLAS at large sizes is now small, maybe 5-10%. But look at the small-matrix performance in the chart below: cuBLAS crushes us on small dimensions. The reason is instructive. By using nvcc --generate-code and peeling apart the cuBLAS binary, you can see it contains ~16 distinct SGEMM implementations, dispatched at runtime based on matrix shape and size. For small square matrices it uses a split-K variant that partitions the K-dimension across thread blocks, enabling more parallelism when M and N are small. Writing one kernel that's optimal at every shape is essentially impossible — cuBLAS doesn't try.

Split-K concept showing K-dimension partitioning across multiple thread blocks
Split-K: partition the K dimension across multiple blocks, useful for small square matrices.
Line graph comparing Kernel 10 performance vs cuBLAS across matrix sizes
Kernel 10 vs cuBLAS across matrix sizes. Near-parity at large dimensions; gap at small sizes.

Conclusion

The thing that surprised me most about this project wasn't any particular optimization — it was the shape of the progress curve. The first two kernels (naive → coalesced) covered about 80% of the gap to cuBLAS and took maybe a weekend to understand and implement. The remaining 14% took weeks more. Every optimization past the low-hanging fruit required deeper hardware knowledge, better profiling intuition, and a higher tolerance for ambiguity. The power law of optimization effort is real.

Looking back, what SGEMM teaches about GPU programming transfers everywhere. Memory bandwidth is almost always the binding constraint, and every layer of the memory hierarchy (GMEM → L2 → SMEM → registers) is there to fight that constraint at a progressively smaller scale. The tiling pattern — identify a bottleneck, tile across the level of memory hierarchy that resolves it, repeat — shows up in virtually every high-performance GPU kernel. Attention kernels (FlashAttention), convolutions, sparse operations: they're all variations on the same theme. Learning SGEMM from scratch is basically learning the vocabulary of GPU optimization.

One thing I'd add beyond the original: if you want to understand why cuBLAS is so hard to beat, spend time with a profiler looking at your kernel's warp efficiency, memory throughput, and SM occupancy simultaneously. These numbers are often at tension with each other — maximizing occupancy can hurt register reuse, maximizing tile size can hurt occupancy. The art is in the tradeoffs. CUTLASS's design, with its hierarchical policy system and autotuned dispatch, is essentially a systematic solution to that multi-objective problem. If you're building production ML infrastructure, using CUTLASS as a foundation is almost certainly the right call. If you're learning, writing kernels from scratch like this is irreplaceable.

Further Resources

  • wangzyon's GitHub repository — the benchmarking harness used here as a starting point. Well-structured for iterating on kernel variants.
  • NVIDIA CUTLASS library — readable, production-grade CUDA for GEMM and related ops. The source of truth for how modern GPU kernels are structured.
  • Official CUDA docs: Toolkit Programming Guide, Best Practices Guide, Kernel Profiling Guide — dense but complete.
  • Onur Mutlu's YouTube lectures on Computer Architecture and Heterogeneous Systems — best free resource for building a mental model of GPU hardware.
  • 'Understanding Latency Hiding on GPUs' (Volkov, 2016) — the canonical deep-dive on occupancy, ILP, and warp scheduling. Required reading if you want to understand why occupancy != utilization.
  • Lei Mao's CUDA blog — pragmatic, code-first coverage of CUDA patterns and pitfalls.
  • ONNX Runtime CUDA provider and cuDNN source — when you want to see what a production system actually looks like under the hood.
Original article by Simon Boehm

Reproduced with permission from the author.