Abstract Syntax Tree

Learning

Understanding Abstract Syntax Trees

Vatsal Bajpai
Vatsal Bajpai
5 min read·
Cover Image for Understanding Abstract Syntax Trees

Understanding Abstract Syntax Trees: A Deep Dive for Engineers

Introduction

If you've ever wondered how compilers understand your code, how linters spot bugs without running your program, or how tools like Prettier can reformat thousands of lines of code in milliseconds, you've stumbled upon one of the most powerful concepts in computer science: the Abstract Syntax Tree (AST).

While most developers interact with ASTs indirectly through their tools, understanding how they work can transform you from a tool user into a tool creator. This post will take you from AST basics to building your own code analysis tools.

What Exactly is an AST?

An Abstract Syntax Tree is a tree representation of the syntactic structure of source code. Each node in the tree represents a construct occurring in the source code. The "abstract" part means it doesn't include every detail that appears in the real syntax (like parentheses, semicolons, or whitespace)—just the structural content that matters.

Consider this simple JavaScript function:

function add(a, b) {
  return a + b;
}

The AST representation would look something like:

FunctionDeclaration
├── Identifier (name: "add")
├── Params
│   ├── Identifier (name: "a")
│   └── Identifier (name: "b")
└── BlockStatement
    └── ReturnStatement
        └── BinaryExpression (operator: "+")
            ├── Identifier (name: "a")
            └── Identifier (name: "b")

The Compilation Pipeline: Where ASTs Fit

Understanding where ASTs fit in the compilation process helps explain their importance:

  1. Lexical Analysis (Tokenization): Source code → Tokens
  2. Parsing: Tokens → AST
  3. Semantic Analysis: AST validation and type checking
  4. Transformation: AST → Modified AST
  5. Code Generation: AST → Target code

The AST serves as the central data structure that most compilation phases operate on. It's much easier to analyze and transform a tree structure than raw text.

Parsing: From Text to Tree

Let's build a simple expression parser to understand how text becomes an AST. We'll parse mathematical expressions like 2 + 3 * 4:

import re
from dataclasses import dataclass
from typing import Union, Optional

# AST Node definitions
@dataclass
class NumberNode:
    value: int

@dataclass
class BinaryOpNode:
    op: str
    left: 'ASTNode'
    right: 'ASTNode'

ASTNode = Union[NumberNode, BinaryOpNode]

class SimpleParser:
    def __init__(self, text):
        self.tokens = re.findall(r'\d+|[+\-*/()]', text)
        self.pos = 0
    
    def parse(self):
        return self.expr()
    
    def expr(self):
        """Parse addition and subtraction"""
        left = self.term()
        
        while self.current_token() in ['+', '-']:
            op = self.current_token()
            self.advance()
            right = self.term()
            left = BinaryOpNode(op, left, right)
        
        return left
    
    def term(self):
        """Parse multiplication and division"""
        left = self.factor()
        
        while self.current_token() in ['*', '/']:
            op = self.current_token()
            self.advance()
            right = self.factor()
            left = BinaryOpNode(op, left, right)
        
        return left
    
    def factor(self):
        """Parse numbers and parentheses"""
        token = self.current_token()
        
        if token and token.isdigit():
            self.advance()
            return NumberNode(int(token))
        elif token == '(':
            self.advance()
            node = self.expr()
            self.advance()  # skip ')'
            return node
    
    def current_token(self):
        return self.tokens[self.pos] if self.pos < len(self.tokens) else None
    
    def advance(self):
        self.pos += 1

# Example usage
parser = SimpleParser("2 + 3 * 4")
ast = parser.parse()
print(ast)
# Output: BinaryOpNode(op='+', left=NumberNode(value=2), 
#         right=BinaryOpNode(op='*', left=NumberNode(value=3), 
#         right=NumberNode(value=4)))

Working with Real ASTs

Let's explore how to work with ASTs in different languages:

JavaScript with Babel

const parser = require('@babel/parser');
const traverse = require('@babel/traverse').default;
const generate = require('@babel/generator').default;

const code = `
function greet(name) {
  console.log("Hello, " + name);
}
`;

// Parse code to AST
const ast = parser.parse(code);

// Transform: Convert string concatenation to template literals
traverse(ast, {
  BinaryExpression(path) {
    if (path.node.operator === '+') {
      const { left, right } = path.node;
      
      // Check if it's string concatenation
      if (left.type === 'StringLiteral' || right.type === 'StringLiteral') {
        // Convert to template literal
        path.replaceWith({
          type: 'TemplateLiteral',
          quasis: [
            { type: 'TemplateElement', value: { raw: left.value }, tail: false },
            { type: 'TemplateElement', value: { raw: '' }, tail: true }
          ],
          expressions: [right]
        });
      }
    }
  }
});

// Generate code from modified AST
const output = generate(ast);
console.log(output.code);
// Output: function greet(name) { console.log(`Hello, ${name}`); }

