Pymc 4.0 and variational inference

Apologies, my code snippet was written quickly and should have been this:

pm.fit(n=1000, obj_optimizer=pm.adam())

The signature for the callback functions is this:

callbacks: list[function: (Approximation, losses, i) -> None]

The source of the callbacks can be found here, but it’s likely you just want pm.callbacks.CheckParametersConvergence(). So something like this:

with pm.Model() as model:
    x = pm.Normal('x', mu=0, sigma=1)
    y = pm.Normal('y', mu=x, sigma=1)

    vi_fit2 = pm.fit(method='advi',
                     n=1000,
                     callbacks=[pm.callbacks.CheckParametersConvergence()],
                     obj_optimizer=pm.adam()
                    )
        
    trace = vi_fit2.sample(10000)