Chapter 19: Attention

Why Focusing Beats Remembering

The Bottleneck of Fixed-Size Memory

Recurrent Neural Networks compress entire sequences into a fixed-size hidden state vector. For machine translation, this means encoding a French sentence of arbitrary length into a single vector (typically 256 or 512 dimensions), then decoding it into English. All the meaning—every word, every grammatical relationship, every nuance—must fit in this fixed-size bottleneck.

This works for short sentences but breaks down for longer ones. A sentence with 50 words contains far more information than can be compressed into 512 numbers without loss. Important details get forgotten. Word order becomes muddled. The decoder struggles because it only has access to a compressed summary, not the full source sentence.

The problem manifests as degrading translation quality with sentence length. Short sentences translate well; long sentences produce nonsensical outputs. The RNN’s hidden state simply doesn’t have the capacity to preserve all relevant information.

Attention solves this by letting the decoder directly access the full input sequence. Instead of compressing everything into a single vector, attention allows the decoder to selectively focus on relevant parts of the input for each output word. When translating “chat” to “cat,” the decoder attends strongly to “chat” in the source sentence, ignoring irrelevant words.

Attention was initially introduced for machine translation (Bahdanau et al., 2014) but became the foundation of Transformers—the dominant architecture in modern AI. Understanding attention is essential to understanding how modern language models work.

Query-Key-Value: Retrieval as Computation

Attention can be understood as a differentiable database lookup. Imagine a database of key-value pairs: keys are identifiers, values are stored information. To retrieve information, you provide a query and the database returns the value associated with the matching key.

Attention works similarly, but softly:

  • Query: What information am I looking for?
  • Keys: What information does each position in the input represent?
  • Values: The actual information stored at each position

Instead of hard matching (return the single exact key match), attention computes soft matching: assign a weight to every key based on how well it matches the query, then return a weighted average of all values.

Formally, given:

  • Query vector q\mathbf{q} (what we’re looking for)
  • Key vectors k1,k2,,kn\mathbf{k}_1, \mathbf{k}_2, \ldots, \mathbf{k}_n (identifiers for each position)
  • Value vectors v1,v2,,vn\mathbf{v}_1, \mathbf{v}_2, \ldots, \mathbf{v}_n (information at each position)

Attention computes:

Attention(q,K,V)=i=1nαivi\text{Attention}(\mathbf{q}, \mathbf{K}, \mathbf{V}) = \sum_{i=1}^{n} \alpha_i \mathbf{v}_i

Where the attention weights αi\alpha_i are computed by:

αi=exp(score(q,ki))j=1nexp(score(q,kj))\alpha_i = \frac{\exp(\text{score}(\mathbf{q}, \mathbf{k}_i))}{\sum_{j=1}^{n} \exp(\text{score}(\mathbf{q}, \mathbf{k}_j))}

The score function measures compatibility between the query and each key. Common choices:

Dot product (most common):

score(q,k)=qTk\text{score}(\mathbf{q}, \mathbf{k}) = \mathbf{q}^T \mathbf{k}

Scaled dot product (used in Transformers):

score(q,k)=qTkdk\text{score}(\mathbf{q}, \mathbf{k}) = \frac{\mathbf{q}^T \mathbf{k}}{\sqrt{d_k}}

Where dkd_k is the dimensionality of keys. The scaling prevents dot products from growing too large in high dimensions.

Why scaling matters: Without scaling, dot products grow proportionally to the dimensionality. For two random unit vectors in dkd_k dimensions, their dot product has variance dkd_k. This means for dk=512d_k = 512, unscaled dot products can easily reach magnitudes of 20 or higher.

Large dot products cause problems during softmax. When scores are very large, softmax saturates—it assigns probability ~1 to the maximum and ~0 to everything else. Gradients vanish because softmax derivatives are tiny in the saturated regime.

Example: Consider three scores: [2, 10, 3]. Softmax gives weights approximately [0.001, 0.999, 0.001]—nearly a hard selection of the second position. Small changes to scores don’t change the output much, so gradients are weak. Training slows or stalls.

Scaling by 1dk\frac{1}{\sqrt{d_k}} keeps dot products in a reasonable range (typically -3 to 3), where softmax is sensitive and gradients flow well. For dk=512d_k = 512, we divide by ~23, bringing large scores back to manageable magnitudes.

Mathematical intuition: If q\mathbf{q} and k\mathbf{k} are random vectors with unit variance per dimension, their dot product qTk=i=1dkqiki\mathbf{q}^T \mathbf{k} = \sum_{i=1}^{d_k} q_i k_i has variance dkd_k (sum of dkd_k independent terms with variance 1). Dividing by dk\sqrt{d_k} normalizes the variance to 1, regardless of dimensionality.