Python's Built-in AST Module

import ast
import astor

class FunctionCallCounter(ast.NodeVisitor):
    """Count function calls in Python code"""
    
    def __init__(self):
        self.function_calls = {}
    
    def visit_Call(self, node):
        # Get function name if it's a simple call
        if isinstance(node.func, ast.Name):
            func_name = node.func.id
            self.function_calls[func_name] = self.function_calls.get(func_name, 0) + 1
        elif isinstance(node.func, ast.Attribute):
            func_name = f"{astor.to_source(node.func).strip()}"
            self.function_calls[func_name] = self.function_calls.get(func_name, 0) + 1
        
        self.generic_visit(node)

code = """
def process_data(data):
    result = []
    for item in data:
        if validate(item):
            processed = transform(item)
            result.append(processed)
    
    print(f"Processed {len(result)} items")
    return result

def main():
    data = load_data()
    processed = process_data(data)
    save_data(processed)
"""

tree = ast.parse(code)
counter = FunctionCallCounter()
counter.visit(tree)

print("Function call frequency:")
for func, count in sorted(counter.function_calls.items(), key=lambda x: x[1], reverse=True):
    print(f"  {func}: {count}")

Building Practical Tools with ASTs

1. Custom Linter: Detecting Common Mistakes

Let's build a linter that catches a common mistake—using == instead of === in JavaScript:

const parser = require('@babel/parser');
const traverse = require('@babel/traverse').default;

function lintCode(code) {
  const warnings = [];
  const ast = parser.parse(code, { sourceType: 'module' });
  
  traverse(ast, {
    BinaryExpression(path) {
      if (path.node.operator === '==') {
        warnings.push({
          line: path.node.loc.start.line,
          column: path.node.loc.start.column,
          message: 'Use === instead of == for strict equality'
        });
      }
    }
  });
  
  return warnings;
}

const testCode = `
if (userInput == "admin") {
  grantAccess();
}
`;

console.log(lintCode(testCode));
// Output: [{ line: 2, column: 14, message: 'Use === instead of ==' }]

2. Code Metrics Analyzer

Build a complexity analyzer using Python's AST:

import ast

class ComplexityAnalyzer(ast.NodeVisitor):
    def __init__(self):
        self.complexity = 0
        self.current_function = None
        self.function_complexities = {}
    
    def visit_FunctionDef(self, node):
        # Save previous context
        prev_complexity = self.complexity
        prev_function = self.current_function
        
        # Set new context
        self.current_function = node.name
        self.complexity = 1  # Base complexity
        
        # Visit function body
        self.generic_visit(node)
        
        # Store result
        self.function_complexities[self.current_function] = self.complexity
        
        # Restore context
        self.complexity = prev_complexity
        self.current_function = prev_function
    
    def visit_If(self, node):
        self.complexity += 1
        self.generic_visit(node)
    
    def visit_While(self, node):
        self.complexity += 1
        self.generic_visit(node)
    
    def visit_For(self, node):
        self.complexity += 1
        self.generic_visit(node)
    
    def visit_ExceptHandler(self, node):
        self.complexity += 1
        self.generic_visit(node)
    
    def visit_With(self, node):
        self.complexity += 1
        self.generic_visit(node)
    
    def visit_BoolOp(self, node):
        # Add complexity for each additional boolean operator
        self.complexity += len(node.values) - 1
        self.generic_visit(node)

# Example usage
code = """
def process_items(items, options):
    if not items:
        return []
    
    results = []
    for item in items:
        if item.is_valid() and (item.priority > 5 or options.force):
            try:
                processed = item.process()
                if processed:
                    results.append(processed)
            except ProcessingError:
                if options.strict:
                    raise
                continue
    
    return results
"""

tree = ast.parse(code)
analyzer = ComplexityAnalyzer()
analyzer.visit(tree)

for func, complexity in analyzer.function_complexities.items():
    print(f"{func}: Cyclomatic Complexity = {complexity}")

3. Auto-generating Boilerplate Code

Use ASTs to generate repetitive code patterns:

import ast
import astor

