Published as a conference paper at ICLR 2020

NEURAL TANGENTS: FASTAND EASY INFINITE NEURAL NETWORKSIN PYTHON

Roman Novak,∗ Lechao Xiao,∗ Jiri Hron†, Jaehoon Lee,

Alexander A. Alemi, Jascha Sohl-Dickstein, Samuel S. Schoenholz∗

Google Brain, †University of Cambridge {romann, xlc}@google.com, [email protected],{jaehlee, alemi, jaschasd, schsam}@google.com

ABSTRACT

NEURAL TANGENTS is a library for working with infinite-width neural networks. It provides a high-level API for specifying complex and hierarchical neural network architectures. These networks can then be trained and evaluated either at finite- width as usual or in their infinite-width limit. Infinite-width networks can be trained analytically using exact Bayesian inference or using via the Neural Tangent Kernel. Additionally, NEURAL TANGENTS provides tools to study gradient descent training dynamics of wide but finite networks in either function space or weight space. The entire library runs out-of-the-box on CPU, GPU, or TPU. All computations can be automatically distributed over multiple accelerators with near-linear scaling in the number of devices. NEURAL TANGENTS is available at www.github.com/google/neural-tangents

We also provide an accompanying interactive Colab notebook1.

1I NTRODUCTION

Deep neural networks (DNNs) owe their success in part to the broad availability of high-level, flexible, and efficient software libraries like Tensorflow (Abadi et al., 2015), Keras (Chollet et al., 2015), PyTorch.nn (Paszke et al., 2017), Chainer (Tokui et al., 2015; Akiba et al., 2017), JAX (Bradbury et al., 2018a), and others. These libraries enable researchers to rapidly build complex models by constructing them out of smaller primitives. The success of new machine learning approaches will similarly depend on developing sophisticated software tools to support them.

1.1I NFINITE-WIDTH BAYESIAN NEURAL NETWORKS

Recently, a new class of machine learning models has attracted significant attention, namely, deep infinitely wide neural networks. In the infinite-width limit, a large class of Bayesian neural networks become Gaussian Processes (GPs) with a specific, architecture-dependent, compositional kernel; these models are called Neural Network Gaussian Processes (NNGPs). This correspondence was first established for shallow fully-connected networks by Neal(1994) and was extended to multi- layer setting in (Lee et al., 2018; Matthews et al., 2018b). Since then, this correspondence has been expanded to a wide range of nonlinearities (Matthews et al., 2018a; Novak et al., 2019) and architectures including those with convolutional layers (Garriga-Alonso et al., 2019; Novak et al., 2019), residual connections (Garriga-Alonso et al., 2019), pooling (Novak et al., 2019), as well as graph neural networks (Du et al., 2019). The results for individual architectures have subsequently been generalized, and it was shown that a GP correspondence holds for a general class of networks that can be mapped to so-called tensor programs in (Yang, 2019). The recurrence relationship defining

∗Equal contribution. †Work done during an internship at Google Brain. 1 colab.research.google.com/github/google/neural-tangents/blob/master/notebooks/neural_tangents_cookbook.ipynb

1 Published as a conference paper at ICLR 2020

the NNGP kernel has additionally been extensively studied in the context of mean field theory and initialization (Cho & Saul, 2009; Daniely et al., 2016; Poole et al., 2016; Schoenholz et al., 2016; Yang & Schoenholz, 2017; Xiao et al., 2018; Li & Nguyen, 2019; Pretorius et al., 2018; Hayou et al., 2018; Karakida et al., 2018; Blumenfeld et al., 2019; Hayou et al., 2019).

1.2I NFINITE-WIDTHNEURALNETWORKSTRAINEDBYGRADIENTDESCENT

In addition to enabling a closed form description of Bayesian neural networks, the infinite-width limit has also very recently provided insights into neural networks trained by gradient descent. In the last year, several papers have shown that randomly initialized neural networks trained with gradient descent are characterized by a distribution that is related to the NNGP, and is described by the so-called Neural Tangent Kernel (NTK) (Jacot et al., 2018; Lee et al., 2019; Chizat et al., 2019), a kernel which was implicit in some earlier papers (Li & Liang, 2018; Allen-Zhu et al., 2018; Du et al., 2018a;b; Zou et al., 2019). In addition to this “function space” perspective, a dual, “weight space” view on the wide network limit was proposed in Lee et al.(2019) which showed that networks under gradient descent were well-described by the first-order Taylor series about their initial parameters.

1.3P ROMISE AND PRACTICAL BARRIERS TO WORKING WITH INFINITE-WIDTHNETWORKS

Combined, these discoveries established infinite-width networks as useful theoretical tools to under- stand a wide range of phenomena in . Furthermore, the practical utility of these models has been proven by achieving state-of-the-art performance on image classification benchmarks among GPs without trainable kernels (Garriga-Alonso et al., 2019; Novak et al., 2019; Arora et al., 2019a), and by their ability to match or exceed the performance of finite width networks in some situations, especially for fully- and locally-connected model families (Lee et al., 2018; Novak et al., 2019; Arora et al., 2019b). However, despite their utility, using NNGPs and NTK-GPs is arduous and can require weeks-to- months of work by seasoned practitioners. Kernels corresponding to neural networks must be derived by hand on a per-architecture basis. Overall, this process is laborious and error prone, and is reminiscent of the state of neural networks before high quality Automatic Differentiation (AD) packages proliferated.

1.4S UMMARY OF CONTRIBUTIONS

In this paper, we introduce a new open-source software library called NEURAL TANGENTS targeting JAX (Bradbury et al., 2018a) to accelerate research on infinite limits of neural networks. The main features of NEURAL TANGENTS are:2 • A high-level neural network API for specifying complex, hierarchical, models. Networks specified using this API can have their infinite-width NNGP kernel and NTK evaluated analytically (§2.1, Listings1,2,3, §B.2). • Functions to approximate infinite-width kernels by Monte Carlo sampling for networks whose kernels cannot be constructed analytically. These methods are agnostic to the neural network library used to build the network and are therefore quite versatile (§2.2, Figure2, §B.5). • An API to analytically perform inference using infinite-width networks either by computing the Bayesian posterior or by computing the result of continuous gradient descent with an MSE loss. The API additionally includes tools to perform inference by numerically solving the ODEs corresponding to: continuous gradient descent, with-or-without momentum, on arbitrary loss functions, at finite or infinite time (§2.1, Figure1, §B.4). • Functions to compute arbitrary-order Taylor series approximations to neural networks about a given setting of parameters to explore the weight space perspective on the infinite-width limit (§B.6, Figure6). • Leveraging XLA, our library runs out-of-the-box on CPU, GPU, or TPU. Kernel computa- tions can automatically be distributed over multiple accelerators with near-perfect scaling (§3.2, Figure5, §B.3).

2See §A for NEURAL TANGENTS comparison against specific relevant prior works.

2 Published as a conference paper at ICLR 2020

We begin with three short examples (§2) that demonstrate the ease, efficiency, and versatility of performing calculations with infinite networks using NEURAL TANGENTS. With a high level view of the library in hand, we then dive into a number of technical aspects of our library (§3).

1.5B ACKGROUND

We briefly describe the NNGP (§1.1) and NTK (§1.2). NNGP. Neural networks are often structured as l th affine transformations followed by pointwise applications of nonlinearities. Let zi(x) describe the i pre-activation following a linear transformation in lth layer of a neural network. At initialization, the parameters of the network are randomly distributed and so central-limit theorem style arguments can be used to show that the pre-activations become Gaussian distributed with mean zero and are therefore l l described entirely by their covariance matrix (x, x0) = E[z (x)z (x0)]. This describes a NNGP K i i with the kernel, (x, x0). One can use the NNGP to make Bayesian posterior predictions at a test K 1 point, x, which are Gaussian distributed with with mean µ(x) = (x, ) ( , )− and variance 2 1 K X K X X Y σ (x) = (x, x) (x, ) ( , )− ( , x), where ( , ) is the training set of inputs and targets respectively.K −NTK. K XWhenK X neuralX networksK X are optimizedX Y using continuous gradient descent with learning rate η on mean squared error (MSE) loss, the function evaluated on training points T evolves as ∂tft( ) = ηJt( )Jt( ) (ft( ) ) where Jt( ) is the Jacobian of the output ft evaluated at andX Θ−( , X) = JX( )J ( X)T−is Y the NTK. InX the infinite-width limit, the NTK X t X X t X t X remains constant (Θt = Θ) throughout training and the time-evolution of the outputs can be solved 1 in closed form as a Gaussian with mean f (x) = Θ(x, )Θ( , )− (I exp [ ηΘ( , )t]) . t X X X − − X X Y

2E XAMPLES

We begin by applying NEURAL TANGENTS to several example tasks. While these tasks are designed for pedagogy rather than research novelty, they are nonetheless emblematic of problems regularly faced in research. We emphasize that without NEURAL TANGENTS, it would be necessary to derive the kernels for each architecture by hand.

2.1I NFERENCEWITHAN INFINITELY WIDE NEURAL NETWORK

We begin by training an infinitely wide neural network with gradient descent and comparing the result to training an ensemble of wide-but-finite networks. This example is worked through in detail in the Colab notebook.3

We train on a synthetic dataset with training data drawn from the process yi = sin(xi) + i with 2 xi Uniform( π, π) and i (0, σ ) independently and identically distributed. To train an infinite∼ neural network− with Erf∼activations N 4 on this data using gradient descent and an MSE loss we write the following:

from neural_tangents import predict, stax

init_fn, apply_fn, kernel_fn= stax.serial( stax.Dense(2048, W_std=1.5, b_std=0.05), stax.Erf(), stax.Dense(2048, W_std=1.5, b_std=0.05), stax.Erf(), stax.Dense(1, W_std=1.5, b_std=0.05))

y_mean, y_var= predict.gp_inference(kernel_fn, x_train, y_train, x_test, 'ntk', diag_reg=1e-4, compute_cov=True)

The above code analytically generates the predictions that would result from performing gradient descent for an infinite amount of time. However, it is often desirable to investigate finite-time learning dynamics of deep networks. This is also supported in NEURAL TANGENTS as illustrated in the following snippet:

3 www.colab.research.google.com/github/google/neural-tangents/blob/master/notebooks/neural_tangents_cookbook.ipynb 4Error function, a nonlinearity similar to tanh; see §D for other implemented nonlinearities, including Relu.

3 Published as a conference paper at ICLR 2020

Figure 1: Training dynamics for an ensemble of finite-width networks compared with an infi- nite network. Left: Mean and variance of the train and test MSE loss evolution throughout training. Right: Comparison between the predictions of the trained infinite network and the respective ensem- ble of finite-width networks. The shaded region and the dashed lines denote two standard deviations of uncertainty in the predictions for the infinite network and the ensemble respectively.

predict_fn= predict.gradient_descent_mse_gp(kernel_fn, x_train, y_train, x_test, 'ntk', diag_reg=1e-4, compute_cov=True) y_mean, y_var= predict_fn(t=100) # Predict the distribution at t = 100.

The above specification set the hidden layer widths to 2048, which has no effect on the infinite width network inference, but the init_fn and apply_fn here correspond to ordinary finite width networks. In Figure1 we compare the result of this exact inference with training an ensemble of one-hundred of these finite-width networks by looking at the training curves and output predictions of both models. We see excellent agreement between exact inference using the infinite-width model and the result of training an ensemble using gradient descent.

2.2A N INFINITELY WIDERESNET

The above example considers a relatively simple network on a synthetic task. In practice we may want to consider real-world architectures, and see how close they are to their infinite-width limit. For this task we study a variant of an infinite-channel Wide Residual Network (Zagoruyko & Komodakis, 2016) (WRN-28- ). We first define both finite and infinite models within Listing1. ∞ We now study how quickly the kernel of the finite-channel WideResNet approaches its infinite channel limit. We explore two different axes along which convergence takes place: first, as a function of the number of channels (as measured by the widening factor, k) and second as a function of the number of finite-network Monte Carlo samples we average over. NEURAL TANGENTS makes it easy to compute MC averages of finite kernels using the following snippet:

kernel_fn= nt.monte_carlo_kernel_fn(init_fn, apply_fn, rng_key, n_samples) sampled_kernel= kernel_fn(x, x)

The convergence is shown in Figure2. We see that as both the number of samples is increased or the network is made wider, the empirical kernel approaches the kernel of the infinite network. As noted in Novak et al.(2019), for any finite widening factor the MC estimate is biased. Here, however, the bias is small relative to the variance and the distance to the empirical kernel decreases with the number of samples.

2.3C OMPARISON OF NEURAL NETWORKARCHITECTURESANDTRAININGSETSIZES

The above examples demonstrate how one might construct a complicated architecture and perform inference using NEURAL TANGENTSNext˙ we train a range of architectures on CIFAR-10 and compare their performance as a function of dataset size. In particular, we compare a fully-connected network,

4 Published as a conference paper at ICLR 2020

a convolutional network whose penultimate layer vectorizes the image, and the wide-residual network described above. In each case, we perform exact infinite-time inference using the analytic infinite- width NNGP or NTK. For each architecture we perform a hyperparameter search over the depth of the network, selecting the depth that maximizes the marginal log likelihood on the training set.

def WideResNetBlock(channels, strides=(1,1), channel_mismatch=False): Main= stax.serial(stax.Relu(), stax.Conv(channels, (3,3), strides, padding= 'SAME'), stax.Relu(), stax.Conv(channels, (3,3), padding= 'SAME')) Shortcut= (stax.Identity() if not channel_mismatch else stax.Conv(channels, (3,3), strides, padding= 'SAME')) return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut), stax.FanInSum())

def WideResNetGroup(n, channels, strides=(1,1)): blocks= [WideResNetBlock(channels, strides, channel_mismatch=True)] for _ in range(n-1): blocks+= [WideResNetBlock(channels, (1,1))] return stax.serial(*blocks)

def WideResNet(block_size, k, num_classes): return stax.serial(stax.Conv(16,(3,3), padding= 'SAME'), WideResNetGroup(block_size, int(16* k)), WideResNetGroup(block_size, int(32* k), (2,2)), WideResNetGroup(block_size, int(64* k), (2,2)), stax.GlobalAvgPool(), stax.Dense(num_classes))

init_fn, apply_fn, kernel_fn= WideResNet(block_size=4, k=1, num_classes=10)

Listing 1: Definition of an infinitely WideResNet. This snippet simultaneously defines a finite ( init_fn, apply_fn ) and an infinite ( kernel_fn ) model. This model is used in Figures2 and3.