Production tip: Always use scaled attention. Never use raw dot products for attention scoring—you’ll encounter vanishing gradients and slow training.

Example: Translation

When translating “Le chat noir” to “The black cat,” the decoder generates words one at a time. When generating “black,” the decoder:

  1. Creates a query representing “what French word describes color?”
  2. Computes scores against all French words’ keys
  3. Softmax converts scores to weights: high weight on “noir” (black), low on “le” and “chat”
  4. Returns weighted sum of values, emphasizing information from “noir”

The decoder attends to different parts of the input for each output word. This selective focus eliminates the fixed-size bottleneck—the decoder always has access to the full input.

Query-Key-Value: Retrieval as Computation diagram

The diagram shows attention as a three-step process: (1) compute similarity between query and keys, (2) softmax to get weights, (3) weighted sum of values. High attention weight (α_2=0.6) means the query strongly matches k_2.

Soft Differentiable Lookup

The power of attention is that it’s differentiable. Unlike hard database lookups (return one exact match), soft attention computes a weighted combination of all entries. This has two critical benefits:

1. Gradients flow through attention

During backpropagation, gradients flow from the output back through the weighted sum, through the softmax, through the similarity scores, to the queries, keys, and values. The model learns what to attend to by adjusting these parameters to minimize loss.

If attending to “noir” when generating “black” reduces translation loss, gradients strengthen the query-key alignment between these positions. The network learns attention patterns automatically through standard backpropagation.

2. Soft assignments enable learning

Hard lookups (argmax) are non-differentiable—small changes to scores don’t change which entry is selected. Soft attention (weighted sum via softmax) is smooth—small changes to scores smoothly change the output. This smoothness is essential for gradient-based learning.

The softmax function also has an interpretable probabilistic interpretation: the attention weights are a probability distribution over positions. The model is uncertain about where to focus, so it hedges by attending to multiple positions with varying confidence.

Self-Attention vs Cross-Attention

Attention has two primary forms:

Cross-attention: Query comes from one sequence, keys and values from another. Used in encoder-decoder models for machine translation: the decoder queries the encoder’s outputs. This allows the decoder to attend to source sentence positions when generating target words.

Self-attention: Query, keys, and values all come from the same sequence. Each position attends to all positions in the same sequence, allowing the model to capture dependencies within a single sequence.

Self-attention is the foundation of Transformers (Chapter 20). For language modeling, self-attention lets each word attend to previous words, capturing long-range dependencies without RNN’s sequential bottleneck.

Given input sequence X=[x1,x2,,xn]\mathbf{X} = [\mathbf{x}_1, \mathbf{x}_2, \ldots, \mathbf{x}_n], self-attention computes:

Q=XWQ,K=XWK,V=XWV\mathbf{Q} = \mathbf{X} \mathbf{W}^Q, \quad \mathbf{K} = \mathbf{X} \mathbf{W}^K, \quad \mathbf{V} = \mathbf{X} \mathbf{W}^V

Where WQ\mathbf{W}^Q, WK\mathbf{W}^K, WV\mathbf{W}^V are learned projection matrices. Then:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{QK}^T}{\sqrt{d_k}}\right) \mathbf{V}

This computes attention for all positions simultaneously. The result is a new representation where each position has incorporated information from all other positions it attended to.

Global Context: Seeing Everything at Once

Unlike RNNs, which process sequences step-by-step (each hidden state only sees the past), attention allows every position to directly access every other position. This global connectivity has profound implications:

1. Long-range dependencies are easy

RNNs struggle with dependencies beyond ~100 steps because information must propagate through many recurrent connections, causing vanishing gradients. With attention, the distance between any two positions is exactly one operation—a single attention layer. Word 1 can directly influence word 100 without gradients flowing through 99 intermediate states.

This makes learning long-range dependencies straightforward. The model learns attention patterns that connect distant related positions automatically.

2. Parallelization

RNNs must process sequences serially: step tt depends on step t1t-1. This makes training slow because you can’t parallelize across time steps.

Attention computes all positions in parallel. The attention matrix QKT\mathbf{QK}^T is computed with a single matrix multiplication, applicable to all positions simultaneously. This makes training on GPUs dramatically faster—one reason Transformers replaced RNNs.

3. Interpretability

Attention weights are interpretable: they show which positions the model focuses on when processing each position. Visualizing attention patterns reveals what the model has learned—for example, in translation, you can see that “noir” strongly attends to when generating “black.”

