
Accelerated Training of Transformer Models Kaarthik Sivashanmugam – Principal Engineering Manager Sherlock Huang – Principal Engineer Azure AI - Frameworks Agenda ONNX Runtime for Training Introduction Integration with training frameworks Acceleration & Native Capabilities Memory usage and execution optimizations Mixed precision training, Distributed training parallelism modes, Gradient checkpointing, AdaSum, DeepSpeed ZeRO Training Recipes & Perf Results Pretraining and finetuning: BERT, GPT-2, Turing Demo: ONNX Runtime Training in Azure Databricks Intro: ONNX, ONNX Runtime ONNX: an open and interoperable format for ML models ONNX Spec ONNX IR (intermediate representation) X ONNX Operator schema (batch x 128) Operation type Attributes Gemm Inputs/outputs Inputs Shape inference function A (batch x 128) B (128 x 256) C (256) weight bias (128 x 256) (256) (128 x 256) (256) Outputs https://onnx.ai/ Y (batch x 256) https://github.com/onnx/onnx/blob/master/docs/Operators.md Attributes alpha: 0.7 beta: 0.5 (batch x 256) Y ONNX Model Graph composed of computational nodes Built-in and custom operators ONNX Runtime (ORT) Cross-platform accelerator for training and inferencing Core part of ML stack at Microsoft for innovations from the company and industry ORT Training Adopted by 1P and 3P workloads for acceleration Current focus on large transformer models (based on demand and acceleration needs) Extensible and supports PyTorch, Keras/Tensorflow, … ONNX Runtime for Training Training & ORT Acceleration Define Model Create ORTTrainer using the model Get Data Batch Compute Loss ORTTrainer.train_step() Train Acceleration scope Loop Compute Gradients & Update Weights Checkpoint Evaluate ORT in PyTorch PyTorch PyTorch + ONNX Runtime backend import torch import torch from onnxruntime.training import ORTTrainer, optim # Model definition # Model definition class NeuralNet(torch.nn.Module): class NeuralNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): def __init__(self, input_size, hidden_size, num_classes): ... ... def forward(self, x): def forward(self, x): ... ... model = NeuralNet(input_size=784, hidden_size=500, num_classes=10) criterion = torch.nn.MSELoss() model = NeuralNet(input_size=784, hidden_size=500, num_classes=10) optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) criterion = torch.nn.Functional.cross_entropy model_description = {'inputs': [('data', ['in', 'batch_size']), # Training Loop ('target', ['label_x_batch_size'])], for t in range(1000): 'outputs’: [('loss', [], True), # forward ('output', ['out', 'batch_size’])] y_pred = model(x) } loss = criterion(y_pred, y) optimizer_config = optim.AdamConfig(lr=learning_rate) # reset gradient buffer trainer = ORTTrainer(model, model_description, optimizer_config, optimizer.zero_grad() optimizer configuration, criterion) # backward # Training Loop loss.backward() for t in range(1000): # forward + backward + weight update # weight update loss, y_pred = trainer.train_step(x, y) optimizer.step() ORT Frontend Adapters PyTorch Script TF/Keras Script PyTorch TF ORTTrainer ORTTrainer GPU GPU To ONNX To ONNX buffer buffer ORT TrainingSession Python API ONNXRuntime Acceleration & Native Capabilities Contributors to ORT Acceleration Graph Optimal Memory CUDA Kernel Other Training Optimizations Gradient Graph Efficiency Optimizations Capabilities Static graph optimization Memory and compute Static graph used for Op fusion Mixed precision training techniques like constant optimized using global preallocation of memory Distributed training folding, redundant node knowledge of data for weights and gradients Reimplemented cuDNN parallelism modes elimination dependencies kernels Memory reuse Removed redundant Gradient checkpointing computation AdaSum DeepSpeed ZeRO Native Capabilities in ORT Mixed Distributed DeepSpeed Gradient Gradient ZeRO Precision Training AdaSum Accumulation Checkpoint Redundancy Training Modes Optimizer Combines gradients in a novel Optimizer State Partitioning Computed gradients are Stashed activations often 16-bit and 32-bit FP types to Parallelism modes: Data, way to improve convergence make training faster and use Horizontal and Pipeline accumulated into gradient buffer dominate memory consumption Gradient Partitioning less memory using partial execution of graph in training Model converges faster repeated for N steps Parameter Partitioning Recompute discarded Averaged gradients are used in activations when needed. optimizer for weight updates Trade off between memory usage vs. computation cost. Code Sample & Training Recipes BERT Pretraining using ORT https://github.com/microsoft/onnxruntime-training-examples/ Training Recipes ▪ BERT Pretraining ▪ Nvidia’s implementation of BERT pretraining accelerated using ORT ▪ https://github.com/microsoft/onnxruntime-training-examples/tree/master/nvidia-bert ▪ GPT-2 Finetuning ▪ Finetuning of Hugging Face GPT-2 model ▪ https://github.com/microsoft/onnxruntime-training-examples/tree/master/huggingface-gpt2 ▪ Turing Finetuning ▪ Finetuning of Microsoft Turing model for abstractive text summarization, sentiment analysis and suggested reply scenarios ▪ https://github.com/microsoft/Turing-NLR (private preview) Performance Improvement Results BERT Pretraining in 4xDGX-2 PyTorch 1.5 with PyTorch 1.5 with % Gain with NGC 20.03-py3 ONNX Runtime ONNX Runtime Phase 1 Throughput (ex/sec) 11522.1 12826.2 11.32% Phase 2 Throughput (ex/sec) 2150.0 2464.1 14.61% Phase 1 time (hours) 11.12 9.99 10.16% Phase 2 time (hours) 6.62 5.77 12.84% Total time (hours) 17.74 15.76 11.16% PyTorch w/ ORT can train with 2x the local batch size as PyTorch w/o ORT (global batch size was kept the same for comparison) Perf Improvement with ORT Model (Scenario)/# Params Perf improvement w/ ORT Turing* (pretraining)/340M 1.4x Turing* (pretraining)/350M 1.2x RoBERTa XL (pretraining)/500M 3x RoBERTa XL (finetuning)/500M 1.2x RoBERTa XXL (pretraining)/1B 7x GPT-2 M(pretraining)/345M 1.2x * https://msturing.org/ Demo: ONNX Runtime Training in Azure Databricks https://github.com/skaarthik/onnxruntime-training-databricks Summary ▪ Optimize and accelerate model training using ONNX Runtime (ORT) ▪ ORT is used in training very large models used in various Microsoft products/services ▪ https://github.com/microsoft/onnxruntime ▪ https://github.com/microsoft/onnxruntime- training-examples Feedback Your feedback is important to us. Don’t forget to rate and review the sessions..
Details
-
File Typepdf
-
Upload Time-
-
Content LanguagesEnglish
-
Upload UserAnonymous/Not logged-in
-
File Pages23 Page
-
File Size-