approx = pm.fit(n=10000, method='advi', model=model,
obj_optimizer=pm.adagrad_window(learning_rate=2e-4),
total_grad_norm_constraint=10 # other constants can appear here, I do not coin 10
)
Do we have an option for passing that into sample for advi initialization?
And is there a possibility for the advi optimizer to ignore a couple of nans/infs near the start of the optimization?
Yes, start point is supported with start kwarg. Nan can’t be ignored as it goes to updates and breaks all. Infs can be ignored as they don’t turn updates to nan
Sorry, my question wasn’t very precise. What I meant was a way to pass parameters like the learning rate to pm.sample. Something analog to nuts_kwargs/step_kwargs.
And about the nans: Wouldn’t it be possible to always store the previous state before a step and then if we encounter a nan go back one step and decrease the learning_rate or so? I don’t really know the literature about those optimizers well, so I hope I’m not asking something stupid here.
Is it difficult from an algorithmic point of view or because of the implementation?
I think the “NaN occurred in optimization” errors are a major pain point for using NUTS right now. I’ve seen a couple of models that sampled well when using previous chains to initialize the scaling, but don’t work with the advi initialization because of some stray nan somewhere in the beginning. As a workaround I’ve used something like this:
with model:
stds = np.ones(model.ndim)
for _ in range(5):
args = {'scaling': stds ** 2, 'is_cov': True}
trace = pm.sample(100, tune=100, init=None, nuts_kwargs=args)
samples = [model.dict_to_array(p) for p in trace]
stds = np.array(samples).std(axis=0)
traces = []
for i in range(2):
step = pm.NUTS(scaling=stds ** 2, is_cov=True, target_accept=0.9)
start = trace[-10 * i]
trace_ = pm.sample(1000, tune=800, chain=i, init=None, step=step, start=start)
traces.append(trace_)
trace = pm.backends.base.merge_traces(traces)
I guess we should put that somewhere in NUTS itself, this is similar to what stan does for initialzation.