Deep Rao-Blackwellised Particle Filters for Time Series Forecasting
Total Page:16
File Type:pdf, Size:1020Kb
Deep Rao-Blackwellised Particle Filters for Time Series Forecasting Richard Kurle∗y 2 Syama Sundar Rangapuram1 Emmanuel de Bezenacy 3 Stephan Günnemann2 Jan Gasthaus1 1AWS AI Labs 2Technical University of Munich 3Sorbonne Université Abstract This work addresses efficient inference and learning in switching Gaussian linear dynamical systems using a Rao-Blackwellised particle filter and a corresponding Monte Carlo objective. To improve the forecasting capabilities, we extend this classical model by conditionally linear state-to-switch dynamics, while leaving the partial tractability of the conditional Gaussian linear part intact. Furthermore, we use an auxiliary variable approach with a decoder-type neural network that allows for more complex non-linear emission models and multivariate observations. We propose a Monte Carlo objective that leverages the conditional linearity by computing the corresponding conditional expectations in closed-form and a suitable proposal distribution that is factorised similarly to the optimal proposal distribution. We evaluate our approach on several popular time series forecasting datasets as well as image streams of simulated physical systems. Our results show improved forecasting performance compared to other deep state-space model approaches. 1 Introduction The Gaussian linear dynamical system (GLS) [4, 38, 31] is one of the most well-studied dynamical models with wide-ranging applications in many domains, including control, navigation, and time- series forecasting. This state-space model (SSM) is described by a (typically Markovian) latent linear process that generates a sequence of observations. The assumption of Gaussian noise and linear state transitions and measurements allows for exact inference of the latent variables—such as filtering and smoothing—and computation of the marginal likelihood for system identification (learning). However, most systems of practical interest are non-linear, requiring more complex models. Many approximate inference methods have been developed for non-linear dynamical systems: Deter- ministic methods approximate the filtering and smoothing distributions e.g. by using a Taylor series expansion of the non-linear functions (known as extended Kalman filter (EKF) and second-order EKF) or by directly approximating these marginal distributions by a Gaussian using moment matching techniques [26, 27, 2, 3, 36]. Stochastic methods such as particle filters or smoothers approximate the filtering and smoothing distributions by a set of weighted samples (particles) using sequential Monte Carlo (SMC) [13, 16]. Furthermore, system identification with deep neural networks has been proposed using stochastic variational inference [17, 15] and variational SMC [23, 29, 20]. A common challenge with both types of approximations is that predictions/forecasts over long forecast horizons suffer from accumulated errors resulting from insufficient approximations at every time step. Switching Gaussian linear systems (SGLS) [1, 12]—which use additional latent variables to switch between different linear dynamics—provide a way to alleviate this problem: the conditional linearity can be exploited by approximate inference algorithms to reduce approximation errors. Unfortunately, this comes at the cost of reduced modelling flexibility compared to more general non-linear dynamical systems. Specifically, we identify the following two weaknesses of the SGLS: i) the switch transition ∗Correspondence to [email protected]. yWork done while at AWS AI Labs. 34th Conference on Neural Information Processing Systems (NeurIPS 2020), Vancouver, Canada. model is assumed independent of the GLS state and observation; ii) conditionally-linear emissions are insufficient for modelling complex multivariate data streams such as video data, while more suitable emission models (e.g. using convolutional neural networks) exist. The first problem has been addressed in [21] through augmentation with a Polya-gamma-distributed variable and a stick-breaking process. However, this approach uses Gibbs sampling to infer model parameters and thus does not scale well to large datasets. The second problem is addressed in [11] using an auxiliary variable between GLS states and emissions and stochastic variational inference to obtain a tractable objective function for learning. Yet, this model predicts the GLS parameters deterministically using an LSTM with an auto-regressive component, resulting in poor long-term forecasts. We propose auxiliary variable recurrent switching Gaussian linear systems (ARSGLS), an extension of the SGLS to address both weaknesses by building on ideas from [21] and [11]. ARSGLS improves upon the SGLS by incorporating a conditionally linear state-to-switch dependence, which is crucial for accurate long-term out-of-sample predictions (forecasting), and a decoder-type neural network that allows modelling multivariate time-series data with a non-linear emission/measurement process. As in the SGLS, a crucial feature of ARSGLS is that approximate inference and likelihood estimation can be Rao-Blackwellised, that is, expectations involving the conditionally linear part of the model can be computed in closed-form, and only the expectations wrt. the switch variables need to be approximated. We leverage this feature and propose a Rao-Blackwellized filtering objective function and a suitable proposal distribution for this model class. We evaluate our model with two different instantiations of the GLS: in the first scenario, the GLS is implemented with a constrained structure that models time-series patterns such as level, trend and seasonality [14, 34]; the second scenario considers a general, unconstrained conditional GLS. We compare our approach to closely related methods such as Deep State Space Models (DeepState) [30] and Kalman Variational Auto-Encoders (KVAE) [11] on 5 popular forecasting datasets and multivariate (image) data streams generated from a physics engine as used in [11]. Our results show that the proposed model achieves improved performance for univariate and multivariate forecasting. 2 Background 2.1 Gaussian Linear Dynamical Systems The GLS models sequence data using a first-order Markov linear latent and emission process: xt = Axt−1 + But + wt; wt ∼ N (0;R); (1a) yt = Cxt + Dut + vt; vt ∼ N (0;Q); (1b) where the vectors u, x and y denote the inputs (also referred to as covariates or controls), latent states, and targets (observations). A, and C are the transition and emission matrix, B and D are optional control matrices, and R, Q are the state and emission noise covariances. In the follow- ing, we will omit the optional inputs u to ease the presentation, however, we use inputs e.g. for our experiments in Sec. 5.2. The appealing property of the GLS is that inference and likelihood computation is analytically tractable using the well-known Kalman filter algorithm, alternating R between a prediction step p(xt j y1:t−1) = p(xt j xt−1)p(xt−1 j y1:t−1)dxt−1 and update step p(xt j y1:t) / p(xt j y1:t−1)p(yt j xt), where p(xt j xt−1) is the state transition and p(yt j xt) is the emission/measurement process corresponding to Eqs. (1a) and (1b), respectively. 2.2 Particle Filters In non-linear dynamical systems, the filter distribution p(xt j y1:t) is intractable and needs to be approximated. Particle filters are SMC algorithms that approximate the filter distribution at every time- step t by a set of P weighted particles fx(p)gP , combining importance sampling and re-sampling. Denoting the Dirac measure centred at xt by δ(xt), the filter distribution is approximated as P X (p) (p) p(xt j y1:t) ≈ wt δ(xt ): (2) p=1 In first-order Markovian dynamical systems, the importance-weights are computed recursively as (p) (p) (p) (p) w~ p(yt j x )p(x j x ) w~(p) = w(p) γ(x(p); x(p) ); w(p) = t ; γ(x(p); x(p) ) = t t t−1 ; (3) t t−1 t t−1 t PP (p) t t−1 (p) (p) p=1 w~t π(xt j x1:t−1) 2 (p) (p) where w~t and wt denote the unnormalised and normalised importance-weights, respectively, (p) (p) (p) γ(xt ; xt−1) is the incremental importance-weight, and π(xt j x1:t−1) is the proposal distribution. To alleviate weight degeneracy issues, a re-sampling step is performed if the importance-weights satisfy a chosen degeneracy criterion. A common criterion is to re-sample when the effective sample PP (p) 2−1 size (ESS) drops below half the number of particles, i.e. PESS = p=1(w ) ≤ P=2. 2.3 Switching Gaussian Linear Systems with Rao-Blackwellised Particle Filters Switching Gaussian linear systems (SGLS), also referred to as conditional GLS or mixture GLS, are a class of non-linear SSMs that use additional (non-linear) latent variables s1:t that index the parameters (transition, emission, control, and noise covariance matrices) of a GLS, allowing them to “switch” between different (linear) regimes: xt = A(st)xt−1 + B(st)ut + wt(st); wt(st) ∼ N (0;R(st)); (4) yt = C(st)xt + D(st)ut + vt(st); vt(st) ∼ N (0;Q(st)): The switch variables st are typically categorical variables, indexing one of K base matrices; however, other choices are possible (see Sec. 3.1). Omitting the inputs u again, the graphical model factorises QT as p(y1:T ; x0:T ; s1:T ) = p(x0) t=1 p(yt j xt; st)p(xt j xt−1; st)p(st j st−1); where p(s1 j s0) = p(s1) is the switch prior (conditioned on u1 if inputs are given). Note that we use an initial state x0 without corresponding observation for convenience.2 A crucial property of this model is that, while the complete model exhibits non-linear dynamics, given (p) a sample of the trajectory s1:T , the rest of the model is (conditionally) linear. This allows for efficient approximate inference and likelihood estimation. The so-called Rao-Blackwellised particle filter (RBPF) leverages the conditional linearity by approximating expectations—that occur in filtering and likelihood computation—wrt.