[FlashAttention] Block-Level Masking with BlockInfo

April 01, 2026

In FlashAttention, causal and local attention masks restrict which Q tokens can attend to which K tokens. The BlockInfo class computes block-level valid ranges to avoid computing invalid Q-K pairs entirely, saving memory bandwidth and computation.

TL;DR


Part 1: The Problem - Element-Level vs Block-Level Masking

1.1 Standard Attention (No Mask)

In standard attention, every Q token can attend to every K token:

Q\K  0  1  2  3  4
0    ✓  ✓  ✓  ✓  ✓
1    ✓  ✓  ✓  ✓  ✓
2    ✓  ✓  ✓  ✓  ✓
3    ✓  ✓  ✓  ✓  ✓
4    ✓  ✓  ✓  ✓  ✓

All 5×5 = 25 positions are valid.

1.2 Causal Attention

In causal attention, Q[i] can only attend to K[0:i+1+shift]:

Q\K  0  1  2  3  4
0    ✓  ✗  ✗  ✗  ✗    # Q[0] attends to K[0]
1    ✓  ✓  ✗  ✗  ✗    # Q[1] attends to K[0:2]
2    ✓  ✓  ✓  ✗  ✗
3    ✓  ✓  ✓  ✓  ✗
4    ✓  ✓  ✓  ✓  ✓

Only 15 positions are valid (lower triangle + diagonal).

Why shift? When seqlen_k > seqlen_q (common in decode), the attention pattern shifts:

1.3 Local (Sliding Window) Attention

In local attention, Q[i] attends to K[i-window_left+shift : i+window_right+1+shift]:

window_left=1, window_right=1, seqlen_k = seqlen_q (shift=0)

Q\K  0  1  2  3  4
0    ✓  ✓  ✗  ✗  ✗    # Q[0] attends to K[0:2]
1    ✓  ✓  ✓  ✗  ✗    # Q[1] attends to K[0:3]
2    ✗  ✓  ✓  ✓  ✗
3    ✗  ✗  ✓  ✓  ✓
4    ✗  ✗  ✗  ✓  ✓

Each Q position attends to a local window.

Causal + Local: Can combine both masks.

1.4 The Naive Approach: Compute Then Mask

# Load all Q and K tiles
for m_block in range(num_m_blocks):
    sQ = load_Q(m_block)
    for n_block in range(num_n_blocks):
        sK = load_K(n_block)
        
        # Compute attention for ALL positions
        sS = sQ @ sK.T  # (tile_m, tile_n)
        
        # Apply mask element-by-element
        sS = apply_mask(sS, m_block, n_block, causal_mask)
        
        # Many positions are zeroed!
        # Still paid memory bandwidth for K, still computed matmul

Problem: We load K blocks and compute matmul for positions that will be zeroed.

1.5 The Efficient Approach: Skip Invalid Blocks

# Only iterate over valid K blocks
for m_block in range(num_m_blocks):
    sQ = load_Q(m_block)
    
    # Compute VALID K block range
    n_block_min, n_block_max = get_valid_k_blocks(m_block)
    
    for n_block in range(n_block_min, n_block_max):
        sK = load_K(n_block)
        
        # All positions in this block are valid!
        # (or only need minimal masking at boundaries)
        sS = sQ @ sK.T
        
        # No wasted loads or computation

Part 2: BlockInfo Class

2.1 Class Definition

@dataclass(frozen=True)
class BlockInfo:
    tile_m: Constexpr[int]        # Q block size (typically 128)
    tile_n: Constexpr[int]        # K/V block size (typically 128)
    is_causal: Constexpr[bool]    # Causal masking enabled
    is_local: Constexpr[bool]    # Local/sliding window masking
    window_size_left: Int32       # Local window left size
    window_size_right: Int32      # Local window right size
    is_split_kv: Constexpr[bool] # Split-KV mode
    qhead_per_kvhead_packgqa: int# GQA: Q heads per KV head

2.2 Key Methods

Method Purpose
get_n_block_min_max(seqlen_info, m_block) Which K blocks can Q block attend to?
get_m_block_min_max(seqlen_info, n_block) Which Q blocks can attend to K block?
get_n_block_k_new_min_max(seqlen_info, m_block) For append-KV: which new-K blocks?
get_n_block_min_causal_local_mask(...) Where does causal mask end?
get_n_block_min_before_local_mask(...) Where does local mask start?

Part 3: get_n_block_min_max() - Finding Valid K Blocks

3.1 Method Signature

def get_n_block_min_max(
    self,
    seqlen_info: SeqlenInfoQK,
    m_block: Int32,
    split_idx: Int32 = 0,      # For split-KV
    num_splits: Int32 = 1,      # For split-KV
) -> Tuple[Int32, Int32]:
    """Return (n_block_min, n_block_max) for this Q block."""    ...

