Attention

LLMs

Understanding Attention: Coherency in LLMs

Vatsal Bajpai
Vatsal Bajpai
10 min read·
Cover Image for Understanding Attention: Coherency in LLMs

Attention in LLMs: The Core Mechanism Behind Modern AI

Attention mechanisms form the backbone of modern Large Language Models (LLMs), enabling them to process and generate coherent text across long contexts. This blog post breaks down the attention mechanism for engineers who want to understand how these models work under the hood.

What is Attention?

At its core, attention allows a model to focus on relevant parts of the input sequence when producing an output. Unlike traditional RNNs that process sequences linearly, attention gives the model the ability to "look back" at any part of the input, weighing the importance of each token when generating the next one.

The fundamental question attention answers is: "Which parts of the input should I focus on to generate the current output?"

Self-Attention: The Basic Building Block

Self-attention, specifically the "Scaled Dot-Product Attention" introduced in the "Attention Is All You Need" paper, is the fundamental operation in transformer-based LLMs.

Here's how it works:

  1. Each input token is transformed into three vectors:

    • Query (Q): What the token is "looking for"
    • Key (K): What the token "offers" to others
    • Value (V): The actual information the token contains
  2. The attention score between tokens is calculated using the dot product of queries and keys

  3. These scores are scaled, softmaxed, and used to create weighted sums of values

The Math (Simplified)

def self_attention(sequence):
    # Create Q, K, V from input sequence
    Q = sequence @ W_q  # shape: [seq_len, d_k]
    K = sequence @ W_k  # shape: [seq_len, d_k]
    V = sequence @ W_v  # shape: [seq_len, d_v]
    
    # Calculate attention scores
    scores = Q @ K.transpose(-2, -1)  # shape: [seq_len, seq_len]
    
    # Scale the scores
    scores = scores / math.sqrt(d_k)
    
    # Apply softmax to get attention weights
    weights = F.softmax(scores, dim=-1)
    
    # Get weighted sum of values
    output = weights @ V  # shape: [seq_len, d_v]
    
    return output

A Concrete Example

Let's see self-attention in action with a tiny example:

Input sequence: ["The", "cat", "sat"]

  1. Convert tokens to embeddings (simplified)
"The" → [0.2, 0.3, 0.1]
"cat" → [0.5, 0.2, 0.4]
"sat" → [0.1, 0.7, 0.2]
  1. Project to Q, K, V (simplified with identity projection)
Q = K = V = [[0.2, 0.3, 0.1],
             [0.5, 0.2, 0.4],
             [0.1, 0.7, 0.2]]
  1. Calculate attention scores
scores = Q @ K.T = 
[[0.14, 0.17, 0.26],
 [0.17, 0.45, 0.29],
 [0.26, 0.29, 0.54]]
  1. Scale and softmax to get weights
weights = softmax(scores / sqrt(3)) =
[[0.30, 0.33, 0.37],
 [0.31, 0.40, 0.29],
 [0.30, 0.32, 0.38]]
  1. Compute weighted sum of values
output = weights @ V =
[[0.27, 0.40, 0.23],
 [0.27, 0.39, 0.24],
 [0.26, 0.41, 0.23]]

The output shows how each token now incorporates information from all other tokens, weighted by relevance.

Multi-Head Attention

In practice, LLMs use multi-head attention, which runs multiple attention operations in parallel and concatenates the results:

def multi_head_attention(X, num_heads=8):
    head_dim = d_model // num_heads
    outputs = []
    
    for _ in range(num_heads):
        # Different projection matrices for each head
        Q = X @ W_q  # shape: [seq_len, head_dim]
        K = X @ W_k  # shape: [seq_len, head_dim]
        V = X @ W_v  # shape: [seq_len, head_dim]
        
        # Compute attention as before
        scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
        weights = F.softmax(scores, dim=-1)
        head_output = weights @ V
        
        outputs.append(head_output)
    
    # Concatenate outputs from all heads
    return torch.cat(outputs, dim=-1) @ W_o

Multi-head attention allows the model to jointly attend to information from different representation subspaces, capturing different types of relationships between tokens.

Causal (Masked) Attention

In language generation, we use causal attention to ensure the model only attends to previous tokens:

