Kotlin∇ A shape-safe DSL for differentiable programming
Breandan Considine Michalis Famelis Liam Paull McGill University Université de Montréal Université de Montréal
Abstract
Kotlin is a statically-typed programming language with support for embedded domain-specific languages, asynchronous programming, and multi-platform com- pilation. In this work, we present an algebraically-based implementation of au- tomatic differentiation (AD) with shape-safe tensor operations, written in pure Kotlin. Our approach differs from existing AD frameworks in that Kotlin∇ is the first shape-safe AD library fully compatible with the Java type system, requir- ing no metaprogramming, reflection or compiler intervention to use. A working prototype is available: https://github.com/breandan/kotlingrad.
1 Introduction
Many existing AD frameworks are implemented in dynamically-typed languages, like Python. Some frameworks are written in statically-typed languages, but only consider primitive data types, and do not attempt to verify the shape of multidimensional arrays. Those which do, either use dynamic type checking or relatively esoteric languages like Haskell (Piñeyro et al., 2019). In our work, we demonstrate a shape-safe AD framework which supports static type checking and inference on array programs in a widely-used programming language called Kotlin. Differentiable programming has a rich history among dynamic languages like Python, Lua and JavaScript, with early implementations including projects like Theano (Bergstra et al., 2010), Torch (Collobert et al., 2002), and TensorFlow (Abadi et al., 2016). Similar ideas have arisen in statically-typed, functional languages, such as Haskell’s Stalin∇ (Pearlmutter & Siskind, 2008b), DiffSharp in F# (Baydin et al., 2015) and recently Swift (Lattner & Wei, 2018). However, the major- ity of existing AD libraries have a loosely- or dynamically- typed DSL, and few support shape-safe array programming in a widely-adopted programming language. To our knowledge, Kotlin has no prior AD implementation. However, the language has several useful features for implementing a native AD framework. Kotlin∇ primarily relies on the following language features:
• Operator overloading and infix functions allow a concise notation for defining arithmetic operations on algebraic structures, i.e. groups, rings and fields. (Niculescu, 2011) • λ-functions support functional programming, following Pearlmutter & Siskind (2008a,b); Siskind & Pearlmutter (2008); Elliott (2009, 2018), et al. • Extension functions support extending classes with new fields and methods which can be exposed to external callers without requiring sub-classing or inheritance.
Kotlin∇ is an embedded domain-specific language (eDSL). Embedded programs may appear struc- turally and behave semantically unlike native code, but are syntactically valid by definition. eDSLs are often used to implement declarative languages, such as SQL/LINQ (Meijer et al., 2006), Op- tiML (Sujeeth et al., 2011) and other fluent interfaces (Fowler, 2005). With a sufficiently expressive host language, one can implement any other language as a library, without needing to write a lexer, parser, compiler or interpreter. With proper type constraints, users will receive code completion and static analysis from their favorite development tools, with no further effort required.
Submitted to Program Transformations for Machine Learning workshop at NeurIPS 2019, Vancouver, Canada. Figure 1: Adapted from van Merriënboer et al. (2018). Kotlin∇ models are data structures, constructed by an eDSL. These are compiled into dataflow graphs at runtime, which are eagerly optimized and lazily evaluated.
2 Usage
Kotlin∇ allows users to implement differentiable programs by composing simple functions to form more complex ones. Operations on functions with an incompatible output shape will fail to compile. Valid expressions are lazily evaluated inside a type-safe numerical context at runtime.
with(DoublePrecision){// Use double-precision numerics val x= variable("x")// Declare immutable input variables val y= variable("y")//(these are just symbolic placeholders) val z= sin(10 * (x*x+ pow(y, 2))) / 10// Lazy expression val dz_dx=d(z)/d(x)// Leibniz derivative notation val d2z_dxdy=d(dz_dx)/d(y)// Mixing higher-order partials val d3z_d2xdy= grad(d2z_dxdy)[x]// Gradient indexing operator plot3D(d3z_d2xdy, -1.0, 1.0)// Plot in -1