{"title": "Composing graphical models with neural networks for structured representations and fast inference", "book": "Advances in Neural Information Processing Systems", "page_first": 2946, "page_last": 2954, "abstract": "We propose a general modeling and inference framework that combines the complementary strengths of probabilistic graphical models and deep learning methods. Our model family composes latent graphical models with neural network observation likelihoods. For inference, we use recognition networks to produce local evidence potentials, then combine them with the model distribution using efficient message-passing algorithms. All components are trained simultaneously with a single stochastic variational inference objective. We illustrate this framework by automatically segmenting and categorizing mouse behavior from raw depth video, and demonstrate several other example models.", "full_text": "Composing graphical models with neural networks\nfor structured representations and fast inference\n\nMatthew James Johnson\n\nHarvard University\n\nmattjj@seas.harvard.edu\n\nDavid Duvenaud\nHarvard University\n\ndduvenaud@seas.harvard.edu\n\nAlexander B. Wiltschko\nHarvard University, Twitter\n\nawiltsch@fas.harvard.edu\n\nSandeep R. Datta\n\nHarvard Medical School\n\nsrdatta@hms.harvard.edu\n\nRyan P. Adams\n\nHarvard University, Twitter\nrpa@seas.harvard.edu\n\nAbstract\n\nWe propose a general modeling and inference framework that combines the com-\nplementary strengths of probabilistic graphical models and deep learning methods.\nOur model family composes latent graphical models with neural network obser-\nvation likelihoods. For inference, we use recognition networks to produce local\nevidence potentials, then combine them with the model distribution using ef\ufb01cient\nmessage-passing algorithms. All components are trained simultaneously with a\nsingle stochastic variational inference objective. We illustrate this framework by\nautomatically segmenting and categorizing mouse behavior from raw depth video,\nand demonstrate several other example models.\n\n1\n\nIntroduction\n\nModeling often has two goals: \ufb01rst, to learn a \ufb02exible representation of complex high-dimensional\ndata, such as images or speech recordings, and second, to \ufb01nd structure that is interpretable and\ngeneralizes to new tasks. Probabilistic graphical models [1, 2] provide many tools to build structured\nrepresentations, but often make rigid assumptions and may require signi\ufb01cant feature engineering.\nAlternatively, deep learning methods allow \ufb02exible data representations to be learned automatically,\nbut may not directly encode interpretable or tractable probabilistic structure. Here we develop a\ngeneral modeling and inference framework that combines these complementary strengths.\nConsider learning a generative model for video of a mouse. Learning interpretable representations\nfor such data, and comparing them as the animal\u2019s genes are edited or its brain chemistry altered,\ngives useful behavioral phenotyping tools for neuroscience and for high-throughput drug discovery\n[3]. Even though each image is encoded by hundreds of pixels, the data lie near a low-dimensional\nnonlinear manifold. A useful generative model must not only learn this manifold but also provide\nan interpretable representation of the mouse\u2019s behavioral dynamics. A natural representation from\nethology [3] is that the mouse\u2019s behavior is divided into brief, reused actions, such as darts, rears,\nand grooming bouts. Therefore an appropriate model might switch between discrete states, with\neach state representing the dynamics of a particular action. These two learning tasks \u2014 identifying\nan image manifold and a structured dynamics model \u2014 are complementary: we want to learn the\nimage manifold in terms of coordinates in which the structured dynamics \ufb01t well. A similar challenge\narises in speech [4], where high-dimensional spectrographic data lie near a low-dimensional manifold\nbecause they are generated by a physical system with relatively few degrees of freedom [5] but also\ninclude the discrete latent dynamical structure of phonemes, words, and grammar [6].\nTo address these challenges, we propose a new framework to design and learn models that couple\nnonlinear likelihoods with structured latent variable representations. Our approach uses graphical\nmodels for representing structured probability distributions while enabling fast exact inference\nsubroutines, and uses ideas from variational autoencoders [7, 8] for learning not only the nonlinear\n\n30th Conference on Neural Information Processing Systems (NIPS 2016), Barcelona, Spain.\n\n\f(a) Data\nFigure 1: Comparison of generative models \ufb01t to spiral cluster data. See Section 2.1.\n\n(c) Density net (VAE)\n\n(b) GMM\n\n(d) GMM SVAE\n\nfeature manifold but also bottom-up recognition networks to improve inference. Thus our method\nenables the combination of \ufb02exible deep learning feature models with structured Bayesian (and\neven nonparametric [9]) priors. Our approach yields a single variational inference objective in\nwhich all components of the model are learned simultaneously. Furthermore, we develop a scalable\n\ufb01tting algorithm that combines several advances in ef\ufb01cient inference, including stochastic variational\ninference [10], graphical model message passing [1], and backpropagation with the reparameterization\ntrick [7]. Thus our algorithm can leverage conjugate exponential family structure where it exists to\nef\ufb01ciently compute natural gradients with respect to some variational parameters, enabling effective\nsecond-order optimization [11], while using backpropagation to compute gradients with respect to all\nother parameters. We refer to our general approach as the structured variational autoencoder (SVAE).\n\n2 Latent graphical models with neural net observations\n\nIn this paper we propose a broad family of models. Here we develop three speci\ufb01c examples.\n\n2.1 Warped mixtures for arbitrary cluster shapes\n\nOne particularly natural structure used frequently in graphical models is the discrete mixture model.\nBy \ufb01tting a discrete mixture model to data, we can discover natural clusters or units. These discrete\nstructures are dif\ufb01cult to represent directly in neural network models.\nConsider the problem of modeling the data y = {yn}N\n\ufb01nding the clusters in data is to \ufb01t a Gaussian mixture model (GMM) with a conjugate prior:\nzn | \u03c0 iid\u223c \u03c0\niid\u223c N (\u00b5zn, \u03a3zn ).\n\u03c0 \u223c Dir(\u03b1),\nHowever, the \ufb01t GMM does not represent the natural clustering of the data (Fig. 1b). Its in\ufb02exible\nGaussian observation model limits its ability to parsimoniously \ufb01t the data and their natural semantics.\nInstead of using a GMM, a more \ufb02exible alternative would be a neural network density model:\n\nn=1 shown in Fig. 1a. A standard approach to\n\n(\u00b5k, \u03a3k) iid\u223c NIW(\u03bb),\n\nyn | zn,{(\u00b5k, \u03a3k)}K\n\nk=1\n\nxn\n\n\u03b3 \u223c p(\u03b3)\n\niid\u223c N (0, I),\n\nyn | xn, \u03b3 iid\u223c N (\u00b5(xn; \u03b3), \u03a3(xn; \u03b3)),\n\n(1)\nwhere \u00b5(xn; \u03b3) and \u03a3(xn; \u03b3) depend on xn through some smooth parametric function, such as\nmultilayer perceptron (MLP), and where p(\u03b3) is a Gaussian prior [12]. This model \ufb01ts the data\ndensity well (Fig. 1c) but does not explicitly represent discrete mixture components, which might\nprovide insights into the data or natural units for generalization. See Fig. 2a for a graphical model.\nBy composing a latent GMM with nonlinear observations, we can combine the modeling strengths of\nboth [13], learning both discrete clusters along with non-Gaussian cluster shapes:\n\n\u03c0 \u223c Dir(\u03b1),\nxn\n\n(\u00b5k, \u03a3k) iid\u223c NIW(\u03bb),\n\n\u03b3 \u223c p(\u03b3)\n\nzn | \u03c0 iid\u223c \u03c0\n\niid\u223c N (\u00b5(zn), \u03a3(zn)),\n\nyn | xn, \u03b3 iid\u223c N (\u00b5(xn; \u03b3), \u03a3(xn; \u03b3)).\n\nThis combination of \ufb02exibility and structure is shown in Fig. 1d. See Fig. 2b for a graphical model.\n\n2.2 Latent linear dynamical systems for modeling video\n\nNow we consider a harder problem: generatively modeling video. Since a video is a sequence of\nimage frames, a natural place to start is with a model for images. Kingma et al. [7] shows that the\n\n2\n\n\f(a) Latent Gaussian (b) Latent GMM\n\n(c) Latent LDS\n\n(d) Latent SLDS\n\nFigure 2: Generative graphical models discussed in Section 2.\n\nn=1 in\ndensity network of Eq. (1) can accurately represent a dataset of high-dimensional images {yn}N\nn=1, each with independent Gaussian distributions.\nterms of the low-dimensional latent variables {xn}N\nTo extend this image model into a model for videos, we can introduce dependence through time\nbetween the latent Gaussian samples {xn}N\nn=1. For instance, we can make each latent variable xn+1\ndepend on the previous latent variable xn through a Gaussian linear dynamical system, writing\n\nxn+1 = Axn + Bun,\n\nun\n\niid\u223c N (0, I),\n\nA, B \u2208 Rm\u00d7m,\n\nwhere the matrices A and B have a conjugate prior. This model has low-dimensional latent states and\ndynamics as well as a rich nonlinear generative model of images. In addition, the timescales of the\ndynamics are represented directly in the eigenvalue spectrum of A, providing both interpretability\nand a natural way to encode prior information. See Fig. 2c for a graphical model.\n\n2.3 Latent switching linear dynamical systems for parsing behavior from video\n\nAs a \ufb01nal example that combines both time series structure and discrete latent units, consider again\nthe behavioral phenotyping problem described in Section 1. Drawing on graphical modeling tools,\nwe can construct a latent switching linear dynamical system (SLDS) [14] to represent the data in\nterms of continuous latent states that evolve according to a discrete library of linear dynamics, and\ndrawing on deep learning methods we can generate video frames with a neural network image model.\nAt each time n \u2208 {1, 2, . . . , N} there is a discrete-valued latent state zn \u2208 {1, 2, . . . , K} that evolves\naccording to Markovian dynamics. The discrete state indexes a set of linear dynamical parameters,\nand the continuous-valued latent state xn \u2208 Rm evolves according to the corresponding dynamics,\n\nzn+1 | zn, \u03c0 \u223c \u03c0zn ,\n\nun\n+ is its kth row. We use the\nwhere \u03c0 = {\u03c0k}K\nsame neural net observation model as in Section 2.2. This SLDS model combines both continuous\nand discrete latent variables with rich nonlinear observations. See Fig. 2d for a graphical model.\n\nk=1 denotes the Markov transition matrix and \u03c0k \u2208 RK\n\nxn+1 = Aznxn + Bzn un,\n\niid\u223c N (0, I),\n\n3 Structured mean \ufb01eld inference and recognition networks\n\nWhy aren\u2019t such rich hybrid models used more frequently? The main dif\ufb01culty with combining rich\nlatent variable structure and \ufb02exible likelihoods is inference. The most ef\ufb01cient inference algorithms\nused in graphical models, like structured mean \ufb01eld and message passing, depend on conjugate\nexponential family likelihoods to preserve tractable structure. When the observations are more\ngeneral, like neural network models, inference must either fall back to general algorithms that do not\nexploit the model structure or else rely on bespoke algorithms developed for one model at a time.\nIn this section, we review inference ideas from conjugate exponential family probabilistic graphical\nmodels and variational autoencoders, which we combine and generalize in the next section.\n\n3.1\n\nInference in graphical models with conjugacy structure\n\nGraphical models and exponential families provide many algorithmic tools for ef\ufb01cient inference [15].\nGiven an exponential family latent variable model, when the observation model is a conjugate\nexponential family, the conditional distributions stay in the same exponential families as in the prior\nand hence allow for the same ef\ufb01cient inference algorithms.\n\n3\n\nynxnxnynzn\u2713\u2713x1x2x3x4y1y2y3y4z1z2z3z4x1x2x3x4y1y2y3y4\u2713\f(a) VAE\n\n(b) GMM SVAE\n\n(c) LDS SVAE\n\n(d) SLDS SVAE\n\nFigure 3: Variational families and recognition networks for the VAE [7] and three SVAE examples.\n\nFor example, consider learning a Gaussian linear dynamical system model with linear Gaussian\nn=1 is\nobservations. The generative model for latent states x = {xn}N\n\nn=1 and observations y = {yn}N\niid\u223c N (0, I),\ngiven parameters \u03b8 = (A, B, C, D) with a conjugate prior p(\u03b8). To approximate the poste-\nrior p(\u03b8, x| y), consider the mean \ufb01eld family q(\u03b8)q(x) and the variational inference objective\n\nxn = Axn\u22121 + Bun\u22121,\n\nun\n\nyn = Cxn + Dvn,\n\nvn\n\niid\u223c N (0, I),\n(cid:20)\n\n(cid:21)\n\nL[ q(\u03b8)q(x) ] = E\n\nq(\u03b8)q(x)\n\nlog\n\np(\u03b8)p(x| \u03b8)p(y | x, \u03b8)\n\nq(\u03b8)q(x)\n\n,\n\n(2)\n\nwhere we can optimize the variational family q(\u03b8)q(x) to approximate the posterior p(\u03b8, x| y) by\nmaximizing Eq. (2). Because the observation model p(y | x, \u03b8) is conjugate to the latent variable\nmodel p(x| \u03b8), for any \ufb01xed q(\u03b8) the optimal factor q\u2217(x) (cid:44) arg maxq(x) L[ q(\u03b8)q(x) ] is itself a\nGaussian linear dynamical system with parameters that are simple functions of the expected statistics\nof q(\u03b8) and the data y. As a result, for \ufb01xed q(\u03b8) we can easily compute q\u2217(x) and use message\npassing algorithms to perform exact inference in it. However, when the observation model is not\nconjugate to the latent variable model, these algorithmically exploitable structures break down.\n\n3.2 Recognition networks in variational autoencoders\n\nThe variational autoencoder (VAE) [7] handles general non-conjugate observation models by intro-\nducing recognition networks. For example, when a Gaussian latent variable model p(x) is paired with\na general nonlinear observation model p(y | x, \u03b3), the posterior p(x| y, \u03b3) is non-Gaussian, and it is\ndif\ufb01cult to compute an optimal Gaussian approximation. The VAE instead learns to directly output a\nsuboptimal Gaussian factor q(x| y) by \ufb01tting a parametric map from data y to a mean and covariance,\n\u00b5(y; \u03c6) and \u03a3(y; \u03c6), such as an MLP with parameters \u03c6. By optimizing over \u03c6, the VAE effectively\nlearns how to condition on non-conjugate observations y and produce a good approximating factor.\n\n4 Structured variational autoencoders\n\nWe can combine the tractability of conjugate graphical model inference with the \ufb02exibility of\nvariational autoencoders. The main idea is to use a conditional random \ufb01eld (CRF) variational family.\nWe learn recognition networks that output conjugate graphical model potentials instead of outputting\nthe complete variational distribution\u2019s parameters directly. These potentials are then used in graphical\nmodel inference algorithms in place of the non-conjugate observation likelihoods.\nThe SVAE algorithm computes stochastic gradients of a mean \ufb01eld variational inference objective.\nIt can be viewed as a generalization both of the natural gradient SVI algorithm for conditionally\nconjugate models [10] and of the AEVB algorithm for variational autoencoders [7]. Intuitively,\nit proceeds by sampling a data minibatch, applying the recognition model to compute graphical\nmodel potentials, and using graphical model inference algorithms to compute the variational factor,\ncombining the evidence from the potentials with the prior structure in the model. This variational\nfactor is then used to compute gradients of the mean \ufb01eld objective. See Fig. 3 for graphical models\nof the variational families with recognition networks for the models developed in Section 2.\nIn this section, we outline the SVAE model class more formally, write the mean \ufb01eld variational\ninference objective, and show how to ef\ufb01ciently compute unbiased stochastic estimates of its gradients.\nThe resulting algorithm for computing gradients of the mean \ufb01eld objective, shown in Algorithm 1, is\n\n4\n\nynxnxnynzn\u2713\u2713x1x2x3x4y1y2y3y4z1z2z3z4x1x2x3x4y1y2y3y4\u2713\fp(\u03b8) = exp(cid:8)\np(x| \u03b8) = exp(cid:8)\n\nwhere we used exponential family conjugacy to write t\u03b8(\u03b8) =(cid:0)\u03b70\n\n(cid:104)\u03b70\n\u03b8 , t\u03b8(\u03b8)(cid:105) \u2212 log Z\u03b8(\u03b70\n(cid:104)\u03b70\nx(\u03b8), tx(x)(cid:105) \u2212 log Zx(\u03b70\n\n\u03b8 )(cid:9) ,\nx(\u03b8))(cid:9) = exp{(cid:104)t\u03b8(\u03b8), (tx(x), 1)(cid:105)} ,\n\nx(\u03b8),\u2212 log Zx(\u03b70\n\nx(\u03b8))(cid:1). The local\n\nAlgorithm 1 Estimate SVAE lower bound and its gradients\nInput: Variational parameters (\u03b7\u03b8, \u03b7\u03b3, \u03c6), data sample y\n\nfunction SVAEGRADIENTS(\u03b7\u03b8, \u03b7\u03b3, \u03c6, y)\n\u03c8 \u2190 r(yn; \u03c6)\n(\u02c6x, \u00aftx, KLlocal) \u2190 PGMINFERENCE(\u03b7\u03b8, \u03c8)\n\u02c6\u03b3 \u223c q(\u03b3)\nL \u2190 N log p(y | \u02c6x, \u02c6\u03b3) \u2212 N KLlocal \u2212 KL(q(\u03b8)q(\u03b3)(cid:107)p(\u03b8)p(\u03b3))\n\n(cid:101)\u2207\u03b7\u03b8L \u2190 \u03b70\nreturn lower bound L, natural gradient (cid:101)\u2207\u03b7\u03b8L, gradients \u2207\u03b7\u03b3 ,\u03c6L\nfunction PGMINFERENCE(\u03b7\u03b8, \u03c8)\nq\u2217(x) \u2190 OPTIMIZELOCALFACTORS(\u03b7\u03b8, \u03c8)\nreturn sample \u02c6x \u223c q\u2217(x), statistics E\n\n\u03b8 \u2212 \u03b7\u03b8 + N (\u00aftx, 1) + N (\u2207\u03b7x log p(y | \u02c6x, \u02c6\u03b3), 0)\n\nq\u2217(x)tx(x), divergence E\n\n(cid:46) Get evidence potentials\n(cid:46) Combine evidence with prior\n(cid:46) Sample observation parameters\n(cid:46) Estimate variational bound\n(cid:46) Compute natural gradient\n\n(cid:46) Fast message-passing inference\nq(\u03b8) KL(q\u2217(x)(cid:107)p(x| \u03b8))\n\nsimple and ef\ufb01cient and can be readily applied to a variety of learning problems and graphical model\nstructures. See the supplementals for details and proofs.\n\n4.1 SVAE model class\n\nTo set up notation for a general SVAE, we \ufb01rst de\ufb01ne a conjugate pair of exponential family densities\non global latent variables \u03b8 and local latent variables x = {xn}N\nn=1. Let p(x| \u03b8) be an exponential\nfamily and let p(\u03b8) be its corresponding natural exponential family conjugate prior, writing\n\nlatent variables x could have additional structure, like including both discrete and continuous latent\nvariables or tractable graph structure, but here we keep the notation simple.\nNext, we de\ufb01ne a general likelihood function. Let p(y | x, \u03b3) be a general family of densities and\nlet p(\u03b3) be an exponential family prior on its parameters. For example, each observation yn may\ndepend on the latent value xn through an MLP, as in the density network model of Section 2.\nThis generic non-conjugate observation model provides modeling \ufb02exibility, yet the SVAE can still\nleverage conjugate exponential family structure in inference, as we show next.\n\n(cid:21)\n\n(cid:20)\n\n4.2 Stochastic variational inference algorithm\nThough the general observation model p(y | x, \u03b3) means that conjugate updates and natural gradient\nSVI [10] cannot be directly applied, we show that by generalizing the recognition network idea we\ncan still approximately optimize out the local variational factors leveraging conjugacy structure.\nFor \ufb01xed y, consider the mean \ufb01eld family q(\u03b8)q(\u03b3)q(x) and the variational inference objective\n\nL[ q(\u03b8)q(\u03b3)q(x) ] (cid:44) E\n\nq(\u03b8)q(\u03b3)q(x)\n\nlog\n\np(\u03b8)p(\u03b3)p(x| \u03b8)p(y | x, \u03b3)\n\nq(\u03b8)q(\u03b3)q(x)\n\n.\n\n(3)\n\nWithout loss of generality we can take the global factor q(\u03b8) to be in the same exponential family\nas the prior p(\u03b8), and we denote its natural parameters by \u03b7\u03b8. We restrict q(\u03b3) to be in the same\nexponential family as p(\u03b3) with natural parameters \u03b7\u03b3. Finally, we restrict q(x) to be in the same\nexponential family as p(x| \u03b8), writing its natural parameter as \u03b7x. Using these explicit variational\nparameters, we write the mean \ufb01eld variational inference objective in Eq. (3) as L(\u03b7\u03b8, \u03b7\u03b3, \u03b7x).\nTo perform ef\ufb01cient optimization of the objective L(\u03b7\u03b8, \u03b7\u03b3, \u03b7x), we consider choosing the variational\nparameter \u03b7x as a function of the other parameters \u03b7\u03b8 and \u03b7\u03b3. One natural choice is to set \u03b7x to be a\nlocal partial optimizer of L. However, without conjugacy structure \ufb01nding a local partial optimizer\nmay be computationally expensive for general densities p(y | x, \u03b3), and in the large data setting this\nexpensive optimization would have to be performed for each stochastic gradient update. Instead, we\n\nchoose \u03b7x by optimizing over a surrogate objective (cid:98)L with conjugacy structure, given by\n(cid:98)L(\u03b7\u03b8, \u03b7x, \u03c6) (cid:44) E\n\n, \u03c8(x; y, \u03c6) (cid:44) (cid:104)r(y; \u03c6), tx(x)(cid:105),\n\np(\u03b8)p(x| \u03b8) exp{\u03c8(x; y, \u03c6)}\n\nq(\u03b8)q(x)\n\n(cid:20)\n\nq(\u03b8)q(x)\n\nlog\n\n(cid:21)\n\n5\n\n\f\u2217\nx(\u03b7\u03b8, \u03c6) (cid:44) arg min\n\n\u03b7x\n\nq\n\n\u2217(x) = exp{(cid:104)\u03b7\n\n(cid:98)L(\u03b7\u03b8, \u03b7x, \u03c6),\n\nx(\u03b7\u03b8, \u03c6) to be a local partial optimizer of (cid:98)L along with the corresponding factor q\u2217(x),\n\nwhere {r(y; \u03c6)}\u03c6\u2208Rm is some parameterized class of functions that serves as the recognition model.\nNote that the potentials \u03c8(x; y, \u03c6) have a form conjugate to the exponential family p(x| \u03b8). We\nde\ufb01ne \u03b7\u2217\n\u2217\nx(\u03b7\u03b8, \u03c6))} .\n\u03b7\nAs with the variational autoencoder of Section 3.2, the resulting variational factor q\u2217(x) is suboptimal\nfor the variational objective L. However, because the surrogate objective has the same form as a\nvariational inference objective for a conjugate observation model, the factor q\u2217(x) not only is easy to\ncompute but also inherits exponential family and graphical model structure for tractable inference.\nGiven this choice of \u03b7\u2217\nx(\u03b7\u03b8, \u03c6)). This\nobjective is a lower bound for the variational inference objective Eq. (3) in the following sense.\nProposition 4.1 (The SVAE objective lower-bounds the mean \ufb01eld objective)\nThe SVAE objective function LSVAE lower-bounds the mean \ufb01eld objective L in the sense that\n\nx(\u03b7\u03b8, \u03c6), the SVAE objective is LSVAE(\u03b7\u03b8, \u03b7\u03b3, \u03c6) (cid:44) L(\u03b7\u03b8, \u03b7\u03b3, \u03b7\u2217\n\n\u2217\nx(\u03b7\u03b8, \u03c6), tx(x)(cid:105) \u2212 log Zx(\u03b7\n\nmax\nq(x) L[ q(\u03b8)q(\u03b3)q(x) ] \u2265 max\n\n\u03b7x L(\u03b7\u03b8, \u03b7\u03b3, \u03b7x) \u2265 LSVAE(\u03b7\u03b8, \u03b7\u03b3, \u03c6) \u2200\u03c6 \u2208 Rm,\nfor any parameterized function class {r(y; \u03c6)}\u03c6\u2208Rm. Furthermore, if there is some \u03c6\u2217\nthat \u03c8(x; y, \u03c6\u2217) = E\nq(\u03b3) log p(y | x, \u03b3), then the bound can be made tight in the sense that\n\u03c6 LSVAE(\u03b7\u03b8, \u03b7\u03b3, \u03c6).\nmax\nq(x) L[ q(\u03b8)q(\u03b3)q(x) ] = max\n\n\u03b7x L(\u03b7\u03b8, \u03b7\u03b3, \u03b7x) = max\n\n\u2208 Rm such\n\nx(\u03b7\u03b8, \u03c6) to be a local partial optimizer of (cid:98)L provides two computational advantages. First,\n\nThus by using gradient-based optimization to maximize LSVAE(\u03b7\u03b8, \u03b7\u03b3, \u03c6) we are maximizing a lower\nbound on the model log evidence log p(y). In particular, by optimizing over \u03c6 we are effectively\nlearning how to condition on observations so as to best approximate the posterior while maintaining\nconjugacy structure. Furthermore, to provide the best lower bound we may choose the recognition\nmodel function class {r(y; \u03c6)}\u03c6\u2208Rm to be as rich as possible.\nChoosing \u03b7\u2217\nit allows \u03b7\u2217\nx(\u03b7\u03b8, \u03c6) and expectations with respect to q\u2217(x) to be computed ef\ufb01ciently by exploiting\nexponential family graphical model structure. Second, it provides a simple expression for an unbiased\nestimate of the natural gradient with respect to the latent model parameters, as we summarize next.\nProposition 4.2 (Natural gradient of the SVAE objective)\nThe natural gradient of the SVAE objective LSVAE with respect to \u03b7\u03b8 is\n\n(cid:101)\u2207\u03b7\u03b8LSVAE(\u03b7\u03b8, \u03b7\u03b3, \u03c6) =(cid:0)\u03b70\n\n\u03b8 + E\n\n(cid:1) + (\u2207\u03b7xL(\u03b7\u03b8, \u03b7\u03b3, \u03b7\n\n\u2217\nx(\u03b7\u03b8, \u03c6)), 0).\n\n(4)\n\nq\u2217(x) [(tx(x), 1)] \u2212 \u03b7\u03b8\n\nNote that the \ufb01rst term in Eq. (4) is the same as the expression for the natural gradient in SVI for\nconjugate models [10], while a stochastic estimate of the second term is computed automatically\nas part of the backward pass for computing the gradients with respect to the other parameters, as\ndescribed next. Thus we have an expression for the natural gradient with respect to the latent model\u2019s\nparameters that is almost as simple as the one for conjugate models and just as easy to compute.\nNatural gradients are invariant to smooth invertible reparameterizations of the variational family [16,\n17] and provide effective second-order optimization updates [18, 11].\nThe gradients of the objective with respect\nto the other variational parameters, namely\n\u2207\u03b7\u03b3LSVAE(\u03b7\u03b8, \u03b7\u03b3, \u03c6) and \u2207\u03c6LSVAE(\u03b7\u03b8, \u03b7\u03b3, \u03c6), can be computed using the reparameterization trick.\nTo isolate the terms that require the reparameterization trick, we rearrange the objective as\n\nLSVAE(\u03b7\u03b8, \u03b7\u03b3, \u03c6) = E\n\nq(\u03b3)q\u2217(x) log p(y | x, \u03b3) \u2212 KL(q(\u03b8)q\n\n\u2217(x)(cid:107) p(\u03b8, x)) \u2212 KL(q(\u03b3)(cid:107) p(\u03b3)).\n\nThe KL divergence terms are between members of the same tractable exponential families. An\nunbiased estimate of the \ufb01rst term can be computed by sampling \u02c6x \u223c q\u2217(x) and \u02c6\u03b3 \u223c q(\u03b3) and\ncomputing \u2207\u03b7\u03b3 ,\u03c6 log p(y | \u02c6x, \u02c6\u03b3) with automatic differentiation. Note that the second term in Eq. (4)\nis automatically computed as part of the chain rule in computing \u2207\u03c6 log p(y | \u02c6x, \u02c6\u03b3).\n5 Related work\n\nIn addition to the papers already referenced, there are several recent papers to which this work is\nrelated. The two papers closest to this work are Krishnan et al. [19] and Archer et al. [20].\n\n6\n\n\f(a) Predictions after 200 training steps.\n\n(b) Predictions after 1100 training steps.\n\nFigure 4: Predictions from an LDS SVAE \ufb01t to 1D dot image data at two stages of training. The\ntop panel shows an example sequence with time on the horizontal axis. The middle panel shows the\nnoiseless predictions given data up to the vertical line, while the bottom panel shows the latent states.\n\n(a) Natural (blue) and standard (orange) gradient updates.\n(b) Subspace of learned observation model.\nFigure 5: Experimental results from LDS SVAE models on synthetic data and real mouse data.\n\nIn Krishnan et al. [19] the authors consider combining variational autoencoders with continuous\nstate-space models, emphasizing the relationship to linear dynamical systems (also called Kalman\n\ufb01lter models). They primarily focus on nonlinear dynamics and an RNN-based variational family, as\nwell as allowing control inputs. However, the approach does not extend to general graphical models\nor discrete latent variables. It also does not leverage natural gradients or exact inference subroutines.\nIn Archer et al. [20] the authors also consider the problem of variational inference in general\ncontinuous state space models but focus on using a structured Gaussian variational family without\nconsidering parameter learning. As with Krishnan et al. [19], this approach does not include discrete\nlatent variables (or any latent variables other than the continuous states). However, the method they\ndevelop could be used with an SVAE to handle inference with nonlinear dynamics.\nIn addition, both Gregor et al. [21] and Chung et al. [22] extend the variational autoencoder framework\nto sequential models, though they focus on RNNs rather than probabilistic graphical models.\n6 Experiments\n\nWe apply the SVAE to both synthetic and real data and demonstrate its ability to learn feature\nrepresentations and latent structure. Code is available at github.com/mattjj/svae.\n\n6.1 LDS SVAE for modeling synthetic data\n\nConsider a sequence of 1D images representing a dot bouncing from one side of the image to the\nother, as shown at the top of Fig. 4. We use an LDS SVAE to \ufb01nd a low-dimensional latent state\nspace representation along with a nonlinear image model. The model is able to represent the image\naccurately and to make long-term predictions with uncertainty. See supplementals for details.\nThis experiment also demonstrates the optimization advantages that can be provided by the natural\ngradient updates. In Fig. 5a we compare natural gradient updates with standard gradient updates at\nthree different learning rates. The natural gradient algorithm not only learns much faster but also\nis less dependent on parameterization details: while the natural gradient update used an untuned\n\n7\n\n01000200030004000iteration\u221215\u221210\u221250510\u2212L\fFigure 6: Predictions from an LDS SVAE \ufb01t to depth video. In each panel, the top is a sampled\nprediction and the bottom is real data. The model is conditioned on observations to the left of the line.\n\n(a) Extension into running\n\n(b) Fall from rear\n\nFigure 7: Examples of behavior states inferred from depth video. Each frame sequence is padded on\nboth sides, with a square in the lower-right of a frame depicting when the state is the most probable.\n\nstepsize of 0.1, the standard gradient dynamics at step sizes of both 0.1 and 0.05 resulted in some\nmatrix parameters to be updated to inde\ufb01nite values.\n\n6.2 LDS SVAE for modeling video\n\nWe also apply an LDS SVAE to model depth video recordings of mouse behavior. We use the dataset\nfrom Wiltschko et al. [3] in which a mouse is recorded from above using a Microsoft Kinect. We\nused a subset consisting of 8 recordings, each of a distinct mouse, 20 minutes long at 30 frames per\nsecond, for a total of 288000 video fames downsampled to 30 \u00d7 30 pixels.\nWe use MLP observation and recognition models with two hidden layers of 200 units each and a 10D\nlatent space. Fig. 5b shows images corresponding to a regular grid on a random 2D subspace of the\nlatent space, illustrating that the learned image manifold accurately captures smooth variation in the\nmouse\u2019s body pose. Fig. 6 shows predictions from the model paired with real data.\n\n6.3 SLDS SVAE for parsing behavior\n\nFinally, because the LDS SVAE can accurately represent the depth video over short timescales, we\napply the latent switching linear dynamical system (SLDS) model to discover the natural units of\nbehavior. Fig. 7 shows some of the discrete states that arise from \ufb01tting an SLDS SVAE with 30\ndiscrete states to the depth video data. The discrete states that emerge show a natural clustering of\nshort-timescale patterns into behavioral units. See the supplementals for more.\n\n7 Conclusion\n\nStructured variational autoencoders provide a general framework that combines some of the strengths\nof probabilistic graphical models and deep learning methods. In particular, they use graphical models\nboth to give models rich latent representations and to enable fast variational inference with CRF\nstructured approximating distributions. To complement these structured representations, SVAEs use\nneural networks to produce not only \ufb02exible nonlinear observation models but also fast recognition\nnetworks that map observations to conjugate graphical model potentials.\n\n8\n\n\fReferences\n\n[1] Daphne Koller and Nir Friedman. Probabilistic graphical models: principles and techniques.\n\nMIT Press, 2009.\n\n[2] Kevin P Murphy. Machine Learning: a Probabilistic Perspective. MIT Press, 2012.\n[3] Alexander B. Wiltschko, Matthew J. Johnson, Giuliano Iurilli, Ralph E. Peterson, Jesse M.\nKaton, Stan L. Pashkovski, Victoria E. Abraira, Ryan P. Adams, and Sandeep Robert Datta.\n\u201cMapping Sub-Second Structure in Mouse Behavior\u201d. In: Neuron 88.6 (2015), pp. 1121\u20131135.\n[4] Geoffrey Hinton, Li Deng, Dong Yu, George E Dahl, Abdel-rahman Mohamed, Navdeep Jaitly,\nAndrew Senior, Vincent Vanhoucke, Patrick Nguyen, Tara N Sainath, et al. \u201cDeep neural\nnetworks for acoustic modeling in speech recognition: The shared views of four research\ngroups\u201d. In: Signal Processing Magazine, IEEE 29.6 (2012), pp. 82\u201397.\n\n[5] Li Deng. \u201cComputational models for speech production\u201d. In: Computational Models of Speech\n\nPattern Processing. Springer, 1999, pp. 199\u2013213.\n\n[6] Li Deng. \u201cSwitching dynamic system models for speech articulation and acoustics\u201d. In:\nMathematical Foundations of Speech and Language Processing. Springer, 2004, pp. 115\u2013133.\n[7] Diederik P. Kingma and Max Welling. \u201cAuto-Encoding Variational Bayes\u201d. In: International\n\nConference on Learning Representations (2014).\n\n[8] Danilo J Rezende, Shakir Mohamed, and Daan Wierstra. \u201cStochastic Backpropagation and\nApproximate Inference in Deep Generative Models\u201d. In: Proceedings of the 31st International\nConference on Machine Learning. 2014, pp. 1278\u20131286.\n\n[9] Matthew J. Johnson and Alan S. Willsky. \u201cStochastic Variational Inference for Bayesian Time\n\nSeries Models\u201d. In: International Conference on Machine Learning. 2014.\n\n[10] Matthew D. Hoffman, David M. Blei, Chong Wang, and John Paisley. \u201cStochastic variational\n\ninference\u201d. In: Journal of Machine Learning Research (2013).\nJames Martens. \u201cNew insights and perspectives on the natural gradient method\u201d. In: arXiv\npreprint arXiv:1412.1193 (2015).\n\n[11]\n\n[12] David J.C. MacKay and Mark N. Gibbs. \u201cDensity networks\u201d. In: Statistics and neural networks:\n\nadvances at the interface. Oxford University Press, Oxford (1999), pp. 129\u2013144.\n\n[13] Tomoharu Iwata, David Duvenaud, and Zoubin Ghahramani. \u201cWarped Mixtures for Nonpara-\nmetric Cluster Shapes\u201d. In: 29th Conference on Uncertainty in Arti\ufb01cial Intelligence. 2013,\npp. 311\u2013319.\n\n[14] E.B. Fox, E.B. Sudderth, M.I. Jordan, and A.S. Willsky. \u201cBayesian Nonparametric Inference\nof Switching Dynamic Linear Models\u201d. In: IEEE Transactions on Signal Processing 59.4\n(2011).\n\n[15] Martin J. Wainwright and Michael I. Jordan. \u201cGraphical Models, Exponential Families, and\n\nVariational Inference\u201d. In: Foundations and Trends in Machine Learning (2008).\n\n[16] Shun-Ichi Amari. \u201cNatural gradient works ef\ufb01ciently in learning\u201d. In: Neural computation\n\n10.2 (1998), pp. 251\u2013276.\n\n[17] Shun-ichi Amari and Hiroshi Nagaoka. Methods of Information Geometry. American Mathe-\n\nmatical Society, 2007.\nJames Martens and Roger Grosse. \u201cOptimizing Neural Networks with Kronecker-factored\nApproximate Curvature\u201d. In: arXiv preprint arXiv:1503.05671 (2015).\n\n[18]\n\n[19] Rahul G Krishnan, Uri Shalit, and David Sontag. \u201cDeep Kalman Filters\u201d. In: arXiv preprint\n\narXiv:1511.05121 (2015).\n\n[20] Evan Archer, Il Memming Park, Lars Buesing, John Cunningham, and Liam Paninski. \u201cBlack\nbox variational inference for state space models\u201d. In: arXiv preprint arXiv:1511.07367 (2015).\n[21] Karol Gregor, Ivo Danihelka, Alex Graves, and Daan Wierstra. \u201cDRAW: A recurrent neural\n\nnetwork for image generation\u201d. In: arXiv preprint arXiv:1502.04623 (2015).\nJunyoung Chung, Kyle Kastner, Laurent Dinh, Kratarth Goel, Aaron C Courville, and Yoshua\nBengio. \u201cA recurrent latent variable model for sequential data\u201d. In: Advances in Neural\ninformation processing systems. 2015, pp. 2962\u20132970.\n\n[22]\n\n9\n\n\f", "award": [], "sourceid": 1474, "authors": [{"given_name": "Matthew", "family_name": "Johnson", "institution": "MIT"}, {"given_name": "David", "family_name": "Duvenaud", "institution": "University of Toronto"}, {"given_name": "Alex", "family_name": "Wiltschko", "institution": "Harvard University and Twitter"}, {"given_name": "Ryan", "family_name": "Adams", "institution": "Harvard and Twitter"}, {"given_name": "Sandeep", "family_name": "Datta", "institution": "Harvard Medical School"}]}