def generate_dataclass(class_name, fields):
    """Generate a Python dataclass with validation"""
    
    # Create class definition
    class_def = ast.ClassDef(
        name=class_name,
        bases=[],
        keywords=[],
        body=[],
        decorator_list=[ast.Name(id='dataclass', ctx=ast.Load())]
    )
    
    # Add __post_init__ method for validation
    post_init_body = []
    
    for field_name, field_type in fields.items():
        # Add type annotation
        annotation = ast.AnnDef(
            target=ast.Name(id=field_name, ctx=ast.Store()),
            annotation=ast.Name(id=field_type, ctx=ast.Load()),
            simple=1
        )
        
        # Add validation in __post_init__
        if field_type == 'str':
            validation = ast.If(
                test=ast.UnaryOp(
                    op=ast.Not(),
                    operand=ast.Attribute(
                        value=ast.Name(id='self', ctx=ast.Load()),
                        attr=field_name,
                        ctx=ast.Load()
                    )
                ),
                body=[
                    ast.Raise(
                        exc=ast.Call(
                            func=ast.Name(id='ValueError', ctx=ast.Load()),
                            args=[ast.Constant(value=f"{field_name} cannot be empty")],
                            keywords=[]
                        )
                    )
                ],
                orelse=[]
            )
            post_init_body.append(validation)
    
    # Add __post_init__ method
    if post_init_body:
        post_init = ast.FunctionDef(
            name='__post_init__',
            args=ast.arguments(
                posonlyargs=[],
                args=[ast.arg(arg='self', annotation=None)],
                kwonlyargs=[],
                kw_defaults=[],
                defaults=[]
            ),
            body=post_init_body,
            decorator_list=[]
        )
        class_def.body.append(post_init)
    
    # Wrap in module
    module = ast.Module(body=[
        ast.ImportFrom(module='dataclasses', names=[ast.alias(name='dataclass', asname=None)], level=0),
        class_def
    ])
    
    # Fix missing fields and generate code
    ast.fix_missing_locations(module)
    return astor.to_source(module)

# Generate a User dataclass
code = generate_dataclass('User', {
    'name': 'str',
    'email': 'str',
    'age': 'int'
})

print(code)

Performance Considerations

When working with ASTs in production:

  1. Parse Once, Transform Many: Parsing is expensive. Cache ASTs when possible.

  2. Visitor vs. Transformer Pattern: Use visitors for analysis (read-only), transformers for modifications.

  3. Incremental Parsing: For large codebases, consider incremental parsing strategies that only re-parse changed portions.

  4. Memory Usage: ASTs can be memory-intensive. For large files, consider streaming or chunking approaches.

Common Pitfalls and Best Practices

Pitfall 1: Modifying While Traversing

# Wrong: Modifying list while iterating
for node in ast.walk(tree):
    if isinstance(node, ast.FunctionDef):
        tree.body.remove(node)  # Don't do this!

# Right: Collect first, then modify
functions_to_remove = []
for node in ast.walk(tree):
    if isinstance(node, ast.FunctionDef):
        functions_to_remove.append(node)

for func in functions_to_remove:
    tree.body.remove(func)

Pitfall 2: Losing Source Information

Always preserve location information when transforming ASTs if you need to provide meaningful error messages:

new_node = ast.Name(id='new_var', ctx=ast.Load())
ast.copy_location(new_node, old_node)  # Preserve line/column info

Pitfall 3: Language-Specific Quirks

Each language's AST has quirks. JavaScript's AST differentiates between expressions and statements, Python's AST requires fixing missing locations, and Go's AST preserves comments as separate entities.

Advanced Applications

1. Cross-Language Transpilation

ASTs enable transpilation between languages with similar semantics. Tools like Transcrypt (Python to JavaScript) work by converting Python AST to JavaScript AST.

2. Program Synthesis

Generate code from specifications by constructing ASTs programmatically, then converting to source code.

3. Differential Testing

Compare ASTs of different implementations to ensure semantic equivalence during refactoring.

4. Security Analysis

Detect security vulnerabilities by pattern matching against known vulnerable AST patterns (taint analysis, SQL injection detection).

Tools and Resources

AST Explorers

  • AST Explorer (astexplorer.net): Interactive AST visualization for multiple languages
  • Python AST Visualizer: Built-in ast.dump() with indent parameter
  • Babel AST Explorer: Specific to JavaScript/TypeScript

Libraries

  • JavaScript/TypeScript: Babel, Acorn, Esprima, TypeScript Compiler API
  • Python: ast (built-in), astroid, RedBaron
  • Go: go/ast, go/parser packages
  • Rust: syn, quote crates
  • Java: JavaParser, Eclipse JDT

Learning Resources

  • "Writing An Interpreter In Go" by Thorsten Ball
  • "Crafting Interpreters" by Robert Nystrom
  • Dragon Book: "Compilers: Principles, Techniques, and Tools"

Conclusion

Abstract Syntax Trees are the backbone of modern development tooling. Understanding them transforms mysterious "magic" into comprehensible engineering. Whether you're building a simple linter, a complex refactoring tool, or even your own programming language, ASTs provide the foundation you need.

Start small—pick a simple problem in your codebase that could benefit from automated analysis or transformation. Parse it into an AST, traverse it, and see what patterns emerge. You'll be surprised how many "impossible" problems become tractable when you're working with structured data instead of text.

The next time you use a development tool that seems to understand your code, remember: it's not magic, it's trees all the way down.


Learn more on how Matter AI helps improve code quality across multiple languages in Pull Requests: https://docs.matterai.so/product/code-quality

Are you looking for a way to improve your code review process? Learn more on how Matter AI helps team to solve code review challenges with AI: https://matterai.so

Share this Article: