Solved – Variance of reparameterization trick and score function

For a function $mathbf E_{zsim q_phi(z|x)}[f(z)]$(assuming $f$ is continuous), where $q_phi$ is a Gaussian distribution, if we want to compute the gradient w.r.t. $phi$, we have two way to do that.

  1. compute the score function estimator:

    nabla_phimathbf E_{zsim q_phi(z|x)}[f(z)]&=nabla int f(z)q_phi(z|x)dz\
    &=int f(z)nabla q_phi(z|x)dz\
    &=int {q_phi(z|x)over q_phi(z|x)}f(z)nabla_phi q_phi(z|x)dz\
    &=int q_phi(z|x)f(z)nabla_phi log q_phi(z|x)dz\
    &=mathbf E_{zsim q_phi(z|x)}[f(z)nabla_philog q_phi(z|x)]tag 1

  2. use reparameterization trick: let $z=mu_phi(x)+epsilonsigma_phi(x)$, where $epsilonsimmathcal N(0,1)$, we then differentiate the objective and have
    nabla_phimathbf E_{zsim q_phi(z|x)}[f(z)]=mathbf E_{epsilonsimmathcal N(0,1)}[nabla_phi f(mu_phi(x)+epsilonsigma_phi(x))]tag 2$$

According to this video, at around 58min the instructor explains that computing the gradient using the reparameterization trick generally has lower variance than the score function estimator. Here's my understanding according to the instructor's explanation, which I'm not so sure if I take right. Welcome to point out the misunderstanding 🙂

Eq.$(1)$ has high variance because $f(z)$ is computed from samples. whose variance is unbound. Multiplying it to $nabla_phi log q_phi(z|x)$, therefore, results in the gradient having high variance. On the other hand, the coefficient of Eq.$(2)$ is fixed except $epsilon$, which has its variance $1$. As a result, Eq.$(1)$ has higher variance than Eq.$(2)$.

The notation here is far more complicated than it needs to be, and I suspect this is contributing to the issue of understanding this method. To clarify the problem, I'm going to re-frame this in standard notation. I'm also going to remove reference to $x$, because the entire analysis is conditional on this value, so it adds nothing to the problem beyond complicating the notation.

You have a problem with a Gaussian random variable $Z sim text{N}(mu(phi), sigma(phi)^2)$, where the mean and variance depend on a parameter $phi$. You can also define the error term $epsilon equiv (Z – mu(phi))/sigma(phi)$ which measures the number of standard deviations from the mean. Now, you want to compute the gradient of the expected value:

$$begin{equation} begin{aligned} J(phi) equiv mathbb{E}(r(Z)) &= int limits_mathbb{R} r(z) cdot text{N}(z|mu(phi), sigma(phi)^2) dz \[6pt] &= int limits_mathbb{R} r(mu(phi) + epsilon cdot sigma(phi)) cdot text{N}(epsilon|0,1) dz. \[6pt] end{aligned} end{equation}$$

(The equivalence of these two integral expressions is a consequence of the change-of-variable formula for integrals.) Differentiating these expressions gives the two equivalent forms:

$$begin{equation} begin{aligned} nabla_phi J(phi) &= int limits_mathbb{R} r(z) bigg( nabla_phi ln text{N}(z|mu(phi), sigma(phi)^2) bigg) cdot text{N}(z|mu(phi), sigma(phi)^2) dz \[6pt] &= int limits_mathbb{R} bigg( nabla_phi r(mu(phi) + epsilon cdot sigma(phi)) bigg) cdot text{N}(epsilon|0,1) dz. \[6pt] end{aligned} end{equation}$$

Both of these expressions are valid expressions for the gradient of interest, and both can be approximated by corresponding finite sums from simulated values of the random variables in the expressions. To do this we can generate a finite set of values $epsilon_1,…,epsilon_M sim text{IID N}(0,1)$ and form the values $z_1,…,z_M$ that correspond to these errors. Then we can use one of the following estimators:

$$begin{equation} begin{aligned} nabla_phi J(phi) approx hat{E}_1(phi) &equiv frac{1}{M} sum_{j=1}^M r(z_j) bigg( nabla_phi ln text{N}(z_j|mu(phi), sigma(phi)^2) bigg), \[10pt] nabla_phi J(phi) approx hat{E}_2(phi) &equiv frac{1}{M} sum_{j=1}^M nabla_phi r(mu(phi) + epsilon_j cdot sigma(phi)). end{aligned} end{equation}$$

The speaker asserts (but does not demonstrate) that the variance of the second estimator is lower than the variance of the first. He claims that this is because the latter estimator uses direct information about the gradient of $r$ whereas the first estimator uses information about the gradient of the log-density for the normal distribution. Personally, without more knowledge of the nature of $r$, this seems to me to be an unsatisfying explanation, and I can see why you are confused by it. I doubt that this result would hold for all functions $r$, but perhaps within the context of that field, the function $r$ tends to have a gradient that is fairly insensitive to changes in the argument value.

Similar Posts:

Rate this post

Leave a Comment