Kernel Case Study: Flash Attention

Understanding all versions of flash attention through a triton implementation

Arun Jith A
14 min readMar 24, 2025

All my work is free to read on my substack. You can read this and more content like this there for free. Education should have no barriers! I publish weekly, mostly technical deep dives and other fun stuff. I invite you to subscribe for more like this!

Attention mechanism is the core of modern day transformers. But scaling the context window of these transformers was a major challenge and it still is even though we are in the era of a million tokens + context window (Qwen 2.5 [1]). There are both considerable compute and memory bound complexities in these models when we scale the context window (A naive attention mechanism scales quadratically in both compute and memory requirements). Revisiting Flash Attention lets us understand the complexities of optimizing the underlying operations on GPUs and more importantly gives us a better grip on thinking what’s next.

Lets quickly revisit a naive attention algorithm to see whats going on.

As you can see if we are not being careful then we will end up materializing a full NxM attention matrix into the GPU HBM. Meaning the memory requirement will go up quadratically to increasing context length.

If you wanna learn more about the GPU memory hierarchy and its differences, my previous post on Triton is a good starting point. This would also be handy as we go along in this post when we get to implementing the flash attention kernel in triton. The flash attention paper also has some really good introduction to this.
https://aarunjith.substack.com/p/simplifying-cuda-kernels-with-triton

Additionally, when we look at the steps involved in executing this algorithm and its pattern of accessing the slow HBM, (which as explained later in the post could be a major bottleneck as well) we notice a few things:

  1. We have Q, K and V in the HBM initially
  2. We need to access Q and K initially from the HBM to compute the dot product
  3. We write the output scores back to the HBM
  4. We access it again to execute the softmax, and optionally for Causal attention, like in the case of LLMs, we will have to mask this output before the softmax. The resulting full attention matrix is written again into the HBM
  5. We access the HBM again to execute the final dot product, to get both the attention weights and the Value matrix to write the output back to the slow GPU memory

I think you get the point. We could smartly read and write from the HBM to avoid redundant operations, to make some potential gains. This is exactly the primary motivation for the original Flash Attention algorithm.

Flash Attention initially came out in 2022 [2], and then a year later came out with some much needed improvements in 2023 as Flash Attention v2 [3] and again in 2024 with additional improvements for Nvidia Hopper and Blackwell GPUs [ 4] as Flash Attention v3 [5]. The original attention paper identified that the attention operation is still limited by the Memory rather than compute. (In the past, there have been attempts to reduce the computation complexity of Attention from O(N**2) to O(NlogN) and lower through approximate algorithms)

Flash attention proposed a fused kernel which does all of the above attention operations in one go, block-wise, to get the final attention output without ever having to realize the full N**2 attention matrix in memory, making the algorithm significantly faster. The term `fused` simply means we combine multiple operations in the GPU SRAM before invoking the much slower journey across the slower GPU memory, making the algorithm performant. All the while providing the exact attention output without any approximations.

This lecture from Stanford CS139, demonstrates brilliantly how we can think of the impact of a well thought out memory access pattern can have on an algorithm. I highly recommend you check this one out if you haven’t already.

Before we start diving into flash attention (its getting tedious to type this over and over so lets agree to call it FA shall we?) in triton there is something else that I wanted to get out of the way.

Numerical Stability in exponents

Lets take the example of FP32 numbers. float32 (standard 32-bit float) uses 1 sign bit, 8 exponent bits, and 23 mantissa bits [6]. The largest finite base for the exponent in float32 is 2¹²⁷≈1.7×10³⁸. Which implies when we look at exponents, e⁸⁸ ~ 1.65 x 10³⁸. Anything close to 80’s and we are in trouble as we could easily overflow. Here’s a very interesting chat with OpenAI o1 shared by folks at AllenAI in their OpenInstruct repo. This although is talking about stabilizing KL Divergence calculations in the setting of RLHF/RL, the ideas translate exactly to this setting as well. Because we have a softmax/exp to take care of. So to deal with this situation what we do is the following:

