UPDATE (March 2021) Topic: Training

• Peng Wang| AI Frameworks @ Microsoft Common problems impacting ML productivity

Inference latency is too high to put into production​

Training in Python but need to deploy into a #/C++/Java app​

Model needs to run on edge/IoT devices​

Same model needs to run on different hardware and operating systems​

Need to support running models created from several different frameworks​

Training very large models takes too long​ high-performance engine for models

Flexible Cross Platform Extensible Training (preview) Mobile (preview)

Supports full ONNX-ML Works on Extensible “execution Distributed training Model-specific package spec (v1.2-1.7) -Mac, Windows, provider” architecture to acceleration on multi- Reduced size -x86, x64, ARM plug-in custom operators, node GPU Android, iOS, Linux Supports CPU, GPU, VPU optimizers, and hardware X86, ARM Also built-in to Windows accelerators Large scale Transformer C#, C, C++, Java, JS and 10 natively (WinML) models Python APIs

.com/microsoft/onnxruntime Training Design Principles

Generic Framework for Training Extensible with new kernels, optimization DNNs algorithms, etc.

Current Implementation optimizes for Transformer-Based models.

Adding model support based on customer demand ONNX Runtime Training (Public Preview)

• Seamless integration with existing training frameworks for accelerated training and fine Frontend PyTorch TF/ Other tuning of large transformer models • Incorporates latest algorithms and techniques IR ONNX such as DeepSpeed/ZeRO and Parasail/Adasum • Integrates with GPU for distributed training ONNX Runtime

MSFT and 3P innovations Graph Compiler (MLIR) Backend MSR DeepSpeed Distributed Execution MSR Parasail

Accelerator Nvidia CPU GPU FPGA VPU DML NPU Megatron SDKs Augmenting ONNX graphs for training

• ORT Training takes an inference (“forward”) graph as input • Training-specific functionality implemented as graph transformations

PyTorch to Forward Distributed Mixed Precision Automatic Training Graph ONNX Graph Communication Transformation Differentiation Optimization Conversion Optimization Setup Forward graph (inference graph)

MatMul MatMul (batch x 784) (batch x 128) (batch x 128) (batch x 10) X Relu predictions B (784 x 128) B (128 x 10)

weight.0 weight.1 (784 x 128) (128 x 10) (user-supplied)

MatMul MatMul (batch x 784) (batch x 128) (batch x 128) (batch x 10) (batch x 10) X Relu predictions SoftmaxCrossEntropy B (784 x 128) B (128 x 10)

labels (batch x 10) (1)

weight.0 weight.1 loss (784 x 128) (128 x 10) Backward graph (loss function gradient)

MatMul MatMul (batch x 784) (batch x 128) (batch x 128) (batch x 10) (batch x 10) X Relu predictions SoftmaxCrossEntropy B (784 x 128) B (128 x 10) (batch x 10) labels (batch x 10) (1)

weight.0 weight.1 loss (784 x 128) (128 x 10)

(1)

(batch x 10) SoftmaxCrossEntropyGrad

(1)

Constant 1 Backward graph (compute gradients)

MatMul MatMul (batch x 784) (batch x 128) (batch x 128) (batch x 10) (batch x 10) X Relu predictions SoftmaxCrossEntropy B (784 x 128) B (128 x 10) (batch x 10) labels (batch x 10) (1)

weight.0 weight.1 loss (784 x 128) (128 x 10)

(1)

MatMulGrad* MatMulGrad* (batch x 784) (batch x 128) (batch x 128) (batch x 10) A (batch x 784) ReluGrad A (batch x 128) SoftmaxCrossEntropyGrad B (784 x 128) B (128 x 10) C (batch x 128) C (batch x 10) (1)

(784 x 128) (128 x 10) Constant 1 Backward graph (use existing operators)

MatMul MatMul (batch x 784) (batch x 128) (batch x 128) (batch x 10) (batch x 10) X Relu predictions SoftmaxCrossEntropy B (784 x 128) B (128 x 10) (batch x 10) labels (batch x 10) (1)

weight.0 weight.1 loss (784 x 128) (128 x 10)

(1)

Gemm MatMulGrad* (batch x 784) (batch x 128) (batch x 128) (batch x 10) AT (784 x batch) ReluGrad A (batch x 128) SoftmaxCrossEntropyGrad B (batch x 128) B (128 x 10) C (batch x 10) (1)

(784 x 128) (128 x 10) Constant 1 Optimizer (Adam/Lamb)

MatMul MatMul (batch x 784) (batch x 128) (batch x 128) (batch x 10) (batch x 10) X Relu predictions SoftmaxCrossEntropy B (784 x 128) B (128 x 10) (batch x 10) labels (batch x 10) (1)

weight.0 weight.1 loss (784 x 128) (128 x 10)

(1)

Gemm MatMulGrad* (batch x 128) (batch x 128) (batch x 10) AT (784 x batch) ReluGrad A (batch x 128) SoftmaxCrossEntropyGrad B (batch x 128) B (128 x 10) C (batch x 10) (1)

(784 x 128) (128 x 10) Constant momentum.0.u momentum.1.u 1 (784 x 128) (128 x 10) Adam Optimizer Adam Optimizer momentum.0.v momentum.1.v (784 x 128) (128 x 10) weight.0 weight.1 (784 x 128) (128 x 10) Training Acceleration Transformer models Usage of ORT Training at Microsoft

