Exclude certain free RVs from gradient calculation during ADVI?

Update: my last approach did the trick. Just in case someone else runs into this problem, here’s the idea:

  1. Extract mu and rho of posterior factors you care to save for later use (e.g. global RVs) by inspecting approx.bij.ordering.vmap (or approx.gbij.ordering.vmap if on PyMC3 v3.1) and slicing approx.params[0] and approx.params[1] accordingly. Here, approx is an instance of variational approximation, e.g. what you get from pm.fit().

  2. To make a predictive model that utilizes only a part of the learned posteriors and infers the rest (for example, in a GMM, takes the Gaussian centroids for granted from a previous analysis but learns the responsibilities of new data points from scratch), then do the following:

    a) instantiate your model as before,
    b) instantiate your approximation (e.g. by calling advi = pm.ADVI(); approx = advi.approx in the model context).
    c) update mu and rho of the global RVs from the saved results in approx,
    d) write a custom objective optimizer that leaves the mu and rho of the global RVs intact. Here’s how adamax would look like:

    def structured_adamax(loss_or_grads=None,
                            params=None,
                            update_indices=None,
                            learning_rate=0.002, beta1=0.9,
                            beta2=0.999, epsilon=1e-8):
     if loss_or_grads is None and params is None:
         return partial(structured_adamax, **_get_call_kwargs(locals()))
     elif loss_or_grads is None or params is None:
         raise ValueError(
             'Please provide both `loss_or_grads` and `params` to get updates')
     assert update_indices is not None
     update_indices_tensor = th.shared(np.asarray(update_indices, dtype=np.int))
    
     all_grads = get_or_compute_grads(loss_or_grads, params)
     t_prev = th.shared(pm.theanof.floatX(0.))
     updates = OrderedDict()
    
     # update degrees of freedom
     num_dof = len(update_indices)
    
     # Using theano constant to prevent upcasting of float32
     one = tt.constant(1)
    
     t = t_prev + 1
     a_t = learning_rate / (one - beta1**t)
    
     for param, g_t in zip(params, all_grads):
         g_t_slice = g_t[update_indices_tensor]
         m_prev = th.shared(np.zeros((num_dof,), dtype=theano.config.floatX),
                            broadcastable=(False,))
         u_prev = th.shared(np.zeros((num_dof,), dtype=theano.config.floatX),
                            broadcastable=(False,))
    
         m_t = beta1 * m_prev + (one - beta1) * g_t_slice
         u_t = tt.maximum(beta2 * u_prev, abs(g_t_slice))
         step = a_t * m_t / (u_t + epsilon)
         new_param = tt.inc_subtensor(param[update_indices_tensor], -step)
    
         updates[m_prev] = m_t
         updates[u_prev] = u_t
         updates[param] = new_param
    
     updates[t_prev] = t
     return updates
    

    Here, update_indices is a list of integers corresponding to the posterior param indices you want to update (e.g. all minus the global RVs). Finally, call advi.fit(obj_optimizer=structured_adamax(update_indices=...)).

I would be happy to add this functionality properly to PyMC3 and make a PR if other people care for it. Again, there’s a chance that this feature already exists in PyMC3 (in particular with the latest OPVI refactorings) but I am unable to find it. @ferrine, could you advise?

Best,
Mehrtash

3 Likes