Machine Learning

Stein’s identity and Stein Variational Gradient Descent


If you have an target distribution p(x) that you want to model, but you can’t compute it (maybe because it involves a partition function), then starting with an initial set of particles \left\{x_i^0\right\}_{i=1}^n, you can iteratively update them as following:

\displaystyle x_i^{l+1} \leftarrow x_i^l + \epsilon_l\hat{\phi}^*\left(x_i^l\right)


  • \epsilon_l is a small step size at iteration l
  • \displaystyle \hat{\phi}^*\left(x\right) = \frac{1}{n}\sum_{j=1}^n \left[k\left(x_j^l, x\right)\nabla_{x_j^l}\log p\left(x_j^l\right) + \nabla_{x_j^l}k\left(x_j^l, x\right)\right] is the steepest update direction
  • k\left( \cdot , \cdot \right) is a kernel, typically RBF.

At the end, this update rule gives a set of particle \left\{x_i\right\}_{i=1}^n that approximates the target distribution.

Note that to compute the update direction, you only need to compute the derivative of the log of the (unnormalized) target distribution with respect to a set of samples \left\{x_j^l\right\}_{j=1}^n.

This result is significant because it is analogous to gradient descent that minimizes a KL divergence. The set of particles \left\{x_i^0\right\}_{i=1}^n, therefore, normally comes from another black-box model.

An interesting intuition is that this direction will push the particles into the regions with high values of p(x), while the second term (derivatives of the kernel), for the case of RBF, has “regularization” effect, which keeps the particles away from each other, prevent them from collapsing into the same mode. This is how the method is probably better than pure MCMC.

The short story

Let x be a continuous random variables in \mathcal{X} \subset \mathbb{R}^d and p\left(x\right) is the (intractable) target distribution. Let \phi\left(x\right) = \left[\phi_1\left(x\right), ..., \phi_d\left(x\right)\right]^T a smooth vector-valued function, the so-called Stein‘s identity says:

\displaystyle \mathbb{E}_{x \sim p}\left[\mathcal{A}_p\phi\left(x\right)\right] = 0

where \mathcal{A}_p\phi\left(x\right) = \phi\left(x\right)\nabla_x\log p\left(x\right)^T + \nabla_x\phi\left(x\right) and \mathcal{A}_p is called the Stein operator.

Let q\left(x\right) be another distribution. Now \mathbb{E}_{x \sim q}\left[\mathcal{A}_p\phi\left(x\right)\right] will no longer be zero. It turns out this can be used to define the discrepancy between two distributions p and q:

\displaystyle \mathbb{S}\left(q, p\right) = \max_\phi \left[\mathbb{E}_{x \sim q}\text{trace}\left(\mathcal{A}_p\phi\left(x\right)\right)\right]^2

Meaning we consider all possible smooth function \phi and use the one that maximizes the violation of the Stein’s identity. This maximum violation is defined to be the Stein discrepancy between q and p.

Considering all possible \phi is impractical. But it turns out if we consider \phi in the unit ball of a reproducing kernel Hilbert space (RKHS) \mathcal{H}^d, then the kernelized Stein discrepancy is defined as:

\displaystyle \mathbb{KS}\left(q, p\right) = \max_{\phi\in\mathcal{H}^d} \left[\mathbb{E}_{x \sim q}\text{trace}\left(\mathcal{A}_p\phi\left(x\right)\right)\right]^2\quad \text{s.t.} \parallel \phi\parallel_{\mathcal{H}^d} \leq 1

and this has a closed-form solution: \displaystyle \phi_{q,p}^*\left(\cdot\right) = \mathbb{E}_{x \sim q}\left[\mathcal{A}_pk\left(x,\cdot\right)\right].

Now if we take T\left(x\right) = x + \epsilon\phi\left(x\right) and q_{[T]} is the distribution of T when x \sim q(x), then [1] shows that the derivatives of the KL divergence between q_{[T]} and p has an interesting form:

\nabla_\epsilon KL\left(q_{[T]} \parallel p\right)\vert_{\epsilon = 0} = -\mathbb{E}_{x \sim q}\left[\text{trace}\left(\mathcal{A}_p\phi\left(x\right)\right)\right]

(without the square)

Relating to all the above, we now have the direction that minimizes the KL divergence is the expectation of the Stein operator:

