Part V: Efficient Natural-gradient Methods for Exponential Family
Warning: working in Progress (incomplete)
Goal (edited: 01-Jan-23)
This blog post should show that we can efficiently implement natural-gradient methods in many cases.
We will give an informal introduction with a focus on high level of ideas.
Click to see how to cite this blog post
Exponential Family
An exponential family takes the following (canonical) form as
$$
\begin{aligned}
p(\mathbf{w}|\mathbf{\eta}) = h_\eta(\mathbf{w}) \exp( \langle \mathbf{\eta} , \mathbf{T}_\eta (\mathbf{w}) \rangle - A_\eta (\mathbf{\eta}) )
\end{aligned}
$$
where $C_\eta(\eta) := \int h_\eta(\mathbf{w}) \exp( \langle \mathbf{\eta} , \mathbf{T}(\mathbf{w}) \rangle ) d \mathbf{w}
$
is the normalization constant.
name | |
---|---|
$A_\eta(\mathbf{\eta}):=\log C_\eta(\eta)$ |
log-partition function |
$h_\eta(\mathbf{w})$ |
base measure |
$\mathbf{T}_\eta(\mathbf{w})$ |
sufficient statistics |
$\eta$ |
natural parameter |
The parameter space of $\eta$
denoted by $\Omega_\eta$
is determined so that the normalization constant is well-defined and (strictly and finitely) positive.
Regular natural parametrization $\eta$
: parameter space $\Omega_\eta$
is open.
In this post, we only consider
regular natural parametrizations since commonly used exponential family distributions have a regular natural parametrization.
This natural parametrization is special since the inner product $\langle \mathbf{\eta} , \mathbf{T}_\eta(\mathbf{w}) \rangle$
is linear in $\eta$
. As we will discuss later, this linearity is essential.
Examples of an exmponential family are Gaussian, Bernoulli, Von Mises–Fisher, and more.
Readers should be aware of the following points when using an exponential family.
-
The support of
$\mathbf{w}$
should not depend on parametrization$\eta$
. - The base measure and the log-partition function are only unique up to a constant as illustrated by the following example.
Univariate Gaussian as an exponential family (click to expand)
- The log-partition function can be differentiable w.r.t.
$\eta$
[1] even when$\mathbf{w}$
is discrete.Bernoulli as an exponential family (click to expand)
- An invertiable linear reparametrization could also be a natural parametrization. Thus, in this setting, natural-gradient descent could be better than (Euclidean) gradient descent since natural-gradient descent is linearly invariant.
Natural parametrization is not unique (click to expand)
Minimal Parametrizations of Exponential Family
Now, we discuss particular parametrizations of an exponential family. We could efficiently compute natural-gradients when using this class of parametrizations.
Minimal natural parametrization: the corresponding sufficient stattistics $\mathcal{T}(\mathbf{w})$
is linearly independent.
A regular, minimal, and natural parametrization $\eta$
has many nice properties [1] [2].
-
It is an intrinsic parametrization as we discussed in Part I.
-
The parameter space
$\Omega_\eta$
is an open set in$\mathcal{R}^K$
, where$K$
is the number of entires of this parameter array. -
The log-partition function
$A_\eta(\eta)$
is infinitely differentiable and strictly convex in$\Omega_\eta$
. -
The FIM
$\mathbf{F}_\eta(\eta) = \nabla_\eta^2 A_\eta(\eta)$
is positive-definite in its domain.
We will only show the first property in this post. The remaining properties can be found in the literature.
Note that the linearity in the inner product $\langle \mathbf{\eta} , \mathbf{T}_\eta(\mathbf{w}) \rangle$
, plays a key role in showing these properties.
Claim:
A regular, minimal, and natural parametrization is intrinsic.
Proof of the claim (click to expand)
Now, we give an example of a regular, minimal and natural parametrization.
Minimal parametrization for multivariate Gaussian (click to expand)
The following example illustrates a non-minimal natural parametrization
Non-minimal parametrization for Bernoulli (click to exapnd)
Efficient Natural-gradient Computation
In general, natural-gradient computation can be challenging due to the inverse of the Fisher matrix.
In cases of an exponential family, natural-gradient computation often can be quite efficient without directly computing the inverse of the Fisher matrix.
We will assume $\eta$
is a reguar, minimal, and natural parametrization.
Expectation Parametrization
We introduce a dual parametrization $\mathbf{m} := E_{p(w|\eta)}[ \mathbf{T}_\eta(\mathbf{w}) ] $
, which is known as the expectation parametrization. This new parametrization plays a key role for the efficient natural-gradient computation.
Recall that in
Part IV, we use the identity of the score function. We can establish a connection
between these two parametrizations via this identity.
$$
\begin{aligned}
\mathbf{0} & = E_{p(w|\eta)} [ \nabla_\eta \log p(\mathbf{w}|\eta) ] &\,\,\,( \text{expectation of the score is zero} ) \\
&=E_{p(w|\eta)} [ \mathbf{T}_\eta(\mathbf{w}) - \nabla_\eta A_\eta(\eta) ] \\
&= \mathbf{m} - \nabla_\eta A_\eta(\eta)
\end{aligned}\tag{1}\label{1}
$$
which is a valid Legendre (dual) transformation1 since
$\nabla_\eta^2 A_\eta(\eta)$
is positive-definite in its domain $\Omega_\eta$
.
Expectation Parameter Space
We can view
the expectation parameter as an ouput of
a continous map
$$
\begin{aligned}
\mathbf{m}(\eta):=\nabla_\eta A_\eta(\eta),
\end{aligned}\tag{2}\label{2}
$$
where the input space of this map is $\Omega_\eta$
.
We define the expectation parameter space as the output space of the map
$$
\begin{aligned}
\Omega_m :=\{\mathbf{m}(\eta) | \eta \in \Omega_\eta \}.
\end{aligned}
$$
Since $\nabla_\eta^2 A_\eta(\eta)$
is positive-definite in open set $\Omega_\eta$
, we can show that there exists an one-to-one relationship between the
natural parameter $\eta$
and the expectation parameter $\mathbf{m}$
, which implies that map $\mathbf{m}(\cdot)$
is injective.
Since $\Omega_\eta$
is open in $\mathcal{R}^K$
, we can show that
the expectation parameter space $\Omega_m$
is also open in $\mathcal{R}^K$
due to the invariance of domain.
Natural-gradient Computation
Note that the FIM of the exponential family under parametrization $\eta$
is
$$
\begin{aligned}
\mathbf{F}_\eta(\eta)=\nabla_\eta^2 A_\eta(\eta) = \nabla_\eta \mathbf{m}
\end{aligned}\tag{3}\label{3}
$$
which means this FIM is a Jacobian matrix $\mathbf{J}=\nabla_\eta \mathbf{m}(\eta)$
.
As we discussed in
Part II,
a natural-gradient w.r.t. $f(\eta)$
can be computed as below, where $\mathbf{g}_\eta=\nabla_\eta f(\eta)$
is a Euclidean gradient w.r.t. natural parameter $\eta$
.
$$
\begin{aligned}
\hat{\mathbf{g}}_\eta & = \mathbf{F}^{-1}_\eta(\eta) \mathbf{g}_\eta \\
&= (\nabla_\eta \mathbf{m} )^{-1} [ \nabla_\eta f(\eta) ] \\
&= [\nabla_{m} \eta ] [ \nabla_\eta f(\eta) ] \\
&= \nabla_{m} f(\eta)
\end{aligned}\tag{4}\label{4}
$$
where
$\nabla_{m} f( \eta )$
is a Euclidean gradient w.r.t. expectation parameter $\mathbf{m}$
and
$\eta=\eta( \mathbf{m} )$
can be viewed as a function of $\mathbf{m}$
.
Therefore, we can efficinetly compute natural-gradients w.r.t. natural parameter $\eta$
[3] if we can easily compute Euclidean gradients w.r.t. its expectation parameter $\mathbf{m}$
.
Efficient NGD for multivariate Gaussians
Given a $d$-dim multivariate Gaussian with a full covariance structure, the naive way to compute natural-gradients in this case has $O( (d^2)^3 )=O(d^6)$ iteration cost since the covariance matrix has $d^2$ entries. Now, we show how to efficiently compute natural-gradients in this case.
Claim:
We can efficiently compute natural-gradients in $O( d^3 )$ iteration cost in this case with the number of $O(d^2)$ parameters.
Proof of the claim (Click to expand)
Natural-gradient Descent as Unconstrained Mirror Descent
We assume natural parametrization $\eta$
is both regular and minimal.
In exponential family cases, we will show that natural-gradient descent as a mirror descent update when the natural parameter space $\Omega_\eta$
is unconstrained.
Since mirror descent is defined by using a Bregman divergence, we first introduce the Bregman divergence.
Bregman Divergence
Given a strictly convex function $\Phi(\cdot)$
in its domain, a Bregman divergence equipped with $\Phi(\cdot)$
is defined as
$$
\begin{aligned}
\mathrm{B}_\Phi(\mathbf{x},\mathbf{y}):= \Phi(\mathbf{x})- \Phi(\mathbf{y}) - \langle \nabla \Phi(\mathbf{y}), (\mathbf{x}-\mathbf{y}) \rangle
\end{aligned}
$$
In particular, the Kullback–Leibler (KL) divergence is a Bregman divergence under natural parametrization $\eta$
:
$$
\begin{aligned}
\, & \,\mathrm{KL} [p(\mathbf{w}| \eta_1 ) || p(\mathbf{w}|{\color{red}\eta_2})]\\
=& \,E_{p(w|\eta_1)} [ \log \frac{p(\mathbf{w}|\eta_1)} {p(\mathbf{w}|\eta_2)} ] \\
=& \,E_{p(w|\eta_1)} [ \langle \eta_1-\eta_2 , \mathbf{T}_\eta (\mathbf{w}) \rangle - A_\eta(\eta_1) + A_\eta(\eta_2) ] & ( p(\mathbf{w}|\eta) \text{ is an exponential family}) \\
=& \,A_\eta(\eta_2) - A_\eta(\eta_1) - E_{p(w|\eta_1)} [ \langle \eta_2-\eta_1, \mathbf{T}_\eta (\mathbf{w}) \rangle ] \\
=& \,A_\eta(\eta_2) - A_\eta(\eta_1) - \langle \eta_2-\eta_1, \underbrace{ E_{p(w|\eta_1)} [ \mathbf{T}_\eta (\mathbf{w}) ] }_{ \nabla_\eta A_\eta(\eta_1) } \rangle \\
=& \, \mathrm{B}_{A_\eta}( {\color{red} \mathbf{\eta}_2}, \mathbf{\eta}_1 ) & ( A_\eta(\eta) \text{ is strictly convex})
\end{aligned}
$$
We denote the expectation parameter as $\mathbf{m}$
.
Recall that by the
Legendre transformation, we have $\mathbf{m}=\nabla_\eta A_\eta(\eta)$
, where $\Omega_m$
has been defined here.
Note that we assume natural parameter $\eta$
is minimal. In other words, $\nabla_\eta^2 A_\eta(\eta)$
is
positive-definite and $A_\eta(\eta)$
is strictly convex in $\Omega_\eta$
.
We define $A_\eta(\mathbf{x}):=+\infty$
when $\mathbf{x} \not \in \Omega_\eta$
.
The convex conjugate of the log-partition function $A_\eta$
is defined as
$$
\begin{aligned}
A^*_\eta( \mathbf{m}) &:= \sup_{x} \{ \langle \mathbf{x},\mathbf{m} \rangle - A_\eta(\mathbf{x}) \} \\
&= \langle \mathbf{\eta},\mathbf{m} \rangle - A_\eta(\mathbf{\eta}) \,\,\,\, (\text{the supremum attains at } \mathbf{x}=\eta \in \Omega_\eta )\\
\end{aligned}\tag{5}\label{5}
$$
where
the domain of $A^*_\eta(\mathbf{m})$
is $\Omega_m$
, and
$\eta=\eta( \mathbf{m} )$
should be viewed as a function of $\mathbf{m}$
.
When $\mathbf{m} \in \Omega_m$
, we have the following identity, which is indeed another Legendre transformation.
$$
\begin{aligned}
\nabla_{\mathbf{m}} A^*_\eta( \mathbf{m})
&= \mathbf{\eta} + \langle \nabla_{\mathbf{m}} \mathbf{\eta},\mathbf{m} \rangle - \nabla_{\mathbf{m}} A_\eta(\mathbf{\eta}) \\
&= \mathbf{\eta} + \langle \nabla_{\mathbf{m}} \mathbf{\eta},\mathbf{m} \rangle - [\nabla_{\mathbf{m}} \eta] \underbrace{ [\nabla_\eta A_\eta(\mathbf{\eta})] }_{ = \mathbf{m}}\\
&= \mathbf{\eta} ,
\end{aligned}\tag{6}\label{6}
$$
where $\eta \in \Omega_\eta$
due to $\eqref{5}$
.
The convex conjugate $A^*_\eta( \mathbf{m})$
is strictly convex w.r.t. $\mathbf{m}$
since the Hessian $\nabla_m^2 A^*_\eta( \mathbf{m})$
is positive-definite as shown below.
$$
\begin{aligned}
\nabla_{\mathbf{m}}^2 A^*_\eta( \mathbf{m})
&= \nabla_{\mathbf{m}} \mathbf{\eta}
\end{aligned}
$$
Note that due to $\eqref{3}$
,
the FIM $\mathbf{F}_\eta(\eta)$
under natural parameter $\mathbf{\eta}$
is the Jacobian matrix $\mathbf{J}= \nabla_{\eta} \mathbf{m}$
and
positive-definite.
Therefore, it is easy to see that
$
\nabla_{\mathbf{m}}^2 A^*_\eta( \mathbf{m}) =
\mathbf{F}^{-1}_\eta(\eta),
$
which is
positive-definite and therefore strictly convex.
By the transformation rule of the FIM, we have the following relationship.
$$
\begin{aligned}
\mathbf{F}_{\eta} (\eta) & = \mathbf{J}^T \mathbf{F}_{m}(\mathbf{m}) \mathbf{J} \\
&= \mathbf{F}^T_{\eta} (\eta) \mathbf{F}_{m}(\mathbf{m})\mathbf{F}_{\eta} (\eta) \\
&= \mathbf{F}_{\eta} (\eta) \mathbf{F}_{m}(\mathbf{m})\mathbf{F}_{\eta} (\eta) & (\text{the FIM is symmetric})
\end{aligned}
$$
which implies that
the FIM under expectation parameter $\mathbf{m}$
is
$\mathbf{F}_m(\mathbf{m})=\mathbf{F}^{-1}_\eta(\eta) = \nabla_{\mathbf{m}}^2 A^*_\eta( \mathbf{m})= \nabla_{\mathbf{m}} \mathbf{\eta}$
.
Moreover, we have the following identity since by $\eqref{5}$
, $A_\eta(\eta)=\langle \mathbf{\eta},\mathbf{m} \rangle- A^*_\eta( \mathbf{m}) $
.
$$
\begin{aligned}
\mathrm{B}_{A_\eta}(\mathbf{\eta}_2, {\color{red}\mathbf{\eta}_1 })
&= A_\eta(\eta_2) - A_\eta(\eta_1) - \langle \eta_2-\eta_1, \overbrace{ \nabla_\eta A_\eta(\eta_1) }^{= \mathbf{m}_1} \rangle \\
&= [ \langle \mathbf{\eta}_2,\mathbf{m}_2 \rangle- A^*_\eta( \mathbf{m}_2) ]
-[ \langle \mathbf{\eta}_1,\mathbf{m}_1 \rangle- A^*_\eta( \mathbf{m}_1) ]
-\langle \eta_2-\eta_1, \mathbf{m}_1 \rangle \\
&= A^*_\eta( \mathbf{m}_1) - A^*_\eta( \mathbf{m}_2) -
\langle \mathbf{m}_1-\mathbf{m}_2, \underbrace{ \eta_2}_{ = \nabla_{\mathbf{m}} A^*_\eta( \mathbf{m}_2)} \rangle\\
&= \mathrm{B}_{A^*_\eta}( {\color{red} \mathbf{m}_1 },\mathbf{m}_2) & (\text{the order is changed})
\end{aligned}
$$
Mirror Descent
Now, we give the definition of mirror descent.
Consider the following optimization problem over a convex domain denoted by $\Omega_\theta$
.
$$
\begin{aligned}
\min_{\theta \in \Omega_\theta} \ell_\theta(\mathbf{\theta})
\end{aligned}
$$
Given a strictly convex function $\Phi(\mathbf{\theta})$
in the domain , mirror
descent with step-size $\alpha$
is defined as
$$
\begin{aligned}
\mathbf{\theta}_{k+1} \leftarrow \arg \min_{x \in \Omega_\theta}\{ \langle \nabla_\theta \ell_\theta(\mathbf{\theta}_k), \mathbf{x}-\mathbf{\theta}_k \rangle + \frac{1}{\alpha} \mathrm{B}_{\Phi}(\mathbf{x},\mathbf{\theta}_k) \}
\end{aligned}
$$
where $\mathrm{B}_\Phi(\cdot,\cdot)$
is a Bregman divergence equipped with the strictly convex function $\Phi(\cdot)$
.
Natural-gradient Descent as Mirror Descent
To show natural-gradient descent as a mirror descent update, we have to make the following assumption.
Additional assumption:
Natural parameter space $\Omega_\eta$
is unconstrainted ($\Omega_\eta=\mathcal{R}^K$
), where $K$
is the number of entries of parameter array $\eta$
.
The following example illustrates that the expectation space $\Omega_m$
is constrained even when
$\Omega_\eta$
is unconstrained.
Example: $\Omega_m$ is constrained while $\Omega_\eta$ is unconstrained (click to expand)
Now, consider the following mirror descent in the expectation parameter space $\Omega_m$
, where $\alpha>0$
.
$$
\begin{aligned}
\mathbf{m}_{k+1} \leftarrow \arg \min_{x \in \Omega_m}\{ \langle \nabla_m \ell_m(\mathbf{m}_k), \mathbf{x}-\mathbf{m}_k \rangle + \frac{1}{\alpha} \mathrm{B}_{A^*_\eta}(\mathbf{x},\mathbf{m}_k) \}
\end{aligned}\tag{7}\label{7}
$$
where $\mathbf{m}_{k} \in \Omega_m$
, $\nabla_m \ell_m(\mathbf{m}_k):= \nabla_m \ell_\eta (\eta(\mathbf{m}_k))$
and the Bregman divergence $\mathrm{B}_{A^*_\eta}(\cdot,\cdot)$
is well-defined
since $A^*_\eta$
is strictly convex in $\Omega_m$
.
Recall that $\Omega_m$
can still be constrained.
Claim:
When $\Omega_\eta = \mathcal{R}^K$
, the solution of $\eqref{7}$
is equivalent to $\eta_{k+1} \leftarrow \eta_k - \alpha \nabla_m \ell_m(\mathbf{m}_k)$
.
Proof of the claim (click to expand)
By the claim,
mirror descent of $\eqref{7}$
in expectation parameter space $\Omega_m$
is equivalent to
the following update
$$
\begin{aligned}
\eta_{k+1} \leftarrow \eta_k - \alpha \nabla_m \ell_m(\mathbf{m}_k)
= \eta_k - \alpha\nabla_m \ell_\eta( \underbrace{ \eta(\mathbf{m}_k) }_{= \eta_k}),
\end{aligned}\tag{8}\label{8}
$$
which is exactly natural gradient
descent in natural parameter space $\Omega_\eta=\mathcal{R}^K$
since by $\eqref{4}$
, we have $\nabla_m \ell_m(\mathbf{m}_k) = \nabla_m \ell_\eta( \eta_k)= \mathbf{F}
_\eta^{-1} (\eta_k) \nabla_\eta \ell_\eta(\eta_k)$
.
References
[1] S. Johansen, Introduction to the theory of regular exponential families (Institute of Mathematical Statistics, University of Copenhagen, 1979).
[2] M. J. Wainwright & M. I. Jordan, Graphical models, exponential families, and variational inference (Now Publishers Inc, 2008).
[3] M. Khan & W. Lin, "Conjugate-computation variational inference: Converting variational inference in non-conjugate models to inferences in conjugate models," Artificial Intelligence and Statistics (PMLR, 2017), pp. 878–887.
Footnotes:
-
When the natural parameter
$\eta$
is minimal, this Legendre transformation is diffeomorphic since$\nabla_\eta^2 A_\eta(\eta)$
is positive-definite in its domain. In other words, the expectation parameter$\mathbf{m}$
is also intrinsic when the natural parameter$\eta$
is minimal. ↩