NNGP NTK 5 5 1 1 log10(Distance) log10(Distance) log2(widening factor) log2(widening factor) −5 −5 −4 −4 0 log2(#Samples) 10 0 log2(#Samples) 10 Figure 2: Convergence of the Monte Carlo (MC) estimates of the WideResNet WRN-28-k (where k is the widening factor) NNGP and NTK kernels (computed with monte_carlo_kernel_fn ) to their analytic values (WRN-28- , computed with kernel_fn ), as the network gets wider by increasing the widening factor (vertical∞ axis) and as more random networks are averaged over (horizontal axis). Experimental detail. The kernel is computed in 32-bit precision on a 100 50 batch of 8 8-downsampled CIFAR10 (Krizhevsky, 2009) images. For sampling efficiency,× for NNGP the× output of the penultimate layer was used, and for NTK the output layer was assumed to be of dimension 1 (all logits are i.i.d. conditioned on a given input). The displayed distance is the 2 2 relative Frobenius norm squared, i.e. k,n F / F , where k is the widening factor and n is the number of samples. kK − K k kKk

5 Published as a conference paper at ICLR 2020

0.80 0.095 3.5 FC 0.75 0.090 CONV 3.0 WRESNET 0.70 0.085 0.65 2.5 0.60 0.080 2.0 0.55 0.075 0.50 FC-NTK 1.5 FC-NNGP 0.070 CONV-NTK

Classification Error 0.45 Marginal NLL (nats) CONV-NNGP Mean Squared Error 0.065 1.0 0.40 WRESNET-NTK WRESNET-NNGP 0.35 0.060 0.5 26 27 28 29 210211212213214215 26 27 28 29 210211212213214215 26 27 28 29 210211212213214215 Training Set Size Training Set Size Training Set Size

Figure 3: CIFAR-10 classification with varying neural network architectures. NEURAL TAN- GENTS simplify experimentation with architectures. Here we use infinite time NTK inference and full Bayesian NNGP inference for CIFAR-10 for Fully Connected (FC, Listing3) , Convolutional network without pooling (CONV, Listing2) , and Wide Residual Network w/ pooling (WRES- NET, Listing1) . As is common in prior work (Lee et al., 2018; Novak et al., 2019), the classification task is treated as MSE regression on zero-mean targets like ( 0.1,..., 0.1, 0.9, 0.1,..., 0.1) . For each training set size, the best model in the family is selected− by minimizing− the− mean negative− marginal log-likelihood (NLL, right) on the training set.

The results are shown in Figure3. We see that in each case the performance of the model increases approximately logarithmically in the size of the dataset. Moreover, we observe a clear hierarchy of performance, especially at large dataset size, in terms of architecture (FC < CONV < WRESNET w/ pooling).

3I MPLEMENTATION:TRANSFORMING TENSOR OPSTO KERNEL OPS

Neural networks are compositions of basic tensor operations such as: dense or convolutional affine transformations, application of pointwise nonlinearities, pooling, or normalization. For most networks without weight tying between layers the kernel computation can also be written compositionally and there is a direct correspondence between tensor operations and kernel operations (see §3.1 for an example). The core logic of NEURAL TANGENTS is a set of translation rules, that sends each tensor operation acting on a finite-width layer to a corresponding transformation of the kernel for an infinite-width network. This is illustrated in Figure4 for a simple convolutional architecture. In the associated table, we compare tensor operations (second column) with corresponding transformations of the NT and NNGP kernel tensors (third and fourth column respectively). See §D for a list of all tensor operations for which translation rules are currently implemented. One subtlety to consider when designing networks is that most infinite-width results require nonlinear transformations to be preceded by affine transformations (either dense or convolutional). This is be- cause infinite-width results often assume that the pre-activations of nonlinear layers are approximately Gaussian. Randomness in weights and biases causes the output of infinite affine layers to satisfy this Gaussian requirement. Fortunately, prefacing nonlinear operations with affine transformations is common practice when designing neural networks and NEURAL TANGENTS will raise an error if this requirement is not satisfied.

3.1A TASTE OF TENSOR-TO-KERNEL OPS TRANSLATION

To get some intuition behind the translation rules, we consider the case of a nonlinearity followed by d n a dense layer. Let z = z ( , θ) R × be the preactivations resulting from d distinct inputs at a node in some hidden layerX of a neural∈ network. Suppose z has NNGP kernel and NTK given by T T ∂zi ∂zi z = Eθ ziz , Θz = Eθ (1) K i ∂θ ∂θ "   #   d th where zi R is the i neuron and θ are the parameters in the network up until z. Here d is the cardinality∈ of the network inputs and n is the number of neurons in the z node. We X

6 Bayesian Deep Convolutional Networks with Many Channels are Gaussian Processes

Roman Novak^, Lechao Xiao*^, Jeahoon Lee*†, Yasaman Bahri*†, Greg Yang°, Daniel A. Abolafia, Jeffrey Pennington, Jascha Sohl-Dickstein Google Brain. *Google AI residents. ^,†Equal contribution. °Microsoft Research AI. Under review as a conference paper at ICLR 2019

1. Summary 5. SGD-trained finite vs Bayesian infinite CNNs (all displayed models are 100% accurate on the training set) Bayesian• Outputs of a deep convolutional Deep neural Convolutional network (CNN) (a) Networks(c) converge to a Gaussian Process ("CNN-GP") as the number withof channels Many in hidden layersChannels go to infinity. are Gaussian Processes Roman• CNN-GP Novak^ ,achieves Lechao Xiao* SOTA^, Jeahoonresults on Lee* CIFAR10†, Yasaman among Bahri* GPs†, Greg Yang°, Daniel A. Abolafia, Jeffrey Pennington, Jascha Sohl-Dickstein without trainable kernels. Google Brain. *Google AI residents. ^,†Equal contribution. °Microsoft Research AI. Under review as a conference paper at ICLR 2019 • Absent pooling CNNs with and without weight sharing 1. Summaryconverge to the same GP. Translation equivariance is lost in 5. SGD-trained finite vs Bayesian infinite CNNs (all displayed models are 100% accurate on the training set) infinite Bayesian CNNs (CNN-GPs), and finite SGD-trained (a) (c) • OutputsCNNs can of outperforma deep convolutional them. neural network (CNN) converge to a Gaussian Process ("CNN-GP") as the number

of channels in hidden layers go to infinity. (inputs covariance) 2. Convergence (an affine transformation) (convergence in distribution) Outputs of a CNN converge in distribution to (b) No Pooling Global Average Pooling • CNN-GP achieves SOTA results on CIFAR10 amongL GPs0 (before the last readout layer) •a Gaussian process with a compositional outputs inputs D 0, ( ) K InL+1 min n1,...,nL kernel (under mild conditions on the weights AAAC4nicbVJNb9QwEHXCV1k+usCRi8UWqYhVlfQCx0I5gKiqIrFtpXV25XidXauOHdkT6CqYH8ANceWPcebCz8D52KpsGSnyy5s3euMZp4UUFqLoVxBeu37j5q2N2707d+/d3+w/eHhsdWkYHzEttTlNqeVSKD4CAZKfFobTPJX8JD3br/Mnn7ixQquPsCx4ktO5EplgFDw17f/ZIpJnsE2An4PJK11CUYLFJBXzL1io+scRI+YLeEbOm5Maoz+PSS5UU0oqNYmHZKbBDtXkoNUSR0AToTJYJhXJKSwYldUb5y7woWuNoyG+4F45woRhXUcrdr9jL6nafrxZo3w/iVYdahA5t/jd1DdVHTyP3Uq7Ne0Pop2oCXwVxB0YoC6Opv3f/k6szLkCJqm14zgqIKmoAcEkdz1SWl5QdkbnfOyhot44qZqNOPzUMzOcaeM/BbhhL1dUNLd2madeWV/Lrudq8r+5NM3XrCF7mVTNorhirXNWSgwa1/vGM2E4A7n0gDIjfPOYLaihDPyr6PmpxOszuAqOd3dijz/sDvZed/PZQI/RE7SNYvQC7aG36AiNEAsOAwhc8DWchd/C7+GPVhoGXc0j9E+EP/8C5Jrtng== ! N A C A ⌦ without trainable kernels. { }!1 and biases distribution and nonlinearity). (#channels in hidden layers) ⇣ ⌘ (#channels in the outputs) (computes covariance of post-nonlinearity FCN (fully-connected): activations given normal pre-activations) • Absent pooling CNNs with and without weight sharing LCN converge to thez 0same GP. Translationz1 equivariancey2 = ϕ (z1) is lost in (CNN w/o weight sharing) z¯2 0 nonlinearity nonlinearity ∞

y ∞ ∞ ∞ ∞ ∞

infinite Bayesian CNNs (CNN-GPs),∞ and finite SGD-trained ∞

4 → → 3 → → → 2 → → → 1 2 1 2 n 2 2 1 1 10 classes n 1 n n n n n n CNNsmatrix-multiply can outperform them.matrix-multiply matrix-multiply per input

4 inputs, 3 features CNN (inputs covariance) 2. Convergence (an affine transformation) 4 3 (convergence in distribution) 2 0 σ2K0 σ2 1 σ2K1 σ2 2 Outputs1 of a CNNK converge in distributionω + tob K ω + b K % (b) No Pooling Global Average Pooling • L 0 (before the last readout layer) a Gaussian process with a compositional outputs inputs D 0, ( ) K InL+1 #Channels affine transform affine transformmin n1,...,nL affine transform 4x4 covariance kernel (under mild conditions on the weights AAAC4nicbVJNb9QwEHXCV1k+usCRi8UWqYhVlfQCx0I5gKiqIrFtpXV25XidXauOHdkT6CqYH8ANceWPcebCz8D52KpsGSnyy5s3euMZp4UUFqLoVxBeu37j5q2N2707d+/d3+w/eHhsdWkYHzEttTlNqeVSKD4CAZKfFobTPJX8JD3br/Mnn7ixQquPsCx4ktO5EplgFDw17f/ZIpJnsE2An4PJK11CUYLFJBXzL1io+scRI+YLeEbOm5Maoz+PSS5UU0oqNYmHZKbBDtXkoNUSR0AToTJYJhXJKSwYldUb5y7woWuNoyG+4F45woRhXUcrdr9jL6nafrxZo3w/iVYdahA5t/jd1DdVHTyP3Uq7Ne0Pop2oCXwVxB0YoC6Opv3f/k6szLkCJqm14zgqIKmoAcEkdz1SWl5QdkbnfOyhot44qZqNOPzUMzOcaeM/BbhhL1dUNLd2madeWV/Lrudq8r+5NM3XrCF7mVTNorhirXNWSgwa1/vGM2E4A7n0gDIjfPOYLaihDPyr6PmpxOszuAqOd3dijz/sDvZed/PZQI/RE7SNYvQC7aG36AiNEAsOAwhc8DWchd/C7+GPVhoGXc0j9E+EP/8C5Jrtng== ! N A C A ⌦ { }!1 per class and 1biases distribution and nonlinearity). (#channels in hidden layers) ⇣ ⌘ 2 ⊗ I10 3 (#channels in the outputs) Figure 3: 4 (a) SGD-trained(a): SGD-trained finite CNNs CNNs often often perform perform better better with with increasing increasing number number of channels. of channels. (computes covariance of post-nonlinearity FCN4x4 input covariance(fully-connected): activations given normal pre-activations) Each line corresponds to a particular choice of architecture and initialization hyperparameters, with best(b)LCN learning Performance rate and of weight SGD-trained decay selected finite CNNs independently often approaches for each number that of of the channels respective (x-axis).

z0 z1 y2 = ϕ (z1) (b):CNN-GPSGD-trained(CNN w/o weight sharing) with increasing CNNs often the approach number of the channels. performance of their corresponding CNN-GP (K)= (x) (x)T 2 0 Ex (0,K) 1 z 0 K AAACeHicbVFdS8MwFE3r9/yaH2++BKeoINLuRV8EUQRBEAWnwlpHkqVbMGlLciuOMv+nz/orfDJbq+jmhcDhnHO5N+fSVAoDnvfmuBOTU9Mzs3OV+YXFpeXqyuqdSTLNeIMlMtEPlBguRcwbIEDyh1Rzoqjk9/TpbKDfP3NtRBLfQi/loSKdWESCEbBUq/q6FSgCXUZkftYPJI9g9zLQotOFPXyMhxql+Xm/lb8ERqgf81Vp9va/7QXRDNKuKKSXUhhnHm8LEG61qjXvwBsWHgd+CWqorOtW9T1oJyxTPAYmiTFN30shzIkGwSTvV4LM8JSwJ9LhTQtjorgJ82FOfbxtmTaOEm1fDHjI/u7IiTKmp6h1Dv5pRrUB+a9GqRoZDdFRmIs4zYDHrJgcZRJDggdXwG2hOQPZs4AwLezymHWJJgzsrSo2FX80g3FwVz/wLb6p105Oy3xm0QbaRLvIR4foBF2ga9RADH04FWfNWXc+XezuuHuF1XXKnjX0p9z6F6yiwSY= C 1 ⇠N K 2 ¯ with increasing number of channels. All models have the same architecture except for pooling K 0 $ ( ) nonlinearity K $ ( ) nonlinearity K % (c) However, the best performing SGD CNN outperforms the best CNN-GP by a

h ∞ i

y ∞ ∞ ∞ ∞ ∞ ∞ ∞ and weight sharing, as well as training-related hyperparameters such as learning rate, weight decay

4 → → 3 → → significant margin. → 2 → → CNN (1D): → 1 2 1 2 n 2 2 1 1 10 classes n 1 n n n n

n x n and batch size, which are selected for each number of channels ( -axis) to maximize validation matrix-multiply matrix-multiply matrix-multiply per input 4 y 3 performance ( -axis) of a neural network. As the number of channels grows, best validation accu- 2 0 1 0 1 K 1 K 2 4 inputs,K 3 features $ ( ) K $ ( ) K % racy increases and approaches accuracy of the respective GP (solid horizontal line). (c): However,

6. CNN When do CNNs outperform CNN-GPs? affine transform affine transform affine transform 4x4 covariance the best-performing SGD-trained CNNs can outperform their corresponding CNN-GPs. Each

1 per class 86 2 point corresponds to the test accuracy of: (y-axis) a specific CNN-GP; (x-axis) the best (on valida- ⊗ I10 3 4 4 3 2 0 2 0 2 1 2 1 2 2 tion) CNN with the same architectural hyper-parameters selected among the 100%-accurate models 1 K σωK + σb K σωK + σb K % Published as4x4x10 a conferenceinput covariance paper at ICLR 2020 on the full training CIFAR10NN (underfitting dataset with different allowed)#Channels learning rates, weight decay and number of chan- (not always necessary to track the whole 4x4x10x10 covariance.) affine transform affine transform affine transform 4x4 covariance 0 1 0 1 2 1 nels.75 While CNN-GP appears competitive against 100%-accurate CNNs (above the diagonal), the 1 z0 y ϕ1 z z 1 y 2ϕ z per class

AAAB+3icbVC7SgNBFL0bX3F9RS1tBoNgFWa10EYM2lhGNA9I1jA7mU2GzOwuM7NCXPIJtvoB6cTWb/ATxNofcZJYaOKBC4dz7uVcTpAIrg3Gn05uYXFpeSW/6q6tb2xuFbZ3ajpOFWVVGotYNQKimeARqxpuBGskihEZCFYP+pdjv37PlOZxdGsGCfMl6UY85JQYK9083OF2oYhLeAI0T7wfUjx/d8+S0YdbaRe+Wp2YppJFhgqiddPDifEzogyngg3dVqpZQmifdFnT0ohIpv1s8uoQHVilg8JY2YkMmqi/LzIitR7IwG5KYnp61huL/3pBIGeiTXjqZzxKUsMiOk0OU4FMjMZFoA5XjBoxsIRQxe3ziPaIItTYulzbijfbwTypHZW84xK+xsXyBUyRhz3Yh0Pw4ATKcAUVqAKFLjzCEzw7Q2fkvDiv09Wc83OzC3/gvH0Dmi6X7Q== = ( ) AAAB+3icbVC7SgNBFL0bX3F9RS1tBoNgFXa10EYM2lhGNA9I1jA7mU2GzMwuM7NCXPIJtvoB6cTWb/ATxNofcZJYaOKBC4dz7uVcTphwpo3nfTq5hcWl5ZX8qru2vrG5Vdjeqek4VYRWScxj1QixppxJWjXMcNpIFMUi5LQe9i/Hfv2eKs1ieWsGCQ0E7koWMYKNlW4e7vx2oeiVvAnQPPF/SPH83T1LRh9upV34anVikgoqDeFY66bvJSbIsDKMcDp0W6mmCSZ93KVNSyUWVAfZ5NUhOrBKB0WxsiMNmqi/LzIstB6I0G4KbHp61huL/3phKGaiTXQaZEwmqaGSTJOjlCMTo3ERqMMUJYYPLMFEMfs8Ij2sMDG2Lte24s92ME9qRyX/uORde8XyBUyRhz3Yh0Pw4QTKcAUVqAKBLjzCEzw7Q2fkvDiv09Wc83OzC3/gvH0Dm8OX7g== = ( )

2 AAAB+3icbVC7SgNBFL0bXzG+opaKDAbBKuxqoWXQxjJB84BkDbOTSTJkZnaZmRWWJaWlrX6AndgK+RVrS3/CyaPQxAMXDufcy7mcIOJMG9f9dDJLyyura9n13Mbm1vZOfnevpsNYEVolIQ9VI8CaciZp1TDDaSNSFIuA03owuB779QeqNAvlnUki6gvck6zLCDZWuk3uvXa+4BbdCdAi8WakUDocVb4fj0bldv6r1QlJLKg0hGOtm54bGT/FyjDC6TDXijWNMBngHm1aKrGg2k8nrw7RiVU6qBsqO9Kgifr7IsVC60QEdlNg09fz3lj81wsCMRdtupd+ymQUGyrJNLkbc2RCNC4CdZiixPDEEkwUs88j0scKE2PrytlWvPkOFkntrOidF92KrecKpsjCARzDKXhwASW4gTJUgUAPnuAZXpyh8+q8Oe/T1Ywzu9mHP3A+fgAurphX z y z yAAAB+3icbVC7SgNBFL3rM8ZX1FKRwSBYhd1YaBm0sUzQPCBZw+xkkgyZmV1mZoVlSWlpqx9gJ7ZCfsXa0p9w8ig08cCFwzn3ci4niDjTxnU/naXlldW19cxGdnNre2c3t7df02GsCK2SkIeqEWBNOZO0apjhtBEpikXAaT0YXI/9+gNVmoXyziQR9QXuSdZlBBsr3Sb3xXYu7xbcCdAi8WYkXzoaVb4fj0fldu6r1QlJLKg0hGOtm54bGT/FyjDC6TDbijWNMBngHm1aKrGg2k8nrw7RqVU6qBsqO9Kgifr7IsVC60QEdlNg09fz3lj81wsCMRdtupd+ymQUGyrJNLkbc2RCNC4CdZiixPDEEkwUs88j0scKE2PrytpWvPkOFkmtWPDOC27F1nMFU2TgEE7gDDy4gBLcQBmqQKAHT/AML87QeXXenPfp6pIzuzmAP3A+fgAwQ5hY I 2⊗ 10 3 z 2 Figurebest CNNs 3: overall outperform CNN-GPs by a significant margin (below the diagonal, right). For

4 ¯ AAACAHicbVC7SgNBFL0bX3F9RS1tBoNgFXZjoY0YtLGMYB6QrGFmMpsMmZ1dZmaFGNL4Cbba2diJrT/gJ4i1P+LkUWjigQuHc+7lXA5JBNfG876czMLi0vJKdtVdW9/Y3Mpt71R1nCrKKjQWsaoTrJngklUMN4LVE8VwRASrkd7FyK/dMqV5LK9NP2FBhDuSh5xiY6V6k2CF7m6KrVzeK3hjoHniT0n+7MM9TZ4/3XIr991sxzSNmDRUYK0bvpeYYICV4VSwodtMNUsw7eEOa1gqccR0MBj/O0QHVmmjMFZ2pEFj9ffFAEda9yNiNyNsunrWG4n/eoREM9EmPAkGXCapYZJOksNUIBOjURuozRWjRvQtwVRx+zyiXawwNbYz17biz3YwT6rFgn9U8K68fOkcJsjCHuzDIfhwDCW4hDJUgIKAB3iEJ+feeXFenbfJasaZ3uzCHzjvPwxnmdI= (a): SGD-trained CNNs often perform better with increasing number of channels. zAAAB+3icbVC7SgNBFL0bX3F9RS1tBoNgFXZjoY0YtLGMaB6QrGF2MpsMmZldZmaFGPIJtvoB6cTWb/ATxNofcfIoNPHAhcM593IuJ0w408bzvpzM0vLK6lp23d3Y3Nreye3uVXWcKkIrJOaxqodYU84krRhmOK0nimIRcloLe1djv/ZAlWaxvDP9hAYCdySLGMHGSreP98VWLu8VvAnQIvFnJH/x4Z4no0+33Mp9N9sxSQWVhnCsdcP3EhMMsDKMcDp0m6mmCSY93KENSyUWVAeDyatDdGSVNopiZUcaNFF/Xwyw0LovQrspsOnqeW8s/uuFoZiLNtFZMGAySQ2VZJocpRyZGI2LQG2mKDG8bwkmitnnEelihYmxdbm2FX++g0VSLRb8k0LxxsuXLmGKLBzAIRyDD6dQgmsoQwUIdOAJnuHFGToj59V5m65mnNnNPvyB8/4DnfyX8Q== ¯ (a) SGD-trained finite CNNs often perform better with increasing number of channels. nonlinearity ∞ nonlinearity further analysis of factors leadingNN (no to similarunderfitting) or diverging behavior between SGD-trainedCNN-GP finite CNNs 4x4 0input0 0 covariance → Each line corresponds to a particular choice of architecture and initialization hyperparameters, with ∞ ∞ ∞ 2

AAAB/HicbVC7SgNBFL3rM8ZX1FKRwSBYhV0ttAzaWCZgHpAsYXYymwyZmV1mZoWwxM7WVj/ATmwlv2Jt6U84m6TQxAMXDufcy7mcIOZMG9f9dJaWV1bX1nMb+c2t7Z3dwt5+XUeJIrRGIh6pZoA15UzSmmGG02asKBYBp41gcJP5jXuqNIvknRnG1Be4J1nICDaZ1I77rFMouiV3ArRIvBkplo/G1e/H43GlU/hqdyOSCCoN4VjrlufGxk+xMoxwOsq3E01jTAa4R1uWSiyo9tPJryN0apUuCiNlRxo0UX9fpFhoPRSB3RTY9PW8l4n/ekEg5qJNeOWnTMaJoZJMk8OEIxOhrAnUZYoSw4eWYKKYfR6RPlaYGNtX3rbizXewSOrnJe+i5FZtPdcwRQ4O4QTOwINLKMMtVKAGBPrwBM/w4jw4r86b8z5dXXJmNwfwB87HD0B7mPY= AAAB/HicbVC7SgNBFL3rM8ZX1FKRwSBYhV0ttAzaWCZgHpAsYXYymwyZmV1mZoWwxM7WVj/ATmwlv2Jt6U84m6TQxAMXDufcy7mcIOZMG9f9dJaWV1bX1nMb+c2t7Z3dwt5+XUeJIrRGIh6pZoA15UzSmmGG02asKBYBp41gcJP5jXuqNIvknRnG1Be4J1nICDaZ1I77rFMouiV3ArRIvBkplo/G1e/H43GlU/hqdyOSCCoN4VjrlufGxk+xMoxwOsq3E01jTAa4R1uWSiyo9tPJryN0apUuCiNlRxo0UX9fpFhoPRSB3RTY9PW8l4n/ekEg5qJNeOWnTMaJoZJMk8OEIxOhrAnUZYoSw4eWYKKYfR6RPlaYGNtX3rbizXewSOrnJe+i5FZtPdcwRQ4O4QTOwINLKMMtVKAGBPrwBM/w4jw4r86b8z5dXXJmNwfwB87HD0B7mPY=

AAACCXicbVC7SgNBFL3rM8ZHVi0VGQyCVdjVQhshaGOZgHlAsobZyWwyZGZ3mZkVliWllZ9gqx9gF2z9CmtLf8LJo9DEAxcO59zLPRw/5kxpx/m0lpZXVtfWcxv5za3tnYK9u1dXUSIJrZGIR7LpY0U5C2lNM81pM5YUC5/Thj+4GfuNByoVi8I7ncbUE7gXsoARrI3UsQttgXWfYJ41h1fpvdOxi07JmQAtEndGiuXDUfX78WhU6dhf7W5EEkFDTThWquU6sfYyLDUjnA7z7UTRGJMB7tGWoSEWVHnZJPgQnRili4JImgk1mqi/LzIslEqFbzbHMdW8Nxb/9XxfzL3WwaWXsTBONA3J9HOQcKQjNK4FdZmkRPPUEEwkM+ER6WOJiTbl5U0r7nwHi6R+VnLPS07V1HMNU+TgAI7hFFy4gDLcQgVqQCCBZ3iBV+vJerNG1vt0dcma3ezDH1gfPxUHncI= yAAACC3icbVC7SgNBFJ2NrxhfGy0VGQyCVdiNhTZC0MYyAfOAJIbZyWwyZGZ2mZlVliWlpZ9gqx9gE8TWj7C29CecTVJo4oELh3Pu5R6OFzKqtON8Wpml5ZXVtex6bmNza3vHzu/WVRBJTGo4YIFsekgRRgWpaaoZaYaSIO4x0vCGV6nfuCNS0UDc6DgkHY76gvoUI22krp2Pbx14Adsc6QFGLGmOunbBKToTwEXizkihfDCufj8cjitd+6vdC3DEidCYIaVarhPqToKkppiRUa4dKRIiPER90jJUIE5UJ5lEH8Fjo/SgH0gzQsOJ+vsiQVypmHtmM42o5r1U/NfzPD73WvvnnYSKMNJE4OlnP2JQBzAtBvaoJFiz2BCEJTXhIR4gibA29eVMK+58B4ukXiq6p8VS1dRzCabIgn1wBE6AC85AGVyDCqgBDO7BE3gGL9aj9Wq9We/T1Yw1u9kDf2B9/ADGvJ4Y =y y 4 n Accuracy → X X → → and(b) infinite Performance Bayesian of CNNs SGD-trained see Tables finite1 and CNNs2. often approaches thatall networks of the respective have reached 3 64 Experimental details: 1 2 1 2 10 classes best learning rate and weight decay selected independently for each number of channels (x-axis). n 1 n n per input 100% training accuracy on CIFAR10. Values in (b) are reported on an 0.5K/4K(Fully-connected) train/validation sub- convolution convolution matrix-multiply (b):CNN-GPSGD-trained with increasing CNNs often the approach number of the channels. performance of their correspondingFCN-GPCNN-GP (K)= (x) (x)T 0 Ex (0,K) 1 set downsampled to 8 8 for computational reasons. See A.7.5 and A.7.1 for full experimental 4 inputs, K3 0features, K AAACeHicbVFdS8MwFE3r9/yaH2++BKeoINLuRV8EUQRBEAWnwlpHkqVbMGlLciuOMv+nz/orfDJbq+jmhcDhnHO5N+fSVAoDnvfmuBOTU9Mzs3OV+YXFpeXqyuqdSTLNeIMlMtEPlBguRcwbIEDyh1Rzoqjk9/TpbKDfP3NtRBLfQi/loSKdWESCEbBUq/q6FSgCXUZkftYPJI9g9zLQotOFPXyMhxql+Xm/lb8ERqgf81Vp9va/7QXRDNKuKKSXUhhnHm8LEG61qjXvwBsWHgd+CWqorOtW9T1oJyxTPAYmiTFN30shzIkGwSTvV4LM8JSwJ9LhTQtjorgJ82FOfbxtmTaOEm1fDHjI/u7IiTKmp6h1Dv5pRrUB+a9GqRoZDdFRmIs4zYDHrJgcZRJDggdXwG2hOQPZs4AwLezymHWJJgzsrSo2FX80g3FwVz/wLb6p105Oy3xm0QbaRLvIR4foBF2ga9RADH04FWfNWXc+XezuuHuF1XXKnjX0p9z6F6yiwSY= C K1 ⇠N K K2 with increasing number of channels. All models have the same architecture except for pooling CNN (1D) $ ( ) $ ( ) % However, the best performing SGD CNN outperforms the best CNN-GP by a activations (c) ⇥ § § 10 pixels h i anddetails weight of (a, sharing, c) and (b) as well plots as respectively. training-related hyperparameters such as learning rate, weight decay CNN (1D): significant53 margin. and batchFCN size, which areCNN w/ selected ERF for eachCNN w/ numbersmall learning of rate channelsCNN w/ ReLU (x -axis)& large learning to maximize rate CNN validation w/ pooling (error function non-linearity) 4 Under review as a conference paper at ICLR 2019 y 3 performance ( -axis) of a neural network. As the number of channels grows, best validation accu- 0 0 T 2 0 00 11 ˜1 1 22 ˜vec2

AAAB+3icbVC7SgNBFL0bX3F9RS1tBoNgFWa10EYM2gg2Ec0DkjXMTmaTITO7y8ysEJZ8gq1+QDqx9Rv8BLH2R5wkFpp44MLhnHs5lxMkgmuD8aeTW1hcWl7Jr7pr6xubW4XtnZqOU0VZlcYiVo2AaCZ4xKqGG8EaiWJEBoLVg/7l2K8/MKV5HN2ZQcJ8SboRDzklxkq31/e4XSjiEp4AzRPvhxTP392zZPThVtqFr1YnpqlkkaGCaN30cGL8jCjDqWBDt5VqlhDaJ13WtDQikmk/m7w6RAdW6aAwVnYigybq74uMSK0HMrCbkpienvXG4r9eEMiZaBOe+hmPktSwiE6Tw1QgE6NxEajDFaNGDCwhVHH7PKI9ogg1ti7XtuLNdjBPakcl77iEb3CxfAFT5GEP9uEQPDiBMlxBBapAoQuP8ATPztAZOS/O63Q15/zc7MIfOG/fT3WXvg== ˜ AAACD3icbVDLSsNAFJ34rPEVFVdugkVwVZK60I1YdCO4qWAf0MQymUzaoTOTMDMRSshH+AniTj/AnbgTv0Bc+yNO2i609cCFwzn3cg8nSCiRynG+jLn5hcWl5dKKubq2vrFpbW03ZZwKhBsoprFoB1BiSjhuKKIobicCQxZQ3AoGF4XfusNCkpjfqGGCfQZ7nEQEQaWlrrXrKUJDnHkMqj6CNLvK81u3a5WdijOCPUvcCSmfvZunyeOnWe9a314Yo5RhrhCFUnZcJ1F+BoUiiOLc9FKJE4gGsIc7mnLIsPSzUfzcPtBKaEex0MOVPVJ/X2SQSTlkgd4sUspprxD/9YKATb1W0YmfEZ6kCnM0/hyl1FaxXZRjh0RgpOhQE4gE0eFt1IcCIqUrNHUr7nQHs6RZrbhHleq1U66dgzFKYA/sg0PggmNQA5egDhoAgQw8gCfwbNwbL8ar8TZenTMmNzvgD4yPH7n8oDo=

AAACD3icbVDLSsNAFJ34rPEVFVdugkVwVZK60I1YdCO4qWAf0MQymUzaoTOTMDMRSshH+AniTj/AnbgTv0Bc+yNO2i609cCFwzn3cg8nSCiRynG+jLn5hcWl5dKKubq2vrFpbW03ZZwKhBsoprFoB1BiSjhuKKIobicCQxZQ3AoGF4XfusNCkpjfqGGCfQZ7nEQEQaWlrrXrKUJDnHkMqj6CNLvK89tq1yo7FWcEe5a4E1I+ezdPk8dPs961vr0wRinDXCEKpey4TqL8DApFEMW56aUSJxANYA93NOWQYelno/i5faCV0I5ioYcre6T+vsggk3LIAr1ZpJTTXiH+6wUBm3qtohM/IzxJFeZo/DlKqa1iuyjHDonASNGhJhAJosPbqA8FREpXaOpW3OkOZkmzWnGPKtVrp1w7B2OUwB7YB4fABcegBi5BHTQAAhl4AE/g2bg3XoxX4228OmdMbnbAHxgfP7uRoDs=

AAACBXicbVC7SgNBFL3rM66vqKXNYhCswm4stBGDNoJNBPOAZA2zk9lkyMzsOjMrhCW1n2CrWNuJraWfINb+iLNJCk08cOFwzr3cwwliRpV23S9rbn5hcWk5t2Kvrq1vbOa3tmsqSiQmVRyxSDYCpAijglQ11Yw0YkkQDxipB/3zzK/fEaloJK71ICY+R11BQ4qRNpLf4kj3MGLp5fDGa+cLbtEdwZkl3oQUTj/sk/j5066089+tToQTToTGDCnV9NxY+ymSmmJGhnYrUSRGuI+6pGmoQJwoPx2FHjr7Ruk4YSTNCO2M1N8XKeJKDXhgNrOQatrLxH+9IOBTr3V47KdUxIkmAo8/hwlzdORklTgdKgnWbGAIwpKa8A7uIYmwNsXZphVvuoNZUisVvcNi6cotlM9gjBzswh4cgAdHUIYLqEAVMNzCAzzCk3VvvViv1tt4dc6a3OzAH1jvP39lnFM=

AAACBXicbVC7SgNBFL3rM66vqKXNYhCswm4stBGDNoJNBPOAZA2zk9lkyMzsOjMrhCW1n2CrWNuJraWfINb+iLNJCk08cOFwzr3cwwliRpV23S9rbn5hcWk5t2Kvrq1vbOa3tmsqSiQmVRyxSDYCpAijglQ11Yw0YkkQDxipB/3zzK/fEaloJK71ICY+R11BQ4qRNpLf4kj3MGLp5fCm1M4X3KI7gjNLvAkpnH7YJ/Hzp11p579bnQgnnAiNGVKq6bmx9lMkNcWMDO1WokiMcB91SdNQgThRfjoKPXT2jdJxwkiaEdoZqb8vUsSVGvDAbGYh1bSXif96QcCnXuvw2E+piBNNBB5/DhPm6MjJKnE6VBKs2cAQhCU14R3cQxJhbYqzTSvedAezpFYqeofF0pVbKJ/BGDnYhT04AA+OoAwXUIEqYLiFB3iEJ+veerFerbfx6pw1udmBP7DefwCA+pxU

AAACInicbVDLSgMxFM34rONr1KWbYBG6KjN1oZti0Y3gpkJf0BeZNNOGZjJDkhHKML/gT+gnuNWt4E5ciODOHzHTFqmtBwLnnnMv9+a4IaNS2fansbS8srq2ntkwN7e2d3atvf2aDCKBSRUHLBANF0nCKCdVRRUjjVAQ5LuM1N3hZerXb4mQNOAVNQpJ20d9Tj2KkdJS18q1fKQGGLH4OunYsAh/60YyQzuVrpW18/YYcJE4U5I9fzGL4f2HWe5a361egCOfcIUZkrLp2KFqx0goihlJzFYkSYjwEPVJU1OOfCLb8fhHCTzWSg96gdCPKzhWZydi5Es58l3dmR4p571U/NdzXX9utfLO2jHlYaQIx5PNXsSgCmCaF+xRQbBiI00QFlQfD/EACYSVTtXUqTjzGSySWiHvnOQLN3a2dAEmyIBDcARywAGnoASuQBlUAQZ34BE8gWfjwXg13oz3SeuSMZ05AH9gfP0AGq+nzA== K 1 AAACD3icbVDLSsNAFJ34rPEVFVdugkVwVZK60I1YdCO4qWAf0MQymUzaoTOTMDMRSshH+AniTj/AnbgTv0Bc+yNO2i609cCFwzn3cg8nSCiRynG+jLn5hcWl5dKKubq2vrFpbW03ZZwKhBsoprFoB1BiSjhuKKIobicCQxZQ3AoGF4XfusNCkpjfqGGCfQZ7nEQEQaWlrrXrKUJDnHkMqj6CNLvK81una5WdijOCPUvcCSmfvZunyeOnWe9a314Yo5RhrhCFUnZcJ1F+BoUiiOLc9FKJE4gGsIc7mnLIsPSzUfzcPtBKaEex0MOVPVJ/X2SQSTlkgd4sUspprxD/9YKATb1W0YmfEZ6kCnM0/hyl1FaxXZRjh0RgpOhQE4gE0eFt1IcCIqUrNHUr7nQHs6RZrbhHleq1U66dgzFKYA/sg0PggmNQA5egDhoAgQw8gCfwbNwbL8ar8TZenTMmNzvgD4yPH7hnoDk= KK KK K KK

AAACBHicbVC7SgNBFL0bXzG+opY2Q0RIFXa1iGXQRrCJYB6QXcPsZDYZMju7zMwKy5LWT7Cx0A+wE1t7P0G080ecPApNPHDhcM69nMvxY86Utu0PK7e0vLK6ll8vbGxube8Ud/eaKkokoQ0S8Ui2fawoZ4I2NNOctmNJcehz2vKH52O/dUulYpG41mlMvRD3BQsYwdpI7uWN081cJgKdjrrFQ7tiT4AWiTMjh7VS+euz+v5Q7xa/3V5EkpAKTThWquPYsfYyLDUjnI4KbqJojMkQ92nHUIFDqrxs8vMIHRmlh4JImhEaTdTfFxkOlUpD32yGWA/UvDcW//V8P5yL1sGplzERJ5oKMk0OEo50hMaNoB6TlGieGoKJZOZ5RAZYYqJNbwXTijPfwSJpHleck4p9Zeo5gynycAAlKIMDVajBBdShAQRiuIdHeLLurGfrxXqdruas2c0+/IH19gOqA5yR AAACBHicbVC7SgNBFJ2NrxhfUUubIUFIFXZjEcugjWATwTwgG8PsZDYZMju7zNwVliWtn2BjoR9gJ7b2foJo5484eRSaeODC4Zx7OZfjRYJrsO0PK7Oyura+kd3MbW3v7O7l9w+aOowVZQ0ailC1PaKZ4JI1gINg7UgxEniCtbzR+cRv3TKleSivIYlYNyADyX1OCRjJvbyp9FKXSx+ScS9ftMv2FHiZOHNSrBVKX5/V94d6L//t9kMaB0wCFUTrjmNH0E2JAk4FG+fcWLOI0BEZsI6hkgRMd9Ppz2N8bJQ+9kNlRgKeqr8vUhJonQSe2QwIDPWiNxH/9TwvWIgG/7SbchnFwCSdJfuxwBDiSSO4zxWjIBJDCFXcPI/pkChCwfSWM604ix0sk2al7JyU7StTzxmaIYuOUAGVkIOqqIYuUB01EEURukeP6Mm6s56tF+t1tpqx5jeH6A+stx+roZyS = AAACHXicbVC7SgNBFJ2NrxhfUUubRRGswq4WsQzaCDYRTBTcGGYnd3VwZnaZuRsMy36A32DhJ9jqB1gIYiuinT/iJLHQxAMDh3PO5d45YSK4Qc97dwoTk1PTM8XZ0tz8wuJSeXmlaeJUM2iwWMT6NKQGBFfQQI4CThMNVIYCTsKr/b5/0gVteKyOsZdAS9ILxSPOKFqpXd4IJMVLRkV2mLezgKsIe/l5FiBco5ZZF1ie25RX8QZwx4n/QzZq61ufH9Xn23q7/BV0YpZKUMgENebM9xJsZVQjZwLyUpAaSCi7ohdwZqmiEkwrG3wmdzet0nGjWNun0B2ovycyKo3pydAm+6ebUa8v/uuFoRxZjdFuK+MqSREUG26OUuFi7ParcjtcA0PRs4Qyze3xLrukmjK0hZZsK/5oB+OkuV3xdyreka1njwxRJGtknWwRn1RJjRyQOmkQRm7IPXkgj86d8+S8OK/DaMH5mVklf+C8fQNy+KfT $ ( ) $AAACHHicbVA9SwNBEN2LXzF+RS1tjgQhacKdFlpGbQSbCOYDcvHY2+wlS/b2jt05IR7p7WytLC1t9QfYia1g7R9xL0mhiQ8GHu/NMDPPizhTYFlfRmZhcWl5JbuaW1vf2NzKb+80VBhLQusk5KFseVhRzgStAwNOW5GkOPA4bXqDs9Rv3lCpWCiuYBjRToB7gvmMYNCSmy84AYY+wTw5GTmc+lC6uLZdhwkfho5kvT6U3XzRqlhjmPPEnpJitXx/G2WfHmpu/tvphiQOqADCsVJt24qgk2AJjHA6yjmxohEmA9yjbU0FDqjqJONfRua+VrqmH0pdAsyx+nsiwYFSw8DTnenlatZLxX89zwtmVoN/3EmYiGKggkw2+zE3ITTTpMwuk5QAH2qCiWT6eJP0scQEdJ45nYo9m8E8aRxU7MOKdanjOUUTZNEeKqASstERqqJzVEN1RNAdekYv6NV4NN6Md+Nj0poxpjO76A+Mzx834KUz ( ) AAACFXicbVC7TgJBFJ3FF+Jr1ZJmIjGBhuxqoSVqY2KDiTwSFsnsMAsTZh+ZuWuCGwq/wFK/QFv9ADtja23tjzgLFAqe5CYn59ybe+9xI8EVWNaXkVlYXFpeya7m1tY3NrfM7Z26CmNJWY2GIpRNlygmeMBqwEGwZiQZ8V3BGu7gLPUbN0wqHgZXMIxY2ye9gHucEtBSx8w7PoE+JSI5GTmCeVC8uLYcyXt9KHXMglW2xsDzxJ6SQqV0fxtlnx6rHfPb6YY09lkAVBClWrYVQTshEjgVbJRzYsUiQgekx1qaBsRnqp2Mnxjhfa10sRdKXQHgsfp7IiG+UkPf1Z3pyWrWS8V/Pdf1Z1aDd9xOeBDFwAI62ezFAkOI04hwl0tGQQw1IVRyfTymfSIJBR1kTqdiz2YwT+oHZfuwbF3qeE7RBFmUR3uoiGx0hCroHFVRDVF0h57RC3o1How34934mLRmjOnMLvoD4/MHS8CiBw== % K XX AK K1 A K 1 K1 KK1 racy increases and approaches accuracy of the respective GP (solid horizontal line). (c): However, 6. When do CNNs outperform10 CNN-GPs? 3. Differentaffine transform NN-GP architecturesaffine transform affine transform 4x4 covariance the7. best-performing Disentangling SGD-trained the CNNs canCNN outperform architecture their corresponding CNN-GPs. Each 1 (Monte Carlo estimate)per class AAACA3icbVC7SgNBFL0bX3F9RS1tBoNgFXa10EaM2lhGMA9IljA7mSRDZmeWmVkhLCn9BFstLO3E1tpPEGt/xNkkhSYeuHA4517u4YQxZ9p43peTW1hcWl7Jr7pr6xubW4XtnZqWiSK0SiSXqhFiTTkTtGqY4bQRK4qjkNN6OLjK/PodVZpJcWuGMQ0i3BOsywg2Vmq2Imz6BPP0YtQuFL2SNwaaJ/6UFM8/3LP4+dOttAvfrY4kSUSFIRxr3fS92AQpVoYRTkduK9E0xmSAe7RpqcAR1UE6jjxCB1bpoK5UdoRBY/X3RYojrYdRaDeziHrWy8R/vTCMZl6b7mmQMhEnhgoy+dxNODISZYWgDlOUGD60BBPFbHhE+lhhYmxtrm3Fn+1gntSOSv5xybvxiuVLmCAPe7APh+DDCZThGipQBQISHuARnpx758V5dd4mqzlnerMLf+C8/wA5N5uk AAACA3icbVC7SgNBFL0bX3F9RS1tBoNgFXa10EaM2lhGMA9IljA7mSRDZmeWmVkhLCn9BFstLO3E1tpPEGt/xNkkhSYeuHA4517u4YQxZ9p43peTW1hcWl7Jr7pr6xubW4XtnZqWiSK0SiSXqhFiTTkTtGqY4bQRK4qjkNN6OLjK/PodVZpJcWuGMQ0i3BOsywg2Vmq2Imz6BPP0YtQuFL2SNwaaJ/6UFM8/3LP4+dOttAvfrY4kSUSFIRxr3fS92AQpVoYRTkduK9E0xmSAe7RpqcAR1UE6jjxCB1bpoK5UdoRBY/X3RYojrYdRaDeziHrWy8R/vTCMZl6b7mmQMhEnhgoy+dxNODISZYWgDlOUGD60BBPFbHhE+lhhYmxtrm3Fn+1gntSOSv5xybvxiuVLmCAPe7APh+DDCZThGipQBQISHuARnpx758V5dd4mqzlnerMLf+C8/wA5N5uk 86 2 A A I point corresponds to the test accuracy of: (y-axis) a specific CNN-GP; (x-axis) the best (on valida- 3 ⊗ 10

AAACA3icbVC7SgNBFL0bX3F9RS1tBoNgFXa10EYMprGMYB6QLGF2MkmGzMwuM7NCWFL6CbZaWNqJrbWfINb+iLNJCk08cOFwzr3cwwljzrTxvC8nt7S8srqWX3c3Nre2dwq7e3UdJYrQGol4pJoh1pQzSWuGGU6bsaJYhJw2wmEl8xt3VGkWyVszimkgcF+yHiPYWKnVFtgMCOZpZdwpFL2SNwFaJP6MFC8/3Iv4+dOtdgrf7W5EEkGlIRxr3fK92AQpVoYRTsduO9E0xmSI+7RlqcSC6iCdRB6jI6t0US9SdqRBE/X3RYqF1iMR2s0sop73MvFfLwzF3GvTOw9SJuPEUEmmn3sJRyZCWSGoyxQlho8swUQxGx6RAVaYGFuba1vx5ztYJPWTkn9a8m68YvkKpsjDARzCMfhwBmW4hirUgEAED/AIT8698+K8Om/T1Zwzu9mHP3DefwA8Y5um AAACA3icbVC7SgNBFL0bX3F9RS1tBoNgFXa10EYMprGMYB6QLGF2MkmGzMwuM7NCWFL6CbZaWNqJrbWfINb+iLNJCk08cOFwzr3cwwljzrTxvC8nt7S8srqWX3c3Nre2dwq7e3UdJYrQGol4pJoh1pQzSWuGGU6bsaJYhJw2wmEl8xt3VGkWyVszimkgcF+yHiPYWKnVFtgMCOZpZdwpFL2SNwFaJP6MFC8/3Iv4+dOtdgrf7W5EEkGlIRxr3fK92AQpVoYRTsduO9E0xmSI+7RlqcSC6iCdRB6jI6t0US9SdqRBE/X3RYqF1iMR2s0sop73MvFfLwzF3GvTOw9SJuPEUEmmn3sJRyZCWSGoyxQlho8swUQxGx6RAVaYGFuba1vx5ztYJPWTkn9a8m68YvkKpsjDARzCMfhwBmW4hirUgEAED/AIT8698+K8Om/T1Zwzu9mHP3DefwA8Y5um

AAACA3icbVC7SgNBFL3rM66vqKXNYBCswm4stBGDNpYR8oJkCbOTSTJkZmeZmRXCktJPsNXC0k5srf0EsfZHnE1SaOKBC4dz7uUeThhzpo3nfTlLyyura+u5DXdza3tnN7+3X9cyUYTWiORSNUOsKWcRrRlmOG3GimIRctoIh9eZ37ijSjMZVc0opoHA/Yj1GMHGSq22wGZAME+r406+4BW9CdAi8WekcPnhXsTPn26lk/9udyVJBI0M4Vjrlu/FJkixMoxwOnbbiaYxJkPcpy1LIyyoDtJJ5DE6tkoX9aSyExk0UX9fpFhoPRKh3cwi6nkvE//1wlDMvTa98yBlUZwYGpHp517CkZEoKwR1maLE8JElmChmwyMywAoTY2tzbSv+fAeLpF4q+qfF0q1XKF/BFDk4hCM4AR/OoAw3UIEaEJDwAI/w5Nw7L86r8zZdXXJmNwfwB877D1f9m7k= AAACA3icbVC7SgNBFL3rM66vqKXNYBCswm4stBGDNpYR8oJkCbOTSTJkZmeZmRXCktJPsNXC0k5srf0EsfZHnE1SaOKBC4dz7uUeThhzpo3nfTlLyyura+u5DXdza3tnN7+3X9cyUYTWiORSNUOsKWcRrRlmOG3GimIRctoIh9eZ37ijSjMZVc0opoHA/Yj1GMHGSq22wGZAME+r406+4BW9CdAi8WekcPnhXsTPn26lk/9udyVJBI0M4Vjrlu/FJkixMoxwOnbbiaYxJkPcpy1LIyyoDtJJ5DE6tkoX9aSyExk0UX9fpFhoPRKh3cwi6nkvE//1wlDMvTa98yBlUZwYGpHp517CkZEoKwR1maLE8JElmChmwyMywAoTY2tzbSv+fAeLpF4q+qfF0q1XKF/BFDk4hCM4AR/OoAw3UIEaEJDwAI/w5Nw7L86r8zZdXXJmNwfwB877D1f9m7k=

CNN-GP Various CNN4 architectural decisions

covariance • TC TC tion) CNN with the same architectural hyper-parameters selected among the 100%-accurate models like pooling / vectorizing / 4x4x10 input covariance 85 NN (underfitting allowed) subsampling, zero or no padding on the full training CIFAR10 dataset with different learning rates, weight decay and number of chan- (not always necessary to track the whole 4x4x10x10 covariance.) have a corresponding NN-GP. 0 nels.75 While CNN-GP appears competitive against 100%-accurate CNNs (above the diagonal), the z y1 = ϕ (z0) z1 y2 = ϕ (z1) 2 74 • Pooling enforces translation-invariant z¯ best CNNs overall outperform CNN-GPs by a significant margin (below the diagonal, right). For Layerpredictions in both CNNs and Tensor CNN- Opnonlinearity NNGP∞ Opnonlinearity NTK Op further analysis of factors leadingNN (no to similarunderfitting) or diverging behavior between SGD-trainedCNN-GP finite CNNs GPs and allows0 for the best → ∞ ∞ ∞ Accuracy y 2 4 n 63 Accuracy → performance. → → and infinite Bayesian CNNs see Tables 1 and 2. all networks have reached 3 64 Experimental details: 1 2 1 2 10 classes n 1 0 n 0 T n 0 (Fully-connected) 0 (input) convolutiony = convolution = Θ =matrix-multiply 0 per input 100% training accuracy on CIFAR10. Values in (b) are reported on an 0.5K/4K train/validationFCN-GP sub- • Shallow CNN models can perform XK XX worse than fully-connected 0 0 0 0 0 0 0 set downsampled52 to 8 8 for computational reasons. See A.7.5 and A.7.1 for full experimental 4 inputs, 3 features, ˜ ˜ ˜ (attaching a fully-connected layer FCN FCN-GP LCN LCN w/pooling CNN-GP CNN CNN w/ pooling 0 (pre-activations)alternatives due to failing toz capture= Conv y = Θ = +only to the centerΘ pixel of the output) Model ⇥ § § 10 pixels details of (a,(Fully-connected c) and (b) network) plots respectively.(CNN w/o weight sharing) non-linear interactions between K A K K A 53 distant pixels. 1 0 1 ˜0 1 ˙ ˜0 ˜ 0 FCN CNN w/ ERF CNN w/ small learning rate CNN w/ ReLU & large learning rate CNN w/ pooling 1 (activations) y = φ z  =  Θ = Θ Quality Compositionality(error function non-linearity) Local connectivity Equivariance Invariance UnderFigure review 1: Different as a conference dimensionalityK paperat collapsingT ICLRK 2019 strategies describedT inK 3. Validation accuracy of 1 (pre-activations) z1 = Convan MC-CNN-GPUndery1 Under review review withas a conferencepooling as a conference˜1( paper3.2.1= ) paper at is ICLR consistently at1 ICLR 2019 2019Θ˜ better1 = than˜1 other+§ models Θ1 due to translation invariance of the kernel. CNN-GPK § withA zeroK padding ( 3.1)K outperformsA an analogous CNN- Under review as a10 conference paper at ICLR 2019 3.4. DiComputingfferent NN-GP the NN-GP architectures covariance15 § 2 GP1 without padding as depth increases.2 At˜ depth1 the2 spatial˙ dimension˜(Monte1 Carlo of estimate) the˜ 1 output without 7. Disentangling the CNN architecture 2 (activations) y = φ zpadding is reduced to 1 1, making= the CNN-GP  withoutΘ padding= equivalent Θ to the center pixel • ManyVarious di ffCNNerent architectural NN architectures decisions converge selection to a GP. strategy ( 3.2.2⇥ ) –K which alsoT performsK worse thanT the CNN-GPK (we conjecture, due 8. Large-depth behavior like pooling / vectorizing / 2 §2 2  2 2 2  2 to overfitting to centrally-located˜ features) but approaches˜ the˜ latter (right) in the limit of large 85 2 (readout)• Somesubsampling admit a, tractablezero or no analytic paddingz =form. Dense E.g. the covarianceFlatten of a y = Tr Θ = + Tr Θ Depth: 1 10 100 1000 Phase boundary depth, as information becomes more uniformly spatially distributed (Xiao et al., 2018). CNN-GPs • CNNs suffer from exploding or vanishing CNN-GPhave a corresponding with no pooling NN-GP. can be computed using◦ K K K gradients at large depths unless they are O(#samples^2 x #pixels) resources. generally outperform FCN-GP, presumably due to the local connectivity prior, but can fail to capture initialized correctly. The phase boundary between    74 Figure• 4:PoolingAn enforces example translation-invariant of the translation nonlinear interactions of a convolutional between spatially-distant neural pixels network at shallow depthsinto (left).a sequence Values areof reported order and chaos (rightmost plot for ERF • Forpredictions other architectures in both CNNs we and use CNN- a Monte Carloon a approach. 2K/4K train/validation subsets of CIFAR10. See A.7.3 for experimental details. nonlinearity) is the region where large-depth

kernel operations.SamplingGPs and allows finite random forWe the bestnetworks demonstrate of a given architecture how the and compositional nature of a typical§ NN computation on its CNN-GP trainingAccuracy is feasible. inputs inducesempiricallyperformance. a computing corresponding the output covariance compositional allows to computation on the NNGP and NT kernels. Presented 63

approximate the CNN-GP covariance. In our experimentsMC-CNN-GP thisMC-CNN-GP 1 • CNN-GP exhibits similar behavior. Initialized away is a 2-hidden-layer• (biased)Shallow estimateCNN models converges 1D can CNNperform to the true with CNN-GP nonlinearity1. covarianceGlobal in average φ, pooling: performingtake h = regressiond 1d. Then on the 10-dimensional from the phase boundary, the covariance between #(instantiated networks) and #channels, both in terms of 52

worse2 than fully-connected all different inputs becomes asymptotically FCN-GP outputs zcovariancefor each Frobenius of thedistance 4 ( 1and, 2the, 3GP, accuracy4) inputs. x from the dataset .2 To declutter notation,(attaching unita fully-connected weight layer FCN FCN-GP LCN (Fully-connected) LCN w/pooling CNN-GP CNN CNN w/ pooling alternatives due to failing to capture pool ! (#NN instantiations)L+1 2only to the center pixel of the output) constantModel and makes learning infeasible. non-linear interactions between X K2 + b . (17) (Fully-connected network) (CNN w/o weight sharing) and zero bias variances are assumed inFigure allFigure layers. 2: Validation 2: ValidationTop: accuracyrecursive accuracy(left)K1 of(left)⌘ output and2 MC-CNN-GPof an MC-CNN-GP(z 1) computation increases↵,↵0 increases with n with inM then (i.e.M CNN channel(i.e. channel count count distant pixels. ↵,↵0 ⇥ ⇥ 2 timestimes number number of samples) of˜2 samples) and approaches and approaches thatX of that the⇥ of exact the⇤ CNN-GPexact CNN-GP (not shown), (not shown), while while the distance the distance Table 3: Validation accuracy of CNN- and FCN- GPs as a function of weight (!, horizontal (top) induces a respective recursive NNGP kernel ( I10) computation (NTK computation being Quality Compositionality axis) and biasLocal (2, vertical connectivity axis) variances. As predictedEquivariance in A.2, the regions ofInvariance good performance (right)This(right) to approach the to exact the exactcorresponds kernel kernel decreases. decreases.to applying The dark The global band dark average in band the in left pooling the plot left corresponds right plot corresponds after the to lastill-conditioning to convolu- ill-conditioning b § Figure 1:L Different+1 L+1 dimensionality8 K ⊗ collapsing strategies describedL in+1 3L.+1 Validation accuracy of concentrate around the critical line (phase boundary, right) as the depth increases (left to right). similar, not shown). Bottom: explicitUnderof listingKtionaln,MUnderof reviewKwhen layer.n,Mreview of aswhen the atensorThis conference numberas the a approach conference number and of paper outer correspondingtakes of paper at outerproducts ICLR all at pixel-pixelproducts ICLR 2019 contributing 2019 contributing covariances kernel to Kn,M to ops intoK§approximatelyn,M consideration inapproximately each layer. equals and equals makes its rank. its rank. All plots share common axes ranges and employ the erf nonlinearity. See A.7.2 for experimental an MC-CNN-GP with pooling ( 3.2.1) is consistently better than other2 2 models due to translation § See Table1 for operation definitions.invarianceValues IllustrationtheValues of reported kernel thereported kernel. translation are and forCNN-GP area description forinvariant.3-layer§ a 3-layer with model However, zero modelapplied adapted padding appliedit requiresto a from( 2K/4K to3.1 a) 2K/4K outperforms Figuretrain/validationd train/validationmemory 3 an in analogoussubset Novak to compute subset of CIFAR10CNN- of the CIFAR10 Underdetails. review as a conference paper at ICLR 2019 4. Computing the NN-GP covariance2 O |X | GP withoutdownsampledsample-sampledownsampled padding to 8as covariance to depth88. See increases.8. Figure Seeof the Figure7 GP Atfor depth (or7 similarfor15 similard resultstheper§ spatial results covariance with dimension other with entry other architectures of in architectures the an iterative output and withoutA.7.2 andor dis-A.7.2for for et al.(2019). experimentalexperimental details. details.⇥ ⇥ O § § padding istributed reduced setting). to 1 It1, is making impractical the CNN-GP to use this without method padding to analyticallyequivalent evaluate to the thecenter GP, and pixel we A.2 RELATIONSHIP TO DEEP SIGNAL PROPAGATION • Many different NN architectures convergeselection to a GP. propose strategy to( use3.2.2⇥ a Monte) – which Carlo also approach performs (see worse4). than the CNN-GP (we conjecture, due 8. Large-depth behavior § § The recurrence relation linking the GP kernel at layer l +1to that of layer l following from Equa- to overfitting to centrally-located features) but approaches the latter (right) in the limit of large l+1 l • Some admit a tractable analytic form. E.g. the covariancerandom2. Subsamplingrandomfinite of a widthfinite onewidth networks particular networks and use pixel:and the use empiricaltake the empiricalh = uncenterede↵, uncentered covariances covariances of activations of activations to estimate to estimateCNNs suffer from exploding or vanishing tionDepth:10 (i.e. K 1=( 10) K ) is 100 precisely the 1000covariance map examinedPhase boundary in a series of related depth, as information becomes more uniformly spatially distributed (Xiao et al., 2018). CNN-GPs • 1 C A 1 CNN-GP with no pooling can be computed usingthe Montethe Monte Carlo-GP Carlo-GP (MC-GP) (MC-GP) kernel, kernel, gradients at large depths unless they are papers on signal propagation (Xiao et al., 2018; Poole et al., 2016; Schoenholz et al., 2017; Lee assume z is a mean zero multivariate Gaussian. We wish toe↵ compute2 L+1 the kernel2 corresponding O(#samples^2 x #pixels) resources. generally outperform FCN-GP, presumably due to! theK local connectivity+ b . prior, but can fail to capture(18) initialized correctly. The phase boundary between et al., 2018) (modulo notational differences; denoted as F , or e.g. ? in Xiao et al. (2018)). K1 ⌘ 1 ↵,↵ C A C to h = Dense (σω, σb)(φ(z)) bynonlinear computing interactions the between kernels spatially-distant of y = φ1( pixelsz)M1and atnM shallowhn = depthsDense (left).(σ Valuesω, σb are)(y reported) order and chaos (rightmost plot for ERF In those works, the action of this map on hidden-state covariance matrices was interpreted as defin- For other architectures we use a Monte Carlo approach. l l ⇥ ⇤ l l l l • on a 2K/4KThis train/validation approach makesKn,M subsetsK usen,M of( ofx, only CIFAR10. x0)( onex, x0 pixel-pixel) See A.7.3 covariance,yc↵for(xy experimental;c✓↵m()x andy; c✓↵m requires)(xyc0;↵ details.✓m(x) the0; ✓m same) amount(19) (19)nonlinearity) is the region where large-depth ing a dynamical system whose large-depth behavior informs aspects of trainability. In particular, ↵,↵0 ↵,↵0 0 0 l+1 l l separately.Sampling Here, finite random networks of a given architecture and ⌘ Mn⌘ Mn§ training is feasible. as l CNN-GP , K =( ) K K K⇤ , i.e. the covariance approaches a fixed point of memory as vectorization ( 3.1) to compute.m=1 c=1m=1 c=1 !1 1 C A 1 ⇡ 1 ⌘ 1 empirically computing the output covariance allows to ⇥ ⇥ ⇤ ⇤ § X XX X K⇤ . The convergence to a fixed point is problematic for learning because the hidden states no 1 approximate the CNN-GP covariance. In our experimentswhereMC-CNN-GP where✓ consists thisMC-CNN-GP ✓ consists of M ofdrawsM draws of the of weights the weights and biases and biases from their from prior their distribution, prior distribution,✓ ✓p (✓),p and(✓), andCNN-GP exhibits similar behavior. Initialized away longer contain information that can distinguish different pairs of inputs. It is similarly problematic h = DenseWe( compareσω, σb the)(y performance) 1/√ ofn presentedσωW strategies1 y + σb inβ, Figure 1. Note that all describedm (2)⇠ strategiesm ⇠ • (biased) estimate converges to the true CNN-GPn1. covarianceisGlobal then is width the in average width or number or pooling: number of channelstake of channelsh in= hiddend 1 ind. hidden Then layers. layers. The MC-GP The MC-GP kernel kernel converges converges to the to analytic the analyticfrom the phase boundary, the covariance between for GPs, as the kernel becomes pathological as it approaches a fixed point. Precisely, in the chaotic #(instantiated networks) and #channels, bothadmit in terms stacking of additional≡ FC layers on top whilel retainingl l l the GP equivalence, using a derivation regime outputs of the GP become asymptotically decorrelated and therefore independent, while in kernelkernel with increasing with increasing width, width,limn limKn K= K =inK probability.in probability. all different inputs becomes asymptotically FCN-GP

n,M2 n,M (Fully-connected) covariance Frobenius distance and the GPanalogous accuracy to. 2 (Lee et al., 2018; Matthews!1 !1 et al., 2018b1 ). 1 the ordered regime they approach perfect correlation of 1. Either of these scenarios captures no and the variables Wij and βi are i.i.d. Gaussian§ (0, 1).pool We will! (#NN computeinstantiations)L+1 kernel2 operations - constant and makes learning infeasible. 2 l K l ↵,↵ +l b . l l l (17) information about the training data in the kernel and makes learning infeasible. FigureFor finiteFigureFor 2: Validation finitewidth 2: Validation width networks,N accuracy networks,accuracy the uncertainty(left)K the1 uncertaintyof(left)⌘ and MC-CNN-GPof in anKn,M MC-CNN-GP in K1isn,MVar increasesis 0KVar increasesn,MK withn,M=n Var with=5✓Mn VarK(i.e.n✓M(✓K channel)(i.e.n/M(✓ channel)./M count From. count From denoted φ∗ and Dense(σω, σb)∗ - induced by the tensor operations ↵φl,↵and0 l Dense1 1 (σω, σb)⇥ . Finally,⇥l l 1 1 TableThis problem 3: can be amelioratedof by CNN- judicious and FCN-hyperparameter GPs as a selection, function of which weight can ( reduce2 , horizontal the rate timesDanielytimesDaniely number et al. number of( et2016 samples)al. of(),2016 samples)we), andknow we approaches andknow that approachesVar that✓ thatVarKX of✓( that✓ theK⇥) of exact(✓ the) ⇤, CNN-GPexact which, CNN-GP which leads (not leadsto shown), (notVar to✓ shown),[KVar while✓[K] while the distance] the. Fordistance. For Validation accuracy ! 4MONTE CARLO EVALUATION OF INTRACTABLEn n/ n / ⇥nGP⇥KERNELS⇤ ⇤ ⇥ n,M⇥ ⇤/n,MMn⇤/ Mn axis)of exponential and bias ( convergence2, vertical axis) to the variances. fixed point. As For predicted hyperpameters in A.2, chosen the regions on a critical of good line performance separating we will compute the kernel operation(right) associated(right) to thel to exact thel exact kernel with kernel decreases. the decreases. composition The dark Thel band darkl in band( theDense in left the plot left(σ corresponds plot, σ corresponds) toφ) ill-conditioning∗ to= ill-conditioning b § finiteThisfiniten, K approachn,Mn, Kisn,M also correspondsis a also biased a biased estimate to applying estimate of K global of, whereK average, where the biaspooling the depends bias rightω depends solely afterb solelythe on last network on convolu- network width. width. concentratetwo untrainable around phases, the critical the convergence line (phase rates boundary, slow to polynomial,right) as the and depth very increases deep networks (left to can right). be L+1 L+1 8 ⇥ 1⇥ ⇤ 1 ⇤ L+1 L+1 ofWeK dotionaln,MofWe notK dowhen layer.n,M currently notwhen the currentlyThis number have the approach number an have of analytic outer an takes of analytic outerproducts form all pixel-pixelproducts forform contributing this for bias,contributing this covariances butbias, to weK butn,M canto weintoK seeapproximatelyn,M can consideration in seeapproximately Figures in◦ Figures2 equalsand and2 equals7 makes itsandthat rank.7 for itsthat rank. for Alltrained, plots and share inference common with axes deep ranges NN-GP and kernels employ can the beerf performednonlinearity. – see See TableA.7.23. for experimental Dense(σω, σb)∗ φ∗. We introduce a Monte Carlo estimation method for NN-GP kernels which2 2 are computationally im- ValuestheValues reported kernel reported translation are for are a forinvariant.3-layer a 3-layer model However, model applied appliedit requiresto a 2K/4K to a 2K/4K train/validationd train/validationmemory subset tol compute subset ofl CIFAR10 of theL CIFAR102 L 2 details. § ◦ practicalthe hyperparameters tothe compute hyperparameters analytically, we probe weor probe it isfor small itwhich is small relative we relative do to not the knowto variance. theO variance.the|X Inanalytic | particular, In particular, form.K Similarn,MK(✓n,M) in spirit(K✓) K downsampleddownsampled to 8 to88. See8. Figure See Figure7 for7 similarfor similar2 results results with other with other architectures architectures and A.7.2 and 1A.7.2forF 1forF to traditionalis nearlysample-sampleis nearly random constant constant feature for covariance constant for methods constantMn of (theRahimi.Mn We GP. thus (or We & Rechttreat thusd Mn treat, 2007perasMn covariance), the theas effective core the effective entry idea sample is in to an sample instantiate size iterative for size the or for many Montedis- the Monte A.3 STRIDED CONVOLUTIONS AND AVERAGE POOLING IN INTERMEDIATE LAYERS First we compute the NNGP and NTexperimental kernelsexperimental for details.y. details.⇥ To⇥ compute y noteO that from its definition, § § CarlotributedCarlo kernel kernel setting). estimate. estimate. It Increasing is impractical IncreasingM and toM use reducingand this reducing methodn cann reduce tocan analytically reduce memory memory evaluate cost, though cost, the though GP, potentially and potentially we at at A.2 RELATIONSHIP TO DEEP SIGNAL PROPAGATION K Our analysis in the main text can easily be extended to cover average pooling and strided convolu- 8 Spatiallythe expenseproposethe local expense of to average increased useT of a increased Montepooling compute Carlo in compute intermediary time approach and time bias. layers (seeandT bias. can4). be constructed in a similar fashion ( A.3). We tions (applied before the pointwise nonlinearity). Recall that conditioned on Kl the pre-activation § § Thel recurrenced1 relation linking the GP kernel at layer l +1tod that2 d of1 layer l following from Equa- y = φ(z) = Efocusθ φ on( globalz)i φ average(z)i pooling= inE thisθ[φ work(zi to) moreφ(z effectivelyi) ] = isolate( thez) effects. of pooling from other(3) aspects z (x) R is a zero-mean multivariate Gaussian. Let B R ⇥ denote a linear operator. Then 2. Subsampling one particular pixel: take h = e↵, pool pool tionj 102(i.e. Kl+1 =( ) Kl ) is precisely the covariance2 map examined in a series of related K K randomIn a non-distributedrandomInfinite a non-distributedwidthfinite width networks setting, networks setting, the and MC-GP use the andthe MC-GP use reduces empirical the reduces empirical the uncentered memoryT the uncentered memoryK requirements covariances requirements covariances of to activations compute of to activations compute to estimate tofrom estimatefrom l d2 of the model like local connectivity or equivariance. Bz (x) R 1is a zero-meanC A Gaussian,1 and the covariance is the Montethe2 Monte Carlo-GP2 Carlo-GP (MC-GP)2 (MC-GP)2 kernel, kernel, GP GP papersj on2 signal propagation (Xiao et al., 2018; Poole et al., 2016; Schoenholz et al., 2017; Lee d2 tod2 to + n2 + ndn2 +e,↵ makingnd , making2 theL evaluation+1 the evaluation of2 CNN-GPs of CNN-GPs with pooling with pooling practical. practical. Since φ does not introduce any new variables Θy can be computed as,! K ↵,↵ + b . (18) et al., 2018) (modulo notationall differences;l T denotedl as F , orl e.g. l ? inT Xiaol etT al. (2018)). O |XO | |X | O |XO | |X | K1 ⌘ M 1 nM n E !l,bl Bzj (x) Bzj (x0) K = BE !l,bCl zj (x)Azj (xC0) K B . (20) ⇣ ⇣ ⌘ ⌘⇣ ⇣ ⌘ ⌘1 1 In those works,{ the} action of this map on hidden-state covariance{ } matrices was interpreted as defin- l l 7⇥ ⇤ l l l l This approach makesKn,MK usen,M of(x, only x0)( onex, x0 pixel-pixel) covariance,yc↵ (xy;c✓↵m()x andy; c✓↵m requires)(xyc0;↵✓m(x) the0; ✓m same) amount(19) (19) ing a dynamical systemh whosel large-depthl behaviori informs aspectsh of trainability.i In particular, T ↵,↵0 ↵,↵0 T 0 0 One can easilyl+1 see that Bzj K l are i.i.d.l multivariate Gaussian as well. 5D5DISCUSSIONISCUSSION ⌘ Mn⌘ Mn as l , K =( ) K j K K⇤ , i.e. the covariance approaches a fixed point ∂φ(zi) ∂φ(zi) of memory as vectorization∂zi ( ∂z3.1)i to compute.m=1 c=1m=1 c=1 !1 1 C A 1 ⇡ 1 ⌘ 1 Θy = θ = θ diag(φ˙(z⇥ i))⇥ ⇤ ⇤ § diagX X(Xφ˙(zXi)) = ˙ ( z) Θz. KStrided⇤ . The convolution convergence. Strided to a fixed convolution point is is problematic equivalent to for a non-strided learning because convolution the hidden composed states with no E E longer1 contain information that can distinguish different pairs of inputs. It is similarly problematic ∂θ ∂θ where5.1where5.1 B✓AYESIANconsists B✓AYESIANconsists ofCNNM ofdrawsSWITHMANYCHANNELSAREIDENTICALTOLOCALLYCONNECTEDCNNM∂θdraws ofSWITHMANYCHANNELSAREIDENTICALTOLOCALLYCONNECTED the∂θ of weights the weights and biases and biases from their from prior theirT distribution, priorK distribution, ✓m ✓pm(✓),p and(✓), and subsampling. Let s N denote size of the stride. Then the strided convolution is equivalent to " #We compare" the performance of presented strategies in Figure 1.# Note that all described⇠ strategies⇠ for GPs, as the kernel2 becomes pathological as it approaches a fixed point. Precisely, in the chaotic   n is thenNETWORKSis width theNETWORKS width or numberIN or THE numberIN of ABSENCE channels THE of ABSENCE channels in OF hidden POOLING in OF hidden POOLING layers. layers. The MC-GP The MC-GP kernel kernel converges converges to the to analytic the analytic choosing B as follows: Bij = (is j) for i 0, 1,...(d2 1) . admit stacking additional, FC, layers on top whilel retainingl l l the GP equivalence, using a derivation regime outputs of the GP become asymptotically 2 { decorrelated and} therefore independent, while in kernelkernel with increasing with increasing width, width,limn limKn n,MK=n,MK =inK probability.in probability. analogous to 2 (Lee et al., 2018; Matthews!1 !1 et al., 2018b1 ). 1 theAverage ordered pooling regime. Average they approach pooling perfect with stride correlations and of window1. Either size ofws theseis equivalent scenarios tocaptures choosing no Taken together these equations implyLocally that,Locally Connected§ Connected Networks Networks (LCNs) (LCNs) (Fukushima (Fukushima, 1975,;1975Lecun; Lecun, 1989,)1989 are CNNs) are CNNs without without weight weight Bij =1/ws for i =0, 1,...(d2 1) and j = is, . . . , (is + ws 1). Kl Kl Var KVarl Kl= Var= VarKl (✓K) l/M(✓) /M information about the training data in the kernel and makes learning infeasible. Forsharing finiteForsharing between finitewidth between width networks, spatial networks, spatial locations. the locations. uncertainty the LCNs uncertainty LCNs preserve in preserven,M in theisn,M connectivity theis connectivityn,M pattern,n,M pattern, and✓ thusn and✓ topology, thusn topology,. From of. a From of a l l 1 1 l l 1 1 ThisND convolutions. problem can beNote ameliorated that our analysis by judicious in the hyperparameter main text (1D) easily selection, extends which to higher-dimensionalcan reduce the rate DanielyDaniely et al. ( et2016 al. (),2016 we), know we know that Var that✓ VarK ✓(✓K) (✓) , which, which leads leadsto Var to✓[KVar✓[K] ] . For . For 4MCNN.ONTECNN. However,C However,ARLO they EVALUATION do they not do possess not possess the˙ OF equivariance INTRACTABLEthen equivariancen/ propertyn / property⇥nGP of a⇥KERNELS⇤ CNN of a⇤ CNN– if an⇥ – input ifn,M an⇥ ⇤ inputis/n,M translated,Mn⇤ is/ translated,Mn ofconvolutions exponential by convergence replacing integer to the fixed pixel point. indices For and hyperpameters sizes d, ↵, chosenwith tupleson a critical (see also line separatingFigure 4). ( y, Θy) = φfinitethe∗ ( latentfinitethenz, K, latent representationlΘn,zK representation)isl alsois a also in biased( an a LCN in biasedz estimate) an, will LCN estimate( be of will completelyzK) bel of, completely whereKΘl ,z different, where the bias different, the rather depends bias rather depends than solely also than solely onbeing also network onbeing(4) translated. network translated. width. width. K K n,M n,M≡ T K T K⇥ ⇥ ⇤ ⇤ two untrainable phases, the convergence rates slow to polynomial, and very deep networks can be We doWe not do currently not currently have an have analytic an analytic form forform1 this for bias,1 this butbias, we but can we see can in see Figures in Figures2 and27andthat7 forthat for trained, and inference with deep NN-GP kernels can be performed – see Table 3. We introduceThe CNN-GPThe a CNN-GP Monte predictions Carlo predictions estimation without without spatial method spatial pooling for pooling NN-GP in 3.1 in kernelsand3.13.2.2and whichdepend3.2.2 aredepend computationallyonly on only sample-sample on sample-sample im- 2 2 18 5  T  the hyperparametersthe0 hyperparameters0 T we probe we probe it is small it is small relative relative to the§ to variance. the§ variance.§ In§ particular, In particular,Kl K(✓l) (K✓)L KL ˙ practicalcovariances, tocovariances, compute and analytically, do and not do depend not or depend on for pixel-pixel which on pixel-pixel we covariances.do not covariances. know LCNs the analytic LCNs destroy destroy form. pixel-pixel Similar pixel-pixeln,M covariances:n,M in spirit covariances:F F (Σ) E φ(u)φ(u) , (Σ) E φ (u)φ (u) , u (0, Σ), as in (Lee et al., 2019). 1 1 A.3 STRIDED CONVOLUTIONS AND AVERAGE POOLING IN INTERMEDIATE LAYERS to traditionalis nearlyL isLCN nearly randomL constantLCN constant feature for constant for methods constantMn (Rahimi.Mn We. thus We & Rechttreat thusMn treat, 2007asMn), the theas effective core the effective idea sample is to sample instantiate size for size the for many Monte the Monte T ≡ T ≡ K K (x, x0)=0(x, x0)=0, for ↵∼, for= N↵0 =and↵0 allandx, allx0 x, x0 and L>and L>0. However,0. However, LCNs LCNs preserve preserve the the Carlo1 Carlo kernel↵,↵10 kernel↵ estimate.,↵0 estimate. Increasing Increasing6 M6 andM reducingand reducing2n canX 2n reducecanX reduce memory memory cost, though cost, though potentially potentially at at LCN LCN CNN CNN Our analysis in the main text can easily be extended to cover average pooling and strided convolu- 8 L L L L l Spatiallythe⇥covariances expense⇤the⇥covariances local expense⇤ of between average increased of between increased pooling input compute input examples in compute intermediary examples time at and time every bias. layersat and every pixel: bias. can pixel: beK constructed↵K,↵ (x,↵, xin↵0)=( ax, similar x0)=K fashion↵K,↵ ( (x,↵A.3,↵ x0)()..x, We As x0). a As a tions (applied before the pointwise nonlinearity). Recall that conditioned on K the pre-activation 1 1 1 1 § l d1 d2 d1 focus on global average pooling in this work to more effectively isolate the effects of pooling from other aspects z (x) R is a zero-mean multivariate Gaussian. Let B R ⇥ denote a linear operator. Then result,result, in the in absence the7 absence of pooling, of pooling, LCN-GPs LCN-GPs and CNN-GPs and CNN-GPs are identical. are identical. Moreover, Moreover, LCN-GPs LCN-GPspool withpool with j 2 2 of theIn model a non-distributedIn like a non-distributed local connectivity setting, setting, orthe equivariance. MC-GP the MC-GP reduces reduces the memory⇥ the memory⇤⇥ requirements⇤ requirements to⇥ compute⇤ to⇥ compute⇤ from from l d2 GP GP Bzj (x) R is a zero-mean Gaussian, and the covariance is 2 d2 2 d2 2 + n2 + ndn2 + nd 2 to to , making, making the evaluation the evaluation of CNN-GPs of CNN-GPs with pooling with pooling practical. practical. l l T l l l T l T O |XO | |X | O |XO | |X | 8 8 E !l,bl Bzj (x) Bzj (x0) K = BE !l,bl zj (x) zj (x0) K B . (20) ⇣ ⇣ ⌘ ⌘⇣ ⇣ ⌘ ⌘ { } { } 7 h i h i One can easily see that Bzl Kl are i.i.d. multivariate Gaussian as well. 5D5DISCUSSIONISCUSSION j j Strided convolution. Strided convolution is equivalent to a non-strided convolution composed with 5.15.1 BAYESIAN BAYESIANCNNSWITHMANYCHANNELSAREIDENTICALTOLOCALLYCONNECTEDCNNSWITHMANYCHANNELSAREIDENTICALTOLOCALLYCONNECTED subsampling. Let s N denote size of the stride. Then the strided convolution is equivalent to choosing B as follows:2 B = (is j) for i 0, 1,...(d 1) . NETWORKSNETWORKS, IN THE, IN ABSENCE THE ABSENCE OF POOLING OF POOLING ij 2 { 2 } Average pooling. Average pooling with stride s and window size ws is equivalent to choosing LocallyLocally Connected Connected Networks Networks (LCNs) (LCNs) (Fukushima (Fukushima, 1975,;1975Lecun; Lecun, 1989,)1989 are CNNs) are CNNs without without weight weight Bij =1/ws for i =0, 1,...(d2 1) and j = is, . . . , (is + ws 1). sharingsharing between between spatial spatial locations. locations. LCNs LCNs preserve preserve the connectivity the connectivity pattern, pattern, and thus and topology, thus topology, of a of a ND convolutions. Note that our analysis in the main text (1D) easily extends to higher-dimensional CNN.CNN. However, However, they do they not do possess not possess the equivariance the equivariance property property of a CNN of a CNN– if an – input if an inputis translated, is translated, convolutions by replacing integer pixel indices and sizes d, ↵, with tuples (see also Figure 4). the latentthe latent representation representation in an LCN in an will LCN be will completely be completely different, different, rather rather than also than being also being translated. translated. The CNN-GPThe CNN-GP predictions predictions without without spatial spatial pooling pooling in 3.1 inand3.13.2.2and depend3.2.2 depend only on only sample-sample on sample-sample 18 covariances,covariances, and do and not do depend not depend on pixel-pixel on pixel-pixel covariances.§ covariances.§ LCNs§ LCNs§ destroy destroy pixel-pixel pixel-pixel covariances: covariances: L LCNL LCN K K (x, x0)=0(x, x0)=0, for ↵, for= ↵0 =and↵0 allandx, allx0 x, x0 and L>and L>0. However,0. However, LCNs LCNs preserve preserve the the 1 ↵,↵10 ↵,↵0 6 6 2 X 2 X L LCNL LCN L CNNL CNN ⇥covariances⇤⇥covariances⇤ between between input input examples examples at every at every pixel: pixel:K K (x, x0)=(x, x0)=K K (x, x0)(.x, As x0). a As a 1 ↵,↵1 ↵,↵ 1 ↵,↵1 ↵,↵ result,result, in the in absence the absence of pooling, of pooling, LCN-GPs LCN-GPs and CNN-GPs and CNN-GPs are identical. are identical. Moreover, Moreover, LCN-GPs LCN-GPs with with ⇥ ⇤⇥ ⇤ ⇥ ⇤⇥ ⇤

8 8 Published as a conference paper at ICLR 2020

will be the translation rule for a pointwise nonlinearity. Note that Equation Equation4 only has an analytic expression for a small set of activation functions φ. Next we consider the case of a dense operation. Using the independence between the weights, the biases, and h it follows that, T 2 T 2 2 2 h = EW,β,θ[hih ] = σ Eθ[yiy ] + σ = σ y + σ . (5) K i ω i b ωK b Finally, the NTK of h can be computed as a sum of two terms:

T T ∂hi ∂hi ∂hi ∂hi 2 2 2 Θh = EW,β,θ + EW,β,θ = σ y + σ + σ Θy . (6) ∂(W, β) ∂(W, β) ∂θ ∂θ ωK b ω "   # "   # This gives the translation rule for the dense layer in terms of and Θ as, Ky y 2 2 2 2 2 ( , Θ ) = Dense(σ , σ )∗ ( , Θ ) σ + σ , σ + σ + σ Θ . (7) Kh h ω b Ky y ≡ ωKy b ωKy b ω y  3.2P ERFORMANCE

Our library performs a number of automatic performance optimizations without sacrificing flexibility. Leveraging block-diagonal covariance structure. A common computational challenge with GPs is inverting the training set covariance matrix. Naively, for a classification task with C classes and training set , NNGP and NTK covariances have the shape of C C. For CIFAR-10, this would be 500X, 000 500, 000. However, if a fully-connected|X readout | × |X layer | is used (which is an extremely common× design in classification architectures), the C logits are i.i.d. conditioned on the input x. This results in outputs that are normally distributed with a block-diagonal covariance matrix of the form Σ IC , where Σ has shape and IC is the C C identity matrix. This reduces the computational⊗ complexity and storage|X | in × many |X | common cases× by an order of magnitude, which makes closed-form exact inference feasible in these cases. Automatically tracking only the smallest necessary subset of intermediary covariance entries. For most architectures, especially convolutional, the main computational burden lies in constructing the covariance matrix (as opposed to inverting it). Specifically for a convolutional network of depth l, constructing the output covariance matrix, Σ, involves computing l intermediate layer covariance matrices,|XΣ |l ×, of |X size | d d (see Listing1 for a model requiring this computation) where d is the total number of pixels|X | × in |X the | intermediate layer outputs (e.g. d = 1024 in the case of CIFAR-10 with SAME padding). However, as Xiao et al.(2018); Novak et al.(2019); Garriga- Alonso et al.(2019) remarked, if no pooling is used in the network the output covariance Σ can be computed by only using the stack of d -blocks of Σl, bringing the time and memory cost |X | × |X | from ( 2 d2) down to ( 2 d) per layer (see Figure4 and Listing2 for models admitting this O |X | O |X | optimization). Finally, if the network has no convolutional layers, the cost further reduces to ( 2) (see Listing3 for an example). These choices are performed automatically by NEURAL TANGENTSO |X | to achieve efficient computation and minimal memory footprint. Expressing covariance computations as 2D convolutions with optimal layout. A key insight to high performance in convolutional models is that the covariance propagation operator for convo- lutional layers can be expressed in terms of 2D convolutions when it operates on both the full d d covarianceA matrix Σ, and on the d diagonal -blocks. This allows utilization of|X modern| × |X hardware| accelerators, many of which target 2D|X convolutions | × |X | as their primary machine learning application. Simultaneous NNGP and NT kernel computations. As NTK computation requires the NNGP covariance as an intermediary computation, the NNGP covariance is computed together with the NTK at no extra cost. This is especially convenient for researchers looking to investigate similarities and differences between these two infinite-width NN limits. Automatic batching and parallelism across multiple devices. In most cases as the dataset or model becomes large, it is impossible to perform the entire kernel computation at once. Additionally, in many cases it is desirable to parallelize the kernel computation across devices (CPUs, GPUs, or TPUs). NEURAL TANGENTS provides an easy way to perform both of these common tasks using a single batch decorator shown below:

8 Published as a conference paper at ICLR 2020

2-4 2-5 GPU Count = 1 2-5 GPU Count = 2 2-6 GPU Count = 4 2-6 GPU Count = 8 2-7

2-7 2-8

2-8 2-9

2-9 2-10 Time / Sample (s) Time / Sample (s)

2-10 2-11 20 21 22 23 24 25 2-1 20 21 22 23 Batch Size GPU Count

Figure 5: Performance scaling with batch size (left) and number of GPUs (right). Shows time per entry needed to compute the analytic NNGP and NTK covariance matrices (using kernel_fn ) in a 21-layer ReLU network with global average pooling. Left: Increasing the batch size when computing the covariance matrix in blocks allows for a significant performance increase until a certain threshold when all cores in a single GPU are saturated. Simpler models are expected to have better scaling with batch size. Right: Time-per-sample scales linearly with the number of GPUs, demonstrating near-perfect hardware utilization.

batched_kernel_fn= nt.batch(kernel_fn, batch_size) batched_kernel_fn(x, x) == kernel_fn(x, x) # True!

This code works with either analytic kernels or empirical kernels. By default, it automatically shares the computation over all available devices. We plot the performance as a function of batch size and number of accelerators when computing the theoretical NTK of a 21-layer convolutional network in Figure5, observing near-perfect scaling with the number of accelerators. Op fusion. JAX and XLA allow end-to-end compilation of the whole kernel computation and/or inference. This enables the XLA compiler to fuse low-level ops into custom model-specific accelerator kernels, as well as eliminating overhead from op-by-op dispatch to an accelerator. In similar vein, we allow the covariance tensor to change its order of dimensions from layer to layer, with the order tracked and parsed as additional metadata under the hood. This eliminates redundant transpositions6 by adjusting the computation performed by each layer based on the input metadata.

4C ONCLUSION

We believe NEURAL TANGENTS will enable researchers to quickly and easily explore infinite-width networks. By democratizing this previously challenging model family, we hope that researchers will begin to use infinite neural networks, in addition to their finite counterparts, when faced with a new problem domain (especially in cases that are data-limited). In addition, we are excited to see novel uses of infinite networks as theoretical tools to gain insight and clarity into many of the hard theoretical problems in deep learning. Going forward, there are significant additions to NEURAL TANGENTS that we are exploring. There are more layers we would like to add in the future (§D) that will enable an even larger range of infinite network topologies. Additionally, there are further performance improvements we would like to implement, to allow experimenting with larger models and datasets. We invite the community to join our efforts by contributing new layers to the library (§B.7), or by using it for research and providing feedback!

ACKNOWLEDGMENTS We thank Yasaman Bahri for frequent discussion and useful feedback on the manuscript. We additionally appreciate both Yasaman Bahri and Greg Yang for the ongoing contributions to improve the library. We thank Sergey Ioffe for feedback on the text, as well as Ravid Ziv, and Jeffrey Pennington for discussion and feedback on early versions of the library. 6These transpositions could not be automatically fused by the XLA compliler.

9 Published as a conference paper at ICLR 2020

REFERENCES

Martín Abadi, Ashish Agarwal, Paul Barham, Eugene Brevdo, Zhifeng Chen, Craig Citro, Greg S. Corrado, Andy Davis, Jeffrey Dean, Matthieu Devin, Sanjay Ghemawat, Ian Goodfellow, Andrew Harp, Geoffrey Irving, Michael Isard, Yangqing Jia, Rafal Jozefowicz, Lukasz Kaiser, Manjunath Kudlur, Josh Levenberg, Dan Mané, Rajat Monga, Sherry Moore, Derek Murray, Chris Olah, Mike Schuster, Jonathon Shlens, Benoit Steiner, Ilya Sutskever, Kunal Talwar, Paul Tucker, Vincent Vanhoucke, Vijay Vasudevan, Fernanda Viégas, Oriol Vinyals, Pete Warden, Martin Wattenberg, Martin Wicke, Yuan Yu, and Xiaoqiang Zheng. TensorFlow: Large-scale machine learning on heterogeneous systems, 2015. URL http://tensorflow.org/. Software available from tensorflow.org. Takuya Akiba, Keisuke Fukuda, and Shuji Suzuki. ChainerMN: Scalable Distributed Deep Learning Framework. In Proceedings of Workshop on ML Systems in The Thirty-first Annual Conference on Neural Information Processing Systems (NIPS), 2017. URL http://learningsys.org/nips17/ assets/papers/paper_25.pdf. Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep learning via over- parameterization. In International Conference on Machine Learning, 2018. Anonymous. Infinite attention: Nngp and ntk for deep attention networks. In International Conference on Machine Learning (ICML), 2020. submission under review. Sanjeev Arora, Simon S Du, Wei Hu, Zhiyuan Li, Ruslan Salakhutdinov, and Ruosong Wang. On exact computation with an infinitely wide neural net. In Advances In Neural Information Processing Systems, 2019a. Sanjeev Arora, Simon S. Du, Zhiyuan Li, Ruslan Salakhutdinov, Ruosong Wang, and Dingli Yu. Harnessing the power of infinitely wide deep nets on small-data tasks, 2019b. Yaniv Blumenfeld, Dar Gilboa, and Daniel Soudry. A mean field theory of quantized deep networks: The quantization-depth trade-off. arXiv preprint arXiv:1906.00771, 2019. James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclau- rin, and Skye Wanderman-Milne. JAX: composable transformations of Python+NumPy programs, 2018a. URL http://github.com/google/jax. James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclau- rin, and Skye Wanderman-Milne. Stax, a flexible neural net specification library in jax, 2018b. URL https://github.com/google/jax/blob/master/jax/experimental/stax.py. Lenaic Chizat, Edouard Oyallon, and Francis Bach. On lazy training in differentiable programming. arXiv preprint arXiv:1812.07956, 2019. Youngmin Cho and Lawrence K Saul. Kernel methods for deep learning. In Advances In Neural Information Processing Systems, 2009. François Chollet et al. Keras. https://keras.io, 2015. Amit Daniely, Roy Frostig, and Yoram Singer. Toward deeper understanding of neural networks: The power of initialization and a dual view on expressivity. In Advances In Neural Information Processing Systems, pp. 2253–2261, 2016. Simon S Du, Jason D Lee, Haochuan Li, Liwei Wang, and Xiyu Zhai. Gradient descent finds global minima of deep neural networks. arXiv preprint arXiv:1811.03804, 2018a. Simon S Du, Xiyu Zhai, Barnabas Poczos, and Aarti Singh. Gradient descent provably optimizes over-parameterized neural networks. arXiv preprint arXiv:1810.02054, 2018b. Simon S Du, Kangcheng Hou, Russ R Salakhutdinov, Barnabas Poczos, Ruosong Wang, and Keyulu Xu. Graph neural tangent kernel: Fusing graph neural networks with graph kernels. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett (eds.), Advances in Neural Information Processing Systems 32, pp.

10 Published as a conference paper at ICLR 2020

5724–5734. Curran Associates, Inc., 2019. URL http://papers.nips.cc/paper/ 8809-graph-neural-tangent-kernel-fusing-graph-neural-networks-with-graph-kernels. pdf. Jacob R Gardner, Geoff Pleiss, David Bindel, Kilian Q Weinberger, and Andrew Gordon Wilson. Gpytorch: Blackbox matrix-matrix gaussian process inference with gpu acceleration. In Advances in Neural Information Processing Systems, 2018. Adrià Garriga-Alonso, Carl Edward Rasmussen, and Laurence Aitchison. Deep convolutional networks as shallow gaussian processes. In International Conference on Learning Representations, 2019. GPy. GPy: A gaussian process framework in python. http://github.com/SheffieldML/GPy, 2012. Soufiane Hayou, Arnaud Doucet, and Judith Rousseau. On the selection of initialization and for deep neural networks. arXiv preprint arXiv:1805.08266, 2018. Soufiane Hayou, Arnaud Doucet, and Judith Rousseau. Mean-field behaviour of neural tangent kernel for deep neural networks, 2019. Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural tangent kernel: Convergence and generalization in neural networks. In Advances in neural information processing systems, 2018. Ryo Karakida, Shotaro Akaho, and Shun-ichi Amari. Universal statistics of fisher information in deep neural networks: mean field approach. 2018. Alex Krizhevsky. Learning multiple layers of features from tiny images. Technical report, 2009. Jaehoon Lee, Yasaman Bahri, Roman Novak, Sam Schoenholz, Jeffrey Pennington, and Jascha Sohl-dickstein. Deep neural networks as gaussian processes. In International Conference on Learning Representations, 2018. Jaehoon Lee, Lechao Xiao, Samuel S. Schoenholz, Yasaman Bahri, Roman Novak, Jascha Sohl- Dickstein, and Jeffrey Pennington. Wide neural networks of any depth evolve as linear models under gradient descent. In Advances in neural information processing systems, 2019. Ping Li and Phan-Minh Nguyen. On random deep weight-tied autoencoders: Exact asymptotic analysis, phase transitions, and implications to training. In International Conference on Learning Representations, 2019. URL https://openreview.net/forum?id=HJx54i05tX. Yuanzhi Li and Yingyu Liang. Learning overparameterized neural networks via stochastic gradient descent on structured data. In Advances in Neural Information Processing Systems, pp. 8157–8166, 2018. Alexander G. de G. Matthews, Mark van der Wilk, Tom Nickson, Keisuke. Fujii, Alexis Boukouvalas, Pablo Le‘on-Villagr‘a, Zoubin Ghahramani, and James Hensman. GPflow: A Gaussian process library using TensorFlow. Journal of Machine Learning Research, 18(40):1–6, apr 2017. URL http://jmlr.org/papers/v18/16-537.html. Alexander G de G Matthews, Mark Rowland, Jiri Hron, Richard E Turner, and Zoubin Ghahramani. Gaussian process behaviour in wide deep neural networks. arXiv preprint arXiv:1804.11271, 2018a. Alexander G. de G. Matthews, Jiri Hron, Mark Rowland, Richard E. Turner, and Zoubin Ghahramani. Gaussian process behaviour in wide deep neural networks. In International Conference on Learning Representations, 2018b. Radford M. Neal. Priors for infinite networks (tech. rep. no. crg-tr-94-1). University of Toronto, 1994. Roman Novak, Lechao Xiao, Jaehoon Lee, Yasaman Bahri, Greg Yang, Jiri Hron, Daniel A. Abolafia, Jeffrey Pennington, and Jascha Sohl-Dickstein. Bayesian deep convolutional networks with many channels are gaussian processes. In International Conference on Learning Representations, 2019.

11 Published as a conference paper at ICLR 2020

Adam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer. Automatic differentiation in pytorch. In NIPS-W, 2017.

Ben Poole, Subhaneil Lahiri, Maithra Raghu, Jascha Sohl-Dickstein, and Surya Ganguli. Exponential expressivity in deep neural networks through transient chaos. In Advances In Neural Information Processing Systems, 2016.

Arnu Pretorius, Elan van Biljon, Steve Kroon, and Herman Kamper. Critical initialisation for deep sig- nal propagation in noisy rectifier neural networks. In S. Bengio, H. Wallach, H. Larochelle, K. Grau- man, N. Cesa-Bianchi, and R. Garnett (eds.), Advances in Neural Information Processing Systems 31, pp. 5717–5726. Curran Associates, Inc., 2018. URL http://papers.nips.cc/paper/ 7814-critical-initialisation-for-deep-signal-propagation-in-noisy-rectifier-neural-networks. pdf.

Samuel S Schoenholz, Justin Gilmer, Surya Ganguli, and Jascha Sohl-Dickstein. Deep information propagation. arXiv preprint arXiv:1611.01232, 2016.

Seiya Tokui, Kenta Oono, Shohei Hido, and Justin Clayton. Chainer: a next-generation open source framework for deep learning. In Proceedings of Workshop on Machine Learning Systems (LearningSys) in The Twenty-ninth Annual Conference on Neural Information Processing Systems (NIPS), 2015. URL http://learningsys.org/papers/LearningSys_2015_paper_33.pdf.

Lechao Xiao, Yasaman Bahri, Jascha Sohl-Dickstein, Samuel Schoenholz, and Jeffrey Penning- ton. Dynamical isometry and a mean field theory of CNNs: How to train 10,000-layer vanilla convolutional neural networks. In International Conference on Machine Learning, 2018.

Lechao Xiao, Jeffrey Pennington, and Samuel S Schoenholz. Disentangling trainability and general- ization in deep learning. arXiv preprint arXiv:1912.13053, 2019.

Ge Yang and Samuel Schoenholz. Mean field residual networks: On the edge of chaos. In Advances In Neural Information Processing Systems, 2017.

Greg Yang. Scaling limits of wide neural networks with weight sharing: Gaussian process behavior, gradient independence, and neural tangent kernel derivation. arXiv preprint arXiv:1902.04760, 2019.

Sergey Zagoruyko and Nikos Komodakis. Wide residual networks. In Proceedings of the British Machine Vision Conference (BMVC), 2016.

Difan Zou, Yuan Cao, Dongruo Zhou, and Quanquan Gu. Gradient descent optimizes over- parameterized deep relu networks. Machine Learning, Oct 2019. ISSN 1573-0565. doi: 10.1007/s10994-019-05839-6. URL https://doi.org/10.1007/s10994-019-05839-6.

APPENDIX

AN EURAL TANGENTSANDPRIORWORK

Here we briefly discuss the differences between NEURAL TANGENTS and the relevant prior work.

1. Prior benchmarks in the domain of infinitely wide neural networks. Various prior works have evaluated convolutional and fully-connected models on certain datasets (Lee et al., 2018; Matthews et al., 2018b;a; Novak et al., 2019; Garriga-Alonso et al., 2019; Arora et al., 2019a). While these efforts must have required implementing certain parts of our library, to our knowledge such prior efforts were either not open-sourced or not comprehensive / user-friendly / scalable enough to be used as a user-facing library. In addition, all of the works above used their own separate implementation, which further highlights a need for a more general approach.

12 Published as a conference paper at ICLR 2020

2. Code released by Lee et al.(2019). Lee et al.(2019) have released code along with their paper submission, which is a strict and minor subset of our library. More specif- ically, at the time of the submission, Lee et al.(2019) have released code equivalent to nt.linearize , nt.empirical_ntk_fn , nt.predict.gradient_descent_mse , nt.predict.gradient_descent , and nt.predict.momentum . Every other part of the library (most notably, nt.stax ) is new in this submission and was not used by Lee et al. (2019) or any other prior work. At the time of writing, NEURAL TANGENTS differs from the code released by Lee et al.(2019) by about +9, 500/ 2, 500 lines of code. − 3. GPy(2012), GPFlow (Matthews et al., 2017), GPyTorch (Gardner et al., 2018), and other GP packages. While various packages allowing for kernel construction, optimization, and inference with Gaussian Processes exist, none of them allow easy construction of the very specific kernels corresponding to infinite neural networks (NNGP/NTK; nt.stax ), nor do they provide the tools and convenience for studying wide but finite networks and their training dynamics ( nt.taylor_expand , nt.predict , nt.monte_carlo_kernel_fn ). On the other hand, NEURAL TANGENTS does not provide any tools for approximate inference with these kernels.

BL IBRARY DESCRIPTION

NEURAL TANGENTS provides a high-level interface for specifying analytic, infinite-width, Bayesian and gradient descent trained neural networks as Gaussian Processes. This interface closely follows the stax API (Bradbury et al., 2018b) in JAX.

B.1N EURALNETWORKSWITH JAX

stax represents each component of a network as two functions: init_fn and apply_fn . These components can be composed in serial or in parallel to produce new network components with their own init_fn and apply_fn . In this way, complicated neural network architectures can be specified hierarchically. Calling init_fn on a random seed and an input shape generates a random draw of trainable parameters for a neural network. Calling apply_fn on these parameters and a batch of inputs returns the outputs of the given finite neural network.

from jax.experimental import stax init_fn, apply_fn= stax.serial(stax.Dense(512), stax.Relu, stax.Dense(10)) _, params= init_fn(key, (-1, 32* 32*3)) fx_train, fx_test= apply_fn(params, x_train), apply_fn(params, x_test)

B.2I NFINITE NEURAL NETWORKS WITH NEURAL TANGENTS

We extend stax layers to return a third function kernel_fn , which represents the covariance functions of the infinite NNGP and NTK networks of the given architecture (recall that since infinite networks are GPs, they are fully defined by their covariance functions, assuming 0 mean as is common in the literature).

from neural_tangents import stax init_fn, apply_fn, kernel_fn= stax.serial(stax.Dense(512), stax.Relu(), stax.Dense(10))

We demonstrate a specification of a more complicated architecture (WideResNet) in Listing1. kernel_fn accepts two batches of inputs x1 and x2 and returns their NNGP covariance and NTK matrices as kernel_fn(x1, x2).nngp and kernel_fn(x1, x2).ntk respectively, which

13 Published as a conference paper at ICLR 2020

can then be used to make posterior test set predictions as the mean of a conditional multivariate 1 normal = ( , ) ( , )− : Ytest K Xtest Xtrain K Xtrain Xtrain Ytrain from jax.numpy.linalg import inv y_test= kernel_fn(x_test, x_train).ntk @ inv(kernel_fn(x_train, x_train).ntk) @ y_train

Note that the above code does not do Cholesky decomposition and is presented merely to show the mathematical expression. We provide efficient GP inference method in the predict submodule:

import neural_tangents as nt y_test= nt.predict.gp_inference(kernel_fn, x_train, y_train, x_test, get='ntk', diag_reg=1e-4, compute_cov=False)

B.3C OMPUTING INFINITE NETWORK KERNELS IN BATCHES AND IN PARALLEL

Naively, the kernel_fn will compute the whole kernel in a single call on one device. However, for large datasets or complicated architectures, it is often necessary to distribute the calculation in some way. To do this, we introduce a batch decorator that takes a kernel_fn and returns a new kernel_fn with the exact same signature. The new function computes the kernel in batches and automatically parallelizes the calculation over however many devices are available, with near-perfect speedup scaling with the number of devices (Figure5, right).

import neural_tangents as nt kernel_fn= nt.batch(kernel_fn, batch_size=32)

Note that batching is often used to compute large covariance matrices that may not even fit on a GPU/TPU device, and require to be stored and used for inference using CPU RAM. This is easy to achieve by simply specifying nt.batch(..., store_on_device=False) . Once the matrix is stored in RAM, inference will be performed with a CPU when nt.predict methods are called. As mentioned in §3.2, for many (notably, convolutional, and especially pooling) architectures, inference cost can be small relative to kernel construction, even when running on CPU (for example, it takes less than 3 minutes to execute jax.scipy.linalg.solve(..., sym_pos=True) on a 45, 000 45, 000 × training covariance matrix and a 45, 000 10 training target matrix). × B.4T RAINING DYNAMICS OF INFINITE NETWORKS In addition to closed form multivariate Gaussian posterior prediction, it is also interesting to consider network predictions following continuous gradient descent. To facilitate this we provide several functions to compute predictions following gradient descent with an MSE loss, for gradient descent with arbitrary loss, or for momentum with arbitrary loss. The first case is handled analytically, while the latter two are computed by numerically integrating the differential equation. For example, the following code will compute the function evaluation on train and test points following gradient descent for some time training_time .

import neural_tangents as nt predictor= nt.predict.gradient_descent_mse(kernel_fn(x_train, x_train), y_train, kernel_fn(x_test, x_train)) fx_train, fx_test= predictor(training_time, fx_train, fx_test)

B.5I NFINITE NETWORKS OF ANY ARCHITECTURE THROUGH SAMPLING Of course, there are cases where the analytic kernel cannot be computed. To support these situations, we provide utility functions to efficiently compute Monte Carlo estimates of the NNGP covariance and NTK. These functions work with neural networks constructed using any neural network library.

14 Published as a conference paper at ICLR 2020

0th order (constant)

10 1 1st order (linear)

2nd order (quadratic)

10 2 Original function Batch Mean Squared Error 0 2 4 6 8 10 Training Epoch

Figure 6: Training a neural network and its various approximations using nt.taylor_expand . Presented is a 5-layer Erf-neural network of width 512 trained on MNIST using SGD with mo- mentum, along with its constant (0th order), linear (1st order), and quadratic (2nd order) Taylor expansions about the initial parameters. As training progresses (left to right), lower-order expansions deviate from the original function faster than higher-order ones.

from jax import random from jax.experimental import stax import neural_tangents as nt

init_fn, apply_fn= stax.serial(stax.Dense(64), stax.BatchNorm(), stax.Sigmoid, stax.Dense(1)) kernel_fn= nt.monte_carlo_kernel_fn(init_fn, apply_fn, key=random.PRNGKey(1), n_samples=128) kernel= kernel_fn(x_train, x_train)

We demonstrate convergence of the Monte Carlo kernel estimates to the closed-form analytic kernels in the case of a WideResNet in Figure2.

B.6W EIGHTS OF WIDE BUT FINITE NETWORKS

While most of NEURAL TANGENTS is devoted to a function-space perspective—describing the distribution of function values on finite collections of training and testing points—we also provide tools to investigate a dual weight space perspective described in Lee et al.(2019). Convergence of dynamics to NTK dynamics coincide with networks being described by a linear approximation about their initial set of parameters. We provide decorators linearize and taylor_expand to approximate functions to linear order and to arbitrary order respectively. Both functions take an apply_fn and returns a new apply_fn that computes the series approximation.

import neural_tangents as nt taylor_apply_fn= nt.taylor_expand(apply_fn, params, order) fx_train_approx= taylor_apply_fn(new_params, x_train)

These act exactly like normal JAX functions and, in particular, can be plugged into gradient descent, which we demonstrate in Figure6.

B.7E XTENDING NEURAL TANGENTS Many neural network layers admit a sensible infinite-width limit behavior in the Bayesian and continuous gradient descent regimes as long as the multivariate central limit theorem applies to their outputs conditioned on their inputs. To add such layer to NEURAL TANGENTS, one only has to implement it as a method in nt.stax with the following signature:

15 Published as a conference paper at ICLR 2020

@_layer # an internal decorator taking care of certain boilerplate. NewLayer(layer_params: Any)-> (init_fn: function, apply_fn: function, kernel_fn: function)

Here init_fn and apply_fn are initialization and the forward pass methods of the finite width layer implementation (see §B.1). If the layer of interest already exists in JAX, there is no need to implement these methods and the user can simply return the respective methods from jax.experimental.stax (see nt.stax.Flatten for an example; in fact the majority of nt.stax layers call the original jax.experimental.stax layers for finite width layer methods). In this case what remains is to implement the kernel_fn method with signature

kernel_fn(input_kernel: nt.utils.Kernel)-> output_kernel: nt.utils.Kernel

Here both input_kernel and output_kernel are namedtuple s containing the NNGP and NTK covariance matrices, as well as additional metadata useful for computing the kernel propagation operation. The specific operation to be performed should be derived by the user in the context of the particular operation that the finite width layer performs. This transformation could be as simple as an affine map on the kernel matrices, but could also be analytically intractable. Once implemented, the correctness of the implementation can be very easily tested by extending the nt.tests.stax_test with the new layer, to test the agreement with large-widths empirical NNGP and NTK kernels.

CA RCHITECTURE SPECIFICATIONS

from neural_tangents import stax

def ConvolutionalNetwork(depth, W_std=1.0, b_std=0.0): layers=[] for _ in range(depth): layers+= [stax.Conv(1,(3,3), W_std, b_std, padding= 'SAME'), stax.Relu()] layers+= [stax.Flatten(), stax.Dense(1, W_std, b_std)] return stax.serial(*layers)

Listing 2: All-convolutional model (ConvOnly) definition used in Figure3 .

from neural_tangents import stax

def FullyConnectedNetwork(depth, W_std=1.0, b_std=0.0): layers= [stax.Flatten()] for _ in range(depth): layers+= [stax.Dense(1, W_std, b_std), stax.Relu()] layers+= [stax.Dense(1, W_std, b_std)] return stax.serial(*layers)

Listing 3: Fully-connected (FC) model definition used in Figure3 .

16 Published as a conference paper at ICLR 2020

DI MPLEMENTEDANDCOMINGSOONFUNCTIONALITY

The following layers7 are currently implemented, with translation rules given in Table1:

• serial • parallel • FanOut • FanInSum • FanInConcat • Dense • Conv with arbitrary filter shapes, strides, dimension numbers, and padding8 • Relu • LeakyRelu • Abs • ABRelu 9 • Erf • Identity • Flatten • AvgPool • GlobalAvgPool • SumPool • GlobalSumPool • Dropout • LayerNorm • GlobalSelfAttention (Anonymous, 2020)

The following is in our near-term plans:

• Exp , Elu , Selu , Gelu • Apache Beam support.

The following layers do not have known closed-form expressions for infinite network covariances, and respective infinite networks have to be estimated empirically (via nt.monte_carlo_kernel_fn ) or using other approximations (not currently implemented):

• Sigmoid , Tanh ,10 Swish ,11 Softmax , LogSoftMax , Softplus , MaxPool .

7 Abs , ABRelu , Erf , GlobalAvgPool , GlobalSumPool , LayerNorm , and GlobalSelfAttention are only available in our library nt.stax and not in jax.experimental.stax . 8In addition to SAME and VALID , we support CIRCULAR padding, which is especially convenient for theoretical analysis and was used by Xiao et al.(2018) and Novak et al.(2019). 9 ABRelu is a more flexible extension of LeakyRelu , computed as a min (x, 0) + b max (x, 0). 10 Sigmoid and Tanh are similar to Erf which does have a covariance expression and is implemented. 11 Swish is similar to Gelu .

17 Published as a conference paper at ICLR 2020

Tensor Op NNGP Op NTK Op Θ XK Dense(σ , σ ) σ2 + σ2 (σ2 + σ2) + σ2 Θ w b wK b wK b w φ ( ) ˙ ( ) Θ T K T K Dropout(ρ) + 1 1 Diag( ) Θ + 1 1 Diag(Θ) K ρ − K ρ − Conv(σ , σ ) σ2 ( ) + σ2 σ2 ( ) + σ2 + σ2 (Θ) w b wA K b wA K b wA Flatten Tr( ) Tr( + Θ) K K AvgPool(s, q, p) AvgPool(s, q, p)( ) AvgPool(s, q, p)( + Θ) K K GlobalAvgPool GlobalAvgPool( ) GlobalAvgPool( + Θ) K K SumPool(s, q, p) SumPool(s, q, p)( ) SumPool(s, q, p)( + Θ) K K GlobalSumPool GlobalSumPool( ) GlobalSumPool( + Θ) K K Attn(σ , σ ) Attn(σ , σ )( ) 2Attn(σ , σ )( )+ QK OV QK OV K QK OV K (Anonymous, 2020) Attn(σQK , σOV )(Θ)

FanInSum( ,..., ) n n Θ X1 Xn j=1 Kj j=1 j FanOut(n)[P] n [Θ]P n K ∗ ∗

Table 1: Translation rules (§3) converting tensor operations into operations on NNGP and NTK kernels. Here the input tensor is assumed to have shape H W C (dataset size, height, width, number of channels), andX the full NNGP and NT kernels|X | × and× ×are considered to 2 2 K T2 be of shape ( H W )× (in practice shapes of × H W and × are also possible, depending on|X which | × optimizations× in §3.2 are applicable).|X |Notation× × details. The|X | Tr, GlobalAvgPool, and GlobalSumPool ops are assumed to act on all spatial axes (with sizes H and W in this example), 2 producing a × -kernel. Similarly, the AvgPool and SumPool ops is assumed to act on all spatial axes as well,|X applying | the specified strides s, pooling window sizes p and padding strategy p to the respective axes pairs in and (acting as 4D pooling with replicated parameters of the 2D K T version). and ˙ are defined identically to Lee et al.(2019) as (Σ) = E φ(u)φ(u)T , ˙ (Σ) = T T T T φ (u)φ (u)T , u (0, Σ). These expressions can be evaluated in closed form for many E 0 0   nonlinearities, and preserve∼ N the shape of the kernel. The op is defined similarly to Novak et al.   w,w0 A w+dw,w0+dw 2 (2019); Xiao et al.(2018) as [ (Σ)] (x, x0) = [Σ] (x, x0) /q , where the A h,h0 dh,dw h+dh,h0+dh summation is performed over the convolutional filter receptive field with q pixels (we assume unit strides and circular padding in this expression, butP generalization to other settings is trivial and supported by the library). [Σ] n = [Σ,..., Σ] (n-fold replication). For LayerNorm, FanInConcat, and Attn (Anonymous, 2020)∗ translation rules we refer the reader to our code at https://github. com/google/neural-tangents, as these ops are challenging to express concisely using current notation. See Figure4 for an example of applying the translation rules to a specific model, and §3.1 for deriving a sample translation rule. See §D for the full list of currently implemented translations.

18 Published as a conference paper at ICLR 2020

Independent Full Test Covariance 3.0 3.0 FC-NTK FC-NNGP CONV-NTK CONV-NNGP 2.5 2.5

2.0 2.0 Test NLL (nats) Test NLL (nats) 1.5 1.5

1.0 1.0 26 27 28 29 210 211 212 213 214 215 26 27 28 29 210 211 212 213 214 215 Training Set Size Training Set Size

NTK NNGP 109 109

108 108

107 107

106 106

105 105

104 104 FC Condition Number Condition Number FC-Test-Cov Conv 103 103 Conv-Test-Cov WResNet WResNet-Test-Cov 102 102 26 27 28 29 210 211 212 213 26 27 28 29 210 211 212 213 Training Set Size Training Set Size

Figure 7: Predictive negative log-likelihoods and condition numbers. Top. Test negative log- likelihoods for NNGP posterior and Gaussian predictive distribution for NTK at infinite training time for CIFAR-10 (test set of 2000 points). Fully Connected (FC, Listing3) and Convolutional network without pooling (CONV, Listing2) models are selected based on train marginal negative log-likelihoods in Figure3. Bottom. Condition numbers for covariance matrices corresponding to NTK/NNGP as well as respective predictive covaraince on the test set. Ill-conditioning of Wide Residual Network kernels due to pooling layers (Xiao et al., 2019) could be the cause of numerical issues when evaluating predictive NLL for this kernels.

19