\displaystyle \phi_{q,p}^*\left(\cdot\right) = \mathbb{E}_{x \sim q}\left[k(x,\cdot)\nabla_x\log p(x) + \nabla_x k(x,\cdot)\right]

The long story

The above is a huge simplification, don’t take it too seriously. If you are dying for more details, the following might help:

[1] introduce the kernelized Stein discrepancy
[2] presents the SVGD algorithm
[3] uses the algorithm to estimate parameters of an energy-based deep neural net.



Random notes on stream-first architecture, Flink and friends

Use-cases and Technical requirements

The main use-cases for stream processing systems are applications that need to process a large amount of data coming at high throughput (say high velocity) with very low latency. Low latency is important, because in some use-cases where computations don’t need to be real-time, e.g. computing monthly cost on AWS, preparing quarterly business reports, re-indexing websites for a search engine, etc… then a traditional batch processing systems like MapReduce or Spark would be a better fit. Time-critical applications like processing events in online games or financial applications, monitoring and predicting device failures, processing critical performance metrics, etc.., however, requires real-time or near real-time insights. In those cases, a stream processing system would be more suitable.

That being said, when we say stream processing systems, we mean tools that allow users to perform computations on-the-fly. Related technologies for simple message queues (without computations), especially in the context of the microservice madness, can sometimes overlap with stream processing technologies, but they should be considered separately in various facets and contexts.

The pioneer in stream processing systems was Apache Storm. Storm can provide low latency, but word on the streets is that it is hard to get high throughput. Moreover, Storm doesn’t guarantees exactly-once semantic (more on this later).

A clever compromise to get exactly-once semantic at a relatively low latency was to cut the data stream into small batches (streaming as a special case of batch). Since computations are done on batches, it can be redone when failure happens. With small batches, this can approximate streaming pretty well and is suitable for applications that aren’t too picky on latency. However, cutting the stream into batches have inherit problems dealing with the semantic of times, and can lead to incorrect results when events do not arrive in order (which happens quite often in distributed systems). Nonetheless, this is the approach taken by Spark Streaming. Storm Trident is an improvement over Storm, also seems to take this micro-batch approach.

Message transport systems

In order to feed data into stream processing engines, we need a message transport system, which does what it should: transport messages. Although it might sound simple, doing this at scale is quite challenging. An ideal message transport system for a stream applications must:

  • provide persistence, i.e. the ability to rewind and replay the stream (i.e. time travel). It was once thought that a message transport system can’t have both performance and persistence. Lately, it turns out that this tradeoff is not fundamental, and systems like Apache Kafka can provide high throughput delivery with persistence. This is one of the important revolutions for the streaming community.
  • decouple producers and consumers. A good transport system doesn’t send messages point-to-point, because doing so in a scalable and fault-tolerant way is difficult. One of the solution that Kafka employs is to allow users to create topics, to which producers can send messages to, and consumers can subscribe from.

Having a good message transport system is critical for the stream processing engine down the road.

The notion of Time

In streaming applications, there are 2 notions of time: event time, which is the time that the message was generated in real-world, often stored as a timestamp along with the message, and processing time, which is the time that the message is observed by the machine processing it, normally measured by the system clock of that machine.

Now messing up with times is dangerous. To perform computations, stream processing engines divide the streams into (sliding) windows, based on the time of each message. It is therefore important to specify which time they are using, otherwise it can lead to incorrect results, as demonstrated in an interesting experiment here.

Micro-batching system, like Spark Streaming, only accepts time windows, and it is therefore tricky to get things right.

The bible in handling times in streaming applications is this VLDB paper from Google Dataflow (now Apache Beam), which was one of the inspirations of Flink. The following section was written based on this paper.

Windows, Triggers, and Watermarks

There are different kinds of windows supported by modern stream processing engines:

  • Time windows
  • Count windows
  • Session windows (or timeout windows): based on the gaps between messages, suitable for streams of user interactions on a website, for instance.

All kinds of windows are special cases of triggers – which define when the content of a window is aggregated and returned to the output stream. Good stream processing engines should allow users to specify custom trigger logic.

Another aspect of time is time travel – the ability to rewind the clock and repeat the computation at some certain time in the past. This is important in real-world applications where users need to replace their applications with a new version, or do A/B testing between different implementations of an algorithm, etc… Obviously if we only use processing times, time travel would be impossible, because the new version of the application would just simply process the latest messages coming right now. Therefore, serious stream processing applications must support event times (and coupled with a message transport that provides persistence).

Event times are tricky to work with though. Since events can arrive to the system out-of-order, when it sees the event happened at time 10:01 and then at 10:10, how does it knows that there won’t be another event happened at 10:05 coming to the system? In other words, there needs to be a clock defined by the data, so that event time windows can be correctly defined.

Stream processing engines deal with this problem by the concept of Watermarks, which are embedded messages in the stream, bearing a time marker t, notifying that there won’t be any other event happened before t coming to the stream. With watermarks, event time progresses independently from processing time. If a watermark is late (in processing time), the results will still be computed correctly, albeit it will be late in the result stream.

Stream processing engines would allow developers to define watermarks, because that requires some understanding of the domain, e.g. they know that their messages can’t be late more than 5 seconds, then the watermark will be the current time minus 5 seconds. This can get as elaborated/fancy as having a Machine Learning model to predict how late the messages come to the stream.

Stateful computation and exactly-once semantic

Most interesting stream-based applications are stateful. If the engine doesn’t support stateful, application developers would need to take care of that (by recruiting a database, for instance). More useful engines, however, can take care of the states and recomputing it in case of failures. Doing this properly will make sure the engine provides exactly-once semantics (the output stream will be correctly computed even in case of failures), but normally it is difficult to do, both at the semantic and implementation levels.

Folks behind Flink came up with a solution described in this paper, called Asynchronous Distributed Snapshots. This algorithm allows states to be restored correctly in case of failure, even when there are multiple workers.

The core idea is simple to understand. They inject checkpoints periodically into the coming stream. If a checkpoint event is processed, each worker will dump its states into a persistence storage. When failure happens on one of the worker, it will notify every other workers, everyone resets their states at the same previous checkpoint, and start the computation again (keep in mind that they can do time travel thanks to event times and watermarks). In order for this to work, the checkpoint ID needs to be synchronized across workers. Some tricks also need to be implemented at the output stage to make sure the output stream doesn’t contain duplicated entries. At the end, this checkpointing mechanism will ensure exactly-once semantic with relatively minimal effect on overall latency.

The same checkpointing mechanism allow Flink to create savepoints, at which moment the state of the whole system is saved, and can be loaded later. This is extremely useful in practice since it allows new version of the streaming application to pick up (reasonably) arbitrary savepoints of previous versions.

Now that we have stateful exactly-one semantics, we can even allow users to query the states directly, which makes stream processors behave more or less like a traditional database. This is still on-going research.

Stream-first architecture, Lambda architecture, Kappa architecture

Fancy words for fancy tech. They are all about employing a message transport, like Kafka, as a courier of messages, and some form of stream processors. Sometimes the message transport can dump data periodically into a persistence layer (like HDFS), and batch jobs will be scheduled to pick up latest data and provide exact insights. However, again, this approach is only suitable for applications that don’t require real-time insights.

An interesting use-case of streaming is to predict failures on a stream of sensor data. A feasible scalable approach would be to write a Flink application that monitor the sensor data stream and output a stream of statistics, stored in a Kafka topic. This stream will be then picked up by another Flink application that does the Machine Learning job, and raise alert if anomaly is detected.


Stream processors might have difficulties in batch computation (although batches can be obviously seen as special cases of streams – they are streams that have a definite start and end), e.g. running SQL queries on Flink is still an open effort. Doing SQL properly on streams has always been tricky.

Other computations traditionally suited for batch processors, like Machine Learning, is also tricky to do on systems designed for streaming.

Finally, like any other emerging technologies, streaming might be hyped. Picking the right solution should be based on the nature of the problems, not on the trends.

Simpson’s paradox

I learned about the Simpson’s paradox fairly recently, and I found it quite disturbing, not because of the mere “paradox” itself, but mainly because I felt it was something I should have known already.

In case you haven’t heard about it, one instance of the paradox is a real-world medical study for comparing the success rate of two treatments for kidney stones (from Wikipedia):


Overall, Treatment B is better because its success rate is 83%, compared to 78% of Treatment A. However, when they split the patients into 2 groups: those with small stones and those with large stones, then Treatment A is better than Treatment B in both subgroups. Paradoxical enough?

Well, it’s not. It turns out that for severe cases (large stones), doctors tend to give the more effective Treatment A, while for milder cases with small stones, they tend to give the inferior Treatment B. Therefore the sum is dominant by group 2 and group 3, while the other groups contribute little to the final sums. So the results can be interpreted more accurately as: when Treatment B is more frequently applied to less severe cases, it can appear to be more effective.

Now, knowing that Treatment and Stone size are not independent, this should not come up as a paradox. In fact, we can visualize the problem as a graphical model like this

Untitled drawing

All the numbers in the table above can be expressed as conditional probabilities like so:

  • Group 1: p\left(S=true \vert T=A, St=small\right) = 0.93
  • Group 2: p\left(S=true \vert T=B, St=small\right) = 0.87
  • Group 3: p\left(S=true \vert T=A, St=large\right) = 0.73
  • Group 4: p\left(S=true \vert T=B, St=large\right) = 0.69
  • p\left(S=true \vert T=A\right) = 0.78
  • p\left(S=true \vert T=B\right) = 0.83

For any of us who studied Probability, it is no surprise that the probabilities might turn up-side-down whenever some conditional variables are stripped out of the equations. In this particular case, since S depends on both St and T, the last 2 equations do not bring any new knowledge about S.

So what is this “paradox” about? Isn’t it nothing more than the problem of confounding/lurking variables, something that most people in Probability/Statistics already known? In this particular case, Stone size is the lurking variable that dictates both Treatment and Success, therefore the scientists who designed the experiment should have taken it into account since the beginning. It is well-known among Statistic practitioners that they must try their best to identify and eliminate the effect of any lurking variables in their experiments, or at least keep them fixed, before drawing any meaningful conclusion.

From a slightly different perspective, the paradox can be understood once we understand the human bias of drawing causal relations. Human, perhaps for the sake of survival, constantly look for causal relations and often tend to ignore rates or proportions. Once we conceived something as being causal (Treatment B gives higher success rate than Treatment A in general), which might be wrong, we continue to assume a causal relation and proceed with that assumption in mind. Obviously with this assumption, we will find the success rates for the subgroups of patients to be highly counter-intuitive, or even paradoxical.

In fact, the connection of this paradox to human intuitions is so important that Judea Pearl dedicated a whole section in his book for it. Modern Statistical textbooks and curriculum, however, don’t even mention it. Instead they will generally present the topic along with lurking/confounding variables.

Therefore, if you haven’t heard about this, it is probably for a good reason, or perhaps you are simply too young.

Albert Einstein and random thoughts on Machine Learning


I read Einstein’s biography with as much enthusiasm as I did with Stephen Hawking’s A brief history of Time and Domigos’ The Master Algorithm. It’s not only because the book is recommended by, among others, Elon Musk, but probably more because of my childhood dream of becoming a physicist back when I was in high school. Although I was too dumb for physics, nothing could prevent me from admiring its beauty.

The book was excellent. Walter Isaacson did a great job in depicting Albert Einstein from his complicated personality to his belief, religion, politics, and, of course, his scientific achievements.

As a human being, Einstein is a typical introvert. He was always a loner, enjoyed discussing ideas more than any personal or emotional entanglements. During the most difficult periods of life, he would rather immerse into science rather than get up and do anything about his real-life struggles. To quote Isaacson, “the stubborn patience that Einstein displayed when dealing with scientific problems was equaled by his impatience when dealing with personal entanglements”, those that put “emotional burdens” on him. Some may criticise and regard him as being “cold-hearted”, but perhaps for him, it was way easier to use the analytical brain rather than the emotional brain to deal with daily mundane affairs. This, often times, resulted in what we can consider as brutal acts, like when he gave Mileva Maric a list of harsh terms in order to stay with him, or when he did almost nothing for his first kid, and let it die in Serbia. For this aspect, though, he probably deserves more sympathy than condemnation. He was surely a complicated man, and expecting him to be well-rounded in handling personal affairs is perhaps as unreasonable as impossible.

Now, it is perhaps worth emphasizing that Einstein is a one-in-a-million genius who happens to have those personality traits. It does not imply those who have those traits are genius. Correlation does not imply causation 😉