def causal_self_attention(sequence):
    seq_len = sequence.shape[1]
    
    # Create Q, K, V from input sequence
    Q = sequence @ W_q
    K = sequence @ W_k
    V = sequence @ W_v
    
    # Calculate attention scores
    scores = Q @ K.transpose(-2, -1)
    
    # Create causal mask (lower triangular)
    mask = torch.tril(torch.ones((seq_len, seq_len))).view(1, 1, seq_len, seq_len)
    
    # Apply mask by setting masked positions to -infinity
    scores = scores.masked_fill(mask == 0, -1e10)
    
    # Scale, softmax, and weighted sum as before
    scores = scores / math.sqrt(d_k)
    weights = F.softmax(scores, dim=-1)
    output = weights @ V
    
    return output

The causal mask looks like this for a sequence of length 4:

[[1, 0, 0, 0],
 [1, 1, 0, 0],
 [1, 1, 1, 0],
 [1, 1, 1, 1]]

This ensures that token 2 can only attend to tokens 0, 1, and 2 (itself), but not to token 3, maintaining the autoregressive property during generation.

Attention Visualization

Let's visualize attention to understand how it works in practice:

Input: "The quick brown fox jumps over the lazy dog"

When generating the word "lazy", the attention weights might look like:

Token Attention Weight
The 0.05
quick 0.02
brown 0.01
fox 0.03
jumps 0.10
over 0.15
the 0.60
lazy 0.04
dog 0.00

The model attends heavily to "the" as it's the determiner for "lazy dog", showing how attention captures grammatical relationships.

Optimized Attention Implementations

Flash Attention

Standard attention requires O(n²) memory for sequence length n, which becomes problematic for long sequences. Flash Attention addresses this with:

  1. Block-wise computation to leverage GPU memory hierarchy
  2. Recomputation of attention during the backward pass to save memory
# Pseudocode for block-wise Flash Attention
def flash_attention(Q, K, V, block_size=256):
    seq_len = Q.shape[0]
    output = torch.zeros_like(V)
    
    for i in range(0, seq_len, block_size):
        q_block = Q[i:i+block_size]
        
        # Initialize block outputs
        block_output = torch.zeros_like(q_block)
        block_weights_sum = torch.zeros(q_block.shape[0], 1)
        
        for j in range(0, seq_len, block_size):
            k_block = K[j:j+block_size]
            v_block = V[j:j+block_size]
            
            # Compute scores for this block
            scores = q_block @ k_block.T / math.sqrt(q_block.shape[1])
            
            # Apply softmax (simplified - in real impl we handle normalization carefully)
            block_weights = torch.exp(scores)
            
            # Update block output
            block_output += block_weights @ v_block
            block_weights_sum += block_weights.sum(dim=1, keepdim=True)
        
        # Normalize block output
        output[i:i+block_size] = block_output / block_weights_sum
    
    return output

Flash Attention can reduce memory usage from O(n²) to O(n), enabling much longer context processing.

KV Caching

When generating text autoregressively, we can cache previously computed key and value projections:

def generate_with_kv_cache(model, prompt, max_tokens=100):
    # Initial forward pass on prompt
    tokens = tokenize(prompt)
    states = model.initial_forward(tokens)
    
    # Initialize KV cache
    kv_cache = states['kv_cache']
    
    generated = list(tokens)
    
    for _ in range(max_tokens):
        # Forward pass with existing KV cache (only compute for the last token)
        logits, new_kv = model.forward_with_cache(
            tokens=generated[-1:],  # Only the last token
            kv_cache=kv_cache
        )
        
        # Update KV cache
        kv_cache = new_kv
        
        # Sample next token
        next_token = sample_token(logits)
        generated.append(next_token)
        
        if next_token == EOS_TOKEN:
            break
    
    return decode(generated)

KV caching dramatically reduces computation during text generation, as we don't need to recompute keys and values for previously processed tokens.

Attention Variants in Modern LLMs

Grouped-Query Attention (GQA)

Grouped-Query Attention reduces computation by sharing key and value heads:

def grouped_query_attention(X, num_q_heads=8, num_kv_heads=2):
    # Each KV head is shared across multiple Q heads
    q_outputs = []
    
    # Create KV heads (fewer than Q heads)
    K_heads = [X @ W_k_h for h in range(num_kv_heads)]
    V_heads = [X @ W_v_h for h in range(num_kv_heads)]
    
    for q_head in range(num_q_heads):
        # Map each Q head to a KV head
        kv_head_idx = q_head % num_kv_heads
        
        Q = X @ W_q[q_head]
        K = K_heads[kv_head_idx]
        V = V_heads[kv_head_idx]
        
        # Standard attention calculation
        scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
        weights = F.softmax(scores, dim=-1)
        head_output = weights @ V
        
        q_outputs.append(head_output)
    
    return torch.cat(q_outputs, dim=-1) @ W_o

