Accelerated Training of Transformer Models
Kaarthik Sivashanmugam – Principal Engineering Manager
Sherlock Huang – Principal Engineer Azure AI - Frameworks Agenda
ONNX Runtime for Training Introduction Integration with training frameworks Acceleration & Native Capabilities Memory usage and execution optimizations Mixed precision training, Distributed training parallelism modes, Gradient checkpointing, AdaSum, DeepSpeed ZeRO Training Recipes & Perf Results Pretraining and finetuning: BERT, GPT-2, Turing Demo: ONNX Runtime Training in Azure Databricks Intro: ONNX, ONNX Runtime ONNX: an open and interoperable format for ML models ONNX Spec
ONNX IR (intermediate representation) X
ONNX Operator schema (batch x 128) Operation type Attributes Gemm Inputs/outputs Inputs Shape inference function A (batch x 128) B (128 x 256) C (256) weight bias (128 x 256) (256) (128 x 256) (256) Outputs https://onnx.ai/ Y (batch x 256) https://github.com/onnx/onnx/blob/master/docs/Operators.md Attributes alpha: 0.7 beta: 0.5
(batch x 256)
Y ONNX Model
Graph composed of computational nodes Built-in and custom operators ONNX Runtime (ORT)
Cross-platform accelerator for training and inferencing
Core part of ML stack at Microsoft for innovations from the company and industry
ORT Training Adopted by 1P and 3P workloads for acceleration Current focus on large transformer models (based on demand and acceleration needs) Extensible and supports PyTorch, Keras/Tensorflow, … ONNX Runtime for Training Training & ORT Acceleration
Define Model Create ORTTrainer using the model
Get Data Batch
Compute Loss ORTTrainer.train_step() Train Acceleration scope Loop Compute Gradients & Update Weights
Checkpoint
Evaluate ORT in PyTorch PyTorch PyTorch + ONNX Runtime backend import torch import torch from onnxruntime.training import ORTTrainer, optim
# Model definition # Model definition class NeuralNet(torch.nn.Module): class NeuralNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): def __init__(self, input_size, hidden_size, num_classes): ......
def forward(self, x): def forward(self, x): ......
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10) criterion = torch.nn.MSELoss() model = NeuralNet(input_size=784, hidden_size=500, num_classes=10) optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) criterion = torch.nn.Functional.cross_entropy model_description = {'inputs': [('data', ['in', 'batch_size']), # Training Loop ('target', ['label_x_batch_size'])], for t in range(1000): 'outputs’: [('loss', [], True), # forward ('output', ['out', 'batch_size’])] y_pred = model(x) } loss = criterion(y_pred, y) optimizer_config = optim.AdamConfig(lr=learning_rate)
# reset gradient buffer trainer = ORTTrainer(model, model_description, optimizer_config, optimizer.zero_grad() optimizer configuration, criterion)
# backward # Training Loop loss.backward() for t in range(1000): # forward + backward + weight update # weight update loss, y_pred = trainer.train_step(x, y) optimizer.step() ORT Frontend Adapters
PyTorch Script TF/Keras Script
PyTorch TF ORTTrainer ORTTrainer GPU GPU To ONNX To ONNX buffer buffer
ORT TrainingSession Python API
ONNXRuntime Acceleration & Native Capabilities Contributors to ORT Acceleration
Graph Optimal Memory CUDA Kernel Other Training Optimizations Gradient Graph Efficiency Optimizations Capabilities
Static graph optimization Memory and compute Static graph used for Op fusion Mixed precision training techniques like constant optimized using global preallocation of memory Distributed training folding, redundant node knowledge of data for weights and gradients Reimplemented cuDNN parallelism modes elimination dependencies kernels Memory reuse Removed redundant Gradient checkpointing computation AdaSum DeepSpeed ZeRO Native Capabilities in ORT
Mixed Distributed DeepSpeed Gradient Gradient ZeRO Precision Training AdaSum Accumulation Checkpoint Redundancy Training Modes Optimizer
Combines gradients in a novel Optimizer State Partitioning Computed gradients are Stashed activations often 16-bit and 32-bit FP types to Parallelism modes: Data, way to improve convergence make training faster and use Horizontal and Pipeline accumulated into gradient buffer dominate memory consumption Gradient Partitioning less memory using partial execution of graph in training Model converges faster repeated for N steps Parameter Partitioning Recompute discarded Averaged gradients are used in activations when needed. optimizer for weight updates Trade off between memory usage vs. computation cost. Code Sample & Training Recipes BERT Pretraining using ORT
https://github.com/microsoft/onnxruntime-training-examples/ Training Recipes
▪ BERT Pretraining ▪ Nvidia’s implementation of BERT pretraining accelerated using ORT ▪ https://github.com/microsoft/onnxruntime-training-examples/tree/master/nvidia-bert ▪ GPT-2 Finetuning ▪ Finetuning of Hugging Face GPT-2 model ▪ https://github.com/microsoft/onnxruntime-training-examples/tree/master/huggingface-gpt2 ▪ Turing Finetuning ▪ Finetuning of Microsoft Turing model for abstractive text summarization, sentiment analysis and suggested reply scenarios ▪ https://github.com/microsoft/Turing-NLR (private preview) Performance Improvement Results BERT Pretraining in 4xDGX-2
PyTorch 1.5 with PyTorch 1.5 with % Gain with NGC 20.03-py3 ONNX Runtime ONNX Runtime
Phase 1 Throughput (ex/sec) 11522.1 12826.2 11.32%
Phase 2 Throughput (ex/sec) 2150.0 2464.1 14.61%
Phase 1 time (hours) 11.12 9.99 10.16%
Phase 2 time (hours) 6.62 5.77 12.84%
Total time (hours) 17.74 15.76 11.16%
PyTorch w/ ORT can train with 2x the local batch size as PyTorch w/o ORT (global batch size was kept the same for comparison) Perf Improvement with ORT
Model (Scenario)/# Params Perf improvement w/ ORT
Turing* (pretraining)/340M 1.4x
Turing* (pretraining)/350M 1.2x
RoBERTa XL (pretraining)/500M 3x
RoBERTa XL (finetuning)/500M 1.2x
RoBERTa XXL (pretraining)/1B 7x
GPT-2 M(pretraining)/345M 1.2x
* https://msturing.org/ Demo: ONNX Runtime Training in Azure Databricks https://github.com/skaarthik/onnxruntime-training-databricks Summary
▪ Optimize and accelerate model training using ONNX Runtime (ORT) ▪ ORT is used in training very large models used in various Microsoft products/services ▪ https://github.com/microsoft/onnxruntime ▪ https://github.com/microsoft/onnxruntime- training-examples Feedback
Your feedback is important to us. Don’t forget to rate and review the sessions.