Inference in Probabilistic Graphical Models by Graph Neural Networks
Total Page:16
File Type:pdf, Size:1020Kb
Inference in Probabilistic Graphical Models by Graph Neural Networks KiJung Yoon 1 2 3 Renjie Liao 4 5 6 Yuwen Xiong 4 Lisa Zhang 4 6 Ethan Fetaya 4 6 Raquel Urtasun 4 5 6 Richard Zemel 4 6 7 Xaq Pitkow 1 2 Abstract node i of the graph: given a loss function, these distributions determine the optimal estimator. Another major goal is to A fundamental computation for statistical infer- compute the most probable state, x∗ = arg max p(x), or ence and accurate decision-making is to compute x MAP (maximum a posteriori) inference. the marginal probabilities or most probable states of task-relevant variables. Probabilistic graphical For complex models with loopy graphs, exact inferences models can efficiently represent the structure of of these sorts are often computationally intractable, and such complex data, but performing these infer- therefore generally relies on approximate methods. One ences is generally difficult. Message-passing al- important method for computing approximate marginals is gorithms, such as belief propagation, are a natural the belief propagation (BP) algorithm, which exchanges sta- way to disseminate evidence amongst correlated tistical information among neighboring nodes (Pearl, 1988; variables while exploiting the graph structure, but Wainwright et al., 2003b). This algorithm performs exact these algorithms can struggle when the condi- inference on tree graphs, but not on graphs with cycles. Fur- tional dependency graphs contain loops. Here thermore, the basic update steps in belief propagation may we use Graph Neural Networks (GNNs) to learn not have efficient or even closed-form solutions, leading a message-passing algorithm that solves these in- researchers to construct BP variants (Sudderth et al., 2010; ference tasks. We first show that the architecture Ihler & McAllester, 2009; Noorshams & Wainwright, 2013) of GNNs is well-matched to inference tasks. We or generalizations (Minka, 2001). then demonstrate the efficacy of this inference ap- In this work, we introduce end-to-end trainable inference proach by training GNNs on a collection of graph- systems based on Graph Neural Networks (GNNs) (Gori ical models and showing that they substantially et al., 2005; Scarselli et al., 2009; Li et al., 2016), which outperform belief propagation on loopy graphs. are recurrent networks that allow complex transformations Our message-passing algorithms generalize out of between nodes. We show how this network architecture is the training set to larger graphs and graphs with well-suited to message-passing inference algorithms, and different structure. have a flexibility that gives them wide applicability even in cases where closed-form algorithms are unavailable. These GNNs have vector-valued nodes that can encode proba- 1. Introduction bilistic information about variables in the graphical model. Probabilistic graphical models provide a statistical frame- The GNN nodes send and receive messages about those work for modelling conditional dependencies between ran- probabilities, and these messages are determined by canon- ical learned nonlinear transformations of the information arXiv:1803.07710v5 [cs.LG] 27 Jun 2019 dom variables, and are widely used to represent complex, real-world phenomena. Given a graphical model for a distri- sources and the statistical interactions between them. The bution p(x), one major goal is to compute marginal proba- dynamics of the GNN reflects the flow of probabilistic in- formation throughout the graphical model, and when the bility distributions pi(xi) of task-relevant variables at each model reaches equilibrium, a nonlinear decoder can extract 1Department of Neuroscience, Baylor College of Medicine approximate marginal probabilities or states from each node. 2Department of Electrical and Computer Engineering, Rice Uni- versity 3Department of Electronic Engineering, Hanyang Univer- To demonstrate the value of these GNNs for inference in sity 4Department of Computer Science, University of Toronto probabilistic graphical models, we create a collection of 5Uber ATG Toronto 6Vector Institute 7Canadian Institute for graphical models, train our networks to perform marginal Advanced Research. Correspondence to: KiJung Yoon <ki- or MAP inference, and test how well these inferences gener- [email protected]>, Xaq Pitkow <[email protected]>. alize beyond the training set of graphs. Our results compare 3rd Tractable Probabilistic Modeling Workshop, 36 th Interna- quite favorably to belief propagation on loopy graphs. tional Conference on Machine Learning, Long Beach, California, 2019. Copyright 2019 by the author(s). Inference in Probabilistic Graphical Models by Graph Neural Networks 2. Related Work normalized product of all factors: Several researchers have used neural networks to imple- 1 Y p(x) = (x ) (1) ment some form of probabilistic inference. (Heess et al., Z α α 2013) proposes to train a neural network that learns to map α2F message inputs to message outputs for each message op- Here Z is a normalization constant, and xα is a vector with eration needed for Expectation Propagation inference, and components xi for all variable nodes i connected to the (Lin et al., 2015) suggests learning CNNs for estimating factor node α by an edge (i; α). factor-to-variable messages in a message-passing procedure. Mean field networks (Li & Zemel, 2014) and structure2vec Our goal is to compute marginal probabilities pi(xi) or ∗ (Dai et al., 2016) model the mean field inference steps as MAP states xi , for such graphical models. For general feedforward and recurrent networks respectively. graphs, these computations require exponentially large re- sources, summing (integrating) or maximizing over all pos- Another related line of work is on inference machines: (Ross sible states except the target node: p (x ) = P p(x) or i i xnxi et al., 2011) trains a series of logistic regressors with hand- ∗ x = arg maxx p(x). crafted features to estimate messages. (Wei et al., 2016) applies this idea to pose estimation using convolutional lay- Belief propagation operates on these factor graphs by con- ers and (Deng et al., 2016) introduces a sequential inference structing messages µi!α and µα!i that are passed between by recurrent neural networks for the same application do- variable and factor nodes: main. X Y µα!i(xi) = α(xα) µj!α(xj) (2) The most similar line of work to the approach we present xαnxi j2Nαni here is that of GNN-based models. GNNs are essentially an Y extension of recurrent neural networks that operate on graph- µi!α(xi) = µβ!i(xi) (3) structured inputs (Scarselli et al., 2009; Li et al., 2016). β2Ninα The central idea is to iteratively update hidden states at where Ni are the neighbors of i, i.e., factors that involve xi, each GNN node by aggregating incoming messages that and Nα are the neighbors of α, i.e., variables that are directly are propagated through the graph. Here, expressive neural coupled by α(xα). The recursive, graph-based structure networks model both message- and node-update functions. of these message equations leads naturally to the idea that (Gilmer et al., 2017) recently provides a good review of we could describe these messages and their nonlinear up- several GNN variants and unify them into a model called dates using a graph neural network in which GNN nodes message-passing neural networks. (Bruna & Li, 2017) also correspond to messages, as described in the next section. proposes spectral approximations of BP with GNNs to solve the community detection problem. GNNs indeed have a Interestingly, belief propagation can also be reformulated similar structure as message passing algorithms used in entirely without messages: BP operations are equivalent to probabilistic inference. For this reason, GNNs are powerful successively reparameterizing the factors over subgraphs architectures for capturing statistical dependencies between of the original graphical model (Wainwright et al., 2003b). variables of interest (Bruna et al., 2014; Duvenaud et al., This suggests that we could construct a different mapping 2015; Li et al., 2016; Marino et al., 2016; Li et al., 2017; Qi between GNNs and graphical models, where GNN nodes et al., 2017; Kipf & Welling, 2017). correspond to factor nodes rather than messages. The repa- rameterization accomplished by BP only adjusts the univari- 3. Background ate potentials, since the BP updates leave the multivariate coupling potentials unchanged: after the inference algo- 3.1. Probabilistic graphical models rithm converges, the estimated marginal joint probability of a factor α, namely Bα(xα), is given by Probabilistic graphical models simplify a joint probability distribution p(x) over many variables x by factorizing the 1 Y B (x ) = (x ) µ (x ) (4) distribution according to conditional independence relation- α α Z α α i!α i ships. Factor graphs are one convenient, general representa- i2Nα tion of structured probability distributions. These are undi- Observe that all of the messages depend only on one variable rected, bipartite graphs whose edges connect variable nodes at a time, and the only term that depends on more than i 2 V that encode individual variables xi, to factor nodes one variable at a time is the interaction factor, α(xα), α 2 F that encode direct statistical interactions α(xα) which is therefore invariant over time. Since BP does not between groups of variables xα. (Some of these factors may change these interactions, to imitate the action of BP the affect only one variable.) The probability distribution is the GNNs need only to represent single variable nodes explicitly, while the nonlinear functions between nodes can account for Inference in Probabilistic Graphical Models by Graph Neural Networks (and must depend on) their interactions. Our experiments form of these functions is canonical, i.e., shared by all graph evaluate both of these architectures, with GNNs constructed edges, but the function can also depend on properties of each with latent states that represent either message nodes or edge. The function is parameterized by a neural network single variable nodes. whose weights are shared across all edges. Eventually, the states of the nodes are interpreted by another trainable ‘read- 3.2.