Einstein made his mark in physics back in 1905 when he challenged Newton’s classical physics. He was bold and stubborn in challenging long-standing beliefs in science that are not backed by any experimental results. Unfortunately during his 40s, quantum physics made the same rationale, and he just couldn’t take it, although he contributed a great amount of results in its development (to this he amusingly commented: a good joke should not be repeated too often). His quest to look for a unified field theory that can explain both gravitational field and electromagnetic field by a consise set of rule was failed, and eventually quantum mechanics, with a probabilistic approach, was widely accepted. This saga tells us a lot:

  • The argument Einstein had with the rest of physicists back in 1910s on his breakthrough in relativity theory was almost exactly the same with the argument Neils Bohr had in 1930s on quantum theory, except that in 1930s, Einstein was on the conservative side. In 1910s, people believed time is absolute, Einstein shown that was wrong. In 1930s, Neils Bohr used probabilistic models to describe subatomic world, while Einstein resisted, because he didn’t believe Nature was “playing dice”.
    Perhaps amusingly, one can draw some analogies in Machine Learning. Einstein’s quest to describe physics in a set of rules sounds like Machine Learners trying to build rule-based systems back in 1980s. That effort failed and probabilistic models took advantages until today. The world is perhaps too complicated to be captured in a deterministic systems, while looking at it from a different angle, probability provides a neat mathematical framework to describe uncertainties that Nature seems to carry. While it seems impossible to describe any complicated system deterministically, it is perhaps okay to describe them probabilistically, although it might not explain how the system was created in the first place.
  • During the 1930s, in a series of lonely, depressing attempts to unify field theories, Einstein sounds a lot like… Geoff Hinton who attempted to explain how the human brain works. Actually, those are perhaps not too far from each other. The brain is eventually the 3-pound universe of mankind, and completely understanding the brain is probably as hard as understanding the universe.
  • Being a theorist his whole life, Einstein’s approach to physics is quite remarkable. He never started from experimental results, but often drawn insights at the abstract level, then proceed with intuitive thought experiments, and then went on with rigorous mathematical frameworks. He would often end his papers with a series of experimental studies that could be used to confirm his theory. This top-down approach is very appealing and became widely adopted in physics for quite a long time.
    On the contrary, many researches in Machine Learning are often bottom-up. Even the master algorithm proposed in Domigos’ book is too bottom-up to be useful. Computer Science, after all, is an applied science in which empirical results are often too emphasized. In particular, Machine Learning research are heavily based on experiments, and theories that justify those experiments often came long after, if there was any. To be fair, there are some work that come from rigorous mathematical inference, like LSTM, SELU and similar ideas, but a lot of breakthroughs in the field are empirical, like Convolutional nets, GANs and so on.
    Looking forward, drawing insights from Neuroscience is probably a promising way of designing Machine Learning systems in a top-down fashion. After all, human brain is the only instance of general intelligence that we known of by far, and the distribution of those generally intelligent devices might be highly skewed and sparse, hence drawing insights from Neuroscience is perhaps our best hope.
  • The way Einstein became an international celebrity was quite bizarre. He officially became celebrity after he paid visits to America for a series of fund-raising events for a Zionist cause in Israel. The world at the time was heavily divided after World War I, and the media was desperately looking for a international symbol to heal the wounds. Einstein, with his self-awareness, twinkling eyes and a good sense of humour, was too ready to become one. American media is surely good in this business, and the rest became history.
  • Einstein’s quest of understanding universe bears a lot of similarities with Machine Learner’s quest of building general AI systems. However, while computer scientists are meddling with our tiny superficial simulations on computers, physicists are looking for ways to understand the universe. Putting our work along side physicists’, we should probably feel humbled and perhaps a bit embarrassing.

It was amazing and refreshing to revise Einstein’s scientific journey about 100 years ago, and with a bit of creativity, one could draw many lessons that are still relevant to the research community today. Not only that, the book gives a well-informed picture about Einstein as a human, with all flaws and weaknesses. Those flaws do not undermine his genius, but on the contrary, make us readers respect him even more. Therefore Einstein is, among others, an exemplar for how much an introvert can contribute to the humankind.

For those of us who happen to live in Berlin, any time you sit in Einstein Kaffee and sip a cup of delighting coffee, perhaps you could pay some thoughts to the man who lived a well-lived life, achieved so much and also missed so much (although the Kaffe itself has nothing to do with Einstein). Berlin, after all, is where Einstein spent 17 years of his life. It is where he produced the general relativity theory – the most important work in his career, it is the only city he considered to be home throughout his bohemian life.

[RL4a] Policy Optimization

I thought I would write about some theory behind Reinforcement Learning, with eligibility traces, contraction mapping, POMDP and so on, but then I realized if I go down that rabbit hole, I would probably never finish this. So here are some practical stuff that people are actually using these days.

