Feed the pm.fit results to NUTS sampler

Hi all,
I found my mean_field approximation using

mean_field = pm.fit(50000, obj_optimizer=pm.adam(learning_rate=0.01), 
                    callbacks=[pm.callbacks.CheckParametersConvergence(diff='absolute')])

Now I want to do sampling using this mean_field approximation as the starting point.

n_sample = 100
with model:
    step      = pm.NUTS(target_accept=0.9)
    start = mean_field.sample(100)
    MM_trace = pm.sample(draws= n_sample, 
                         step= step, 
                         cores= 4, 
                         tune=1000,
                         start=start)

However, I got the following error. Is there any way to somehow tile the starting point as the number of chains?

Number of seeds and start_points must be 4.

Thanks,
Mahdi

You probably need to input the start dict once for each chain, so try
pm.sample(..., start=[start, start, start, start])

I have tried this and got the following error:

TypeError: start argument must be a dict or an array-like of dicts

Sorry, I misread your original post. You need to supply a single point as start, so a dict of {varname:number}

something like this:

tr = mean_field.sample(100)
start = {var_name:tr[var_name].mean() for var_name in tr.varnames}
with model:
    trace = pm.sample(start=start)
1 Like

Thank you very much.