Front-End Supports in FlexFlow: Python, TensorFlow , PyTorch, ONNX Wei Wu and Mandeep Baines Overview of FlexFlow’s Structure

Keras PyTorch ONNX

Native Python API

Mapper/Parallelizer C++ API API

FlexFlow Runtime

Legion Runtime (https://legion.stanford.edu) Outline

• Overview of Python Interface • Native Python/C++ APIs • Keras Support • PyTorch Support • ONNX Support Overview of Python Integration

• CFFI (C Foreign Function Interface) • Minimal overhead Keras PyTorch ONNX • Thin Native Python API • Direct interact with C++ API • Support interaction with NumPy arrays CFFI

• Use the Legion built-in Python interpreter C++ API Overview of Python Interface (con’t)

Not the default Python interpreter

“flexflow_python” binary contains: • FlexFlow • The Legion runtime • The Legion Python interpreter Native Python/C++ APIs

import python module

top level task: code starts from here

main function Native Python/C++ APIs – Model Creation configuations of running the model from the cmd line (batch size, # GPUs per node, # nodes, …)

create a FlexFlow model

NCHW format

create input tensor output input

Add operators to the model Native Python/C++ APIs – Model Initialization

compile the model (lazy initialization)

numpy arrays

create data loaders

Initialize the model Native Python/C++ APIs – Train the Model

• Use the fit function

• Implement a customized training procedure Native Python API vs C++ API

Python API

C++ API Status of Native Python/C++ APIs

• Operators supported • Add, Subtract, Multiply, Divide • Exp, ReLU, Sigmoid, Tanh, ELU • Conv2D, Pool2D • Flat • Dense • Embedding • Batch_Norm, Batch_Matmul • Concat, Split • Reshape, Transpose • Dropout • Softmax Keras Support

No changes on the model ! Challenges for Supporting PyTorch in FlexFlow

• PyTorch allows users to dynamically construct DNN models

• But FlexFlow optimizes DNN parallelization statically and requires a fixed DNN model PyTorch Support

.fx FlexFlow FlexFlow PyTorch Model Model Graph Representation PyTorch Support (con’t)

PyTorch model

output file

Convert to graph representation PyTorch Support (con’t)

create input tensor

create the model from the graphcreate representationthe model

compile the model (lazy initialization) …