Two main approaches in model-free RL are policy-based, where we build a model of the policy function, and value-based, where we learn the value function. For many practical problems, the structure of the (unknown) policy function might be easier to learn than the (unknown) value function, in which case it makes more sense to use policy-based methods. Moreover, with policy-based methods, we will have an explicit policy, which is the ultimate goal of Reinforcement learning to control (the other type being learn-to-predict, more on that some other time). With value-based methods, like DQN, we will need to do an additional inference step to get the optimal policy.

The hybrid of policy-based and value-based is called Actor-Critic methods, which hopefully will be mentioned in another post.

One of the most straightforward approach in policy-based RL is, unsurprisingly, evolutionary algorithms. In this approach, a population of policies is maintained and evolutionized over time. People show that this works pretty well, e.g. for the Tetris game. However, due to the randomness, this is apparently only applied to problems where the number of parameters of the policy function is small.

A big part of policy-based methods, however, is based on Policy Gradient, where an exact estimate of the gradient of the expected future reward can be computed. Roughly speaking, there is an exact formulation for the gradient of the policy, which we can then use to optimize the policy. Since Deep Learning people basically worship gradients (pun intended), this method suites very well and became trendy about 2 years ago.

So, what is it all about? It is based on a very simple trick called Likelihood Ratio.


Kalman filters (and how they relate to HMMs)

Kalman filters are insanely popular in many engineering fields, especially those involve sensors and motion tracking. Consider how to design a radar system to track military aircrafts (or warships, submarines, … for that matter), how to track people or vehicles in a video stream, how to predict location of a vehicle carrying a GPS sensor… In all these cases, some (advanced) variation of Kalman filter is probably what you would need.

Learning and teaching Kalman filters is therefore quite challenging, not only because of the mere complexity of the algorithms, but also because there are many variations of them.

With a Computer Science background, I encountered Kalman filters several years ago, but never managed to put them into the global context of the field. I had chances to look at them again recently, and rediscovered yet another way to present and explain Kalman filters. It made a lot of sense to me, and hopefully it does to you too.

Note that there are a lot of details missing from this post (if you are building a radar system to track military aircrafts, look somewhere else!). I was just unhappy to see many introductory material on Kalman filters are either too engineering or too simplified. I want something more Machine Learning-friendly, so this is my attempt.

Let’s say you want to track an enemy’s aircraft. All you have is a lame radar (bought from Russia probably) which, when oriented properly, will give you a 3-tuple of range, angle and angular velocity [r \;\phi\;\dot{\phi}]^{T} of the aircraft. This vector is called the observation \mathbf{z}_k (subscript k because it depends on time). The actual position of the aircraft, though, is a vector in cartesian coordinates \mathbf{x}_k = [x_1\;x_2\;x_3]^{T}. Since it is an enemy’s aircraft, you can only observe \mathbf{z}_k, and you want to track the state vector \mathbf{x}_k over time, every time you receive a measurement \mathbf{z}_k from the radar.

Visualised as a Bayesian network, it looks like this:

Untitled Diagram (1)

With all the Markov properties hold, i.e. \mathbf{x}_k only depends on \mathbf{x}_{k-1} and \mathbf{z}_k only depends on \mathbf{x}_k, does this look familiar?


Variational Autoencoders 3: Training, Inference and comparison with other models

Variational Autoencoders 1: Overview
Variational Autoencoders 2: Maths
Variational Autoencoders 3: Training, Inference and comparison with other models

Recalling that the backbone of VAEs is the following equation:

\log P\left(X\right) - \mathcal{D}\left[Q\left(z\vert X\right)\vert\vert P\left(z\vert X\right)\right] = E_{z\sim Q}\left[\log P\left(X\vert z\right)\right] - \mathcal{D}\left[Q\left(z\vert X\right) \vert\vert P\left(z\right)\right]

In order to use gradient descent for the right hand side, we need a tractable way to compute it:

  • The first part E_{z\sim Q}\left[\log P\left(X\vert z\right)\right] is tricky, because that requires passing multiple samples of z through f in order to have a good approximation for the expectation (and this is expensive). However, we can just take one sample of z, then pass it through f and use it as an estimation for E_{z\sim Q}\left[\log P\left(X\vert z\right)\right] . Eventually we are doing stochastic gradient descent over different sample X in the training set anyway.
  • The second part \mathcal{D}\left[Q\left(z\vert X\right) \vert\vert P\left(z\right)\right] is even more tricky. By design, we fix P\left(z\right) to be the standard normal distribution \mathcal{N}\left(0,I\right) (read part 1 to know why). Therefore, we need a way to parameterize Q\left(z\vert X\right) so that the KL divergence is tractable.

Here comes perhaps the most important approximation of VAEs. Since P\left(z\right) is standard Gaussian, it is convenient to have Q\left(z\vert X\right) also Gaussian. One popular way to parameterize Q is to make it also Gaussian with mean \mu\left(X\right) and diagonal covariance \sigma\left(X\right)I, i.e. Q\left(z\vert X\right) = \mathcal{N}\left(z;\mu\left(X\right), \sigma\left(X\right)I\right), where \mu\left(X\right) and \sigma\left(X\right) are two vectors computed by a neural network. This is the original formulation of VAEs in section 3 of this paper.

This parameterization is preferred because the KL divergence now becomes closed-form:

\displaystyle \mathcal{D}\left[\mathcal{N}\left(\mu\left(X\right), \sigma\left(X\right)I\right)\vert\vert P\left(z\right)\right] = \frac{1}{2}\left[\left(\sigma\left(X\right)\right)^T\left(\sigma\left(X\right)\right) +\left(\mu\left(X\right)\right)^T\left(\mu\left(X\right)\right) - k - \log \det \left(\sigma\left(X\right)I\right) \right]

Although this looks like magic, but it is quite natural if you apply the definition of KL divergence on two normal distributions. Doing so will teach you a bit of calculus.

So we have all the ingredients. We use a feedforward net to predict \mu\left(X\right) and \sigma\left(X\right) given an input sample X draw from the training set. With those vectors, we can compute the KL divergence and \log P\left(X\vert z\right), which, in term of optimization, will translate into something similar to \Vert X - f\left(z\right)\Vert^2.

It is worth to pause here for a moment and see what we just did. Basically we used a constrained Gaussian (with diagonal covariance matrix) to parameterize Q. Moreover, by using \Vert X - f\left(z\right)\Vert^2 for one of the training criteria, we implicitly assume P\left(X\vert z\right) to be also Gaussian. So although the maths that lead to VAEs are generic and beautiful, at the end of the day, to make things tractable, we ended up using those severe approximations. Whether those approximations are good enough totally depend on practical applications.

There is an important detail though. Once we have \mu\left(X\right) and \sigma\left(X\right) from the encoder, we will need to sample z from a Gaussian distribution parameterized by those vectors. z is needed for the decoder to reconstruct \hat{X}, which will then be optimized to be as close to X as possible via gradient descent. Unfortunately, the “sample” step is not differentiable, therefore we will need a trick call reparameterization, where we don’t sample z directly from \mathcal{N}\left(\mu\left(X\right), \sigma\left(X\right)\right), but first sample z' from \mathcal{N}\left(0, I\right), and then compute z = \mu\left(X\right) + \mu\left(X\right)Iz'. This will make the whole computation differentiable and we can apply gradient descent as usual.

The cool thing is during inference, you won’t need the encoder to compute \mu\left(X\right) and \sigma\left(X\right) at all! Remember that during training, we try to pull Q to be close to P\left(z\right) (which is standard normal), so during inference, we can just inject \epsilon \sim \mathcal{N}\left(0, I\right) directly into the decoder and get a sample of X. This is how we can leverage the power of “generation” from VAEs.

There are various extensions to VAEs like Conditional VAEs and so on, but once you understand the basic, everything else is just nuts and bolts.

To sum up the series, this is the conceptual graph of VAEs during training, compared to some other models. Of course there are many details in those graphs that are left out, but you should get a rough idea about how they work.


In the case of VAEs, I added the additional cost term in blue to highlight it. The cost term for other models, except GANs, are the usual L2 norm \Vert X - \hat{X}\Vert^2.

GSN is an extension to Denoising Autoencoder with explicit hidden variables, however that requires to form a fairly complicated Markov Chain. We may have another post  for it.

With this diagram, hopefully you will see how lame GAN is. It is even simpler than the humble RBM. However, the simplicity of GANs makes it so powerful, while the complexity of VAE makes it quite an effort just to understand. Moreover, VAEs make quite a few severe approximation, which might explain why samples generated from VAEs are far less realistic than those from GANs.

That’s quite enough for now. Next time we will switch to another topic I’ve been looking into recently.