Jingwei Zhuo*, Chang Liu, Jiaxin Shi, Jun Zhu, Ning Chen and Bo Zhang Motivations & Preliminaries Particle Degeneracy Of
Total Page:16
File Type:pdf, Size:1020Kb
Message Passing Stein Variational Gradient Descent Jingwei Zhuo*, Chang Liu, Jiaxin Shi, Jun Zhu, Ning Chen and Bo Zhang Department of Computer Science and Technology, Tsinghua University. * [email protected] Motivations & Preliminaries Theoretical Analysis Experiments: Synthetic Results 0.8 0.45 0.45 Variational inference: to approximate an intractable So, for which q the RF R(x; q) suffers from such a negative 0.40 0.40 0.7 0.6 0.35 0.35 HMC distribution p(x) with q(x) in some tractable family Q by correlation with the dimensionality D? 0.30 0.30 SVGD 0.5 ) MP-SVGD-s d 0.25 0.25 x 0.4 MP-SVGD-m ( d 0.20 0.20 p EPBP 0.3 0.15 0.15 EP q(x) = argmin KL(q(x)kp(x)). Theorem 1 (Gaussian) Given the RBF kernel k(x, y) and 0.2 0.10 0.10 Ground Truth q∈Q 0.1 0.05 0.05 0.0 0.00 0.00 q(y) = N (y|µ, Σ), the repulsive force satisfies 4 2 0 2 4 6 8 4 2 0 2 4 6 8 4 2 0 2 4 6 8 √ Stein Variational Gradient Descent (SVGD): Us- D Figure 2: A qualitative comparison of inference methods with (i) M kR(x; q)k∞ ≤ kx − µk∞, ing a set of particles {x } (with the empirical distri- D 2 D i=1 λmin(Σ)( + 1)(1 + ) 2 100 particles (except EP) for estimating marginal densities of 1 PM (i) 2 D bution qˆM (x) = M i=1 δ(x − x )) as approximation three randomly selected nodes. for p(x), updated iteratively by where λmin(Σ) is the smallest eigenvalue of Σ. By us- 1/x 0.0 1.0 2.0 1.0 ing limx→0(1 + x) = e, we have kR(x; q)k∞ kx − √ . 0.5 2.5 1.5 0.5 (i) (i) (i) 3.0 ˆ µk∞/(λmin(Σ) D). 1.0 2.0 HMC x ← x + φ(x ), E 0.0 S 3.5 SVGD M 0 1.5 2.5 MP-SVGD-s 1 g 4.0 MP-SVGD-m o 0.5 where l 2.0 3.0 EPBP 4.5 1.0 h i 2.5 5.0 3.5 ˆ 3.0 1.5 5.5 4.0 φ(x) = k(x, y)∇ log p(y) + ∇ k(x, y) . 10 150 300 10 150 300 10 150 300 10 150 300 Ey∼qˆM y y Theorem 2 (Bounded) Let k(x, y) be an RBF kernel. Particle Size (M) Particle Size (M) Particle Size (M) Particle Size (M) Suppose q(y) is supported on a bounded set X which satisfies k(x, y) is a positive definite kernel, e.g., RBF kernel kyk∞ ≤ C for y ∈ X , and Var(yd|y1, ..., yd−1) ≥ C0 almost Figure 3: A quantitative comparison of inference methods 2 surely for any 1 ≤ d ≤ D. Let {x(i)}M be a set of samples with varying number of particles. Performance is measured k(x, y) = exp(−kx − yk2/(2h)). i=1 ˆ of q and qˆM the corresponding empirical distribution. Then, by the MSE of the estimation of expectation Ex∼qˆM [f(x)] – Remark: M = 1, φ(x) = ∇x log p(x), MAP. 2 for any kxk∞ ≤ C, α, δ ∈ (0, 1), there exists D0 > 0, such for test functions f(x) = x, x , 1/(1 + exp(! ◦ x + b)) and that for any D > D0, cos(! ◦ x + b), arranged from left to right, where ◦ denotes φˆ is an unbiased estimate of φ, the steepest direc- 100 the element-wise product and !, b ∈ R with ωd ∼ N (0, 1) tion to reduce KL(qkp) in a reproducing kernel Hilbert 2 kR(x;q ˆM )k∞ ≤ and bd ∈ Uniform[0, 2π], ∀d ∈ {1, ..., 100}. space (RKHS) HD , eDα holds with at least probability 1 − δ. Figure 4: PAMRF φ(x) = argmin ∇ KL(q kp)| , 1 PM (i) [T] =0 M i=1 kR(x ; qˆM )kr kφkHD ≤1 101 for converged (i) M SVGD (100) 100 {x }i=1 with r = ∞ SVGD (200) where q is the density of T (x) = x + φ(x) when the Message Passing SVGD SVGD (300) [T] 0 10 MP-SVGD-s (100) (left) and r = 2 (right). density of x is q. MP-SVGD-s (200) PAMRF 10-1 MP-SVGD-s (300) The grid size ranges Key Idea: We decompose KL(qkp) based on the struc- MP-SVGD-m (100) MP-SVGD-m (200) – Convergence Condition: φ(x) ≡ 0, which holds if 10-1 Q MP-SVGD-m (300) from 2 × 2 to 10 × 10 tural information of p(x) ∝ F ∈F ψF (xF ), i.e., 10-2 and only if q = p with a proper choice of k(x, y). 4 16 49 100 4 16 49 100 and the number of Dimensionality (D) Dimensionality (D) particles is denoted in KL(q(x¬d)||p(x¬d))+Eq(x¬d) [KL(q(xd|x¬d)||p(xd|xΓd ))] , φ can be decomposed into two parts: the bracket. – Kernel Smoothed Gradient (KSG): where Γd = ∪F 3dF is the Markov blanket (MB) of d such G(x; p, q) = Ey∼q[k(x, y)∇y log p(y)]. that p(xd|x¬d) = p(xd|Γd). – Repulsive Force (RF): Experiments: Image Denoising R(x; q) = Ey∼q[∇yk(x, y)]. MP-SVGD: To apply SVGD with local kernel kd(xSd , ySd ) (Sd = {d} ∪ Γd) to optimize q(xd|x¬d) while Targets:p(x|y) ∝ p(x)p(y|x), where keeping q(x ) fixed, which results in, 2 ¬d p(y|x) = N (y|x, σnI): Noise distribution. 2 Particle Degeneracy of SVGD ||x||2 Q QN T p(x) ∝ exp(− 2 ) C∈C i=1 φ(Ji xC ; αi): Fields > T We observe particle degeneracy of SVGD, even for infer- Theorem 3 Let z = T(x) = [x1, ..., Td(xd), ..., xD] with of Experts (FOE) model, where φ(Ji xC ; αi) = J ring p(x) = N (x|0, I) with M = 50, 100, 200 particles. Td : xd → xd + φd(xS ),Sd = {d} ∪ Γd where φd ∈ Hd P T 2 d j=1 αijN (Ji xC |0, σi /sj). associated with the local kernel kd : XSd × XSd → R. Then, Methods: Figure 1: Top we have We compare SVGD, MP-SVGD and Gibbs sampling with 1.00 figures: Esti- 0.2 auxiliary variables (original inference method). mated variance ∇KL(q[T]||p) = ∇Eq(zΓ ) KL q[Td](zd|zΓd ) p(zd|zΓd ) , 0.75 0.1 d Evaluation: 0.0 and mean; and φd(xSd ) = argminkφ k ≤1 ∇KL(q[T]|kp)|=0 = peak signal-to-noise ratio (PSNR) and structural simi- 0.50 0.1 Bottom figures: d Hd Dim-Averaged Mean 0.2 Magnitude larity index (SSIM). 0.25 1 50 100 1 50 100 EyS ∼q kd(xSd , ySd )∇yd log p(yd|yΓd ) + ∇yd kd(xSd , ySd ) . Dim-Averaged Marginal Variance of RF and d Dimensionality (D) Dimensionality (D) Clean Noisy MAP SVGD MP-SVGD Aux. Gibbs (28.16, 0.678) (31.05, 0.867) (31.89, 0.890) (33.20, 0.911) (32.95, 0.901) 0.5 0.5 KSG, at both 1.0 0.0 50 (E) the beginning 1.5 100 (E) 0.5 – Convergence Condition: φd(xSd ) ≡ 0, ∀d, which 2.0 200 (E) (dotted;B) 2.5 50 (B) PAKSG PAMRF 1.0 holds if and only if q(xd|xΓd ) = p(xd|xΓd ) with a proper 100 (B) and the end 3.0 1.5 200 (B) choice of k (x , y ). 3.5 of iterations d Sd Sd 4.0 2.0 1 50 100 1 50 100 (solid;E). Dimensionality (D) Dimensionality (D) Final Algorithms: To approximate p(x) ∝ Explanations: ψF (xF ) with a set of parti- Figure 5: Denoising results for Lena using 50 particles, 256× 1. KSG alone corresponds to mode-seeking, i.e., (i) M cles {x }i=1 (with the em- 256 pixels, σn = 10. The number in bracket is PSNR and G(x; p, q) pirical distribution qˆM (x) = SSIM. Upper Row: The full size image; Bottom Row: The = argmax ∇Ez∼q[T] [log p(z)]|=0, 1 PM (i) 50 × 50 patches. D δ(x − x )), kG(x; p, q)kH kφk D ≤1 M i=1 H updated iteratively by avg. PSNR avg. SSIM is the steepest direction for maximizing Ex∼q[log p(x)] Inference σ = 10 σ = 20 σ = 10 σ = 20 (instead of KL(qkp)), which leads q(x) to collapse to the x(i) ← x(i) + φˆ (x(i)), n n n n d d d Sd MAP 30.27 26.48 0.855 0.720 modes of p(x) in convergence. Aux. Gibbs 32.09 28.32 0.904 0.808 ˆ Aux. Gibbs (M = 50) 31.87 28.05 0.898 0.795 where φd(xSd ) = 2. RF is critical for SVGD to minimize KL(qkp), but its Aux. Gibbs (M = 100) 31.98 28.17 0.901 0.801 h i SVGD (M = 50) 31.58 27.86 0.894 0.766 magnitude kR(x; q)k∞ may correlate negatively with k (x , y )∇ log p(y |y )+∇ k(x , y ) . EySd ∼qˆM d Sd Sd yd d Γd yd Sd Sd SVGD (M = 100) 31.65 27.90 0.896 0.767 dimensionality D. E.g., for the RBF kernel with any h, MP-SVGD (M = 50) 32.09 28.21 0.905 0.808 2 kx − yk MP-SVGD (M = 100) 32.12 28.27 0.906 0.809 kR(x; q)k ≤ · ∞ .