I tried a callback in the simplest way possible and got an error related to the number of argument expected in __call__. I am at a loss!
Code:
class Callback:
def __call__(self):
raise NotImplementedError
class Writeout(Callback):
def __init__(self):
pass
def __call__(self):
print("gordon")
callback = Writeout()
def run_and_plot(model, seed, nb_iter=10000):
# with model:
# vi_fit2 = pm.fit(method='svgd', n=nb_iter, random_seed=seed)
with model:
# the `callbacks` argument must be iterable (not the case in pymc3)
vi_fit2 = pm.fit(method='advi', n=nb_iter, random_seed=seed, callbacks=[callback])
# obj_optimizer=adam())
trace5 = vi_fit2.sample(10000)
pm.plot_trace(trace5);
fig, ax = plt.subplots(figsize=(8, 6))
plot_w = np.arange(K) + 1
ax.bar(plot_w - 0.5, trace5.posterior['w'].squeeze().mean(axis=0), width=1., lw=1, ec='w');
ax.set_xlim(0.5, K);
ax.set_xlabel('Component');
ax.set_ylabel('Posterior expected mixture weight');
plt.savefig("figure.png")
mean_w = np.mean(trace5.posterior['w'].squeeze(), axis=0)
nonzero_component = np.where(mean_w > 0.05)[0]
mean_theta = np.mean(trace5.posterior['theta'].squeeze(), axis=0)
print("mean_theta[nonzero_component]:\n", mean_theta[nonzero_component, :])
print("theta_actual:\n", theta_actual)
run_and_plot(model, seed=3437, nb_iter=100)
And the error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [251], in <cell line: 1>()
----> 1 run_and_plot(model, seed=3437, nb_iter=10000)
Input In [250], in run_and_plot(model, seed, nb_iter)
1 def run_and_plot(model, seed, nb_iter=10000):
2 # with model:
3 # vi_fit2 = pm.fit(method='svgd', n=nb_iter, random_seed=seed)
5 with model:
----> 6 vi_fit2 = pm.fit(method='advi', n=nb_iter, random_seed=seed, callbacks=[callback])
7 # obj_optimizer=adam())
9 trace5 = vi_fit2.sample(10000)
File ~/opt/anaconda3/envs/pymc4/lib/python3.9/site-packages/pymc/variational/inference.py:765, in fit(n, local_rv, method, model, random_seed, start, inf_kwargs, **kwargs)
763 else:
764 raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance")
--> 765 return inference.fit(n, **kwargs)
File ~/opt/anaconda3/envs/pymc4/lib/python3.9/site-packages/pymc/variational/inference.py:144, in Inference.fit(self, n, score, callbacks, progressbar, **kwargs)
142 progress = range(n)
143 if score:
--> 144 state = self._iterate_with_loss(0, n, step_func, progress, callbacks)
145 else:
146 state = self._iterate_without_loss(0, n, step_func, progress, callbacks)
File ~/opt/anaconda3/envs/pymc4/lib/python3.9/site-packages/pymc/variational/inference.py:240, in Inference._iterate_with_loss(self, s, n, step_func, progress, callbacks)
238 progress.comment = f"Average Loss = {avg_loss:,.5g}"
239 for callback in callbacks:
--> 240 callback(self.approx, scores[: i + 1], i + s + 1)
241 except (KeyboardInterrupt, StopIteration) as e: # pragma: no cover
242 # do not print log on the same line
243 scores = scores[:i]
TypeError: __call__() takes 1 positional argument but 4 were given