Inputs:

Returns:

3.2 Algorithm

def get_n_block_min_max(self, seqlen_info, m_block, split_idx=0, num_splits=1):
    # Step 1: Start with full K range
    n_block_max = ceil_div(seqlen_info.seqlen_k, self.tile_n)
    
    # Step 2: Apply causal/local UPPER bound
    if self.is_causal or (self.is_local and self.window_size_right is not None):
        # Compute maximum Q position in this block
        m_idx_max = (m_block + 1) * self.tile_m
        
        # GQA: Convert Q position to KV position
        if self.qhead_per_kvhead_packgqa > 1:
            # Multiple Q heads share one KV head
            # Q position q maps to KV position q // qhead_per_kvhead_packgqa
            m_idx_max = ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa)
        
        # Compute corresponding K index
        # Q[i] attends to K[j] where j <= i + shift
        # shift = seqlen_k - seqlen_q (handles different Q/K lengths)
        n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q
        
        # For local attention: extend by window_size_right
        n_idx_right = n_idx if self.is_causal else n_idx + self.window_size_right
        
        # Convert to block index
        n_block_max = min(n_block_max, ceil_div(n_idx_right, self.tile_n))
    
    # Step 3: Apply local LOWER bound
    n_block_min = 0
    if self.is_local and self.window_size_left is not None:
        # Compute minimum Q position in this block
        m_idx_min = m_block * self.tile_m
        
        # GQA adjustment
        if self.qhead_per_kvhead_packgqa > 1:
            m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa
        
        # Compute corresponding K index
        # Q[i] attends to K[j] where j >= i - window_size_left + shift
        n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q
        n_idx_left = n_idx - self.window_size_left
        
        # Convert to block index (clamp to 0)
        n_block_min = max(n_idx_left // self.tile_n, 0)
    
    # Step 4: Apply split-KV if enabled
    if self.is_split_kv:
        num_n_blocks_per_split = (n_block_max - n_block_min + num_splits - 1) // num_splits
        n_block_min = n_block_min + split_idx * num_n_blocks_per_split
        n_block_max = min(n_block_min + num_n_blocks_per_split, n_block_max)
    
    return n_block_min, n_block_max

3.3 Visual Example: Causal Attention

seqlen_q = 256, seqlen_k = 256
tile_m = 128, tile_n = 128
is_causal = True
m_block = 0 (Q indices [0: 128])

Step 1: n_block_max = ceil(256 / 128) = 2

Step 2: Apply causal upper bound
  m_idx_max = (0 + 1) * 128 = 128
  
  shift = seqlen_k - seqlen_q = 0
  
  # Q[i] attends to K[j] where j <= i+---------------------------
  # Q[127] attends to K[0:128]
  # So all K blocks [0: 128/128] = 1 are valid for causal
  
  n_idx = 128 + 0 = 128
  n_block_max = min(2, ceil(128/128)) = min(2, 1) = 1
  
  # ERROR: Actually for Q[0:128], Q[0] attends to K[0:1]
  # Q[127] attends to K[0:128], all K blocks valid
  # Let me reconsider...

Actually for causal:
  # Q[i] attends to K[0:i+1]
  # Q block [0:128] has Q[0]...Q[127]
  # Q[0] attends to K[0:1]  → n_block_min = 0
  # Q[127] attends to K[0:128] → n_block_max = 1
  
  # For m_block = 0:
  # n_block_min = 0
  # n_block_max = ceil(128/128) = 1
  
  # K block indices: [0, 1)
  # Q[0:128] attends to K[0:128]

But wait, celld328 gets n_block_max from:
  m_idx_max = 128 (max Q index + 1for this block)
  n_idx = 128+ shift = 128
  n_block_max = ceil(128 / 128) = 1

Result: n_block_min=0, n_block_max=1

For m_block = 1 (Q[128:256]):
  m_idx_max = 256
  n_idx = 256
  n_block_max = ceil(256/128) = 2
  
  Result: n_block_min=0, n_block_max=2
  # Q[128:256] attends to K[0:256]

Visual:
  K blocks:|'||  0|   |   1||   |
  Q block 0 (Q[0:128]):   [0~~~~~~~~~~~~~~~~|   |   |   |   ]
                           ||   ||   |
  Q block 1 (Q[128:256]): [0~~~~~~~~~~~~~~~~~] + [128~~~~~~~~~~]

##3.4 Visual Example: Different Q/K Lengths

seqlen_q = 1, seqlen_k = 1000( decode step)
tile_m = 128, tile_n = 128
is_causal = True
m_block = 0 (Q indices [0:1)only Q[0])

shift = seqlen_k - seqlen_q = 999

# Q[0] attends to K[0:0+1+999] = K[0:1000]
# All K blocks are valid!

m_idx_max = 1
n_idx = 1 + 999 = 1000
n_block_max = min(ceil(1000/128), ceil(1000/128)) = ceil(1000/128) = 8

Result: n_block_min=0, n_block_max=8

Visual:
  Q\K:  [0~~~~~~~K[1000]~~~~~~~]
  Q[0]: [||||||||||||||||||||]  (all K blocks valid)

Part 4: get_m_block_min_max() - Finding Valid Q Blocks

##4.1 Method Signature

def get_m_block_min_max(
    self,
    seqlen_info: SeqlenInfoQK,
    n_block: Int32,
) -> Tuple[Int32, Int32]:
    """Return (m_block_min, m_block_max) for this K block.
    
    For K block n_block, which Q blocks [m_block_min, m_block_max) are valid?
    """...

Use case: When iterating over K blocks instead of Q blocks (for some optimizations).

4.2 Algorithm

def get_m_block_min_max(self, seqlen_info, n_block):
    # Start with full Q range
    m_block_max = ceil_div(seqlen_info.seqlen_q, self.tile_m)
    m_block_min = 0
    
    # Apply causal/local lower bound on Q
    if self.is_causal or (self.is_local and self.window_size_right is not None):
        # K[j] can be attended by Q[i] where i >= j - shift
        n_idx_min = n_block * self.tile_n
        
        m_idx = n_idx_min + seqlen_info.seqlen_q - seqlen_info.seqlen_k
        m_idx_right = m_idx if self.is_causal else m_idx - self.window_size_right
        
        m_block_min = max(m_block_min, m_idx_right // self.tile_m)
    
    # Apply local upper bound on Q
    if self.is_local and self.window_size_left is not None:
        n_idx_max = (n_block + 1) * self.tile_n
        
        m_idx = n_idx_max + seqlen_info.seqlen_q - seqlen_info.seqlen_k
        m_idx_left = m_idx + self.window_size_left
        
        m_block_max = min(m_block_max, ceil_div(m_idx_left, self.tile_m))
    
    return m_block_min, m_block_max

4.3 Example

seqlen_q = 256, seqlen_k = 256
tile_m = 128, tile_n = 128
is_causal = True
n_block = 0 (K indices [0:128])

# K[0:128] can be attended by Q[i] where i >= 0
# All Q blocks valid!

Result: m_block_min=0, m_block_max=2

n_block = 1 (K indices [128:256]):
  
  # K[j] where j >= 128
  # Can be attended by Q[i] where i >= 128
  # Shift = 0, so m_idx_right = 128
  
  m_block_min = 128 // 128 = 1
  
  Result: m_block_min=1, m_block_max=2
  # Only Q[128:256] can attend to K[128:256]

Part 5: GQA/MQA Handling

##5.1 The Problem

In Multi-Query Attention (MQA) and Grouped Query Attention (GQA):

Example: 8 Q heads,1 KV head (MQA)
qhead_per_kvhead_packgqa = 8

Q tensor shape: (seqlen, num_q_heads=8, head_dim)
KV tensor shape: (seqlen, num_kv_heads=1, head_dim)

For Q position q:
  KV position = q // qhead_per_kvhead_packgqa
  
Q position 0→ KV position 0//8 = 0
Q position 7 → KV position 7//8 = 0
Q position 8 → KV position 8//8 = 1

5.2 Implementation

# In get_n_block_min_max:
m_idx_max = (m_block + 1) * self.tile_m

if self.qhead_per_kvhead_packgqa > 1:
    # Convert Q position to KV position
    # Q positions [0:128] with8 Q heads per KV head
    # Map to KV position ceil(128/8) = 16
    m_idx_max = ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa)

n_idx = m_idx_max + seqlen_k - seqlen_q
# Now n_idx is in KVposition space

5.3 Visual Example

seqlen_q = 128, seqlen_k = 16 (after GQA packing)
qhead_per_kvhead_packgqa = 8
tile_m = 128, tile_n = 128
is_causal = True
m_block = 0

# Without GQA adjustment:
m_idx_max = 128
n_idx = 128 + 16 - 128 = 16
n_block_max = ceil(16/128) = 1

# With GQA adjustment:
m_idx_max = ceil(128/8) = 16  # Conver to KV position
n_idx = 16 + 16 - 16 = 16
n_block_max = ceil(16/128) = 1

# But the correct answer is n_block_max = 1
# All K blocks are valid since seqlen_k = 16 < tile_n = 128

Part 6: Split-KV Mode

##6.1 The Problem

For very long KV sequences, we split KV across multiple CTAs:

seqlen_k = 10000, num_splits = 4
Each split processes seqlen_k / 4 = 2500 tokens

Split 0: K[0:2500]
Split 1: K[2500:5000]
Split 2: K[5000:7500]
Split 3: K[7500:10000]

Each CTA processes one split, then results are combined.

6.2 Implementation

if self.is_split_kv:
    # Compute blocks per split
    num_n_blocks_per_split = (n_block_max - n_block_min + num_splits - 1) // num_splits
    
    # Adjust for this split
    n_block_min = n_block_min + split_idx * num_n_blocks_per_split
    n_block_max = min(n_block_min + num_n_blocks_per_split, n_block_max)

##6.3 Example

seqlen_k =2560, tile_n = 128
num_n_blocks = ceil(2560/128) = 20
num_splits = 4

Without split: n_block_min=0, n_block_max=20

With split_kv:  num_n_blocks_per_split = (20 + 3) // 4 = 5

split_idx=0: n_block_min=0, n_block_max=5   # K[0:640]
split_idx=1: n_block_min=5, n_block_max=10  # K[640:1280]
split_idx=2: n_block_min=10, n_block_max=15 # K[1280:1920]
split_idx=3: n_block_min=15, n_block_max=20 # K[1920:2560]

Part 7: Append-KV (SeqlenInfoQKNewK)

7.1 Use Case

For incremental decode, we have:

Need to compute which NEW K blocks a Q block needs.

7.2 get_n_block_k_new_min_max()

def get_n_block_k_new_min_max(self, seqlen_info, m_block, split_idx=0, num_splits=1):
    # First get full K block range
    n_block_min, n_block_max = self.get_n_block_min_max(
        seqlen_info, m_block, split_idx, num_splits
    )
    
    # Map to NEW K index space (subtract original K length)
    idx_k_new_min = max(n_block_min * tile_n - seqlen_k_og, 0)
    idx_k_new_max = min(n_block_max * tile_n - seqlen_k_og, seqlen_k_new)
    
    # Convert back to block indices
    n_block_new_min = idx_k_new_min // tile_n
    n_block_new_max = ceil_div(idx_k_new_max, tile_n)
    
    return n_block_new_min, n_block_new_max

7.3 Example

seqlen_k_og = 1000 (original KV cache)
seqlen_k_new = 10 (new tokens to append)
seqlen_q = 1 (decode, single Q token)
tile_n = 128
is_causal = True

# Q[0] attends to K[0:1010]
# n_block_min=0, n_block_max=ceil(1010/128)=8

# Map to new K space:
# Original K occupies blocks [0: ceil(1000/128)] = [0:8)
# New K occupies blocks [0: ceil(10/128)] = [0:1)

idx_k_new_min = max(0* 128 - 1000, 0) = 0
idx_k_new_max = min(8 * 128 - 1000, 10) = min(24, 10) = 10

n_block_new_min = 0// 128 = 0
n_block_new_max = ceil(10/128) = 1

Result: Load NEW K block [0:1) for append

Part 8: Complete Usage Example

# Setup
block_info = BlockInfo(
    tile_m=128,
    tile_n=128,
    is_causal=True,
    is_local=False,
    qhead_per_kvhead_packgqa=1,
)

seqlen_info = SeqlenInfoQK.create(
    batch_idx=0,
    seqlen_q_static=256,
    seqlen_k_static=256,
    mCuSeqlensQ=None,
    mCuSeqlensK=None,
)
# seqlen_q = 256, seqlen_k = 256

# In kernel:
for m_block in range(ceil_div(seqlen_q, tile_m)):
    # Load Q tile
    sQ = load_Q_tile(mQ, m_block * tile_m, (m_block+1) * tile_m)
    
    # Get valid K blocks for this Q block
    n_block_min, n_block_max = block_info.get_n_block_min_max(
        seqlen_info, m_block
    )
    
    # Only iterate valid K blocks!
    for n_block in range(n_block_min, n_block_max):
        # Load K tile
        sK = load_K_tile(mK, n_block * tile_n, (n_block+1) * tile_n)
        
        # Compute attention
        sS = sQ @ sK.T
        sP = softmax(sS / sqrt(d))
        sO = sP @ sV
        
        # Accumulate output
        accumulate(sO, sAccum)

Summary

BlockInfo Responsibilities

Responsibility Method
Which K blocks to load for Q block get_n_block_min_max()
Which Q blocks can use K block get_m_block_min_max()
Which NEW K blocks for append-KV get_n_block_k_new_min_max()
Causal mask boundary get_n_block_min_causal_local_mask()
Local mask boundary get_n_block_min_before_local_mask()

Key Formulas

# Causal: Q[i] attends to K[j] where j <= i + shift
# shift = seqlen_k - seqlen_q

# Local: Q[i] attends to K[j] where
#        i - window_left + shift <= j <= i + window_right + shift

# GQA: Q position q → KV position q // qhead_per_kvhead_packgqa

Performance Impact

Without block-level masking:

With block-level masking:

← Back to Articles