Attention
LLMs
Understanding Attention: Coherency in LLMs

Attention in LLMs: The Core Mechanism Behind Modern AI
Attention mechanisms form the backbone of modern Large Language Models (LLMs), enabling them to process and generate coherent text across long contexts. This blog post breaks down the attention mechanism for engineers who want to understand how these models work under the hood.
What is Attention?
At its core, attention allows a model to focus on relevant parts of the input sequence when producing an output. Unlike traditional RNNs that process sequences linearly, attention gives the model the ability to "look back" at any part of the input, weighing the importance of each token when generating the next one.
The fundamental question attention answers is: "Which parts of the input should I focus on to generate the current output?"
Self-Attention: The Basic Building Block
Self-attention, specifically the "Scaled Dot-Product Attention" introduced in the "Attention Is All You Need" paper, is the fundamental operation in transformer-based LLMs.
Here's how it works:
-
Each input token is transformed into three vectors:
- Query (Q): What the token is "looking for"
- Key (K): What the token "offers" to others
- Value (V): The actual information the token contains
-
The attention score between tokens is calculated using the dot product of queries and keys
-
These scores are scaled, softmaxed, and used to create weighted sums of values
The Math (Simplified)
def self_attention(sequence):
# Create Q, K, V from input sequence
Q = sequence @ W_q # shape: [seq_len, d_k]
K = sequence @ W_k # shape: [seq_len, d_k]
V = sequence @ W_v # shape: [seq_len, d_v]
# Calculate attention scores
scores = Q @ K.transpose(-2, -1) # shape: [seq_len, seq_len]
# Scale the scores
scores = scores / math.sqrt(d_k)
# Apply softmax to get attention weights
weights = F.softmax(scores, dim=-1)
# Get weighted sum of values
output = weights @ V # shape: [seq_len, d_v]
return output
A Concrete Example
Let's see self-attention in action with a tiny example:
Input sequence: ["The", "cat", "sat"]
- Convert tokens to embeddings (simplified)
"The" → [0.2, 0.3, 0.1]
"cat" → [0.5, 0.2, 0.4]
"sat" → [0.1, 0.7, 0.2]
- Project to Q, K, V (simplified with identity projection)
Q = K = V = [[0.2, 0.3, 0.1],
[0.5, 0.2, 0.4],
[0.1, 0.7, 0.2]]
- Calculate attention scores
scores = Q @ K.T =
[[0.14, 0.17, 0.26],
[0.17, 0.45, 0.29],
[0.26, 0.29, 0.54]]
- Scale and softmax to get weights
weights = softmax(scores / sqrt(3)) =
[[0.30, 0.33, 0.37],
[0.31, 0.40, 0.29],
[0.30, 0.32, 0.38]]
- Compute weighted sum of values
output = weights @ V =
[[0.27, 0.40, 0.23],
[0.27, 0.39, 0.24],
[0.26, 0.41, 0.23]]
The output shows how each token now incorporates information from all other tokens, weighted by relevance.
Multi-Head Attention
In practice, LLMs use multi-head attention, which runs multiple attention operations in parallel and concatenates the results:
def multi_head_attention(X, num_heads=8):
head_dim = d_model // num_heads
outputs = []
for _ in range(num_heads):
# Different projection matrices for each head
Q = X @ W_q # shape: [seq_len, head_dim]
K = X @ W_k # shape: [seq_len, head_dim]
V = X @ W_v # shape: [seq_len, head_dim]
# Compute attention as before
scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
weights = F.softmax(scores, dim=-1)
head_output = weights @ V
outputs.append(head_output)
# Concatenate outputs from all heads
return torch.cat(outputs, dim=-1) @ W_o
Multi-head attention allows the model to jointly attend to information from different representation subspaces, capturing different types of relationships between tokens.
Causal (Masked) Attention
In language generation, we use causal attention to ensure the model only attends to previous tokens:
def causal_self_attention(sequence):
seq_len = sequence.shape[1]
# Create Q, K, V from input sequence
Q = sequence @ W_q
K = sequence @ W_k
V = sequence @ W_v
# Calculate attention scores
scores = Q @ K.transpose(-2, -1)
# Create causal mask (lower triangular)
mask = torch.tril(torch.ones((seq_len, seq_len))).view(1, 1, seq_len, seq_len)
# Apply mask by setting masked positions to -infinity
scores = scores.masked_fill(mask == 0, -1e10)
# Scale, softmax, and weighted sum as before
scores = scores / math.sqrt(d_k)
weights = F.softmax(scores, dim=-1)
output = weights @ V
return output
The causal mask looks like this for a sequence of length 4:
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1]]
This ensures that token 2 can only attend to tokens 0, 1, and 2 (itself), but not to token 3, maintaining the autoregressive property during generation.
Attention Visualization
Let's visualize attention to understand how it works in practice:
Input: "The quick brown fox jumps over the lazy dog"
When generating the word "lazy", the attention weights might look like:
Token | Attention Weight |
---|---|
The | 0.05 |
quick | 0.02 |
brown | 0.01 |
fox | 0.03 |
jumps | 0.10 |
over | 0.15 |
the | 0.60 |
lazy | 0.04 |
dog | 0.00 |
The model attends heavily to "the" as it's the determiner for "lazy dog", showing how attention captures grammatical relationships.
Optimized Attention Implementations
Flash Attention
Standard attention requires O(n²) memory for sequence length n, which becomes problematic for long sequences. Flash Attention addresses this with:
- Block-wise computation to leverage GPU memory hierarchy
- Recomputation of attention during the backward pass to save memory
# Pseudocode for block-wise Flash Attention
def flash_attention(Q, K, V, block_size=256):
seq_len = Q.shape[0]
output = torch.zeros_like(V)
for i in range(0, seq_len, block_size):
q_block = Q[i:i+block_size]
# Initialize block outputs
block_output = torch.zeros_like(q_block)
block_weights_sum = torch.zeros(q_block.shape[0], 1)
for j in range(0, seq_len, block_size):
k_block = K[j:j+block_size]
v_block = V[j:j+block_size]
# Compute scores for this block
scores = q_block @ k_block.T / math.sqrt(q_block.shape[1])
# Apply softmax (simplified - in real impl we handle normalization carefully)
block_weights = torch.exp(scores)
# Update block output
block_output += block_weights @ v_block
block_weights_sum += block_weights.sum(dim=1, keepdim=True)
# Normalize block output
output[i:i+block_size] = block_output / block_weights_sum
return output
Flash Attention can reduce memory usage from O(n²) to O(n), enabling much longer context processing.
KV Caching
When generating text autoregressively, we can cache previously computed key and value projections:
def generate_with_kv_cache(model, prompt, max_tokens=100):
# Initial forward pass on prompt
tokens = tokenize(prompt)
states = model.initial_forward(tokens)
# Initialize KV cache
kv_cache = states['kv_cache']
generated = list(tokens)
for _ in range(max_tokens):
# Forward pass with existing KV cache (only compute for the last token)
logits, new_kv = model.forward_with_cache(
tokens=generated[-1:], # Only the last token
kv_cache=kv_cache
)
# Update KV cache
kv_cache = new_kv
# Sample next token
next_token = sample_token(logits)
generated.append(next_token)
if next_token == EOS_TOKEN:
break
return decode(generated)
KV caching dramatically reduces computation during text generation, as we don't need to recompute keys and values for previously processed tokens.
Attention Variants in Modern LLMs
Grouped-Query Attention (GQA)
Grouped-Query Attention reduces computation by sharing key and value heads:
def grouped_query_attention(X, num_q_heads=8, num_kv_heads=2):
# Each KV head is shared across multiple Q heads
q_outputs = []
# Create KV heads (fewer than Q heads)
K_heads = [X @ W_k_h for h in range(num_kv_heads)]
V_heads = [X @ W_v_h for h in range(num_kv_heads)]
for q_head in range(num_q_heads):
# Map each Q head to a KV head
kv_head_idx = q_head % num_kv_heads
Q = X @ W_q[q_head]
K = K_heads[kv_head_idx]
V = V_heads[kv_head_idx]
# Standard attention calculation
scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
weights = F.softmax(scores, dim=-1)
head_output = weights @ V
q_outputs.append(head_output)
return torch.cat(q_outputs, dim=-1) @ W_o
GQA offers a good trade-off between computation cost and model quality, commonly used in models like PaLM-2 and Claude.
Multi-Query Attention (MQA)
Multi-Query Attention takes GQA to the extreme with only one KV head:
def multi_query_attention(X, num_q_heads=8):
# Single KV pair shared across all query heads
K = X @ W_k # shape: [seq_len, d_k]
V = X @ W_v # shape: [seq_len, d_v]
q_outputs = []
for h in range(num_q_heads):
Q = X @ W_q[h]
# Compute attention using shared K,V
scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
weights = F.softmax(scores, dim=-1)
head_output = weights @ V
q_outputs.append(head_output)
return torch.cat(q_outputs, dim=-1) @ W_o
MQA further reduces computation and memory but may sacrifice some performance compared to full multi-head attention.
Sliding Window Attention
For very long contexts, sliding window attention restricts each token to attend only to its neighborhood:
def sliding_window_attention(X, window_size=1024):
seq_len = X.shape[1]
# Create Q, K, V
Q = X @ W_q
K = X @ W_k
V = X @ W_v
# Calculate attention scores
scores = Q @ K.transpose(-2, -1)
# Create window mask
mask = torch.ones(seq_len, seq_len)
for i in range(seq_len):
for j in range(seq_len):
if j < i - window_size//2 or j > i + window_size//2:
mask[i, j] = 0
# Apply mask
scores = scores.masked_fill(mask == 0, -1e10)
# Scale, softmax and weighted sum
scores = scores / math.sqrt(d_k)
weights = F.softmax(scores, dim=-1)
output = weights @ V
return output
This approach scales linearly with sequence length, enabling much longer context processing.
When Attention Breaks: Understanding Hallucinations and Failure Modes
While attention mechanisms have revolutionized language modeling, they are not infallible. Understanding how and why attention fails is crucial for building more reliable AI systems. Research has identified several key failure modes that lead to hallucinations and inconsistent outputs.
The Attention Collapse Problem
One of the most documented issues is attention collapse, where the model's attention patterns become overly concentrated on specific tokens, leading to repetitive or nonsensical outputs.
Softmax Saturation
The softmax function in attention can saturate, causing the model to focus almost exclusively on one or two tokens:
# Example of attention saturation
import torch
import torch.nn.functional as F
# Simulated attention scores
scores = torch.tensor([10.0, 2.0, 1.0, 0.5]) # One very high score
weights = F.softmax(scores, dim=0)
print(weights) # Output: [0.9975, 0.0018, 0.0007, 0.0003]
# The model becomes almost deterministic, losing diversity
This saturation effect, documented in papers like "On the Pitfalls of Analyzing Individual Neurons in Language Models" (Donnelly & Roegiest, 2019), can cause:
- Repetitive generation: The model gets stuck attending to the same patterns
- Loss of context: Important information gets ignored due to over-concentration
- Factual inconsistencies: The model may confidently generate false information
Positional Bias and Distance Decay
Research by Khandelwal et al. (2018) in "Sharp Nearby, Fuzzy Far Away" demonstrated that attention mechanisms exhibit strong positional biases:
def analyze_positional_bias(attention_weights, sequence_length):
"""
Analyze how attention weights decay with distance
Based on findings from "Sharp Nearby, Fuzzy Far Away"
"""
distances = []
weights = []
for i in range(sequence_length):
for j in range(sequence_length):
distance = abs(i - j)
weight = attention_weights[i, j]
distances.append(distance)
weights.append(weight)
# Typically shows exponential decay: closer tokens get more attention
return distances, weights
This bias leads to several problems:
- Long-range dependency failures: Information from early in the context gets progressively ignored
- Recency bias: The model over-weights recent tokens, potentially missing crucial earlier context
- Inconsistent reasoning: The model may contradict information provided earlier in the conversation
The Hallucination Mechanism
Hallucinations in LLMs often stem from attention mechanism failures. Research by Ji et al. (2023) in "Survey of Hallucination in Natural Language Generation" identifies key attention-related causes:
1. Attention Leakage
When attention weights "leak" to irrelevant tokens, the model can generate content not grounded in the input:
def detect_attention_leakage(attention_weights, relevant_token_indices):
"""
Detect when attention leaks to irrelevant tokens
High leakage correlates with hallucination risk
"""
total_attention = attention_weights.sum()
relevant_attention = attention_weights[relevant_token_indices].sum()
leakage_ratio = 1 - (relevant_attention / total_attention)
# Leakage > 0.3 often indicates hallucination risk
return leakage_ratio
2. Attention Entropy Collapse
Low attention entropy indicates the model is being overly confident about uncertain information:
def attention_entropy(attention_weights):
"""
Calculate attention entropy - low entropy indicates overconfidence
Based on "Quantifying Attention Flow in Transformers" (Abnar & Zuidema, 2020)
"""
# Add small epsilon to avoid log(0)
epsilon = 1e-10
attention_weights = attention_weights + epsilon
entropy = -torch.sum(attention_weights * torch.log(attention_weights), dim=-1)
return entropy
# Low entropy (< 2.0) often correlates with hallucinations
Empirical Evidence from Research
Several studies have documented attention failures:
1. The "Attention is Not Explanation" Debate
Jain & Wallace (2019) showed that attention weights don't always correlate with model reasoning, revealing that:
- High attention weights don't guarantee the token influenced the output
- Models can attend to irrelevant tokens while ignoring crucial ones
- Attention patterns can be manipulated without changing outputs
Link: https://aclanthology.org/N19-1357.pdf
2. Context Length Degradation
Research by Liu et al. (2023) in "Lost in the Middle" demonstrated that:
def context_utilization_analysis(model_outputs, context_positions):
"""
Analyze how well models use information at different context positions
Based on "Lost in the Middle" findings
"""
utilization_scores = {}
for position in ['beginning', 'middle', 'end']:
# Models show U-shaped utilization: strong at beginning/end, weak in middle
if position == 'beginning':
score = 0.85 # High utilization
elif position == 'middle':
score = 0.45 # Poor utilization - "lost in the middle"
else: # end
score = 0.80 # High utilization
utilization_scores[position] = score
return utilization_scores
- Information in the middle of long contexts is poorly utilized
- Models exhibit a "U-shaped" attention pattern, focusing on beginnings and ends
- This leads to hallucinations when crucial information is buried in the middle
Link: https://cs.stanford.edu/~nfliu/papers/lost-in-the-middle.tacl2023.pdf
3. Factual Inconsistency Patterns
Mallen et al. (2022) in "When Not to Trust Language Models" found that attention failures correlate with factual errors:
def factual_confidence_analysis(attention_patterns, known_facts):
"""
Analyze correlation between attention patterns and factual accuracy
Based on findings from "When Not to Trust Language Models"
"""
confidence_indicators = {
'high_entropy_attention': False, # Indicates uncertainty
'consistent_fact_attention': False, # Attends to supporting facts
'contradiction_attention': False, # Attends to contradictory info
}
# High entropy + attention to contradictory info = low factual confidence
if attention_patterns['entropy'] > 3.0 and attention_patterns['contradiction_score'] > 0.2:
return 'high_hallucination_risk'
return 'normal_confidence'
Mitigation Strategies
Recent research has proposed several approaches to address attention failures:
1. Attention Regularization
Regularizing attention to prevent collapse:
def attention_entropy_loss(attention_weights, target_entropy=3.0):
"""
Regularization loss to maintain healthy attention entropy
Based on "Regularizing Attention Networks for Anomaly Detection" (Pang et al., 2021)
"""
current_entropy = attention_entropy(attention_weights)
entropy_loss = F.mse_loss(current_entropy, torch.full_like(current_entropy, target_entropy))
return entropy_loss
2. Attention Supervision
Training attention to focus on relevant information:
def supervised_attention_loss(predicted_attention, ground_truth_attention):
"""
Supervise attention patterns using human annotations
Based on "Attention Supervision for Fair and Explainable Models" (Barrett et al., 2021)
"""
return F.kl_div(
F.log_softmax(predicted_attention, dim=-1),
F.softmax(ground_truth_attention, dim=-1),
reduction='batchmean'
)
3. Multi-Scale Attention
Combining different attention scales to improve robustness:
def multi_scale_attention(x, scales=[1, 2, 4]):
"""
Combine attention at different scales to improve robustness
Based on "Multi-Scale Attention for Neural Machine Translation" (Xu et al., 2020)
"""
outputs = []
for scale in scales:
# Apply attention at different granularities
scaled_attention = scaled_self_attention(x, scale=scale)
outputs.append(scaled_attention)
# Combine multi-scale outputs
return torch.cat(outputs, dim=-1)
Detection and Monitoring
Practitioners can monitor attention health using several metrics:
def attention_health_metrics(attention_weights):
"""
Comprehensive attention health monitoring
"""
metrics = {
'entropy': attention_entropy(attention_weights).mean(),
'max_weight': attention_weights.max(),
'weight_variance': attention_weights.var(),
'sparsity': (attention_weights < 0.01).float().mean(),
}
# Health thresholds based on empirical research
health_score = 1.0
if metrics['entropy'] < 2.0: # Too focused
health_score -= 0.3
if metrics['max_weight'] > 0.8: # Over-concentration
health_score -= 0.3
if metrics['sparsity'] > 0.9: # Too sparse
health_score -= 0.2
return metrics, health_score
Understanding these failure modes is crucial for building more reliable AI systems. While attention mechanisms are powerful, they require careful monitoring and mitigation strategies to prevent hallucinations and ensure consistent, factual outputs.
Practical Implementation Tips
Memory Optimization
-
Gradient checkpointing: Trade computation for memory by recomputing activations during backpropagation
# Using PyTorch's gradient checkpointing output = torch.utils.checkpoint.checkpoint(attention_fn, query, key, value)
-
Mixed precision: Using FP16 or BF16 drastically reduces memory footprint
# Using PyTorch's automatic mixed precision with torch.cuda.amp.autocast(): output = self_attention(x)
-
Attention chunking: Process attention in chunks when sequence length is large
def chunked_attention(q, k, v, chunk_size=1024): outputs = [] for i in range(0, q.size(1), chunk_size): chunk_q = q[:, i:i+chunk_size] # Initialize block outputs block_output = torch.zeros_like(chunk_q) block_weights_sum = torch.zeros(chunk_q.shape[0], 1) for j in range(0, q.size(1), chunk_size): k_block = k[:, j:j+chunk_size] v_block = v[:, j:j+chunk_size] # Compute scores for this block scores = chunk_q @ k_block.transpose(-2, -1) / math.sqrt(chunk_q.shape[1]) # Apply softmax (simplified - in real impl we handle normalization carefully) block_weights = torch.exp(scores) # Update block output block_output += block_weights @ v_block block_weights_sum += block_weights.sum(dim=1, keepdim=True) # Normalize block output outputs.append(block_output / block_weights_sum) # Combine chunked outputs return torch.cat(outputs, dim=1)
Performance Tuning
-
Fused kernels: Use optimized CUDA kernels for attention
# Using xformers' memory-efficient attention from xformers.ops import memory_efficient_attention output = memory_efficient_attention(q, k, v, attn_bias=None)
-
Optimize for inference speed with techniques like KV caching and batch processing
-
Flash Attention 2: Latest optimization makes attention even faster
from flash_attn import flash_attn_qkvpacked_func # Pack QKV and use optimized implementation qkv = torch.cat([q, k, v], dim=2) output = flash_attn_qkvpacked_func(qkv, dropout_p=0.0)
Learn more on how Matter AI helps improve code quality across multiple languages in Pull Requests: https://docs.matterai.so/product/code-quality
Are you looking for a way to improve your code review process? Learn more on how Matter AI helps team to solve code review challenges with AI: https://matterai.so
Share this Article:
More Articles

LLM Tokenisation fundamentals and working
What is LLM Tokenisation and how it works

LLM Quantization: Making models faster and smaller
What is LLM Quantization and how it enables to make models faster and smaller

Understanding LLM Context Window and Working
What is LLM Context Window and how it works

LLM Prompt Caching
What is LLM Prompt Caching and how it can help reduce LLM cost

How Matter AI brings Velocity, Cost Optimization and Governance to Engineering Teams
Dive into what and how MatterAI offers to engineering teams