Issue using pm.Interpolated() to update priors in online updating loop

I am having some trouble after changing my pymc3 version to 3.5 with online updating of a model prior from incoming data. The scheme starts with a sampled prior model trace and then uses pm.Interpolated() with a KDE (similar to the Updating Priors example) in each iteration of the online updating. This technique worked well in 3.4.1 so I’m wondering if the API has changed or if I’ve added some bugs.

See the gist below for a minimal reproducible example:

I’m using SMC sampling steps here, perhaps its the different backend between the traces? With the online update step being pm.Metropolis() and it appears to work.

Environment:
pymc3 3.5
python 3.6.5
GCC 7.2.0 on linux

Here is the stack trace:

(pymc3env) root@369fa5b87409:~/src/tconstrmodel/bayesian# python demo_problem.py
3.5
Generating prior model
Sequential sampling (1 chains in 1 job)
Metropolis: [eta]
100%|##############################################################################################################################| 1500/1500 [00:01<00:00, 1111.38it/s]
Only one chain was sampled, this makes it impossible to run some convergence checks
Loading prior model from: /tmp/tmpetssr7q9/prior.pkl
online update: 1
Adding model likelihood to RVs!
Init new trace!
Sample initial stage: ...
Beta: 0.000000 Stage: 0
Initializing chain traces ...
Sampling ...
  0%|                                                                                                                                            | 0/100 [00:00<?, ?it/s]Traceback (most recent call last):
  File "demo_problem.py", line 111, in <module>
    step=pm.SMC()
  File "/root/miniconda/envs/pymc3env/lib/python3.6/site-packages/pymc3/sampling.py", line 340, in sample
    **kwargs)
  File "/root/miniconda/envs/pymc3env/lib/python3.6/site-packages/pymc3/step_methods/smc.py", line 518, in sample_smc
    _iter_parallel_chains(**sample_args)
  File "/root/miniconda/envs/pymc3env/lib/python3.6/site-packages/pymc3/step_methods/smc.py", line 694, in _iter_parallel_chains
    for _ in p:
  File "/root/miniconda/envs/pymc3env/lib/python3.6/site-packages/tqdm/_tqdm.py", line 930, in __iter__
    for obj in iterable:
  File "/root/miniconda/envs/pymc3env/lib/python3.6/site-packages/pymc3/backends/smc_text.py", line 58, in paripool
    yield function(work_item)
  File "/root/miniconda/envs/pymc3env/lib/python3.6/site-packages/pymc3/step_methods/smc.py", line 655, in _work_chain
    return _sample(*work)
  File "/root/miniconda/envs/pymc3env/lib/python3.6/site-packages/pymc3/step_methods/smc.py", line 603, in _sample
    for strace in sampling:
  File "/root/miniconda/envs/pymc3env/lib/python3.6/site-packages/pymc3/step_methods/smc.py", line 636, in _iter_sample
    point, out_list = step.step(point)
  File "/root/miniconda/envs/pymc3env/lib/python3.6/site-packages/pymc3/backends/smc_text.py", line 103, in step
    apoint, alist = self.astep(self.bij.map(point))
  File "/root/miniconda/envs/pymc3env/lib/python3.6/site-packages/pymc3/blocking.py", line 75, in map
    apt[slc] = dpt[var].ravel()
KeyError: 'eta_interval__'

Its clear that pm.Interpolated() adds transformed interval_ variables, but this did not previously cause any issues when starting from a trace that does not have the transformed variables.

If anyone has insight it will be very much appreciated as this has me quite stumped! Also, if anyone knows of a good way to resample a trace to a desired length please let me know.

Thank you,
Tom

I found that I can naively add a reference of "eta_interval__" to the online_trace._straces[0].samples dict in the first iteration with the ndarray backend, then it runs afterwards. It’s not very elegant and between this and how I am resampling the prior trace I feel like I’m not using the API correctly at all. Comments and advice welcome :stuck_out_tongue:

        with pm.Model() as update_model:

            # using the pm.Interpolated distribution causes problem here
            # it adds '_interval__' variables which are not in previous trace
            eta = from_posterior('eta', online_trace.get_values('eta'))
            theta = X_1*eta
            obs = pm.Normal('obs', theta, observed=Y_1)
            # naively adding a reference for "_interval__" variables here in first iteration
            for varname in online_trace.varnames:
                if ("_interval__" not in varname) and (
                    varname+"_interval__" not in online_trace.varnames) and (
                        varname != "l_like__"):
                    online_trace._straces[0].samples.update(
                        {varname+"_interval__":
                            online_trace._straces[0].samples[varname]})

            online_trace = pm.sample(
                draws=num_samples,
                start=online_trace,
                chains=num_samples,
                cores=1,
                step=pm.SMC()
            )