Apologies, my code snippet was written quickly and should have been this:
pm.fit(n=1000, obj_optimizer=pm.adam())
The signature for the callback functions is this:
callbacks: list[function: (Approximation, losses, i) -> None]
The source of the callbacks can be found here, but it’s likely you just want pm.callbacks.CheckParametersConvergence(). So something like this:
with pm.Model() as model:
x = pm.Normal('x', mu=0, sigma=1)
y = pm.Normal('y', mu=x, sigma=1)
vi_fit2 = pm.fit(method='advi',
n=1000,
callbacks=[pm.callbacks.CheckParametersConvergence()],
obj_optimizer=pm.adam()
)
trace = vi_fit2.sample(10000)