Accelerated Training of Transformer Models

Kaarthik Sivashanmugam – Principal Engineering Manager

Sherlock Huang – Principal Engineer Azure AI - Frameworks Agenda

ONNX Runtime for Training Introduction Integration with training frameworks Acceleration & Native Capabilities Memory usage and execution optimizations Mixed precision training, Distributed training parallelism modes, Gradient checkpointing, AdaSum, DeepSpeed ZeRO Training Recipes & Perf Results Pretraining and finetuning: BERT, GPT-2, Turing Demo: ONNX Runtime Training in Azure Databricks Intro: ONNX, ONNX Runtime ONNX: an open and interoperable format for ML models ONNX Spec

ONNX IR (intermediate representation) X

ONNX Operator schema (batch x 128) Operation type Attributes Gemm Inputs/outputs Inputs Shape inference function A (batch x 128) B (128 x 256) (256) weight bias (128 x 256) (256) (128 x 256) (256) Outputs https://onnx.ai/ Y (batch x 256) https://github.com/onnx/onnx/blob/master/docs/Operators.md Attributes alpha: 0.7 beta: 0.5

(batch x 256)

Y ONNX Model

Graph composed of computational nodes Built-in and custom operators ONNX Runtime (ORT)

Cross-platform accelerator for training and inferencing

Core part of ML stack at Microsoft for innovations from the company and industry

ORT Training Adopted by 1P and 3P workloads for acceleration Current focus on large transformer models (based on demand and acceleration needs) Extensible and supports PyTorch, /Tensorflow, … ONNX Runtime for Training Training & ORT Acceleration

Define Model Create ORTTrainer using the model

Get Data Batch

Compute Loss ORTTrainer.train_step() Train Acceleration scope Loop Compute Gradients & Update Weights

Checkpoint

Evaluate ORT in PyTorch PyTorch PyTorch + ONNX Runtime backend import torch import torch from onnxruntime.training import ORTTrainer, optim

# Model definition # Model definition class NeuralNet(torch.nn.Module): class NeuralNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): def __init__(self, input_size, hidden_size, num_classes): ......

def forward(self, x): def forward(self, x): ......

model = NeuralNet(input_size=784, hidden_size=500, num_classes=10) criterion = torch.nn.MSELoss() model = NeuralNet(input_size=784, hidden_size=500, num_classes=10) optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) criterion = torch.nn.Functional.cross_entropy model_description = {'inputs': [('data', ['in', 'batch_size']), # Training Loop ('target', ['label_x_batch_size'])], for t in range(1000): 'outputs’: [('loss', [], True), # forward ('output', ['out', 'batch_size’])] y_pred = model(x) } loss = criterion(y_pred, y) optimizer_config = optim.AdamConfig(lr=learning_rate)

# reset gradient buffer trainer = ORTTrainer(model, model_description, optimizer_config, optimizer.zero_grad() optimizer configuration, criterion)

# backward # Training Loop loss.backward() for t in range(1000): # forward + backward + weight update # weight update loss, y_pred = trainer.train_step(x, y) optimizer.step() ORT Frontend Adapters

PyTorch Script TF/Keras Script

PyTorch TF ORTTrainer ORTTrainer GPU GPU To ONNX To ONNX buffer buffer

ORT TrainingSession Python API

ONNXRuntime Acceleration & Native Capabilities Contributors to ORT Acceleration

Graph Optimal Memory CUDA Kernel Other Training Optimizations Gradient Graph Efficiency Optimizations Capabilities

Static graph optimization Memory and compute Static graph used for Op fusion Mixed precision training techniques like constant optimized using global preallocation of memory Distributed training folding, redundant node knowledge of data for weights and gradients Reimplemented cuDNN parallelism modes elimination dependencies kernels Memory reuse Removed redundant Gradient checkpointing computation AdaSum DeepSpeed ZeRO Native Capabilities in ORT

Mixed Distributed DeepSpeed Gradient Gradient ZeRO Precision Training AdaSum Accumulation Checkpoint Redundancy Training Modes Optimizer

Combines gradients in a novel Optimizer State Partitioning Computed gradients are Stashed activations often 16-bit and 32-bit FP types to Parallelism modes: Data, way to improve convergence make training faster and use Horizontal and Pipeline accumulated into gradient buffer dominate memory consumption Gradient Partitioning less memory using partial execution of graph in training Model converges faster repeated for N steps Parameter Partitioning Recompute discarded Averaged gradients are used in activations when needed. optimizer for weight updates Trade off between memory usage vs. computation cost. Code Sample & Training Recipes BERT Pretraining using ORT

https://github.com/microsoft/onnxruntime-training-examples/ Training Recipes

▪ BERT Pretraining ▪ Nvidia’s implementation of BERT pretraining accelerated using ORT ▪ https://github.com/microsoft/onnxruntime-training-examples/tree/master/nvidia-bert ▪ GPT-2 Finetuning ▪ Finetuning of Hugging Face GPT-2 model ▪ https://github.com/microsoft/onnxruntime-training-examples/tree/master/huggingface-gpt2 ▪ Turing Finetuning ▪ Finetuning of Microsoft Turing model for abstractive text summarization, sentiment analysis and suggested reply scenarios ▪ https://github.com/microsoft/Turing-NLR (private preview) Performance Improvement Results BERT Pretraining in 4xDGX-2

PyTorch 1.5 with PyTorch 1.5 with % Gain with NGC 20.03-py3 ONNX Runtime ONNX Runtime

Phase 1 Throughput (ex/sec) 11522.1 12826.2 11.32%

Phase 2 Throughput (ex/sec) 2150.0 2464.1 14.61%

Phase 1 time (hours) 11.12 9.99 10.16%

Phase 2 time (hours) 6.62 5.77 12.84%

Total time (hours) 17.74 15.76 11.16%

PyTorch w/ ORT can train with 2x the local batch size as PyTorch w/o ORT (global batch size was kept the same for comparison) Perf Improvement with ORT

Model (Scenario)/# Params Perf improvement w/ ORT

Turing* (pretraining)/340M 1.4x

Turing* (pretraining)/350M 1.2x

RoBERTa XL (pretraining)/500M 3x

RoBERTa XL (finetuning)/500M 1.2x

RoBERTa XXL (pretraining)/1B 7x

GPT-2 M(pretraining)/345M 1.2x

* https://msturing.org/ Demo: ONNX Runtime Training in Azure Databricks https://github.com/skaarthik/onnxruntime-training-databricks Summary

▪ Optimize and accelerate model training using ONNX Runtime (ORT) ▪ ORT is used in training very large models used in various Microsoft products/services ▪ https://github.com/microsoft/onnxruntime ▪ https://github.com/microsoft/onnxruntime- training-examples Feedback

Your feedback is important to us. Don’t forget to rate and review the sessions.