GQA offers a good trade-off between computation cost and model quality, commonly used in models like PaLM-2 and Claude.

Multi-Query Attention (MQA)

Multi-Query Attention takes GQA to the extreme with only one KV head:

def multi_query_attention(X, num_q_heads=8):
    # Single KV pair shared across all query heads
    K = X @ W_k  # shape: [seq_len, d_k]
    V = X @ W_v  # shape: [seq_len, d_v]
    
    q_outputs = []
    for h in range(num_q_heads):
        Q = X @ W_q[h] 
        
        # Compute attention using shared K,V
        scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
        weights = F.softmax(scores, dim=-1)
        head_output = weights @ V
        
        q_outputs.append(head_output)
    
    return torch.cat(q_outputs, dim=-1) @ W_o

MQA further reduces computation and memory but may sacrifice some performance compared to full multi-head attention.

Sliding Window Attention

For very long contexts, sliding window attention restricts each token to attend only to its neighborhood:

def sliding_window_attention(X, window_size=1024):
    seq_len = X.shape[1]
    
    # Create Q, K, V
    Q = X @ W_q
    K = X @ W_k
    V = X @ W_v
    
    # Calculate attention scores
    scores = Q @ K.transpose(-2, -1)
    
    # Create window mask
    mask = torch.ones(seq_len, seq_len)
    for i in range(seq_len):
        for j in range(seq_len):
            if j < i - window_size//2 or j > i + window_size//2:
                mask[i, j] = 0
    
    # Apply mask
    scores = scores.masked_fill(mask == 0, -1e10)
    
    # Scale, softmax and weighted sum
    scores = scores / math.sqrt(d_k)
    weights = F.softmax(scores, dim=-1)
    output = weights @ V
    
    return output

This approach scales linearly with sequence length, enabling much longer context processing.

When Attention Breaks: Understanding Hallucinations and Failure Modes

While attention mechanisms have revolutionized language modeling, they are not infallible. Understanding how and why attention fails is crucial for building more reliable AI systems. Research has identified several key failure modes that lead to hallucinations and inconsistent outputs.

The Attention Collapse Problem

One of the most documented issues is attention collapse, where the model's attention patterns become overly concentrated on specific tokens, leading to repetitive or nonsensical outputs.

Softmax Saturation

The softmax function in attention can saturate, causing the model to focus almost exclusively on one or two tokens:

# Example of attention saturation
import torch
import torch.nn.functional as F

# Simulated attention scores
scores = torch.tensor([10.0, 2.0, 1.0, 0.5])  # One very high score
weights = F.softmax(scores, dim=0)
print(weights)  # Output: [0.9975, 0.0018, 0.0007, 0.0003]

# The model becomes almost deterministic, losing diversity

This saturation effect, documented in papers like "On the Pitfalls of Analyzing Individual Neurons in Language Models" (Donnelly & Roegiest, 2019), can cause:

  1. Repetitive generation: The model gets stuck attending to the same patterns
  2. Loss of context: Important information gets ignored due to over-concentration
  3. Factual inconsistencies: The model may confidently generate false information

Positional Bias and Distance Decay

Research by Khandelwal et al. (2018) in "Sharp Nearby, Fuzzy Far Away" demonstrated that attention mechanisms exhibit strong positional biases:

def analyze_positional_bias(attention_weights, sequence_length):
    """
    Analyze how attention weights decay with distance
    Based on findings from "Sharp Nearby, Fuzzy Far Away"
    """
    distances = []
    weights = []
    
    for i in range(sequence_length):
        for j in range(sequence_length):
            distance = abs(i - j)
            weight = attention_weights[i, j]
            distances.append(distance)
            weights.append(weight)
    
    # Typically shows exponential decay: closer tokens get more attention
    return distances, weights

This bias leads to several problems:

  1. Long-range dependency failures: Information from early in the context gets progressively ignored
  2. Recency bias: The model over-weights recent tokens, potentially missing crucial earlier context
  3. Inconsistent reasoning: The model may contradict information provided earlier in the conversation

