Key Moments

Stanford CS336 Language Modeling from Scratch | Spring 2026 | Lecture 5: GPUs, TPUs

Stanford OnlineStanford Online
Education8 min read79 min video
Apr 20, 2026|1,531 views|63|2
Save to Pod
TL;DR

GPUs offer massive parallel computation, but optimizing for them requires deep understanding of their memory hierarchy and unique programming model. Ignoring these can lead to performance bottlenecks and inefficient resource utilization.

Key Insights

1

GPUs achieve high throughput via massive parallelism with hundreds of lightweight cores, contrasting with CPUs' focus on low-latency serial execution.

2

A GPU's memory hierarchy includes registers, L1/L2 cache, and global memory, with significant latency differences (e.g., L1/shared memory ~20-30 cycles vs. global memory ~10x slower).

3

The SIMT (Single Instruction, Multiple Threads) model on GPUs requires all threads in a warp to execute the same instructions, making control divergence due to if-statements inefficient.

4

Low-precision computation (e.g., FP16, BF16, INT8, FP8) is a major driver of GPU performance gains by reducing memory bandwidth requirements.

5

Tiling is a crucial technique for optimizing GPU performance by loading data into faster shared memory, allowing for repeated reuse of that data before writing back to global memory.

6

FlashAttention significantly speeds up attention mechanisms by using tiled matrix multiplies and an online softmax computation, reducing memory transfers and utilizing shared memory effectively.

The shift from serial to parallel computing with GPUs

The lecture highlights a fundamental shift in computing paradigms, moving from the serial execution model of CPUs, where clock speed was paramount, to the massively parallel architecture of GPUs. For decades, Dennard scaling drove improvements in CPU performance by shrinking transistors and increasing clock speeds. However, this scaling eventually hit physical limits in the 2000s. Consequently, the focus shifted to parallel scaling. GPUs, with their hundreds of lightweight cores, enable horizontal scaling, executing numerous instructions concurrently, which is crucial for the ever-increasing computational demands of modern language models. This shift is evident in the dramatic increase in floating-point operations per second (FLOPS) seen with GPUs like the V100, driven by hardware innovations such as tensor cores and support for lower-precision formats.

GPU architecture: The streaming multiprocessor and memory hierarchy

At the heart of a GPU are Streaming Multiprocessors (SMs), which act as independent compute units. Each SM contains streaming processors capable of executing threads in parallel. An NVIDIA A100 GPU, for instance, might have 128 SMs, each capable of independent execution. However, compute is only half the story; memory is critical. GPUs feature a hierarchical memory system: registers are the fastest and most limited, followed by L1 cache and shared memory (both physically close to the SMs and very fast, with latencies around 20-30 cycles), then L2 cache, and finally global or high-bandwidth memory (DRAM), which is significantly slower (about 10 times slower than L1/shared memory). This hierarchy is fundamental to GPU performance, as accessing global memory is a major bottleneck. The design principle is to keep data in faster, closer memory tiers as much as possible.

GPU programming model: Threads, blocks, warps, and SIMT

The GPU programming model revolves around threads, blocks, and warps, operating under the SIMT (Single Instruction, Multiple Threads) paradigm. Threads are the basic units of computation, and in SIMT, all threads within a warp (a group of 32 threads) execute the same instruction simultaneously, though they can operate on different data. This uniformity simplifies programming but introduces inefficiencies if threads diverge in their execution paths (e.g., due to if-else statements), leading to 'control divergence' where some threads idle. A block is a collection of threads that are guaranteed to run on a single SM and share access to that SM's local memory (shared memory). Warps are the scheduling units, managed by the GPU's scheduler. Understanding these concepts is vital for writing efficient GPU code, as it dictates how work is distributed and managed.

Optimizing for GPUs: Six key tricks for speed

To achieve high performance on GPUs, several optimization techniques are essential, broadly categorized by their aim to minimize memory movement and maximize compute utilization. These include: 1. Avoiding Control Divergence: Minimizing `if-else` statements in code that executes on GPUs, as they can lead to threads idling. 2. Low-Precision Computation: Utilizing reduced precision formats like BF16, INT8, or FP8. This significantly cuts down memory bandwidth requirements and can offer linear speedups in computation by reducing the number of bits to move and process. For example, moving from FP32 to lower precision halves or quarters the memory footprint per operation. 3. Operator Fusion: Combining multiple sequential operations into a single kernel. Instead of reading and writing data from global memory between each operation, fusion allows intermediate results to be computed and used within faster on-chip memory (like shared memory), drastically reducing memory traffic. 4. Recomputation: In scenarios where memory capacity is a severe constraint or compute is abundant, recomputing intermediate activations during the backward pass instead of storing them can save significant memory. This trades compute for memory. For instance, a simple computation might involve 8 memory accesses, but recomputation can reduce this to 5. 5. Coalesced Memory Accesses: Ensuring that threads within a warp access contiguous memory locations. DRAM accesses occur in bursts, so coalescing leverages this by allowing a single read to fetch data for multiple threads simultaneously, especially when they access data within the same burst window. 6. Tiling: A method to maximize data reuse. Large matrices are divided into smaller 'tiles', which are loaded into fast shared memory. Computations are performed on these tiles, leveraging the fast on-chip memory for repeated reads and writes before the results are written back to slower global memory. This drastically reduces the number of global memory accesses. For example, loading a tile into shared memory and accessing it multiple times can reduce global memory reads by a factor of 'tile size'.