TRICK : Lets also observe the following, if you do this :

then you can rescale/readjust values without affecting the final softmax value. This is really useful when you have an initial estimate for the maximum value, but that might change when we encounter a new set of values. I know I know, stay with me and let me explain.

Setting the scene

Lets take a small detour into matrix multiplication.

This shows a toy example of a blocked matrix multiplication except we have blocks only on the rows of A (green) and columns of B (Orange? Beige?). As you can see above the output O1, O2, O3 and O4 are complete (those positions need no more calculations). We just need to fill in the remaining columns in the initial rows by using the remaining columns of B. Like below:

So we can fill these places in the output with a block of columns from B and a block of rows from A at a time.

Connecting the dots

When I introduced FA, I said that we never have to compute the full attention matrix and store the whole thing. So here’s what we do:

  1. Compute a block of the attention matrix using a block of rows from Q and a block of columns from K. Once you get the partial attention matrix compute a few statistics and keep it in the memory.

I have greyed O5 to O12 because we don’t know those values yet, as they need to come from the subsequent blocks. We then transform Sb like below:

Now you have a the setup for a partial softmax

But but :

  1. What if the true maximum is in the Oi’s that are yet to come?
  2. The sum is still local, so we need to update this every time we see new Pi’s. We know how to keep track of a sum, but what about rebasing it to the true maximum?

Recall the trick above. All that we have to do is to keep a track of the maximum values we encounter for each row, and iteratively update as you see new maximum’s from the remaining blocks of columns for the same set of rows.

We still do not want to write our partial softmax matrix into HBM. We keep it for the next step.

The final dot product

The last step in our attention computation is our dot product with V. To start we would have initialized a matrix full of 0’s in our HBM as our output of shape NxD. Where N is the number of Queries as above. We use the same block size for V as we had for K except we can apply it row wise like below (The subscripts just denote that this is only a block and not the full matrix)

Notice how we need the attention scores from all the blocks to get the final product. But if we calculate the local score and `accumulate` it like how we did to get the actual L’s we can form the full output at the end of processing all the blocks of columns (K_b) for a given row block (Q_b).

Putting it all together

Lets put all these ideas together to form the final algorithm

To understand the notation, wherever you see _ij implies thats the local values for a given block of columns and rows and _i implies its for the global output rows and Query blocks. The only part we haven’t explained so far is the final update to O_i. That’s where we use all the ideas from above to get the right scaling.

The whole code is available as a gist here.

Lets see what these initializations look like in torch:

def flash_attn_v1(Q, K, V, Br, Bc):
"""Flash Attention V1"""
B, N, D = Q.shape
M = K.shape[1]
Nr = int(np.ceil(N/Br))
Nc = int(np.ceil(N/Bc))

Q = Q.to('cuda')
K = K.to('cuda')
V = V.to('cuda')

batch_stride = Q.stride(0)

O = torch.zeros_like(Q).to('cuda')
lis = torch.zeros((B, Nr, int(Br)), dtype=torch.float32).to('cuda')
mis = torch.ones((B, Nr, int(Br)), dtype=torch.float32).to('cuda')*-torch.inf

grid = (B, )
flash_attn_v1_kernel[grid](
Q, K, V,
N, M, D,
Br, Bc,
Nr, Nc,
batch_stride,
Q.stride(1),
K.stride(1),
V.stride(1),
lis, mis,
O,
O.stride(1),
)
return O

If you are unsure about the launch grid, checkout my introduction to Triton

Take a closer look at how we initialized our Ls and Ms. We are keeping one for each row block of Output/Query, each of size Br. There are Nr such blocks in total.

In the example above I was simply using Br = 2 and Bc = 2. But in the above code the initialization is based on the device capacity. I have included the calculation for a T4 GPU. For any other GPU, we need to get the SRAM capacity and adjust these numbers accordingly. Now for the actual kernel implementation:

# Flash Attention V1
import triton
import triton.language as tl
import torch
import numpy as np
import pdb