This interpretability is limited (attention patterns are complex and multi-layered), but it provides more insight than RNN hidden states.

Multi-Head Attention: Parallel Attention Patterns

A single attention mechanism forces the model to blend different types of relationships into one weighted average. When processing “The cat sat on the mat,” a single attention head must simultaneously capture:

  • Syntactic relationships (“cat” is the subject of “sat”)
  • Semantic relationships (“cat” is related to “mat” as location)
  • Positional relationships (nearby words are often related)

Multi-head attention solves this by running multiple attention mechanisms in parallel, each learning different patterns. The Transformer uses 8-16 heads, allowing different heads to specialize in different relationships.

Architecture:

  1. Split query, key, and value into hh heads (typically 8)
  2. Apply attention independently for each head with different learned projections
  3. Concatenate head outputs and project back to original dimension

Formally, for head ii:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(\mathbf{QW}_i^Q, \mathbf{KW}_i^K, \mathbf{VW}_i^V)

Then:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \mathbf{W}^O

Where WiQ,WiK,WiV\mathbf{W}_i^Q, \mathbf{W}_i^K, \mathbf{W}_i^V are per-head projection matrices, and WO\mathbf{W}^O combines heads.

Why heads specialize: Empirical analysis shows heads learn distinct patterns:

  • Positional heads: Attend to nearby tokens (capturing local context)
  • Syntactic heads: Attend to syntactic dependencies (subject→verb, adjective→noun)
  • Semantic heads: Attend to semantically related words (synonyms, co-occurrences)
  • Rare heads: Attend broadly, serving as a fallback or capturing unusual patterns

Example: In BERT, researchers found:

  • Head 8-3 specializes in detecting direct objects
  • Head 5-4 attends to the next word (capturing sequential structure)
  • Head 2-1 attends to the previous word

This specialization emerges from training—the model discovers that splitting attention into multiple heads improves performance by allowing parallel extraction of different relationship types.

Computational cost: Multi-head attention has the same computational cost as single-head attention with the combined dimension. If you have 8 heads with 64-dimensional projections each, that’s equivalent to one head with 512 dimensions. The difference is representational, not computational.

Typical values:

  • Small models (BERT-base, GPT-2): 12 heads, d_model=768 (64 dims per head)
  • Large models (GPT-3, GPT-4): 96-128 heads, d_model=12288 (128 dims per head)

Connection to ensemble learning: Multiple heads are like an ensemble—each learns a different view of the data, and their combination is more robust than any single head. This is one reason Transformers generalize well.

Attention Masking: Controlling Information Flow

Attention computes scores between all positions, but sometimes you want to prevent certain positions from attending to others. Masking achieves this by setting attention scores to -\infty before softmax, causing softmax to assign zero weight to masked positions.

Causal Masking (Autoregressive Models):

For language models like GPT that predict the next token, the model must not see future tokens during training—that would be cheating. Causal masking prevents position ii from attending to positions j>ij > i (future positions).

Implementation: Create a mask matrix where entry (i,j)=(i, j) = -\infty if j>ij > i, else 0. Add this to the attention scores before softmax:

Attention=softmax(QKTdk+Mcausal)V\text{Attention} = \text{softmax}\left(\frac{\mathbf{QK}^T}{\sqrt{d_k}} + \mathbf{M}_{\text{causal}}\right) \mathbf{V}

Where Mcausal\mathbf{M}_{\text{causal}} is an upper triangular matrix of -\infty. Softmax converts -\infty to probability 0, so future positions contribute nothing to the weighted sum.

This ensures the model learns to predict token tt using only tokens 1,,t11, \ldots, t-1, matching the generation setting where future tokens aren’t available.

Padding Masking:

Sequences have variable lengths but are batched into fixed-size tensors by padding shorter sequences with special [PAD] tokens. Padding tokens are meaningless and shouldn’t influence attention.

Padding masking sets attention weights to zero for padded positions. Implementation: for any position that’s a [PAD] token, set its attention scores to -\infty before softmax.

Custom Masking:

Task-specific masking patterns enable more control:

  • Block-diagonal masking: Attend only within local windows (for efficient long-context attention)
  • Entity masking: Mask out certain entities to test model robustness
  • Prefix masking: For models that process a prefix bidirectionally but generate autoregressively

Production example: GPT-3 uses causal masking + padding masking. During training, sequences in a batch have different lengths, so:

  1. Pad shorter sequences to max batch length
  2. Apply padding mask (padded tokens don’t attend or get attended to)
  3. Apply causal mask (each token only attends to past)

