FlashAttention & Kernel Development on Atalla Ax01
Atalla is a research-grade AI accelerator built end-to-end at Purdue's SoCET lab — a weight-stationary 32×32 BF16 systolic array with programmer-managed scratchpad SRAM, VLIW scheduling, and no hardware cache. I own the systems software workstream: FlashAttention kernel mapping, implicit im2col convolution, tiled GEMM, and PyTorch frontend integration.
Overview
Atalla is a student-led effort within Purdue's SoCET lab to design a research-grade AI accelerator stack from scratch — RTL through kernel software through FPGA emulation. The core is a parameterizable 32×32 BF16 systolic array with three dataflow implementations (naïve, MEISSA-inspired, TPU-inspired), a 1MB dual-partition scratchpad SRAM, and a VLIW scheduler.
I work on the Systems Software team, owning the kernel and PyTorch integration layer. The design philosophy is closer to a TPU than a GPU: there's no hardware cache, no SIMT abstraction, and no hardware scoreboard. All data movement between DRAM and on-chip SRAM is explicit via SDMA instructions. All dependency tracking is the programmer's responsibility. This makes the programming model hard — and the optimization opportunities interesting.
Architecture & Programming Model
The Atalla programming model is tile-centric and single-threaded. Instead of writing scalar SIMT code that the runtime fans out, you write explicit tile descriptors and intrinsic calls that directly orchestrate the systolic array and SRAM. The key abstractions are GlobalTile (tensor in DRAM), ScpadTile (tensor in on-chip SRAM), and VectorReg (register in the VEGGIE file).
Memory Hierarchy
| Level | Size | Access | Latency |
|---|---|---|---|
| DRAM (Global) | 8GB+ | SDMA only (scpad.ld / scpad.st) | High |
| Scratchpad (SCPAD) | 1MB SRAM, 2 partitions | SDMA only | Low |
| Vector Registers (VEGGIE) | On-chip | VM instructions (vreg.ld / vreg.st) | Very low |
| Scalar Registers | On-chip | Hardware-managed L1 cache | Very low |
| SA Accumulation Buffers | Hardware-controlled | Not programmable | — |
The dual scratchpad partition (SCPAD0, SCPAD1) is a key design choice. The compiler can issue loads to both partitions in the same VLIW bundle, enabling overlap between SCPAD0 loads and SCPAD1 compute. The GEMMV execution pattern exploits this: A tiles load into SCPAD0, B and C tiles into SCPAD1, allowing the systolic array to consume one pair while the next is loading.
ISA Highlights
The ISA has 7-bit opcodes across instruction types: scalar integer and BF16 arithmetic, vector-vector and vector-scalar masked operations, SDMA bulk DMA, VM vector-register loads and stores, and the GEMMV and CONV compute intrinsics. Notably, expi.vi (element-wise exp) costs 15 cycles — a significant consideration for softmax implementations. The vector reduction tree (rmax.vi, rsum.vi) costs 13 cycles. These latencies make the case for polynomial exp emulation in attention kernels.
FlashAttention Kernel
I own the FlashAttention kernel workstream. The central challenge: attention has a split personality. The Q·Kᵀ and attn·V matmuls map cleanly onto GEMMV — tiles of Q, K, V are loaded into SCPAD, the systolic array handles the matmul, and partial sums accumulate in the hardware accumulation buffers before transfer to VEGGIE. But softmax is inherently scalar and sequential — it runs on the scalar unit using rmax.vi (13 cycles), expi.vi (15 cycles), and rsum.vi (13 cycles) per tile.
This split creates an interesting co-design question: can the scalar unit begin softmax rescaling computations on the previous tile's output while the systolic array's PSUM writeback to VEGGIE is happening for the current tile? The two operations use distinct hardware units — the scalar unit and the accumulation buffer writeback path. If they can be pipelined across tiles using the dual-SCPAD architecture, that's free latency hiding. Verifying this is an open question being explored in the emulator.
Key Optimizations
- Polynomial exp emulation: degree-3 Horner's method replaces the 15-cycle expi.vi with a ~4–5 cycle sequence of mul.vv + add.vv at acceptable accuracy for inference.
- Conditional softmax rescaling: skip the rescale step when the running tile max doesn't change, reducing scalar unit pressure on locally stable attention distributions.
- Tiling strategy: explicit SDMA prefetch of Q/K/V tiles into alternating SCPAD partitions, targeting overlap between SCPAD loads and systolic array compute.
- Online softmax (FlashAttention-style): maintain running (max, denominator, output) accumulators across tiles — never materialize the full N×N attention weight matrix in SRAM.
Implicit Im2col Convolution
Standard convolution can be lowered to GEMM via im2col: rearrange the input tensor so each convolution window becomes a column, then run a single GEMM. The naive approach materializes the im2col buffer explicitly — for a 3×3 conv over a 224×224 feature map, that's a 9× expansion of the input, turning a 48MB tensor into 432MB. This is impractical for a 1MB scratchpad.
The implicit channel-first variant avoids materialization entirely. The im2col addressing is computed on-the-fly inside the SDMA load — the scratchpad sees a logically rearranged tensor computed from index arithmetic, not a pre-expanded buffer. This keeps the memory footprint bounded by the tile size. I built a Streamlit-based simulator and visualizer for this kernel to make the addressing logic debuggable and to help teammates understand the co-design constraints.
Tiled GEMM
The GEMMV intrinsic operates on tiles ≤ 32×32. A full M×N×K GEMM is decomposed into MT×NT×KT tile groups, each loaded into SCPAD before a GEMMV call. The C tile accumulates across all k-slices before being written back to DRAM. Tile sizing is a constraint satisfaction problem: TM × TK + TK × TN + TM × TN must fit within one SCPAD partition's budget with room for double-buffering the other partition.
// Outer tile loop (output tiles)
for each (i, j) in output tile grid:
load C_tile → SCPAD1
// Inner K-reduction loop
for each k-slice:
load A_tile[i, k] → SCPAD0 // SDMA to partition 0
load B_tile[k, j] → SCPAD1 // SDMA to partition 1
GEMMV(sc_C, sc_A, sc_B) // blocks until SCPAD_C updated
store C_tile → DRAMPyTorch Backend Integration
To run real PyTorch models on Atalla, I built a custom backend using torch.export and FX graph capture. The flow: capture a model as an FX graph, walk the graph nodes, map recognized op patterns (linear, conv2d, relu, etc.) to Atalla kernel calls via an op registry, emit the SDMA + GEMMV instruction sequence for each mapped op, and hand off to the simulator for cycle-accurate execution. This gives the team a path from 'standard PyTorch model' to 'runs on Atalla' without hand-writing kernels for every model.
Open Co-Design Questions
Working at the HW/SW boundary of a research chip surfaces questions that only arise when you own both layers. Some of the open co-design directions the team is exploring:
- PSUM overlap: can PSUM writeback to VEGGIE be pipelined with scalar unit softmax computation? This requires careful timing analysis in the emulator.
- Sparsity support: structured sparsity in specific model families could give 2× weight compression with hardware support — what ISA changes would enable this?
- Multi-datatype: extending beyond BF16 to INT8 or FP8 would open up quantized inference; the scalar core already has type conversion instructions (stbf.s, bfts.s).
- Transpose units: attention requires Kᵀ — currently handled in software via SDMA swizzle; dedicated transpose hardware would eliminate that overhead.
- Compiler packetization: the VLIW scheduler leaves performance on the table; better packetization heuristics could meaningfully improve utilization.