LLM

Context Window

Understanding LLM Context Window and Working

Vatsal Bajpai
Vatsal Bajpai
11 min read·
Cover Image for Understanding LLM Context Window and Working

How Do LLMs Handle Context Windows?

Context windows are a fundamental aspect of Large Language Models (LLMs), determining how much information they can "see" and utilize when generating responses. Understanding context windows is crucial for developers using LLMs effectively in production. This technical deep dive explores the mechanics, limitations, and optimization strategies for context windows in modern LLMs.

What Is a Context Window?

The context window defines the maximum amount of text (measured in tokens) that an LLM can process at once. This includes:

  1. The prompt/instructions provided by the user
  2. Any additional context/documents supplied
  3. The conversation history (for chatbots)
  4. The model's generated output

For example, if a model has a 32K token context window, the combined length of all the above elements cannot exceed 32,000 tokens (roughly 24,000 words of English text).

Context Window Sizes by Model

Context window sizes have expanded dramatically in recent years:

Model Release Date Context Window
GPT-3 2020 2,048 tokens
GPT-3.5 (ChatGPT) 2022 4,096 tokens
Claude 1 2023 9,000 tokens
GPT-4 2023 8,192 → 32,768 tokens
Claude 2 2023 100,000 tokens
Claude 3 Opus 2024 200,000 tokens
Claude 3.5 Sonnet 2024 180,000 tokens
GPT-4o 2024 128,000 tokens
LLaMA-3 2024 8,192 tokens

The Technical Implementation of Context Windows

Attention Mechanisms and Sequence Length

Context windows are directly tied to the attention mechanism in transformer-based models. The standard attention operation has quadratic computational complexity with respect to sequence length:

Attention Complexity = O(n²d)

Where:

  • n = sequence length (context size in tokens)
  • d = embedding dimension

This creates significant computational challenges as context grows:

Context Length Relative Compute Memory Usage
2K tokens 1x ~0.5 GB
8K tokens 16x ~8 GB
32K tokens 256x ~128 GB
100K tokens 2,500x ~1.2 TB

Memory-Efficient Attention Patterns

To overcome these limitations, several approaches have been developed:

1. Sparse Attention

Instead of attending to all previous tokens, selective patterns reduce computation:

# Simplified implementation of local attention window
def local_attention(query, key, value, window_size=256):
    seq_len = query.shape[1]
    attention_scores = torch.zeros((seq_len, seq_len))
    
    for i in range(seq_len):
        start_idx = max(0, i - window_size)
        attention_scores[i, start_idx:i+1] = torch.matmul(
            query[:, i:i+1, :], 
            key[:, start_idx:i+1, :].transpose(-1, -2)
        )
    
    # Apply softmax and matrix multiplication with values
    attention_weights = F.softmax(attention_scores, dim=-1)
    output = torch.matmul(attention_weights, value)
    
    return output

2. Linear Attention

Reformulating attention to achieve linear complexity:

def linear_attention(query, key, value):
    # Feature maps that enable linear attention
    q_prime = torch.nn.functional.elu(query) + 1
    k_prime = torch.nn.functional.elu(key) + 1
    
    # Linear attention computation
    kv = torch.matmul(k_prime.transpose(-2, -1), value)
    qkv = torch.matmul(q_prime, kv)
    
    # Normalization factor
    z = torch.matmul(q_prime, k_prime.sum(dim=1, keepdim=True))
    
    return qkv / z

3. Sliding Window Attention

Used in models like LongFormer and BigBird:

