--- license: apache-2.0 tags: - kernels --- # flash-attention — Triton kernel A pure-[Triton](https://github.com/triton-lang/triton) implementation of **Flash Attention 1** (Dao et al., 2022) packaged for the [Hugging Face Kernel Hub](https://huggingface.co/docs/kernels/). Unlike hand-written CUDA implementations, this kernel is written entirely in Python/Triton and is JIT-compiled at runtime, making it easy to read, modify, and experiment with. ## Algorithm Flash Attention avoids materialising the full *N × N* attention matrix in HBM by fusing the softmax and the value-weighted sum into a single tiled pass using the **online softmax** trick (Milakov & Gimelshein, 2018): ``` O_i ← softmax(Q_i · Kᵀ) · V (tiled over K/V, never storing full S) ``` Memory complexity drops from O(N²) → O(N · d), which is the primary bottleneck for long-context inference and training. ## Usage ### Via the `kernels` package ```python import torch from kernels import get_kernel fa = get_kernel("kernels-community/flash-attention", version=1) B, H, N, d = 2, 8, 1024, 64 q = torch.randn(B, H, N, d, device="cuda", dtype=torch.float16) k = torch.randn(B, H, N, d, device="cuda", dtype=torch.float16) v = torch.randn(B, H, N, d, device="cuda", dtype=torch.float16) out = fa.flash_attention_forward(q, k, v, causal=False) print(out.shape) # [2, 8, 1024, 64] ``` ### Local development ```bash # 1. Clone git clone https://huggingface.co/kernels-community/flash-attention cd flash-attention # 2. Install dependencies pip install torch triton pytest # 3. Run tests pytest tests/ -v # 4. Run benchmark python benchmarks/bench_flash_attention.py ``` ## Performance | Sequence length | Flash-Attn Triton | PyTorch ref | Speedup | |-----------------|-------------------|-------------|------------| | 128 | 0.11 ms | 0.19 ms | **1.70×** | | 256 | 0.15 ms | 0.24 ms | **1.57×** | | 512 | 0.17 ms | 0.43 ms | **2.47×** | | 1024 | 0.29 ms | 1.78 ms | **6.15×** | | 2048 | 0.79 ms | 7.11 ms | **8.98×** | | 4096 | 2.54 ms | 27.01 ms | **10.63×** | ## Repository structure ``` flash-attention-1-triton/ ├── build.toml # kernel-builder configuration ├── flake.nix # Nix build environment ├── flash_attention_kernel/ │ └── flash_attention.py # Triton forward/backward kernels + launcher ├── torch-ext/ │ ├── torch_binding.h # C++ op declaration │ ├── torch_binding.cpp # Torch op registration │ └── flash_attention/ │ └── __init__.py # Python-level wrapper (uses _ops alias) ├── tests/ │ └── test_flash_attention.py # pytest correctness & smoke tests ├── benchmarks/ │ └── bench_flash_attention.py # triton.testing perf report └── README.md ``` ## References - Dao et al. (2022) — *Flash Attention: Fast and Memory-Efficient Exact Attention with IO-Awareness* - Milakov & Gimelshein (2018) — *Online normalizer calculation for softmax* - Triton tutorials — *Flash Attention* - HuggingFace Kernel Hub docs