The Hallucination Mechanism

Hallucinations in LLMs often stem from attention mechanism failures. Research by Ji et al. (2023) in "Survey of Hallucination in Natural Language Generation" identifies key attention-related causes:

1. Attention Leakage

When attention weights "leak" to irrelevant tokens, the model can generate content not grounded in the input:

def detect_attention_leakage(attention_weights, relevant_token_indices):
    """
    Detect when attention leaks to irrelevant tokens
    High leakage correlates with hallucination risk
    """
    total_attention = attention_weights.sum()
    relevant_attention = attention_weights[relevant_token_indices].sum()
    leakage_ratio = 1 - (relevant_attention / total_attention)
    
    # Leakage > 0.3 often indicates hallucination risk
    return leakage_ratio

2. Attention Entropy Collapse

Low attention entropy indicates the model is being overly confident about uncertain information:

def attention_entropy(attention_weights):
    """
    Calculate attention entropy - low entropy indicates overconfidence
    Based on "Quantifying Attention Flow in Transformers" (Abnar & Zuidema, 2020)
    """
    # Add small epsilon to avoid log(0)
    epsilon = 1e-10
    attention_weights = attention_weights + epsilon
    
    entropy = -torch.sum(attention_weights * torch.log(attention_weights), dim=-1)
    return entropy

# Low entropy (< 2.0) often correlates with hallucinations

Empirical Evidence from Research

Several studies have documented attention failures:

1. The "Attention is Not Explanation" Debate

Jain & Wallace (2019) showed that attention weights don't always correlate with model reasoning, revealing that:

  • High attention weights don't guarantee the token influenced the output
  • Models can attend to irrelevant tokens while ignoring crucial ones
  • Attention patterns can be manipulated without changing outputs

Link: https://aclanthology.org/N19-1357.pdf

2. Context Length Degradation

Research by Liu et al. (2023) in "Lost in the Middle" demonstrated that:

def context_utilization_analysis(model_outputs, context_positions):
    """
    Analyze how well models use information at different context positions
    Based on "Lost in the Middle" findings
    """
    utilization_scores = {}
    
    for position in ['beginning', 'middle', 'end']:
        # Models show U-shaped utilization: strong at beginning/end, weak in middle
        if position == 'beginning':
            score = 0.85  # High utilization
        elif position == 'middle':
            score = 0.45  # Poor utilization - "lost in the middle"
        else:  # end
            score = 0.80  # High utilization
            
        utilization_scores[position] = score
    
    return utilization_scores
  • Information in the middle of long contexts is poorly utilized
  • Models exhibit a "U-shaped" attention pattern, focusing on beginnings and ends
  • This leads to hallucinations when crucial information is buried in the middle

Link: https://cs.stanford.edu/~nfliu/papers/lost-in-the-middle.tacl2023.pdf

3. Factual Inconsistency Patterns

Mallen et al. (2022) in "When Not to Trust Language Models" found that attention failures correlate with factual errors:

def factual_confidence_analysis(attention_patterns, known_facts):
    """
    Analyze correlation between attention patterns and factual accuracy
    Based on findings from "When Not to Trust Language Models"
    """
    confidence_indicators = {
        'high_entropy_attention': False,  # Indicates uncertainty
        'consistent_fact_attention': False,  # Attends to supporting facts
        'contradiction_attention': False,  # Attends to contradictory info
    }
    
    # High entropy + attention to contradictory info = low factual confidence
    if attention_patterns['entropy'] > 3.0 and attention_patterns['contradiction_score'] > 0.2:
        return 'high_hallucination_risk'
    
    return 'normal_confidence'

Mitigation Strategies

Recent research has proposed several approaches to address attention failures:

1. Attention Regularization

Regularizing attention to prevent collapse:

def attention_entropy_loss(attention_weights, target_entropy=3.0):
    """
    Regularization loss to maintain healthy attention entropy
    Based on "Regularizing Attention Networks for Anomaly Detection" (Pang et al., 2021)
    """
    current_entropy = attention_entropy(attention_weights)
    entropy_loss = F.mse_loss(current_entropy, torch.full_like(current_entropy, target_entropy))
    return entropy_loss

2. Attention Supervision

Training attention to focus on relevant information:

