FlashAttention-4 by Ted Zadouri X GPU MODE
Why It Matters
FlashAttention‑4 unlocks substantially faster transformer inference and training on Blackwell GPUs, reducing compute costs and setting a new performance baseline for large‑language‑model deployments.
Key Takeaways
- •FlashAttention‑4 redesigns forward/backward passes for NVIDIA Blackwell GPUs.
- •New tensor memory and 2‑CTA MMA instructions reduce register pressure.
- •Overlaps MMA with softmax using ping‑pong query tiles for higher throughput.
- •Software‑emulated exponential (75/25 split) mitigates SFU bottleneck during attention.
- •Conditional scaling in online softmax cuts non‑math operations, improving efficiency.
Summary
The lecture introduced FlashAttention‑4, a co‑designed attention kernel targeting NVIDIA’s Blackwell B200 architecture. By re‑examining the forward and backward passes, the authors aligned algorithmic structure with Blackwell’s new tensor‑memory and fifth‑generation tensor‑core capabilities, addressing the incompatibilities of earlier FlashAttention versions.
Key technical insights include exploiting the 2‑CTA MMA mode, which expands the M dimension to 256 and halves shared‑memory traffic, and leveraging tensor memory to alleviate register pressure. The forward pass now runs in a ping‑pong fashion, assigning two query tiles per CTA to overlap matrix‑multiply‑accumulate (MMA) work with softmax computation. A hybrid software‑emulated exponential—75 % hardware SFU, 25 % polynomial approximation—removes the exponential unit bottleneck, while a conditional scaling trick in the online softmax reduces non‑math operations.
The presenters highlighted concrete performance numbers: the redesigned kernel matches softmax latency with MMA execution, and the software emulator executes in roughly nine SAS instructions, dramatically faster than a full SFU call. The 2‑CTA approach cuts shared‑memory footprint and SM bandwidth, and the slack‑based scaling maintains determinism with minimal numeric error, typically using a factor of eight (2⁸≈256) for BF16 stability.
These innovations translate into up to 2‑3× speedups for large‑scale transformer workloads on Blackwell GPUs, lowering training costs and enabling higher token throughput. The broader lesson—mapping under‑utilized GPU functional units to algorithmic steps—offers a template for future kernel optimizations across emerging GPU architectures.
Comments
Want to join the conversation?
Loading comments...