The role of low-precision computation in compute scaling

Lowering numerical precision (e.g., from FP32 to FP16, BF16, INT8, or FP8) is a primary driver of the remarkable compute scaling observed in GPUs. This is because it directly addresses the memory bandwidth bottleneck. With fewer bits per number, less data needs to be transferred from global memory, and computations involving these numbers can be faster. Modern tensor cores are designed to accelerate these low-precision operations. However, implementing low-precision arithmetic effectively, especially in training, is complex. It involves careful downcasting of values before computation, summing partial results in higher precision (like FP32) for stability, and deciding which operations can safely be performed at lower precision. Advanced formats like MXFP8 introduce multiple scaling factors within a matrix to handle varying magnitudes more effectively, though they add complexity, particularly for operations like matrix transposition, which might require storing multiple versions for efficiency.

Tiling and its impact on memory access and alignment

Tiling is perhaps the most impactful technique for optimizing GPU performance by fully exploiting the memory hierarchy. The core idea is to break down large computations (like matrix multiplication) into smaller, manageable blocks or 'tiles'. These tiles are then loaded into the GPU's fast on-chip shared memory. Once loaded, computations involving these tiles can be performed repeatedly using the much faster shared memory, minimizing the need to access slower global memory. For a tiled matrix multiplication, an element might be read from global memory once, but then accessed many times from shared memory within its tile. This can reduce global memory access by a factor proportional to the tile size. However, optimal tiling requires careful consideration of matrix dimensions, available shared memory size, and memory burst properties to ensure efficient alignment and avoid what's known as 'wave quantization', where the number of tiles can exceed the available SMs, leading to underutilization. Choosing appropriate tile sizes, often powers of two and divisible by 32, is critical for performance.

FlashAttention: Applied systems optimization for attention mechanisms

FlashAttention is a prime example of how systems-level optimizations can dramatically improve the performance of core deep learning components, specifically the attention mechanism. It tackles the memory-intensive nature of attention by applying tiling and recomputation strategies. The attention computation, involving matrix multiplies (QxK, (QK)xV) and a softmax, is broken down using tiled matrix multiplies that operate on data within shared memory. A key challenge is the global softmax, which requires careful handling. FlashAttention uses an 'online' softmax computation that can proceed tile by tile, maintaining running sums of exponents and maximum values to normalize results progressively. This avoids materializing the entire intermediate attention matrix (N x N), which is memory-prohibitive for long sequences. Furthermore, by recomputing necessary activations during the backward pass instead of storing large intermediate tensors, it significantly reduces memory footprint and transfers, leading to substantial latency improvements and enabling longer context windows.

Hardware trends and future considerations: Compute vs. Memory

The lecture concludes by emphasizing the widening gap between compute speed growth and memory bandwidth growth. GPUs are becoming tremendously faster in terms of FLOPS, but memory bandwidth is increasing at a much slower rate. This disparity means that utilizing the full potential of modern GPUs increasingly relies on optimizing data movement and memory access patterns. Future hardware designs and system optimizations will continue to focus on mitigating these memory bottlenecks. This can involve specialized accelerators, clever memory-tiering strategies, and hardware-aware algorithms like FlashAttention. The trade-offs between different memory types (SRAM vs. DRAM), precision formats, and parallelism strategies are critical considerations for maximizing efficiency in large-scale AI systems. Understanding these low-level system details is no longer an option but a necessity for anyone building or deploying performant AI models.

Optimizing GPU Performance: Key Strategies

Practical takeaways from this episode

Do This

Understand the GPU hardware model (massively parallel, throughput-oriented).
Prioritize memory hierarchy: utilize shared memory and caches heavily.
Maximize compute intensity to stay in the 'compute-bound' region of the roofline model.
Coalesce memory accesses to leverage DRAM burst capabilities.
Use tiling to reuse data loaded into shared memory.
Consider quantization (FP8, FP4) for reduced memory footprint and faster computation.
Leverage operator fusion to combine operations and minimize memory transfers.
Employ recomputation to trade compute for memory savings, especially in backpropagation.
Pad matrices and adjust sizes to align with burst windows and avoid wave quantization.

Avoid This

Avoid 'if' statements in GPU code due to control divergence, which leads to idle compute units.
Do not assume naive implementations of operations (like attention or matrix multiplies) are efficient.
Do not ignore the memory hierarchy; global memory access is a significant bottleneck.
Avoid non-coalesced memory accesses, which waste DRAM burst efficiency.
Do not use arbitrary matrix sizes that lead to inefficient tiling or alignment issues.
Do not solely focus on compute power without considering memory bandwidth limitations.

Common Questions

CPUs are designed for fast, serial execution with complex logic and low latency. GPUs, on the other hand, are optimized for high throughput with many lightweight cores executing tasks in parallel, even if individual tasks take longer to complete.

Topics

Mentioned in this video

More from Stanford Online

View all 25 summaries

Found this useful? Build your knowledge library

Get AI-powered summaries of any YouTube video, podcast, or article in seconds. Save them to your personal pods and access them anytime.

Get Started Free