Why masking matters: Without proper masking, models “cheat” during training by seeing future information, then fail at test time when future isn’t available. Masking is essential for correctness in autoregressive models.

Production Example: Machine Translation

Consider a production machine translation system translating English to French at scale.

Architecture:

  • Encoder: 6-layer Transformer with self-attention on English sentence
  • Decoder: 6-layer Transformer with causal self-attention on French (generated so far) + cross-attention to English encoder outputs

Example sentence: “The cat sat on the mat” → “Le chat s’est assis sur le tapis”

Attention patterns during generation:

When generating “assis” (sat):

  1. Decoder self-attention: Attends to previous French tokens [“Le”, “chat”, “s’est”] to maintain grammatical coherence
  2. Encoder-decoder cross-attention: Query is “current French generation state,” keys/values are English encoder outputs. The decoder attends strongly to “sat” in the English sentence—the cross-attention shows alignment between source and target.

Visualization: Plot attention weights as a heatmap. Rows are target (French) words, columns are source (English) words. High values show strong alignment. For “assis” → “sat”, you see a bright spot, indicating the decoder correctly identified the corresponding English word.

Debugging with attention: If translation is wrong, check attention patterns:

  • If “assis” attends to “cat” instead of “sat” → model hasn’t learned correct alignment
  • If attention is uniformly spread → model isn’t focusing, may need more training data
  • If attention is sharp but on wrong word → check encoder representations

Inference latency:

  • Input: English sentence (20 tokens)
  • Encoder: One forward pass (all tokens processed in parallel)
  • Decoder: Autoregressive generation (one token at a time)
  • Per-token latency: ~5ms on GPU (including encoder cross-attention)
  • Full sentence (25 French tokens): ~125ms total

Production tips:

  • Cache encoder outputs: For multi-sentence documents, encode once and reuse for all target sentences
  • Batch inference: Process multiple translation requests simultaneously for better GPU utilization
  • Beam search: Generate multiple candidate translations, rank by score (increases latency 3-5× but improves quality)

Scaling: Modern translation services (Google Translate, DeepL) handle billions of requests per day. Key optimizations: model distillation (smaller models with similar quality), quantization (INT8 inference), and caching common phrase translations.

Efficient Attention Variants

Standard attention has O(n²) complexity in sequence length, becoming prohibitive for long sequences (documents, codebases, entire books). Several techniques reduce this cost.

Sparse Attention:

Instead of attending to all positions, attend only to a subset:

  • Local attention (Longformer): Attend to a sliding window (e.g., 512 tokens around current position)
  • Global + local (Longformer, BigBird): A few tokens attend globally, others attend locally
  • Strided attention (Sparse Transformer): Attend every k-th position

Tradeoff: O(n) complexity vs O(n²), but loses some information from unattended positions. Empirically, sparse attention loses ~1-2% accuracy on long-context tasks while enabling 10× longer sequences.

Linear Attention:

Approximate attention using kernel methods (Performer, Linear Transformers):

Attention(Q,K,V)ϕ(Q)(ϕ(K)TV)\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) \approx \phi(\mathbf{Q}) (\phi(\mathbf{K})^T \mathbf{V})

Where ϕ\phi is a kernel feature map. By reordering operations, this computes attention in O(n) time. Tradeoff: approximation quality. Linear attention loses 2-5% accuracy compared to full attention.

Flash Attention:

Flash Attention doesn’t change the attention algorithm—it optimizes memory access patterns. Standard attention writes intermediate results (attention matrix) to GPU memory, then reads them back. Flash Attention keeps everything in fast on-chip memory (SRAM), dramatically reducing memory bandwidth.

Benefits:

  • 2-4× faster training and inference (no approximation, exact same results)
  • Enables longer contexts by reducing memory bottleneck
  • Free improvement: just swap in Flash Attention implementation

When to use what:

  • Default: Flash Attention (always use if available—free speedup)
  • Long sequences (> 4k tokens): Sparse attention (Longformer, BigBird)
  • Very long sequences (> 100k tokens): Linear attention or hierarchical methods
  • Mobile/edge: Consider linear attention for efficiency, but test accuracy loss

Production: Modern LLM serving systems (GPT-4, Claude) use Flash Attention by default. It’s not optional—it’s the standard implementation for production Transformers.

Engineering Takeaway

Attention revolutionized deep learning by eliminating the sequential bottleneck and fixed-size memory limitations of RNNs. Understanding attention—scaled dot products, multi-head mechanisms, masking patterns, and efficient variants—is essential for building modern AI systems.