@triton.jit
def flash_attn_v1_kernel(
Q, K, V,
N: tl.constexpr, M: tl.constexpr, D: tl.constexpr,
Br: tl.constexpr,
Bc: tl.constexpr,
Nr: tl.constexpr,
Nc: tl.constexpr,
batch_stride: tl.constexpr,
q_rstride: tl.constexpr,
k_rstride: tl.constexpr,
v_rstride: tl.constexpr,
lis, mis,
O,
o_rstride: tl.constexpr):

"""Flash Attention V1 kernel"""

pid = tl.program_id(0)

for j in range(Nc):
k_offset = ((tl.arange(0, Bc) + j*Bc) * k_rstride)[:, None] + (tl.arange(0, D))[None, :] + pid * M * D
# Using k_rstride and v_rstride as we are looking at the entire row at once, for each k v block
v_offset = ((tl.arange(0, Bc) + j*Bc) * v_rstride)[:, None] + (tl.arange(0, D))[None, :] + pid * M * D
k_mask = k_offset < (pid + 1) * M*D
v_mask = v_offset < (pid + 1) * M*D
k_load = tl.load(K + k_offset, mask=k_mask, other=0)
v_load = tl.load(V + v_offset, mask=v_mask, other=0)
for i in range(Nr):
q_offset = ((tl.arange(0, Br) + i*Br) * q_rstride)[:, None] + (tl.arange(0, D))[None, :] + pid * N * D
q_mask = q_offset < (pid + 1) * N*D
q_load = tl.load(Q + q_offset, mask=q_mask, other=0)
# Compute attention
s_ij = tl.dot(q_load, tl.trans(k_load))
m_ij = tl.max(s_ij, axis=1, keep_dims=True)
p_ij = tl.exp(s_ij - m_ij)
l_ij = tl.sum(p_ij, axis=1, keep_dims=True)

ml_offset = tl.arange(0, Br) + Br * i + pid * Nr * Br
m = tl.load(mis + ml_offset)[:, None]
l = tl.load(lis + ml_offset)[:, None]

m_new = tl.where(m < m_ij, m_ij, m)

l_new = tl.exp(m - m_new) * l + tl.exp(m_ij - m_new) * l_ij

o_ij = tl.dot(p_ij, v_load)

output_offset = ((tl.arange(0, Br) + i*Br) * o_rstride)[:, None] + (tl.arange(0, D))[None, :] + pid * N * D
output_mask = output_offset < (pid + 1) * N*D
o_current = tl.load(O + output_offset, mask=output_mask)

o_new = (1/l_new) * (l * tl.exp(m - m_new) * o_current + tl.exp(m_ij - m_new) * o_ij)

tl.store(O + output_offset, o_new, mask=output_mask)
tl.store(mis + ml_offset, tl.reshape(m_new, (Br,)))
tl.store(lis + ml_offset, tl.reshape(l_new, (Br,)))

Lets understand whats happening here:

  1. Create 1 kernel for each NxD matrix in the batch. In reality we would have one more dimension to parallelize across, the head dimension. But for understanding the implementation I think this would suffice.
  2. In each kernel we do the following:

a. For each block of columns in K and V we load up the relevant part of the matrix (Bc x D) into the GPU SRAM (Current total SRAM usage = 2BcD). This stays in the SRAM till we are done with all the row blocks

b. For each row block of Q, we load the block onto SRAM as well (Current total SRAM Usage = 2BcD + BrD)

c. On chip we compute the dot product (s_ij), compute the local row-maxes (m_ij), the exp (p_ij), and the expsum (l_ij)

d. We load up the running stats for the ith row block. Two vectors of size Br x 1, which denotes the current global row-maxes (m_i) and the expsum (l_i). (Current SRAM usage: 2BcD + BrD + 2Br)

e. We get the new estimates for the global m_i and l_i.

f. We load the part of the output for this block of Q and update it using the new running stats and the exponent trick, we then write this back into the HBM. (Current SRAM usage: 2BcD + 2BrD + 2Br)

