Pymc 4.0 and variational inference

Thanks.
I would like to explore my options with pymc version 4. I searched for info on the variational API and found the following for pymc3: Variational API quickstart — PyMC3 3.11.5 documentation

I have not found an updated version for pymc3 version 4. I am wondering whether the examples demonstrating the use of callbacks will have changed. Thanks.

I tried your suggestion and got an error message stating that name 'adam' is not defined. Here are the details.

Code:

def run_and_plot(model, seed, nb_iter=10000):
    # with model:
    #     vi_fit2 = pm.fit(method='svgd', n=nb_iter, random_seed=seed)
        
    with model:
        vi_fit2 = pm.fit(method='advi', n=nb_iter, random_seed=seed, obj_optimizer=adam())
        
    trace5 = vi_fit2.sample(10000)
    pm.plot_trace(trace5);
    
    fig, ax = plt.subplots(figsize=(8, 6))
    plot_w = np.arange(K) + 1 
    ax.bar(plot_w - 0.5, trace5.posterior['w'].squeeze().mean(axis=0), width=1., lw=1, ec='w');
    ax.set_xlim(0.5, K);
    ax.set_xlabel('Component');
    ax.set_ylabel('Posterior expected mixture weight');
    plt.savefig("figure.png")
    
    mean_w = np.mean(trace5.posterior['w'].squeeze(), axis=0)
    nonzero_component = np.where(mean_w > 0.05)[0]

    mean_theta = np.mean(trace5.posterior['theta'].squeeze(), axis=0)
    print("mean_theta[nonzero_component]:\n", mean_theta[nonzero_component, :])
    print("theta_actual:\n", theta_actual)

run_and_plot(model, seed=3437, nb_iter=10000)

Error:

---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Input In [215], in <cell line: 1>()
----> 1 run_and_plot(model, seed=3437, nb_iter=10000)

Input In [214], in run_and_plot(model, seed, nb_iter)
      1 def run_and_plot(model, seed, nb_iter=10000):
      2     # with model:
      3     #     vi_fit2 = pm.fit(method='svgd', n=nb_iter, random_seed=seed)
      5     with model:
----> 6         vi_fit2 = pm.fit(method='advi', n=nb_iter, random_seed=seed, obj_optimizer=adam())
      8     trace5 = vi_fit2.sample(10000)
      9     pm.plot_trace(trace5);

NameError: name 'adam' is not defined

A working example here or in the documentation would be nice! Thank you!