VI Drives me NUTS

Feel free to be a bit weary.

Vincent Warmerdam koaning.io
2018-11-09

Variational Inference is “hip” and I can’t say that I am a huge fan. I decided to give it a try and immediately it hit my head. In this document I hope to quickly demonstrate a potential failure scenario.

The Model

Here is the code for the model. It is a model about increased weights of chickens who are given different diets.


df = pd.read_csv("http://koaning.io/theme/data/chickweight.csv", 
                 skiprows=1,
                 names=["r", "weight", "time", "chick", "diet"])
time_input = 10

with pm.Model() as mod: 
    intercept = pm.Normal("intercept", 0, 2)
    time_effect = pm.Normal("time_weight_effect", 0, 2, shape=(4,))
    diet = pm.Categorical("diet", p=[0.25, 0.25, 0.25, 0.25], shape=(4,),
                          observed=dummy_rows)
    sigma = pm.HalfNormal("sigma", 2)
    sigma_time_effect = pm.HalfNormal("time_sigma_effect", 2, shape=(4,))
    weight = pm.Normal("weight", 
                       mu=intercept + time_effect.dot(diet.T)*df.time, 
                       sd=sigma + sigma_time_effect.dot(diet.T)*df.time, 
                       observed=df.weight)
    trace = pm.sample(5000, chains=1)

Next I’ll show how the traceplots are different if we compare different inference methods.

NUTS sampling results

I took 5500 samples with NUTS. It took about 7 seconds and this is the output:

Metropolis sampling results

I took 20000 samples with Metropolis. It took about 14 seconds and this is the output:

VI results

I used the fullrank_advi setting. Here’s a traceplot from the samples I took from the approximated posteriour.

Fix?

The interesting thing is that if I change the model slightly, VI suddenly has no issues (this was pointed out to me by a collegue, Mathijs).


n_diets = df.diet.nunique()

with pm.Model() as model:
    mu_intercept = pm.Normal('mu_intercept', mu=40, sd=5)
    mu_slope = pm.HalfNormal('mu_slope', 10, shape=(n_diets,))
    mu = mu_intercept + mu_slope[df.diet-1] * df.time
    sigma_intercept = pm.HalfNormal('sigma_intercept', sd=2)
    sigma_slope = pm.HalfNormal('sigma_slope', sd=2, shape=n_diets)
    sigma = sigma_intercept + sigma_slope[df.diet-1] * df.time
    weight = pm.Normal('weight', mu=mu, sd=sigma, observed=df.weight)
    approx = pm.fit(20000, random_seed=42, method="fullrank_advi")

The main difference is that I am no longer using pm.Categorical.

With that out of the way suddenly the estimates look a whole lot better.

Conclusion

Be careful when using variational inference. It might be faster but it is only faster because it approximates. I’m not the only person why is a bit skeptical of variational inference.

The alternative, NUTS sampling still amazes me, even though it isn’t perfect.