Understanding Distributed Training in Deep Learning
Introduction
Since last year, the quest for large X models have been nonstop, and people kept exploring the possibility to build more universal and robust models. While some still put a doubt if models with more parameters will be effective, most have faith in the scaling law proposed by DeepMind and OpenAI researchers. The progress in 1 year is promising, as it seems that we are steadily moving towards the era of AGI. However, the education barely follows. College and Unversity are still bound by the budget to enable students to get in touch to large model training, especially when it comes to multi-gpu / multi-node distributed training. In light of this, I would love to share what I understand about distributed training, and how can we get started in this domain to catch up with recent industrial progress.
1. Definition
- Leverages multiple compute resources—often across multiple nodes or GPUs—simultaneously, accelerating the model training process.
- Mainly a form of parallelism, requires some understanding of low-level operation system (memory, communication and GPU architecture)
- For those interested, I will recommend taking CMU 15-418 Parallel Computer Architecture and Programming to get an in-depth understanding.
2. Parallelism in Training
Two primary forms of parallelism: model parallelism and data parallelism
Model Parallelism:
- Used when a model doesn’t fit into the memory of a single device.
- Different parts of the model are placed on different devices, enabling the training process to occur across multiple GPUs or nodes. This approach is particularly useful for exceptionally large models.
Data Parallelism:
- Split the dataset across various devices, with each processing a unique subset of the data.
- The model’s parameters are then updated based on the collective gradients computed from these subsets (with different strategies).
3. Strategies in detail
[Note]: I'll mainly use PyTorch in this blog as it is the most popular and convenient choice. It is mainly based on torch.distributed
package. In the meantime, some convenient scripts are created by Lightning AI with their own libraries. I'll show some code using their library for people who just want a shortcut and get rid of the details behind distributed training.
Data Parallelism
- How
DistributedDataParallel
works:- NCCL: multi-GPU, multi-node communication primitives. all-gather, all-reduce, broadcast, reduce-scatter, reduce routines, point-to-point send/receive. High bandwidth, low latency on PCIe and NVLink interconnects
- All GPUs share same initial weights. Aggregate all gradients in different GPUs and update the weight collectively.
- Need to update optimizer state and weights after AllReduce.
- DDP Implementation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28### DDP - PyTorch Version
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
def main():
# Initialize distributed environment
dist.init_process_group(backend='nccl')
# Create model
model = YourModel()
model = DDP(model)
# Load data and distribute it across processes
train_loader = DistributedSampler(YourDataset())
# Training loop
for epoch in range(epochs):
for data in train_loader:
inputs, labels = data
outputs = model(inputs)
loss = YourLoss(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()For more advanced details like RPC-Based Distributed Training (RPC) and Collective Communication (c10d), refer to
torch.distributed
original docsFully Sharded DP (FSDP)
- What is in the GPU memory (x params, FP16)
- Params: 2x (fp16 with 2 bytes)
- Gradients: 2x
- Optimizer (AdamW)
- Param copy: 4x (float32)
- Momentum: 4x
- Variance: 4x
- How FSDP works
- FSDP unit (vertical splitting), can be:
- A layer splitted
- A stage splitted
- A group of layers splitted
- Sharding
- Storing the FSDP unit on
FlatParameter
- Split
FlatParameter
on multiple nodes (after zero padding for divisible property)
- Storing the FSDP unit on
- All-Gather
- performed by NCCL
- gather all parts and sync across all nodes
- Done before both forward and backwards
- discard peer parts after forward/backward
- Reduce-scatter
- performed via NCCL
- Each node gets part of the result of gradient (backward only)
- Note that All-Reduce is not used coz it broadcast same results to all nodes
- E.g. Each node
i
has all gradientsG_i1, G_i2, ..., G_in
, after reduce-scatter, each node will have gradient redistributed, with nodei
gettingsum of G_ki
, where k spans from 1 to n
- FSDP unit (vertical splitting), can be:
- Reason to use/not to use FSDP
- When to use
- Model size is too large (not data size)
- More communication between GPUs
- Hence trade memory for speed: more GPU memory cost due to communication, however, communication overhead reduced via NCCL acceleration
- If want to trade speed for memory, see activation checkpointing
- When not to use
- For models < 100 million params, consider activation-checkpointing and reversible layers
- Recommend to use BFloat16 instead of Float16 (Float16 requires ShardedGradScaler)
- Mixed Precision Training Concern (Package compatibility)
- When to use
- FSDP Implementation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30### FSDP Version
from torch.distributed.fsdp import (
FullyShardedDataParallel,
CPUOffload,
)
from torch.distributed.fsdp.wrap import (
default_auto_wrap_policy,
enable_wrap,
wrap
)
import torch.nn as nn
class model(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(8, 4)
self.layer2 = nn.Linear(4, 16)
self.layer3 = nn.Linear(16, 4)
model = DistributedDataParallel(model())
fsdp_model = FullyShardedDataParallel(
model(),
fsdp_auto_wrap_policy=default_auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=True),
)
# Custom wrap
wrapper_kwargs = Dict(cpu_offload=CPUOffload(offload_params=True))
with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs):
fsdp_model = wrap(model())- What is in the GPU memory (x params, FP16)
- How
Model Parallelism
- split horizontally
- Implementation
1
2
3
4
5
6
7
8
9
10
11
12
13
14class model_parallel(nn.Module):
def __init__(self):
super().__init__()
self.layer_1 = nn.Sequential(...)
self.layer_2 = nn.Sequential(...)
self.layer_1.cuda(0)
self.layer_2.cude(1)
def forward(self, x):
x = x.cuda(0)
x = self.layer_1(x)
x = x.cuda(1)
x = self.layer_2(x)
x = ...
return x - Inefficient sometimes: in the code above, GPU may be idle if layer 2 is not run during training
- Does not work well if the model architecture does not naturally lend itself to being divided into parallelizable segments.
Pipeline Parallelism
- Mixed data and model parallelism, involves scheduling of data flow
- Split into multiple stages, and each stage is assigned to a different device
- The output of one stage is fed as input to the next stage.
- Sometimes inefficient and suffers from idle time when machines are waiting for other machines to finish their stages: pipeline is waiting for a stage to finish in both the forward and backward pass, the period when some machine are idle aer referred to as a bubble.
Tensor parallelism
- Split vertically + horizontally (in units of a tensor)
- Can be more effective as it leverages efficiencies within matrix multiplication by spliting a tensor up into smaller fractions and expedite the computation
- The detail can be expanded into another blog, however, I will refer you to this excellent blog instead of reinventing the wheel myself again.
- Might require models specifically designed to take advantage of this form of parallelism. It may not be as universally applicable as data or model parallelism.
torchrun
- An elegant way to run distributed training using
torch.distributed
package. Please refer to details here - Make use of rendezvous backend to achieve high availability and failure recovery
- A few major advantages include:
- Single-node multi-worker
- Multi-node
- Multi-GPU
- Fault tolerant
- Elastic
- An elegant way to run distributed training using
Distributed Training on the Cloud
- Since most of the resources are available from the cloud, and they are on-demand, it is common practice to migrate local code to be run on remote servers. You can spin up GPU resources (usually more capable than your local version) yourself and manage the dependencies/monitoring independenly, or you can resort to integrated solutions like AWS SageMaker or Azure ML or Google AI Studio as they often provide convenient API endpoints to interact with those GPU instances. In many scenarios, their management include inter-gpu/inter-node communication as well, which is a big plus.
- As an example, you can setup AWS accordingly and run your distributed training using SageMaker via this tutorial
- A sample script is as follows:
1
2
3
4
5
6
7
8
9
10
11from sagemaker.pytorch import PyTorch
estimator = PyTorch(
...,
instance_count=2,
instance_type="ml.p4d.24xlarge",
# Activate distributed training with SMDDP
distribution={ "pytorchddp": { "enabled": True } } # mpirun, activates SMDDP AllReduce OR AllGather
# distribution={ "torch_distributed": { "enabled": True } } # torchrun, activates SMDDP AllGather
# distribution={ "smdistributed": { "dataparallel": { "enabled": True } } } # mpirun, activates SMDDP AllReduce OR AllGather
)Other packages
PyTorch Lightning - a lightweight PyTorch wrapper that provides a high-level interface for researchers and practitioners to streamline the training of deep learning models. It abstracts away many of the boilerplate code components traditionally required for training models, making the code cleaner, more modular, and more readable. It requires little setup of code and just need to insert a few parameters to the trainer
- Example
1
2
3
4
5
6
7
8
9
10
11trainer = L.Trainer(
max_epochs=3,
callbacks=callbacks,
accelerator="gpu",
devices=4, # <-- NEW
strategy="ddp", # <-- NEW
precision="16",
logger=logger,
log_every_n_steps=10,
deterministic=True,
)
- Example
Hugging Face
Accelerate
: a library that enables the same PyTorch code to be run across any distributed configuration by adding just four lines of code. It is still built ontorch_xla
andtorch.distributed
, but have get users rid of writing custom code to adapt to these platforms.- Benefits include easy utilization of ZeRO Optimizer from DeepSpeed, achieve FSDP and mixed-precision training as well.
- Example
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17from accelerate import Accelerator
accelerator = Accelerator()
model, optimizer, training_dataloader, scheduler = accelerator.prepare(
model, optimizer, training_dataloader, scheduler
)
for batch in training_dataloader:
optimizer.zero_grad()
inputs, targets = batch
inputs = inputs.to(device)
targets = targets.to(device)
outputs = model(inputs)
loss = loss_function(outputs, targets)
accelerator.backward(loss)
optimizer.step()
scheduler.step()- In terminal, run
accelerate launch {my_script.py}
4. Challenges and Solutions
- Communication Overhead:
In distributed training, the exchange of information between devices becomes a potential bottleneck. As the number of devices increases, coordinating updates and sharing gradients become more complex.
Solutions:
Optimized Communication Protocols: Leveraging optimized communication protocols, such as NVIDIA NCCL for GPU communication, helps minimize the latency associated with inter-device communication.
Gradient Accumulation: By accumulating gradients locally on each device before synchronization, communication frequency is reduced. This strategy can be beneficial in scenarios where frequent synchronization is not necessary.
- Fault Tolerance:
In distributed environments, hardware failures or network issues are inevitable. Ensuring fault tolerance is essential to maintain the integrity of the training process.
Solutions
Checkpointing: Regularly saving model checkpoints allows training to resume from the most recent checkpoint in case of a failure. This practice minimizes data loss and ensures continuity.
Redundancy: Introducing redundancy by running multiple instances of the training job across different nodes adds a layer of resilience. Load balancing techniques can be employed to distribute tasks effectively.
- Scaling Issues:
- Scaling distributed training to a large number of nodes presents challenges in terms of efficiency and resource management.
- Strategies
- Dynamic Resource Allocation: Implementing dynamic resource allocation ensures that resources are allocated efficiently based on the current load. Kubernetes and other orchestration tools can facilitate dynamic scaling.
- Parameter Servers: Utilizing parameter servers, which are dedicated servers responsible for storing and distributing model parameters, can enhance the scalability of distributed training.
5. Common Mistakes
- Not pipelining
- Pipeline Parallelism is always something to include. Notice the use of ZeRO-3 also uses pipeline parallelism
- Not balancing pipeline stages
- There will be some brief periods where either a machine is idle and waiting on the next minibatch from the previous machine or takes longer than other machines to execute its computation, thus slowing down the pipeline.
- You should ideally construct your pipeline such that each machine does as close to the same amount of computation as possible. This means timing how long it takes data to get through different layers in the model, timing how long forward and backward propagation takes for each model partition, and ensuring roughly equivalent data sizes across mini-batches. This is critical for optimizing pipeline efficiency.
- To achieve this, setting up profiler like PyTorch Profiler is critical for evaluation of computations done during model training
- Weight staleness
- When model training is pipelined across multiple machines, there is a delay that happens between when the forward computation on data occurs and when the gradients based on that computation are backpropagated to update the model weights. As a result, forward propagation are calculated using weights that aren't updated with the latest gradients.
- Solution: weight stashing
A system “maintains multiple versions of a model’s weights, one for each minibatch.” After the completion of each forward pass, the system can store a model’s weights as part of the state associated with that minibatch. When the time comes for backpropagation, the weights associated with that minibatch are retrieved from the stash and used for the gradient computation. This ensures that the same version of weights are used for the forward and backward pass over a single minibatch of data within a pipelined stage, and statistical convergence is improved.
- Driver and library inconsistencies between machines
- Containerization / Virtualization using tools like Docker solves the problem
- Wrong type of Optimizer Update
- Example: Synchronous vs Asynchronous SGD
- Asynchronous SGD (HogWild as a popular choice) which showed that SGD could be run in parallel, without locks, and without too much effect on algorithm convergence. Asynchronous SGD allows weight updates to proceed without each machine waiting for the other to send their gradients.
- Network issues, firewalls, ports, and communication errors
- Solutions:
- Relying less on network for communication
- If necessary to communicate, a process must specify the IP address and port number across which to transmit this information
- Backup Frequently
- Better logging
- Slow data transmission
- Solutions:
- Avoid making RPC calls
- Try higher bandwidth interconnects like NVLink and Infini-band
- FP32 -> FP16 / Mixed precision
- transmit a subset of gradients as soon as they are calculated (i.e. sending the gradients of a single layer) while at the same time, backpropagation is being performed on subsequent layers.
6. A complete Distributed DL pipeline
Distributed Training Setup:
- Set up a distributed computing environment, typically using a cluster or cloud infrastructure like AWS, Google Cloud, or Azure.
- Ensure that all nodes in the cluster have the necessary libraries (TensorFlow, PyTorch, etc.) and dependencies installed.
- Split the training dataset across nodes to distribute the workload.
Synchronization and Communication:
- Implement a synchronization mechanism to ensure that the model’s weights are updated consistently across all nodes.
- Choose a communication protocol (e.g., Parameter Server, AllReduce) for aggregating gradients and exchanging model updates.
Model Initialization:
- Initialize the same model architecture with random weights on all nodes.
- Setup model to follow data parallelism
Training Loop (The main discussion we had in the blog):
- Start the training loop on each node with its batch of data.
- Compute gradients for the batch, update local weights, and synchronize with other nodes.
- Repeat this process for a predefined number of epochs.
- Implement early stopping to prevent overfitting and save the best-performing model checkpoint during training.
- Periodically evaluate the model’s performance on the validation dataset to ensure it’s learning effectively.
- Save model checkpoints at regular intervals during training to resume from a specific point in case of interruptions.
- If necessary, scale up the distributed training environment by adding more nodes to further accelerate training or handle larger datasets.
Monitoring and Logging:
- Implement monitoring and logging to track training progress, including loss, accuracy, and other relevant metrics.
- Use tools like TensorBoard or custom logging solutions to visualize training statistics.
Hyperparameter Tuning:
- Perform hyperparameter tuning, which may include learning rate adjustments, batch sizes, and other parameters, to optimize the model’s performance.
- Note: you should set a budget alert before this, as running multiple experiments (on a large model) in a distributed setting can be very COSTLY!!!
Post-training Analysis:
- This can go before/after/hand-in-hand with step 6, as part of model tuning
- Analyze the trained model’s performance on the test dataset to assess its generalization capabilities.
Deployment:
- Deploy the trained model for inference in your production environment, whether it’s on the cloud or at the edge.
- Sometime this requires distributing model weights across servers as well
Additional Fine-tuning (Optional):
- Fine-tune the model as needed based on deployment feedback or new data.
- Checkout Hugging Face’s TRL library & its tutorials to understand more.
Documentation:
- Document the entire distributed training process, including configuration settings, data preprocessing steps, and model architecture, for future reference.
- Maintenance and Updates:
- Regularly update and maintain the distributed training system, including libraries, dependencies, and data pipelines, to ensure its reliability and performance.
For the basic scripts without distributed training and with basic DDP. You may refer to the tutorial here. If you want a one-off solution, please refer to the code below.
7. A more challenging code using native PyTorch
If you are interested in building it from scratch with PyTorch directly, checkout the following code (if you don’t understand the syntax, please DIY)
1 | """A demo on how to setup custom trainer with efficient training""" |
References
Understanding Distributed Training in Deep Learning
https://criss-wang.github.io/post/blogs/mlops/distributed-training/