Hi all,
I found my mean_field approximation using
mean_field = pm.fit(50000, obj_optimizer=pm.adam(learning_rate=0.01),
callbacks=[pm.callbacks.CheckParametersConvergence(diff='absolute')])
Now I want to do sampling using this mean_field approximation as the starting point.
n_sample = 100
with model:
step = pm.NUTS(target_accept=0.9)
start = mean_field.sample(100)
MM_trace = pm.sample(draws= n_sample,
step= step,
cores= 4,
tune=1000,
start=start)
However, I got the following error. Is there any way to somehow tile
the starting point as the number of chains?
Number of seeds and start_points must be 4.
Thanks,
Mahdi