Attention replaced recurrence for sequences. Before attention, RNNs were the only option for sequences. Attention showed that recurrence isn’t necessary—global connectivity through attention is superior. Transformers (pure attention, no recurrence) now dominate language modeling, machine translation, and most sequence tasks. The key insight: direct connections between all positions (one attention layer) beat sequential propagation (many recurrent steps) for both learning long-range dependencies and parallelization.

Scaled dot-product attention is non-negotiable. Always scale attention scores by 1dk\frac{1}{\sqrt{d_k}} to prevent softmax saturation and vanishing gradients. Unscaled dot products grow with dimensionality, causing training to slow or fail. For dk=512d_k = 512, scaling divides scores by ~23, keeping them in the sensitive softmax range (-3 to 3). This is a standard practice, not an optimization—never use raw dot products for attention scoring.

Multi-head attention enables parallel specialization. Instead of a single attention mechanism, use 8-16 heads in parallel with different learned projections. Different heads specialize in different patterns (syntax, semantics, position). Empirically, BERT heads specialize in detecting subject-verb relationships, direct objects, and sequential structure. Multi-head attention costs the same as single-head with the same total dimension but provides better representations. Think of it as an ensemble within a single layer.

Attention masking is essential for correctness. Causal masking (prevent future tokens from being seen) is mandatory for autoregressive models like GPT—without it, models cheat during training and fail at inference. Padding masking (ignore padding tokens) is necessary for batching variable-length sequences. Implement masking by setting attention scores to -\infty before softmax. Forgetting masking causes silent failures—the model trains fine but doesn’t generalize.

Attention patterns are interpretable and debuggable. Visualize attention weights as heatmaps to understand what the model learned. In translation, you see word alignments (source→target). In language models, you see syntactic dependencies (subject→verb) and coreference (pronoun→antecedent). When models fail, check attention—if attention is uniform (not focusing), the model hasn’t learned the task; if it’s sharp but on wrong positions, representations need improvement. Attention visualization is the best debugging tool for Transformer models.

Efficient attention enables longer contexts. Standard attention is O(n²) in sequence length—prohibitive for documents, codebases, or books. Flash Attention (2-4× speedup, no approximation) is the production default—always use it. For sequences > 4k tokens, use sparse attention (Longformer, BigBird) or linear attention (Performer). Tradeoff: sparse loses ~1-2% accuracy, linear loses ~2-5%, but both enable 10-100× longer sequences. Most production LLM systems use Flash Attention + sparse patterns for long contexts.

Attention generalizes beyond sequences. While introduced for sequences, attention now powers vision (Vision Transformers treat images as patches), graphs (GNNs use attention over neighbors), sets (Set Transformers), and multimodal models (CLIP uses attention to align text and images). The core idea—differentiable database lookup with soft weighting—applies wherever you need to selectively aggregate information. Attention is a universal mechanism, not just for text.

The lesson: Attention is a general mechanism for selectively focusing on relevant information. By making retrieval differentiable and enabling global connectivity, attention solved fundamental limitations of recurrence. Modern AI is built on attention—Transformers, GPT, BERT, Vision Transformers all use multi-head attention with proper scaling, masking, and efficient implementations. Understanding attention mechanics is essential for building, debugging, and optimizing production AI systems.


References and Further Reading

Neural Machine Translation by Jointly Learning to Align and Translate – Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio (2014) https://arxiv.org/abs/1409.0473

This is the paper that introduced attention for sequence-to-sequence learning. Bahdanau et al. showed that letting the decoder attend to encoder hidden states dramatically improves translation quality, especially for long sentences. The paper explains the motivation (fixed-size bottleneck), the mechanism (query-key-value), and demonstrates empirically that attention works. Reading this gives you the historical context and foundational intuition for attention.

Attention Is All You Need – Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, et al. (2017) https://arxiv.org/abs/1706.03762

This is the Transformer paper—the most influential paper in modern AI. Vaswani et al. showed that attention alone (without recurrence or convolution) is sufficient for state-of-the-art sequence modeling. They introduced scaled dot-product attention, multi-head attention, and positional encodings. The Transformer architecture now powers GPT, BERT, and virtually all large language models. Understanding this paper is essential for understanding modern AI systems.

Attention and Augmented Recurrent Neural Networks – Chris Olah and Shan Carter (2016) https://distill.pub/2016/augmented-rnns/

This Distill article provides beautiful visualizations and intuitive explanations of attention mechanisms. Olah and Carter show how attention works visually and explain various attention architectures (encoder-decoder attention, self-attention, memory networks). Reading this complements the formal papers with interactive diagrams that make attention intuitive.