Key Moments

[Paper Club] Weight Streaming on Wafer-Scale Clusters (w/ Sarah Chieng of Cerebras)

Latent Space PodcastLatent Space Podcast
Science & Technology5 min read44 min video
Dec 7, 2024|431 views|7|2
Save to Pod
TL;DR

Cerebras's weight streaming decouples model storage from compute for efficient LLM training on wafer-scale clusters.

Key Insights

1

Weight streaming separates parameter storage (Memory X) from compute (Wafer Scale Engine) to enable efficient and scalable training of giant neural networks.

2

Cerebras's Wafer Scale Engine (WSE) features a massive number of cores and on-chip SRAM, providing direct access to weights and values, unlike traditional GPUs with off-chip memory bottlenecks.

3

The weight streaming system comprises three components: the WSE (compute), Memory X (external storage for weights and optimizer states), and SwarmX (interconnect fabric).

4

Memory X streams weights twice per training pass (forward and backward) and is responsible for computing updated weights using optimizer states.

5

SwarmX acts as a sophisticated interconnect, abstracting the complexity of distributing weights and aggregating gradients between Memory X and multiple WSE compute units.

6

Cerebras leverages weight sparsity by eliminating zero/near-zero values during transmission and having cores skip computations on these values, reducing bandwidth and improving compute efficiency.

INTRODUCTION TO CEREBRAS AND WAFER-SCALE COMPUTING

This presentation introduces Cerebras's weight streaming technique for training large neural networks, specifically on their wafer-scale clusters. Cerebras, an AI processor company, has developed the Wafer Scale Engine (WSE), their largest and fastest AI processor, with the latest iteration being WSE-3. This paper focuses on WSE-2, a predecessor. Context is provided by Cerebras's impressive inference benchmarks, showcasing speeds significantly faster than traditional NVIDIA GPUs, with WSE-3 serving LLaMA 70b at 2100 tokens per second. This hardware advancement is a cornerstone for the training techniques discussed.

ARCHITECTURAL ADVANTAGES OF THE WAFER SCALE ENGINE

The Cerebras Wafer Scale Engine (WSE) is a massive chip, featuring approximately 4 trillion transistors and 900,000 cores for WSE-3 (WSE-2 has 850,000 cores), with 44GB of on-chip SRAM. This is in stark contrast to NVIDIA's H100, which has significantly fewer transistors (80 billion) and cores (around 17,000). A key architectural difference is that Cerebras places all necessary weights and values directly on-chip within each core's SRAM. This eliminates the traditional memory bandwidth bottleneck seen in GPUs, where cores constantly fetch data from slower off-chip memory, leading to reduced latency and power consumption.

TRADITIONAL TRAINING PARALLELISM AND ITS LIMITATIONS

Efficient training of large neural networks typically employs data parallelism or model parallelism. Data parallelism splits training batches across compute units, but requires each unit to store the entire model, limiting scalability for massive models. Fully sharded data parallelism attempts to mitigate this by distributing model weights, re-broadcasting them as needed. Model parallelism splits the model itself across units, which can involve pipeline or tensor parallelism. However, both forms of model parallelism introduce significant communication overhead and complexity, especially as cluster sizes grow, creating bottlenecks.

THE WEIGHT STREAMING SOLUTION: DECOUPLING COMPUTE AND STORAGE

Weight streaming fundamentally separates model weight storage from primary compute. Instead of storing weights directly on the compute units (CSX systems containing WSEs), they are held in an external memory service called Memory X. During training, weights are streamed from Memory X to the compute units for both forward and backward passes. This approach decouples compute capacity from memory constraints, allowing for greater scalability and simpler implementation compared to complex hybrid parallelism strategies.

CORE COMPONENTS: WSE, MEMORY X, AND SWARMX

The weight streaming system consists of three main components. First, the compute units are Cerebras CSX systems, powered by the Wafer Scale Engine (WSE) which has processing, on-chip memory (SRAM), and an internal interconnect fabric. Second, Memory X serves as the external, persistent storage for all model parameters and optimizer states, scaling from terabytes to petabytes and supporting models up to 120 trillion parameters. Lastly, SwarmX is the high-speed interconnect fabric that facilitates communication between Memory X and the CSX compute units, handling the streaming of weights and aggregation of gradients.

OPERATIONS AND DATAflow IN WEIGHT STREAMING

During the forward pass, weights are streamed from Memory X, distributed by SwarmX to the compute units, where activations are computed. For the backward pass, activation gradients are computed and stored on-chip. These gradients are then sent back to Memory X via SwarmX, where the optimizer states are used to compute the updated weights. Memory X then streams these updated weights back. This entire process abstracts the underlying complexity, making Memory X appear as if it's working with a single compute unit, and compute units appear simpler to manage.

LEVERAGING WEIGHT SPARSITY FOR EFFICIENCY

Cerebras significantly utilizes weight sparsity, where zero or near-zero parameters are eliminated. This is applied during data transmission: Memory X streams only significant weights, pruning up to 90% of data and reducing bandwidth needs. Research indicates this can be done without accuracy loss, as supported by the lottery ticket hypothesis. Furthermore, within the WSE, cores are designed to recognize and skip computations involving zero or near-zero values, optimizing compute efficiency by focusing solely on meaningful operations. This capability to handle unstructured sparsity is a key advantage over traditional GPUs.

THE CEREBRAS GRAPH COMPILER AND DATA LAYOUT

The Cerebras Graph Compiler (CGC) integrates with ML frameworks like TensorFlow and PyTorch, compiling user models into executable binaries that natively support weight streaming. The system also addresses how activation tensors are laid out on the wafer. Depending on the relative dimensions of sequence length (S) and hidden features (H) in tensors like those from GPT-3, data can be distributed across the wafer in different configurations. The matrix multiplications performed are optimized for both dense and sparse inputs, supporting activation calculations, activation gradients, and weight gradients efficiently.

COMPARISON TO ALTERNATIVE TRAINING STRATEGIES

Weight streaming differentiates itself from other large-scale training approaches. Unlike NVIDIA's Megatron, Microsoft's DeepSpeed, or Meta's FSDP, which employ various forms of data and model parallelism, Cerebras's weight streaming uniquely disaggregates model storage from compute. This fundamental separation is not present in other discussed techniques, offering a distinct architectural solution for overcoming the memory limitations faced when training ever-larger neural networks.

Common Questions

Weight streaming is a training technique where parameter storage is separated from primary compute. Weights are stored externally and streamed to compute units during training, allowing for efficient scaling of large language models.

Topics

Mentioned in this video

More from Latent Space

View all 134 summaries

Found this useful? Build your knowledge library

Get AI-powered summaries of any YouTube video, podcast, or article in seconds. Save them to your personal pods and access them anytime.

Try Summify free