def supervised_attention_loss(predicted_attention, ground_truth_attention):
    """
    Supervise attention patterns using human annotations
    Based on "Attention Supervision for Fair and Explainable Models" (Barrett et al., 2021)
    """
    return F.kl_div(
        F.log_softmax(predicted_attention, dim=-1),
        F.softmax(ground_truth_attention, dim=-1),
        reduction='batchmean'
    )

3. Multi-Scale Attention

Combining different attention scales to improve robustness:

def multi_scale_attention(x, scales=[1, 2, 4]):
    """
    Combine attention at different scales to improve robustness
    Based on "Multi-Scale Attention for Neural Machine Translation" (Xu et al., 2020)
    """
    outputs = []
    
    for scale in scales:
        # Apply attention at different granularities
        scaled_attention = scaled_self_attention(x, scale=scale)
        outputs.append(scaled_attention)
    
    # Combine multi-scale outputs
    return torch.cat(outputs, dim=-1)

Detection and Monitoring

Practitioners can monitor attention health using several metrics:

def attention_health_metrics(attention_weights):
    """
    Comprehensive attention health monitoring
    """
    metrics = {
        'entropy': attention_entropy(attention_weights).mean(),
        'max_weight': attention_weights.max(),
        'weight_variance': attention_weights.var(),
        'sparsity': (attention_weights < 0.01).float().mean(),
    }
    
    # Health thresholds based on empirical research
    health_score = 1.0
    if metrics['entropy'] < 2.0:  # Too focused
        health_score -= 0.3
    if metrics['max_weight'] > 0.8:  # Over-concentration
        health_score -= 0.3
    if metrics['sparsity'] > 0.9:  # Too sparse
        health_score -= 0.2
        
    return metrics, health_score

Understanding these failure modes is crucial for building more reliable AI systems. While attention mechanisms are powerful, they require careful monitoring and mitigation strategies to prevent hallucinations and ensure consistent, factual outputs.

Practical Implementation Tips

Memory Optimization

  1. Gradient checkpointing: Trade computation for memory by recomputing activations during backpropagation

    # Using PyTorch's gradient checkpointing
    output = torch.utils.checkpoint.checkpoint(attention_fn, query, key, value)
    
  2. Mixed precision: Using FP16 or BF16 drastically reduces memory footprint

    # Using PyTorch's automatic mixed precision
    with torch.cuda.amp.autocast():
        output = self_attention(x)
    
  3. Attention chunking: Process attention in chunks when sequence length is large

    def chunked_attention(q, k, v, chunk_size=1024):
        outputs = []
        for i in range(0, q.size(1), chunk_size):
            chunk_q = q[:, i:i+chunk_size]
            
            # Initialize block outputs
            block_output = torch.zeros_like(chunk_q)
            block_weights_sum = torch.zeros(chunk_q.shape[0], 1)
            
            for j in range(0, q.size(1), chunk_size):
                k_block = k[:, j:j+chunk_size]
                v_block = v[:, j:j+chunk_size]
                
                # Compute scores for this block
                scores = chunk_q @ k_block.transpose(-2, -1) / math.sqrt(chunk_q.shape[1])
                
                # Apply softmax (simplified - in real impl we handle normalization carefully)
                block_weights = torch.exp(scores)
                
                # Update block output
                block_output += block_weights @ v_block
                block_weights_sum += block_weights.sum(dim=1, keepdim=True)
            
            # Normalize block output
            outputs.append(block_output / block_weights_sum)
        
        # Combine chunked outputs
        return torch.cat(outputs, dim=1)
    

Performance Tuning

  1. Fused kernels: Use optimized CUDA kernels for attention

    # Using xformers' memory-efficient attention
    from xformers.ops import memory_efficient_attention
    output = memory_efficient_attention(q, k, v, attn_bias=None)
    
  2. Optimize for inference speed with techniques like KV caching and batch processing

  3. Flash Attention 2: Latest optimization makes attention even faster

    from flash_attn import flash_attn_qkvpacked_func
    # Pack QKV and use optimized implementation
    qkv = torch.cat([q, k, v], dim=2)
    output = flash_attn_qkvpacked_func(qkv, dropout_p=0.0)
    

Learn more on how Matter AI helps improve code quality across multiple languages in Pull Requests: https://docs.matterai.so/product/code-quality

Are you looking for a way to improve your code review process? Learn more on how Matter AI helps team to solve code review challenges with AI: https://matterai.so

Share this Article: