The EM Algorithm is used when we would like to do maximum likelihood
(or MAP) estimation but our model has hidden variables i.e.
variables that we cannot observe but that we believe are involved in
the generation of the data. For instance, it may be the case that we
believe that there is a correlation between people having eczema and
asthma, but we may not believe that any of these two causes the other.
We may instead believe that they are both caused by the presence of
some allele in a persons DNA (by the way, I do not claim to know
anything about biology or genetics so take this with a grain of salt).
The presence of this could be the latent variable, but we may never see
this in most people.
In this case, we have the visible variables, lets denote them as \(x\)
(the presence of eczema and asthma),
and the hidden variables, denoted by \(z\) (the allele). We would
like to maximise
the marginal likelihood over the visible variables i.e. we want to
solve the following
$$
\theta^* = \arg \max_{\theta} \log p(x|\theta) = \arg \max_{\theta}
\log \sum_{z} p(x, z|\theta).
$$
Due to the presence of the summation in the log, this is actually a
very difficult problem to solve. We cannot solve for it directly.
Instead, the usual way to solve this is using the aforementioned
EM algorithm or the Expectation Maximisation
algorithm.
Derivation
The EM algorithm can be derived in the following way. We want to
maximise the log likelihood of the visible data i.e.
$$
l(\theta) = \log \sum_{z} p(x, z|\theta),
$$
for each datapoint \(x\) in the data set, so we do this very neat
(and completely unintuitive) trick. We introduce a distribution over the
hidden variables, \(q(z)\) and multiply and divide by it in the
summation, as such
$$
\begin{align}
l(\theta) &= \log \left(\sum_{z} q(z) \frac{p(x, z|\theta)}{q(z)}\right)
\\
&\geq \sum_{z} q(z) \log\left(\frac{p(x, z|\theta)}{q(z)}\right).
\end{align}
$$
The inequality is present by applying Jensen's inequality.
Now there are two ways to view this equation. Since this will turn out
to be an iterative algorithm, we will call the parameters at iteration
\(t\), \(\theta^t\). At iteration \(t\) we then have that
$$
\begin{align}
l(\theta^t) &\geq \sum_{z} q(z)
\log\left(\frac{p(x, z|\theta^t)}{q(z)}\right)
\\
&= \sum_{z} q(z) \log \left(\frac{p(z|x, \theta^t)}{q(z)} \right)+ \log p(x|\theta^t)
\\
&= -\text{KL}\left(q(z)|| p(z|x, \theta^t) \right) + \log p(x|\theta^t).
\end{align}
$$
We know that the KL divergence is always non-negative and it is zero
when \(q(z) = p(z|x, \theta^t)\). Therefore, we maximise this with
respect to \(q(z)\) (keeping \(\theta^t\) fixed) by setting
these as equal. When this is the case, the lower bound on the
log-likelihood of the visible data is tight i.e. there is equality
because \(l(\theta^t) = \log p(x|\theta^t)\) by definition.
We can also view this in another way, however. Let's rewrite the above
(after setting \(q(z) = p(z|x, \theta^t)\)) as
$$
\begin{align}
l(\theta^t) &= \sum_{z} p(z | x, \theta^t) \log \left(p(x, z|\theta^t)
\right) - \sum_{z} p(z | x, \theta^t) \log \left(p(z | x, \theta^t)\right)
\\
&= \sum_{z} p(z | x, \theta^t) \log \left(p(x, z|\theta^t) \right) + H(p(z | x, \theta^t) )
\\
&= \mathbb{E}_{p(z | x, \theta^t)} \left[ \log \left( p(x, z| \theta^t) \right) \right] + H( p(z | x, \theta^t) ).
\end{align}
$$
Now, if we allow the parameters of \(p(x, z | \theta^t)\) to be free,
but still keep \(q(z) = p(z | x, \theta^t)\), then we get the following
$$
\begin{align}
l(\theta) &\geq \sum_{z} p(z | x, \theta^t) \log \left(p(x, z|\theta \right) + H(p(z | x, \theta^t) )
\\
&= \mathbb{E}_{p(z | x, \theta^t)} \left[ \log \left( p(x, z| \theta) \right) \right] + H( p(z | x, \theta^t) ).
\end{align}
$$
The inequality is true because, we showed above that it is true for
any value of \(q(z)\), even when it is equal to \(p(z | x, \theta^t)\).
We can then maximise the right hand side with respect to \(\theta\)
(remember, we fix \(\theta^t\)). To do this, we don't really need to
take the entropy (\(H(p(z | x, \theta^t))\)) into account, since it is
not a function of \(\theta\). Let us define an auxiliary function
$$
Q(\theta, \theta^t) = \mathbb{E}_{p(z | x, \theta^t)} \left[ \log p(x, z|\theta) \right].
$$
We then maximise this with respect to \(\theta\) i.e.
$$
\theta^{t+1} = \arg \max_{\theta} Q(\theta, \theta^t)
$$
then we repeat. So, basically, the EM algorithm iterates over two steps.
At iteration \(t\) we make the lower bound tight, which we do by setting
\(q(z) = p(z|x, \theta^t)\). This is needed to take the
expected complete log likelihood over the visible data,
which is why it
is called the Expectation step. We then *maximise* this with respect to
\(\theta\), which is why this is called the Maximisation step. We can
prove that this always is guaranteed to increase the log likelihood of
the visible data (until it converges). This is because
$$
\begin{align}
l(\theta^t) &= Q(\theta, \theta^t) + H(p(z | x, \theta^t))
\\
&\leq Q(\theta^{t+1}, \theta^t) + H(p(z | x, \theta^t))
\\
&=
-\text{KL}\left(p(z|x, \theta^t)|| p(z|x, \theta^{t+1}) \right)
+ \log p(x|\theta^{t+1})
\\
&\leq l(\theta^{t+1})
\end{align}
$$
which means that \(l(\theta^t)\leq l(\theta^{t+1})\), for all \(t\).
This means that at each iteration of the EM algorithm, the parameters
become better explanations of the data (if we are doing maximum
likelihood) or at least do not become worse.
Of course, if we have a data set with \(N\) data points
\(\mathcal{D} = \{(x^{(n)})\}_{n=1}^N\), then we have that
$$
l(\theta) = \sum_{n=1} \log \sum_{z} p(x^{(n)}, z^{(n)})
$$
and
$$
Q(\theta, \theta^t) = \sum_n \mathbb{E}_{p(z^{(n)} | x^{(n)}, \theta^t)}
\left[ \log p(x^{(n)}, z^{(n)}|\theta) \right],
$$
but everything else stays the same.
The catch
The EM algorithm only guarantees that we will reach a local
optimium. This means that there may have been better parameters to
increase the likelihood but because it guarantees that it will never
decrease, it will never reach them (because to get there it would have
to temporarily decrease). For this reason, it is sometimes useful to
do the algorithm a couple of times starting from different initial
parameters and then choose the one that maximises the likelihood
(or use cross validation).
However, it works well in practice and it is used extensively in
machine learning and statistics.
Example: Mixture of Bernoulli
A lot of textbooks stop at this point it seems or give the EM algorithm
for a Mixture of Gaussians, so I have decided to show a concrete example
on a Mixture of Bernoulli (MoB) distribution.
A MoB distribution is a model for multivariate binary data which takes
the form
$$
p(x, z) = p(x | z) p(z) = \prod_{k=1}^K
(p(z = k) p(x | z = k))^{\mathbb{I}(z = k)},
$$
where \(z \in \{1, 2, \dots, K\}\), is the latent variable i.e. we
don't see it, and \(x \in \{0, 1\}^D\), which is visible. The
generative process can be thought of first picking a value for \(z\)
from a a categorical distribution and then generating an \(x\) from
a multivariate Bernoulli distribution, whose parameters depend on the
value of \(z\). Each multivariate Bernoulli distribution,
\(p(x | z = k)\), makes an independence assumption about the features
in \(x\) i.e. given the value of \(z\), we assume all the features of
\(x\) are independent, as such
$$
p(x|z=k) = \prod_{d=1}^D p(x_d = 1|z=k)^{x_d}
(1 - p(x_d = 1|z=k))^{1-x_d}.
$$
Now, this may seem like a very strong assumption, but it still can model
some things. For example, suppose that we want to model a multivariate
binary variable, whose dimension is \(D = L^2 = 21^2 = 441\) and
represents an \(L\times L\) image of either a square or a triangle,
whose length, \(l\), is
equal to 15 and it is randomly placed within this grid. Some sample
images can be seen below.
The pixels that are coloured dark purple are equal to 1 and the white
background pixels are equal to zero. The middle pixel is coloured in a
light purple. We can think of the generative process as first choosing
whether a triangle or a square will be generated, and then choosing
where to place the shape. The choice between the triangle and the square
can be encoded by \(z\) and then once we choos the shape, we can
generate it by sampling from the associated Bernoullis.
Let \(z=1\) denote a square and \(z=2\) denote a triangle. If we pick a
square, we can see that the middle pixel is always going to be one
(since the square has length 15 and the canvas has length only 21,
which means that the 10th pixel will always be covered by the square).
This means that \(p(x = 1| z = 1) = 1\). However, when \(z=2\) we can
see from som of the examples below that this is not the case, and we can
also see that for any pixel in the top right corner \(p(x =1 | z=2)
= 0\) (the pixels are always zero). Of course, it is clear that given \(z\) the
dimensions are still highly dependent, but this is just an illustration
which is easy to visualise, which is why I chose it.
The update equations
So, now that we know what a MoB is, let's derive the equations for the
EM algorithm for it. First of all, let us define the following for
simplicity
$$
\begin{align}
\pi_k &= p(z = k),
\\
\mu_{k, d} &= p(x_d = 1 | z = k ).
\end{align}
$$
We know that the parameters to be estimated are
\(\theta = \{\pi_k, \mu_{k, d}\}_{k=1, \cdots, K, d = 1, \cdots, D}\).
So, starting with the E step, we have
$$
\begin{align}
p(z^{(n)} = k | x^{(n)}, \theta^t) &=
\frac{p(z^{(n)} = k, x^{(n)}| \theta^t)}
{\sum_j p(z^{(n)} = j, x^{(n)}| \theta^t)}
\\
&= \frac{\pi_k^t p(x^{(n)} | z^{(n)} = k, \theta^t)}
{\sum_j \pi_j^t p(x^{(n)} | z^{(n)} = j, \theta^t)}
\\
&= q^t(z^{(n)}=k).
\end{align}
$$
Note, everything in this equation is calculable when we have the
parameters. Another thing is that, we have a different posterior
probability for the hidden variable for each visible data point. This
means that to calculate this quantity for all data points we basically
need to do \(\mathcal{O}(NK)\) operations, since we need to do this
for every data point (\(N\)) and for as many components as there are
data points (\(K\)). And that is all the calculation that needs to be
done for the E step.
Now, let's derive the updates in the M step. Once we plug
\(q^t(z) =p(z | x, \theta^t)\)
into \(Q(\theta, \theta^t)\), we get
$$
\begin{align}
Q(\theta, \theta^t)
&= \sum_n \sum_k q^t(z^{(n)} = k)
\left( \log \pi_k + \sum_d x_d^{(n)}\log \mu_{k, d}
+ (1-x_d^{(n)})\log (1-\mu_{k, d}) \right)
\end{align}
$$
Taking the derivative with respect to \(\mu_{k', d'}\) and setting
it to zero, we get
$$
\begin{align}
\frac{\partial Q(\theta, \theta^t)}{\partial \mu_{k', d'}}
&= \sum_n q^t(z^{(n)} = k')
\left( \frac{x_{d'}^{(n)} }{\mu_{k', d'}}
- \frac{(1-x_{d'}^{(n)})}{ (1-\mu_{k', d'})} \right) = 0.
\end{align}
$$
When we then solve for \(\mu_{k', d'}\), we get
$$
\mu_{k', d'} = \frac{\sum_n q^t(z^{(n)} = k') x_{d'}^{(n)}}
{\sum_n q^t(z^{(n)} = k')}.
$$
As we can see, we have to do this for all dimensions \(D\) and for all
components \(K\) and for each of these we are summing over \(N\)
data points, therefore, for each of these updates the complexity is
\(\mathcal{O}(NKD)\).
So, that is the update for the means of each of the Bernoulli variables.
Now, let's do the same for the priors on the components. In this case,
to get something sensible, we need to add a Lagrange multiplier i.e.
we need to maximise the following equation
$$
\begin{align}
\tilde{Q}(\theta, \theta^t)
&= \sum_n \sum_k q^t(z^{(n)} = k)
\left( \log \pi_k + \sum_d x_d^{(n)}\log \mu_{k, d}
+ (1-x_d^{(n)})\log (1-\mu_{k, d}) \right)
+ \lambda \left( \sum_k \pi_k - 1\right).
\end{align}
$$
Once we take the derivative with respect to \(\pi_{k'}\) and set it to
zero, we get
$$
\begin{align}
\frac{\partial \tilde{Q}(\theta, \theta^t)}{\partial\pi_{k'}}
&= \sum_n \sum_k q^t(z^{(n)} = k') \log \pi_{k'} + \lambda = 0.
\end{align}
$$
Taking the derivative with respect to \(\lambda\) and setting it to
zero, we get
$$
\sum_k \pi_k = 1,
$$
and once we put these two equations together, we get
$$
\pi_{k'} = \frac{1}{N} \sum_n q^t(z^{(n)} = k').
$$
In this case, we don't actually need to do any more work than is done
in calculating the means, because the term \(\sum_n q^t(z^{(n)} = k')\)
is used as the denominator in the calculation of the means of the
Bernoullis, so it can be stored. And that is it! We have the update
equations for a MoB.
It is worth looking at these equations more. Firstly, let us look at
\(\sum_n q^t(z^{(n)} = k') \). This basically is the effective
number of data points that component \(k'\) accounts for and it can
be useful to think of it as \(N_{k'}\). So, how can we interpret the
update of the prior, \(\pi_{k'}\)? Well, it basically says that it is
the effective proportion of data points that the component accounts for
of the total population of data points. This makes sense, right?
Likewise, the updates of the mean of component \(k'\) for dimension
\(d'\) is a weighted sum of the data points, where each is weighted by
how likely it was to have been generated by component \(k'\), divided
by that same number, \(N_{k'}\). In other words, if many data points
have a value of 1 for their \(d'th\) dimension, but none of these are
very likely to have been generated by component \(k'\), we can see that
\(\mu_{k', d'}\) would be close to zero. Also, note that the maximum
value it can take is one, which it gets if all the data points are
equal to one.
Conclusion
So that's it! Here we have seen how the EM Algorithm works and how to
derive the necessary equations to implement it yourself. We also
derived an example for a Mixture of Bernoullis and showed that the
update equations do what we expect them to do (in hindsight, at least).