g. We write the updated running stats also into the HBM.

3. For a matrix of any size, aka any context length, at a time we will never materialize the full attention matrix, only a part of it always.

4. We managed to fuse together all the ops into a single kernel, reducing HBM access considerably.

5. Final SRAM usage stands although at 4BD + 2B, where B was initially calculated as M/4d where M is the SRAM capacity. Not sure if am missing something here. Please comment if you know why this is the case!

Block Sparse Attention and V2 and V3

I will keep this short as these versions keep the core idea but figured out better and better way to do the same.

For Block Sparse Attention,

  1. Consider we had masks for each block like in the case of causal attention. If for a given block we have the masks all set to zero then we can simply skip the entire block without computing anything really. Saving FLOPs. This is where the major gains were seen. To put this into perspective, in the case of BERT pretraining the algorithm gets a 15% boost over the best performing training setup, whereas for GPT-2 we get a 3x over huggingface training implementation and ~ 2x over Megatron setup.

You can literally get the same performance in GPT2 in a fraction of the time, literally shaving of days from the training run, which is awesome!

In V2,

  1. Notice how currently we can only do parallelization at the batch and head dimension. But if you simply just flip the order to look at all the column blocks for a given row block then we get the following advantages:
  2. Each row block becomes embarrassingly parallel. How you know this is by looking at the illustrations above. You need all the column blocks for a given row block to fully form the attention output. If you were to run all the column blocks in parallel, you will end up with a race condition that will try to update the same rows of the output at the same time. But not if you do it the other way around. Although there are atomic add operators in triton which could help, they may potentially set us back.
  3. We can avoid hitting the HBM to get the global Ms and Ls. We can initialize one on the chip for each kernel
  4. Also we do not have to scale all the output update terms with the new estimate of L. We can just compute stuff without dividing by L and at the end of all the column blocks, simply divide the output with the latest estimate of L, saving some FLOPS again!
  5. Much of the improvement also comes in the form of the backward kernel. I am omitting all the backward kernels from this. But they are a fun exercise to try and implement, although they are significantly more complex

Here are some benchmarks

The actual implementations of these kernels need to take into account various nuances that we encounter in the real world. I have tried to keep it simple. But do check them out here.

More recently V3,

  1. Newer GPUs, especially the Hopper and Blackwell GPUs have low precision modes (FP8 in Hopper and GP4 in Blackwell) which can double and quadruple the throughput for the same power and chip area and more specialized GEMM (General Matrix Multiply) kernels, which the previous version of the algorithm fails to capitalize on. This is because there are many operations which are non-GEMM like softmax which reduces the utilization of these specialized GPU kernels.
  2. The FA v1 and v2 are essentially synchronous. Recall in the v2 description I mentioned that we are limited when column blocks try to write to the same output pointers, or when we have to go step by step using the output from the previous steps. Well these modern GPUs can make use special instructions to break this synchrony.

We overlap the comparatively low-throughput non-GEMM operations involved in softmax, such as floating point multiply-add and exponential, with the asynchronous WGMMA instructions for GEMM. As part of this, we rework the FlashAttention-2 algorithm to circumvent certain sequential dependencies between softmax and the GEMMs. For example, in the 2-stage version of our algorithm, while softmax executes on one block of the scores matrix, WGMMA executes in the asynchronous proxy to compute the next block.

FA v3, Shah, Bikshandi et.al

Some more benchmarks

Conclusion

There’s much to admire in the work here. I like to think of these algorithms (and people who are working on it) as cowboys, extracting the best performance out of these modern GPUs, which is not easy to say the least.

The floor for this technical skill level often seemed high owing to the low level details. But hopefully tools like Triton could change the game and get more people into this! Future is bright.

Originally published at https://aarunjith.substack.com.

--

--

Arun Jith A
Arun Jith A

Written by Arun Jith A

A full stack Data Scientist and Gamer

No responses yet