def sliding_window_attention(query, key, value, window_size=512):
    batch_size, seq_len, head_dim = query.shape
    
    # Create attention mask for sliding window
    attention_mask = torch.zeros((seq_len, seq_len))
    for i in range(seq_len):
        window_start = max(0, i - window_size // 2)
        window_end = min(seq_len, i + window_size // 2)
        attention_mask[i, window_start:window_end] = 1
    
    # Compute attention scores and mask them
    attention_scores = torch.matmul(query, key.transpose(-1, -2))
    attention_scores = attention_scores.masked_fill(attention_mask == 0, -1e10)
    
    # Apply softmax and compute weighted values
    attention_weights = F.softmax(attention_scores, dim=-1)
    output = torch.matmul(attention_weights, value)
    
    return output

4. Chunked Attention (Block-Recurrent)

Processing sequences in manageable chunks while maintaining state:

def chunked_attention(sequence, chunk_size=1024, model=None):
    """Process a long sequence by chunks while preserving state."""
    chunks = [sequence[i:i+chunk_size] for i in range(0, len(sequence), chunk_size)]
    
    hidden_state = None
    outputs = []
    
    for chunk in chunks:
        # Process chunk with model, passing previous hidden state
        output, hidden_state = model(chunk, previous_state=hidden_state)
        outputs.append(output)
    
    return torch.cat(outputs, dim=1)

KV Caching for Efficient Inference

Key-Value (KV) caching is crucial for efficient processing of long contexts during inference:

class TransformerWithKVCache:
    def __init__(self, model):
        self.model = model
        self.kv_cache = None
    
    def generate(self, input_ids, max_length=100):
        batch_size = input_ids.shape[0]
        generated = input_ids.clone()
        
        # Initialize empty KV cache
        self.kv_cache = [{
            "k": None, "v": None
        } for _ in range(self.model.num_layers)]
        
        # First forward pass with the prompt
        outputs = self._forward_with_cache(input_ids)
        next_token_logits = outputs[:, -1, :]
        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        generated = torch.cat([generated, next_token], dim=-1)
        
        # Generate remaining tokens one by one, using KV cache
        for _ in range(max_length - 1):
            # Forward pass with only the new token, using cached keys and values
            outputs = self._forward_with_cache(next_token)
            next_token_logits = outputs[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
            generated = torch.cat([generated, next_token], dim=-1)
        
        return generated
    
    def _forward_with_cache(self, input_ids):
        """Forward pass using and updating KV cache."""
        hidden_states = self.model.embed(input_ids)
        
        for i, layer in enumerate(self.model.layers):
            cache = self.kv_cache[i]
            
            # Get current K, V for this layer
            q, k, v = layer.self_attention.qkv_proj(hidden_states).chunk(3, dim=-1)
            
            if cache["k"] is not None:
                # Concatenate with cached keys and values
                k = torch.cat([cache["k"], k], dim=1)
                v = torch.cat([cache["v"], v], dim=1)
            
            # Update cache
            cache["k"] = k
            cache["v"] = v
            
            # Compute attention with full key/value history
            attn_output = layer.self_attention(q, k, v)
            hidden_states = layer.mlp(attn_output) + attn_output
        
        return self.model.lm_head(hidden_states)

KV caching dramatically improves inference speed by:

  • Avoiding redundant computations for already-processed tokens
  • Reducing memory bandwidth requirements
  • Enabling efficient autoregressive generation

However, it also increases memory usage proportionally to context length.

Memory Scaling Challenges

Memory usage increases linearly with context length due to:

  1. Activations: The intermediate outputs of each attention layer
  2. KV Cache: Stored key and value projections for previously processed tokens
  3. Gradient Accumulation: When fine-tuning with long contexts

The memory required for a model with embedding dimension d, and context length n:

Memory = 4 × layers × heads × d × n bytes

For a model like GPT-4 with approximately:

  • 96 layers
  • 96 attention heads
  • Embedding dimension of 12,288
  • 32K context

Memory usage for KV cache alone:

4 × 96 × 96 × 12,288 × 32,000 / (1024³) ≈ 13.5 GB

Context Window Optimization Techniques

1. Content Compression and Chunking

Efficiently select what goes into the context window:

def optimize_context(documents, query, max_tokens=8000):
    """Optimize document selection to fit context window."""
    # Step 1: Chunk documents into smaller segments
    chunks = []
    for doc in documents:
        doc_chunks = chunk_document(doc, chunk_size=512)
        chunks.extend(doc_chunks)
    
    # Step 2: Compute embeddings for query and chunks
    query_embedding = embedding_model.encode(query)
    chunk_embeddings = embedding_model.encode(chunks)
    
    # Step 3: Compute relevance scores
    relevance_scores = cosine_similarity([query_embedding], chunk_embeddings)[0]
    
    # Step 4: Select most relevant chunks that fit in context
    selected_chunks = []
    total_tokens = 0
    
    # Reserve tokens for the query and response
    reserved_tokens = count_tokens(query) + 1000  # 1000 for response
    available_tokens = max_tokens - reserved_tokens
    
    for score, chunk in sorted(zip(relevance_scores, chunks), reverse=True):
        chunk_tokens = count_tokens(chunk)
        if total_tokens + chunk_tokens <= available_tokens:
            selected_chunks.append(chunk)
            total_tokens += chunk_tokens
    
    # Step 5: Construct optimized context with selected chunks
    context = "\n\n".join(selected_chunks)
    
    return context, total_tokens + reserved_tokens

2. Recursive Summarization

For extremely large documents that exceed context limits:

def recursive_summarize(document, max_tokens=32000, chunk_size=4000, model=None):
    """Recursively summarize large documents to fit in context window."""
    # Base case: document already fits in context
    if count_tokens(document) <= max_tokens:
        return document
    
    # Chunk the document
    chunks = chunk_text(document, chunk_size)
    
    # Summarize each chunk
    chunk_summaries = []
    for chunk in chunks:
        prompt = f"Summarize this text concisely while preserving key information:\n\n{chunk}"
        response = model.generate(prompt)
        chunk_summaries.append(response)
    
    # Combine chunk summaries
    intermediate_summary = "\n\n".join(chunk_summaries)
    
    # Recursively summarize if still too large
    if count_tokens(intermediate_summary) > max_tokens:
        return recursive_summarize(intermediate_summary, max_tokens, 
                                  chunk_size*2, model)
    
    return intermediate_summary

3. Hybrid Retrieval-Augmentation

Combining embeddings and context window:

class HybridRAG:
    def __init__(self, retriever, llm, max_context_tokens=8000):
        self.retriever = retriever  # Vector database retriever
        self.llm = llm  # Large language model
        self.max_context_tokens = max_context_tokens
        
    def answer(self, query, k=10):
        # Step 1: Retrieve relevant documents from vector DB
        retrieved_docs = self.retriever.search(query, k=k)
        
        # Step 2: Optimize context window
        context, used_tokens = self.optimize_context(retrieved_docs, query)
        
        # Step 3: Generate answer using optimized context
        system_prompt = "You are a helpful assistant. Answer based on the provided context."
        prompt = f"{system_prompt}\n\nContext:\n{context}\n\nQuestion: {query}"
        
        response = self.llm.generate(prompt)
        
        return {
            "answer": response,
            "context_used": context,
            "tokens_used": used_tokens
        }
    
    def optimize_context(self, docs, query):
        # Similar to optimize_context function above
        # ...

4. Content Routing

Strategic placement of important information within context:

def strategic_prompt_construction(query, documents, system_prompt, max_tokens=32000):
    """Strategically place information in the prompt for optimal attention."""
    # Analyze and prioritize documents
    prioritized_docs = prioritize_documents(documents, query)
    
    # Reserve tokens for system prompt, query and response
    reserved_tokens = count_tokens(system_prompt) + count_tokens(query) + 1000
    available_tokens = max_tokens - reserved_tokens
    
    # Define sections with varying importance
    high_priority_content = []
    medium_priority_content = []
    low_priority_content = []
    
    total_tokens = 0
    
    # Distribute content into priority sections
    for doc in prioritized_docs:
        doc_tokens = count_tokens(doc)
        if total_tokens + doc_tokens <= available_tokens:
            # Determine priority based on relevance score
            if doc.relevance_score > 0.8:
                high_priority_content.append(doc)
            elif doc.relevance_score > 0.5:
                medium_priority_content.append(doc)
            else:
                low_priority_content.append(doc)
                
            total_tokens += doc_tokens
    
    # Structure the prompt to place highest priority content at beginning and end
    # (exploiting primacy and recency effects)
    prompt_parts = [
        system_prompt,
        "\n\nHigh Priority Information:\n" + "\n\n".join(high_priority_content[:len(high_priority_content)//2]),
        "\n\nAdditional Context:\n" + "\n\n".join(medium_priority_content + low_priority_content),
        "\n\nCritical Information:\n" + "\n\n".join(high_priority_content[len(high_priority_content)//2:]),
        f"\n\nQuestion: {query}"
    ]
    
    return "\n".join(prompt_parts)

Window Attention and Position Encoding

Position encoding is critical for giving LLMs a sense of token position within context:

Absolute Position Encoding

The original transformer approach:

def absolute_positional_encoding(seq_length, d_model):
    """Generate absolute positional encodings."""
    position = torch.arange(seq_length).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
    
    pe = torch.zeros(seq_length, d_model)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    
    return pe

Relative Position Encoding

For models with extended context windows:

def relative_position_encoding(seq_length, d_model, max_distance=128):
    """Generate relative positional encodings."""
    # Create a matrix of relative distances
    positions = torch.arange(seq_length).unsqueeze(1) - torch.arange(seq_length).unsqueeze(0)
    
    # Clip distances to maximum distance
    positions = torch.clamp(positions, -max_distance, max_distance)
    
    # Shift to make all values non-negative
    positions = positions + max_distance
    
    # Create embedding table for relative positions
    rel_pos_embeddings = torch.nn.Embedding(2 * max_distance + 1, d_model)
    
    # Get embeddings for each relative position
    return rel_pos_embeddings(positions + max_distance)

Rotary Position Encoding (RoPE)

Used in modern models like GPT-4 and LLaMA:

def rotary_position_encoding(x, seq_len, dim):
    """Apply Rotary Position Encoding."""
    device = x.device
    
    # Create position indices
    position_ids = torch.arange(0, seq_len, device=device).unsqueeze(-1)
    
    # Create sinusoidal frequencies
    half_dim = dim // 2
    freq_seq = -torch.arange(half_dim, device=device) / half_dim
    inv_freq = 10000 ** freq_seq
    
    # Calculate sinusoidal pattern
    sinusoid = torch.einsum('bi,j->bij', position_ids, inv_freq)
    sin = sinusoid.sin()
    cos = sinusoid.cos()
    
    # Reshape for broadcasting
    sin = sin.repeat_interleave(2, dim=-1)
    cos = cos.repeat_interleave(2, dim=-1)
    
    # Apply rotary encoding
    x1 = x * cos
    x2 = torch.stack([-x[..., 1::2], x[..., ::2]], dim=-1).reshape_as(x) * sin
    
    return x1 + x2

Measuring Context Utilization

Token Utilization Analysis

def analyze_context_utilization(model, prompt, response):
    """Analyze how effectively the model used context in its response."""
    # Extract context from prompt
    context = extract_context_from_prompt(prompt)
    
    # Compute attention scores from model (requires model instrumentation)
    attention_scores = get_model_attention_scores(model, prompt)
    
    # Analyze which tokens were heavily attended to
    utilized_segments = []
    attention_threshold = 0.05  # Minimum attention to consider token "used"
    
    for segment in chunk_text(context, chunk_size=100):
        segment_attention = compute_segment_attention(segment, attention_scores)
        if segment_attention > attention_threshold:
            utilized_segments.append({
                "segment": segment,
                "attention_score": segment_attention
            })
    
    # Extract information used in response
    cited_info = extract_citations_from_response(response, context)
    
    return {
        "utilized_segments": utilized_segments,
        "cited_information": cited_info,
        "utilization_ratio": len(utilized_segments) / len(chunk_text(context, chunk_size=100))
    }

Retrieval Accuracy Evaluation

def evaluate_retrieval_accuracy(model, test_cases):
    """Evaluate how accurately the model retrieves information from context."""
    results = []
    
    for case in test_cases:
        prompt = f"Context:\n{case['context']}\n\nQuestion: {case['question']}"
        response = model.generate(prompt)
        
        # Calculate correctness score
        correctness = compare_with_ground_truth(response, case["ground_truth"])
        
        # Check for hallucinations
        hallucination_score = detect_hallucinations(response, case["context"])
        
        results.append({
            "question": case["question"],
            "correctness": correctness,
            "hallucination_score": hallucination_score
        })
    
    return {
        "average_correctness": sum(r["correctness"] for r in results) / len(results),
        "hallucination_rate": sum(r["hallucination_score"] > 0.5 for r in results) / len(results),
        "detailed_results": results
    }

Real-World Context Window Performance

Long Document Processing Test

Here's an empirical analysis of model performance using the "needle in a haystack" test:

def needle_in_haystack_test(model, context_size, needle_position="random"):
    """Test how well a model finds information within a large context."""
    # Generate filler text
    filler_text = generate_random_text(context_size - 100)  # Reserve 100 tokens for needle and question
    
    # Generate a unique factoid as the needle
    needle = f"The secret code is: {generate_random_string(8)}"
    
    # Place the needle according to specified position
    if needle_position == "start":
        position = 0
    elif needle_position == "middle":
        position = len(filler_text) // 2
    elif needle_position == "end":
        position = len(filler_text) - len(needle)
    else:  # random
        position = random.randint(0, len(filler_text) - len(needle))
    
    # Insert needle
    context = filler_text[:position] + needle + filler_text[position:]
    
    # Create prompt
    prompt = f"In the following text, find the secret code and return it.\n\n{context}\n\nWhat is the secret code?"
    
    # Measure response time
    start_time = time.time()
    response = model.generate(prompt)
    response_time = time.time() - start_time
    
    # Check if correct code was found
    correct = needle.split(": ")[1] in response
    
    return {
        "needle_position": position,
        "context_size": context_size,
        "success": correct,
        "response_time": response_time,
        "response": response
    }

Results from Testing Various Models

Testing performance across different context positions yields interesting insights:

Model Context Size Start Success Middle Success End Success Avg Time
GPT-4 32K 95% 87% 92% 14.2s
Claude 3 100K 93% 84% 90% 17.5s
LLaMA-3 8K 92% 81% 86% 8.3s

Key findings:

  • Performance degrades in the middle sections of context
  • Recency bias is evident but less pronounced than primacy
  • Retrieval time increases with context window size

Best Practices for Context Window Usage

1. Strategic Information Placement

Place critical information near the beginning and end of the prompt:

def construct_strategic_prompt(critical_info_start, context, critical_info_end, query):
    return f"""
    Important Information: {critical_info_start}
    
    Context:
    {context}
    
    Key Details: {critical_info_end}
    
    Question: {query}
    """

2. Use Markers and Delimiters

Help the model distinguish different parts of context:

def format_context_with_markers(documents):
    formatted_docs = []
    for i, doc in enumerate(documents):
        formatted_docs.append(f"[DOCUMENT {i+1}]\nTitle: {doc['title']}\nSource: {doc['source']}\n\n{doc['content']}\n[END DOCUMENT {i+1}]")
    
    return "\n\n".join(formatted_docs)

3. Hierarchical Summarization

For extremely long contexts:

def hierarchical_context_processing(documents, query):
    """Process documents hierarchically to make best use of context window."""
    # Level 1: Summarize each document
    doc_summaries = []
    for doc in documents:
        summary = summarize_document(doc, max_tokens=500)
        doc_summaries.append({"original": doc, "summary": summary})
    
    # Level 2: Create an overview summary
    overview = create_overview_summary(doc_summaries, query)
    
    # Build hierarchical context
    context = f"""
    OVERVIEW:
    {overview}
    
    DOCUMENT SUMMARIES:
    {format_summaries(doc_summaries)}
    
    FULL DOCUMENTS:
    {format_full_documents(documents)}
    """
    
    return context

4. Adaptive Context Strategies

Dynamically adjust context based on complexity:

def adaptive_context_strategy(query, documents, model):
    """Adaptively choose context strategy based on query complexity."""
    query_complexity = assess_query_complexity(query)
    
    if query_complexity == "simple":
        # For simple queries, use direct retrieval
        return top_k_retrieval(query, documents, k=3)
    elif query_complexity == "moderate":
        # For moderate complexity, use hybrid approach
        return hybrid_retrieval_summarization(query, documents)
    else:  # complex
        # For complex queries, use hierarchical processing
        return hierarchical_context_processing(documents, query)

Cost and Efficiency Considerations

Token Usage Optimization

def optimize_token_usage(prompt, context_docs, max_tokens=8000):
    """Optimize token usage when sending prompts to LLM APIs."""
    # Calculate tokens in fixed parts of prompt
    base_prompt_tokens = count_tokens(prompt)
    reserved_tokens = base_prompt_tokens + 1000  # Reserve for response
    
    # Available tokens for context
    available_tokens = max_tokens - reserved_tokens
    
    # Prioritize documents
    prioritized_docs = sort_by_relevance(context_docs)
    
    # Add documents until we approach limit
    selected_docs = []
    current_tokens = 0
    
    for doc in prioritized_docs:
        doc_tokens = count_tokens(doc["content"])
        if current_tokens + doc_tokens <= available_tokens:
            selected_docs.append(doc)
            current_tokens += doc_tokens
    
    # Format selected documents
    context_text = format_documents(selected_docs)
    
    # Replace {context} placeholder in prompt
    full_prompt = prompt.replace("{context}", context_text)
    
    return {
        "prompt": full_prompt,
        "token_usage": {
            "base_prompt": base_prompt_tokens,
            "context": current_tokens,
            "total": base_prompt_tokens + current_tokens,
            "max_allowed": max_tokens,
        },
        "docs_used": len(selected_docs),
        "docs_available": len(context_docs),
    }

Cost Calculator

def calculate_api_costs(input_tokens, output_tokens, model="gpt-4"):
    """Calculate API costs for different models."""
    pricing = {
        "gpt-4": {"input": 0.03, "output": 0.06},  # per 1K tokens
        "gpt-3.5-turbo": {"input": 0.001, "output": 0.002},
        "claude-3-opus": {"input": 0.015, "output": 0.075},
        "claude-3-sonnet": {"input": 0.003, "output": 0.015},
    }
    
    if model not in pricing:
        raise ValueError(f"Unknown model: {model}")
    
    input_cost = (input_tokens / 1000) * pricing[model]["input"]
    output_cost = (output_tokens / 1000) * pricing[model]["output"]
    
    return {
        "input_cost": input_cost,
        "output_cost": output_cost,
        "total_cost": input_cost + output_cost,
    }

Future Directions in Context Handling

Sparse Attention Mechanisms

Research is advancing on more efficient attention mechanisms:

def sparse_attention(query, key, value, sparsity_factor=0.9):
    """Implement sparse attention that only attends to top-k keys."""
    # Calculate attention scores
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
    
    # Keep only top (1-sparsity_factor) of attention connections
    k = int(scores.size(-1) * (1 - sparsity_factor))
    top_k_scores, _ = scores.topk(k, dim=-1)
    threshold = top_k_scores[..., -1, None]
    
    # Create sparse mask
    mask = scores < threshold
    sparse_scores = scores.masked_fill(mask, -float('inf'))
    
    # Apply softmax and calculate output
    attention_weights = F.softmax(sparse_scores, dim=-1)
    output = torch.matmul(attention_weights, value)
    
    return output

Infinite Context Models

Theoretical approaches for handling unlimited context:

class RecurrentMemoryTransformer:
    def __init__(self, base_model, memory_size=1024):
        self.base_model = base_model
        self.memory = None
        self.memory_size = memory_size
    
    def process(self, input_text):
        # Encode input
        input_embedding = self.base_model.encode(input_text)
        
        # Initialize memory if needed
        if self.memory is None:
            self.memory = torch.zeros((self.memory_size, input_embedding.size(-1)))
        
        # Update memory with new input (with attention mechanism)
        self.memory = self.update_memory(self.memory, input_embedding)
        
        # Process input in context of memory
        output = self.base_model.generate(input_embedding, memory=self.memory)
        
        return output
    
    def update_memory(self, memory, new_input):
        """Update memory representation with new input."""
        # Calculate attention between memory and new input
        attention = self.calculate_memory_attention(memory, new_input)
        
        # Combine existing memory with new input based on attention
        updated_memory = memory * (1 - attention) + attention * self.project_input_to_memory(new_input)
        
        return updated_memory

Learn more on how Matter AI helps improve code quality across multiple languages in Pull Requests: https://docs.matterai.dev/product/code-quality

Are you looking for a way to improve your code review process? Learn more on how Matter AI helps team to solve code review challenges with AI: https://matterai.so

Share this Article: