Pymc 4.0 and variational inference

Is it possible to call pm.fit('advi') and specify that I want to use the Adam optimizer for a faster inference? I assume that stochastic gradient descent is used by default. Thanks.

I suspect that you are looking for something like this:

pm.fit(n=1000, obj_optimizer=adam())

The various update methods can be found here.

Thanks.
I would like to explore my options with pymc version 4. I searched for info on the variational API and found the following for pymc3: Variational API quickstart — PyMC3 3.11.5 documentation

I have not found an updated version for pymc3 version 4. I am wondering whether the examples demonstrating the use of callbacks will have changed. Thanks.

I tried your suggestion and got an error message stating that name 'adam' is not defined. Here are the details.

Code:

def run_and_plot(model, seed, nb_iter=10000):
    # with model:
    #     vi_fit2 = pm.fit(method='svgd', n=nb_iter, random_seed=seed)
        
    with model:
        vi_fit2 = pm.fit(method='advi', n=nb_iter, random_seed=seed, obj_optimizer=adam())
        
    trace5 = vi_fit2.sample(10000)
    pm.plot_trace(trace5);
    
    fig, ax = plt.subplots(figsize=(8, 6))
    plot_w = np.arange(K) + 1 
    ax.bar(plot_w - 0.5, trace5.posterior['w'].squeeze().mean(axis=0), width=1., lw=1, ec='w');
    ax.set_xlim(0.5, K);
    ax.set_xlabel('Component');
    ax.set_ylabel('Posterior expected mixture weight');
    plt.savefig("figure.png")
    
    mean_w = np.mean(trace5.posterior['w'].squeeze(), axis=0)
    nonzero_component = np.where(mean_w > 0.05)[0]

    mean_theta = np.mean(trace5.posterior['theta'].squeeze(), axis=0)
    print("mean_theta[nonzero_component]:\n", mean_theta[nonzero_component, :])
    print("theta_actual:\n", theta_actual)

run_and_plot(model, seed=3437, nb_iter=10000)

Error:

---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Input In [215], in <cell line: 1>()
----> 1 run_and_plot(model, seed=3437, nb_iter=10000)

Input In [214], in run_and_plot(model, seed, nb_iter)
      1 def run_and_plot(model, seed, nb_iter=10000):
      2     # with model:
      3     #     vi_fit2 = pm.fit(method='svgd', n=nb_iter, random_seed=seed)
      5     with model:
----> 6         vi_fit2 = pm.fit(method='advi', n=nb_iter, random_seed=seed, obj_optimizer=adam())
      8     trace5 = vi_fit2.sample(10000)
      9     pm.plot_trace(trace5);

NameError: name 'adam' is not defined

A working example here or in the documentation would be nice! Thank you!

I tried a callback in the simplest way possible and got an error related to the number of argument expected in __call__. I am at a loss!

Code:

class Callback:
    def __call__(self):
        raise NotImplementedError
        
class Writeout(Callback):
    def __init__(self):
        pass
    
    def __call__(self):
        print("gordon")
        
callback = Writeout()

def run_and_plot(model, seed, nb_iter=10000):
    # with model:
    #     vi_fit2 = pm.fit(method='svgd', n=nb_iter, random_seed=seed)
        
    with model:
       # the `callbacks` argument must be iterable (not the case in pymc3)
        vi_fit2 = pm.fit(method='advi', n=nb_iter, random_seed=seed, callbacks=[callback])
                         # obj_optimizer=adam())
        
    trace5 = vi_fit2.sample(10000)
    pm.plot_trace(trace5);
    
    fig, ax = plt.subplots(figsize=(8, 6))
    plot_w = np.arange(K) + 1 
    ax.bar(plot_w - 0.5, trace5.posterior['w'].squeeze().mean(axis=0), width=1., lw=1, ec='w');
    ax.set_xlim(0.5, K);
    ax.set_xlabel('Component');
    ax.set_ylabel('Posterior expected mixture weight');
    plt.savefig("figure.png")
    
    mean_w = np.mean(trace5.posterior['w'].squeeze(), axis=0)
    nonzero_component = np.where(mean_w > 0.05)[0]

    mean_theta = np.mean(trace5.posterior['theta'].squeeze(), axis=0)
    print("mean_theta[nonzero_component]:\n", mean_theta[nonzero_component, :])
    print("theta_actual:\n", theta_actual)

run_and_plot(model, seed=3437, nb_iter=100)

And the error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [251], in <cell line: 1>()
----> 1 run_and_plot(model, seed=3437, nb_iter=10000)

Input In [250], in run_and_plot(model, seed, nb_iter)
      1 def run_and_plot(model, seed, nb_iter=10000):
      2     # with model:
      3     #     vi_fit2 = pm.fit(method='svgd', n=nb_iter, random_seed=seed)
      5     with model:
----> 6         vi_fit2 = pm.fit(method='advi', n=nb_iter, random_seed=seed, callbacks=[callback])
      7                          # obj_optimizer=adam())
      9     trace5 = vi_fit2.sample(10000)

File ~/opt/anaconda3/envs/pymc4/lib/python3.9/site-packages/pymc/variational/inference.py:765, in fit(n, local_rv, method, model, random_seed, start, inf_kwargs, **kwargs)
    763 else:
    764     raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance")
--> 765 return inference.fit(n, **kwargs)

File ~/opt/anaconda3/envs/pymc4/lib/python3.9/site-packages/pymc/variational/inference.py:144, in Inference.fit(self, n, score, callbacks, progressbar, **kwargs)
    142     progress = range(n)
    143 if score:
--> 144     state = self._iterate_with_loss(0, n, step_func, progress, callbacks)
    145 else:
    146     state = self._iterate_without_loss(0, n, step_func, progress, callbacks)

File ~/opt/anaconda3/envs/pymc4/lib/python3.9/site-packages/pymc/variational/inference.py:240, in Inference._iterate_with_loss(self, s, n, step_func, progress, callbacks)
    238                 progress.comment = f"Average Loss = {avg_loss:,.5g}"
    239         for callback in callbacks:
--> 240             callback(self.approx, scores[: i + 1], i + s + 1)
    241 except (KeyboardInterrupt, StopIteration) as e:  # pragma: no cover
    242     # do not print log on the same line
    243     scores = scores[:i]

TypeError: __call__() takes 1 positional argument but 4 were given

Apologies, my code snippet was written quickly and should have been this:

pm.fit(n=1000, obj_optimizer=pm.adam())

The signature for the callback functions is this:

callbacks: list[function: (Approximation, losses, i) -> None]

The source of the callbacks can be found here, but it’s likely you just want pm.callbacks.CheckParametersConvergence(). So something like this:

with pm.Model() as model:
    x = pm.Normal('x', mu=0, sigma=1)
    y = pm.Normal('y', mu=x, sigma=1)

    vi_fit2 = pm.fit(method='advi',
                     n=1000,
                     callbacks=[pm.callbacks.CheckParametersConvergence()],
                     obj_optimizer=pm.adam()
                    )
        
    trace = vi_fit2.sample(10000)

I got it to work. However, now I have to learn to use the output of the callbacks. Will report back later. :slight_smile:

Now that I have callbacks working, I would like to access its output. I cannot figure out the logic. Is there information out there explaining how it works (aside from looking at the source code?) Thanks.

I cannot figure out how to access the data stored in the callback CheckParametersConvergence() . Any help would be appreciated! Thanks.

I think the quickstart illustrates how to interrogate the convergence.

I went through the Quickstart notebook in detail and tried things out before writing on discourse. Consider the following line:

with model:
    mean_field = pm.fit(method="advi", callbacks=[CheckParametersConvergence()])
plt.plot(mean_field.hist);

I tried this (with pymc 4). I would like to do the following: update an array with my own diagnostics at a specified frequency. So the question is how to retrieve my array from the callback? I guess I must create my own callback? I might experiment with that.

It would be nice to check results every n iterations? Is that done by running pm.sample at a lower frequency, and simply restarting the simulation ?I have not tried this yet. Thanks!

I created my own class, copying CheckParametersConvergence. I was able to get the results allowed. I will now experiment with Tracker.