Front-End Supports in FlexFlow: Python, TensorFlow Keras, PyTorch, ONNX Wei Wu and Mandeep Baines Overview of FlexFlow’s Structure
Keras PyTorch ONNX
Native Python API
Mapper/Parallelizer C++ API API
FlexFlow Runtime
Legion Runtime (https://legion.stanford.edu) Outline
• Overview of Python Interface • Native Python/C++ APIs • Keras Support • PyTorch Support • ONNX Support Overview of Python Integration
• CFFI (C Foreign Function Interface) • Minimal overhead Keras PyTorch ONNX • Thin layer Native Python API • Direct interact with C++ API • Support interaction with NumPy arrays CFFI
• Use the Legion built-in Python interpreter C++ API Overview of Python Interface (con’t)
Not the default Python interpreter
“flexflow_python” binary contains: • FlexFlow • The Legion runtime • The Legion Python interpreter Native Python/C++ APIs
import python module
top level task: code starts from here
main function Native Python/C++ APIs – Model Creation configuations of running the model from the cmd line (batch size, # GPUs per node, # nodes, …)
create a FlexFlow model
NCHW format
create input tensor output input
Add operators to the model Native Python/C++ APIs – Model Initialization
compile the model (lazy initialization)
numpy arrays
create data loaders
Initialize the model Native Python/C++ APIs – Train the Model
• Use the fit function
• Implement a customized training procedure Native Python API vs C++ API
Python API
C++ API Status of Native Python/C++ APIs
• Operators supported • Add, Subtract, Multiply, Divide • Exp, ReLU, Sigmoid, Tanh, ELU • Conv2D, Pool2D • Flat • Dense • Embedding • Batch_Norm, Batch_Matmul • Concat, Split • Reshape, Transpose • Dropout • Softmax Keras Support
No changes on the model ! Challenges for Supporting PyTorch in FlexFlow
• PyTorch allows users to dynamically construct DNN models
• But FlexFlow optimizes DNN parallelization statically and requires a fixed DNN model PyTorch Support
torch.fx FlexFlow FlexFlow PyTorch Model Model Graph Representation PyTorch Support (con’t)
PyTorch model
output file
Convert to graph representation PyTorch Support (con’t)
create input tensor
create the model from the graphcreate representationthe model
compile the model (lazy initialization) …