Scenario / Team Improvement Model

Pre-training From 4 days to ~2 days Office services TuringNLR (1.4x higher throughput)

Pre-training From 8 days to 4.5 days Bing Ads RoBERTa-XL as (1.4x higher throughput) base model Now able to train; stock Fine-tuning GPT-2 Office apps PyTorch could not train for word prediction with data parallelism

Pre-training GPT-2 From 8 days to 6.5 Visual Studio Medium for days (1.19x higher IntelliSense throughput) Nvidia A100

• FP16 and TF32 supported, BF16 is in progress • Scales up to 512 A100 GPUs

BERT-L Pretraining (ORT vs. Nvidia PT) Throughput Batch size / Accumulation steps Throughput (seq/sec) GPUs Speedup GPU PT ORT PT ORT (ORT vs PT) 1 65536 1024 512 415 509.4 22.7% Phase 1 4 16384 256 128 1618 2024.3 25.1% 8 8192 128 64 3231 4058.5 25.6% 1 32768 2048 1024 78 96.2 23.3% Phase 2 4 8192 512 256 308 382 24.0% 8 4096 256 128 620 766.9 23.7%

* Both ORT and PyTorch are using mixed precision training with lamb * PT numbers are adopted from Nvidia Examples Repo CUDA Kernel Optimizations

• Transformer models are sharing a “stable and small” set of operators • Few variations of activation and normalization functions • Easy to support the new models

• Kernel optimizations for BERT are transferable to other models • Prioritize for generally applicable and reusable kernel optimizations • RoBERTa, GPT-2, and other variants of transformer models run faster with ORT out of the box

• Graph based optimization, no change in model definition Address Space First-time Allocation Memory Optimizations Reusing Buffer

• Optimizing tensor placement in 2D space of Memory-Time • Heavily reusing allocated buffer space • Minimizes memory fragmentations Time • Predicts peak memory consumption before running the model • Runs BERT-L @ 2x of PyTorch’s batch size • Enables training GPT2-Medium on 16GB V100, which PyTorch runs OOMs • Allows fitting larger model PyTorch TensorFlow PyTorch model TF model Front-end integration

Torch IR GraphDef IR

ONNX Runtime

ONNX IR

Autodiff

Graph optimizations (Parallelization, Fusion, DeepSpeed, PipeDream, …)

ORT host runtime (initialization, kernel launch, data feeding, monitoring, …)

CUDA ROCm Runtime Runtime Driver Driver

Nvidia AMD GPU GPU PyTorch integration: today

PyTorch PyTorch + ONNX Runtime backend

# Model definition # Model definition class NeuralNet(torch.nn.Module): class NeuralNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): def __init__(self, input_size, hidden_size, num_classes): ......

def forward(self, x): def forward(self, x): ......

model = NeuralNet(input_size=784, hidden_size=500, num_classes=10) model = NeuralNet(input_size=784, hidden_size=500, num_classes=10)

criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) # Describe entire computation to offload # Training Loop optimizer = optim.SGDConfig(lr=1e-4) for data, target in data_loader: model_desc = {"inputs": [("x", ["batch", 784])], # reset gradient buffer "outputs": [("y", ["batch", 10])]} optimizer.zero_grad() trainer = ORTTrainer(model, optimizer, model_desc, criterion)

# forward # Training Loop y_pred = model(data) for data, target in data_loader: loss = criterion(output, target) # forward + backward + weight update loss, y_pred = trainer.train_step(data, target) # backward loss.backward()

# weight update optimizer.step() PyTorch integration: next

PyTorch PyTorch + ONNX Runtime backend

# Model definition # Model definition class NeuralNet(torch.nn.Module): class NeuralNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): def __init__(self, input_size, hidden_size, num_classes): ......

def forward(self, x): def forward(self, x): ......

model = NeuralNet(input_size=784, hidden_size=500, num_classes=10) model = NeuralNet(input_size=784, hidden_size=500, num_classes=10) model = ORTModule(model) criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)

# Training Loop # Training Loop for data, target in data_loader: for data, target in data_loader: # reset gradient buffer # reset gradient buffer optimizer.zero_grad() optimizer.zero_grad()

# forward # forward y_pred = model(data) y_pred = model(data) loss = criterion(output, target) loss = criterion(output, target)

# backward # backward loss.backward() loss.backward()

# weight update # weight update optimizer.step() optimizer.step() Example Description

Get started with ONNX Runtime with a simple PyTorch getting-started transformer model

Training Using ONNX Runtime Training with BERT pretraining nvidia-bert Examples implementation in PyTorch maintained by nvidia

Using ONNX Runtime Training with GPT2 finetuning for huggingface-gpt2 Language Modeling in PyTorch maintained by huggingface

• GitHub - microsoft/onnxruntime-training-examples: Examples for using ONNX Runtime for model training. Thanks 谢谢

More to Read ONNX Runtime Training Technical Deep Dive - Microsoft Tech Community, Announcing accelerated training with ONNX Runtime—train models up to 45% faster - Open Source Blog (microsoft.com)