Saving a model object with minimal size for later sampling

Hi

My idea was to save a model object to a pickle file but with dummy data held in the model. So the dimensions of the data are correct but the number of rows is just 1.
Then later on I would load the pickled model, use pm.set_data() to input all of the training data, sample to generate the trace, then use pm.set_data() to input the training data and finally do posterior predictive sampling.

However when I do the sampling after inputting the real training data it fails with error message:

Only 8 samples in chain.
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Traceback (most recent call last):
  File "model_template.py", line 45, in <module>
    main()
  File "model_template.py", line 31, in main
    trace = pm.sample(8, tune=0, chains=1, cores=1)
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/pymc3/sampling.py", line 368, in sample
    progressbar=progressbar, **kwargs)
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/pymc3/sampling.py", line 1579, in init_nuts
    step = pm.NUTS(potential=potential, model=model, **kwargs)
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/pymc3/step_methods/hmc/nuts.py", line 148, in __init__
    super().__init__(vars, **kwargs)
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/pymc3/step_methods/hmc/base_hmc.py", line 72, in __init__
    super().__init__(vars, blocked=blocked, model=model, dtype=dtype, **theano_kwargs)
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/pymc3/step_methods/arraystep.py", line 228, in __init__
    vars, dtype=dtype, **theano_kwargs)
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/pymc3/model.py", line 733, in logp_dlogp_function
    return ValueGradFunction(self.logpt, grad_vars, extra_vars, **kwargs)
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/pymc3/model.py", line 466, in __init__
    grad = tt.grad(self._cost_joined, self._vars_joined)
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 605, in grad
    grad_dict, wrt, cost_name)
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1371, in _populate_grad_dict
    rval = [access_grad_cache(elem) for elem in wrt]
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1371, in <listcomp>
    rval = [access_grad_cache(elem) for elem in wrt]
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1326, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1021, in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1021, in <listcomp>
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1326, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1021, in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1021, in <listcomp>
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1326, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1021, in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1021, in <listcomp>
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1326, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1021, in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1021, in <listcomp>
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1326, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1237, in access_term_cache
    "of shape %s" % (node.op, t_shape, i, i_shape))
ValueError: Elemwise{sub,no_inplace}.grad returned object of shape (1,) as gradient term on input 0 of shape (64,)

I know from previous projects that I can sample with training data if I build the model using it originally and then change the number of rows in the testing data using pm.set_data() and pm.sample_posterior_predictive() works fine.
Having pm.set_data() before pm.sample() doesn’t seem to “reset” the theano graph to allow a different number of rows, but before pm.sample_posterior_predictive() it is fine.

My code below:

import pymc3 as pm
import numpy as np

def main():
    X_pre = np.random.rand(1,2)
    y_pre = np.random.rand(1)
    
    with pm.Model() as model:
        X = pm.Data('X', X_pre)
        y = pm.Data('y', y_pre)
        
        coef = pm.Normal('coef', 0, 1, shape=2)
        intercept = pm.Normal('intercept', 0, 1)
        
        mu = intercept + X.dot(coef)
        
        obs = pm.Normal('obs', mu=mu, sigma=1, observed=y)
        #Sampling here works if I uncomment out, as expected
        #trace = pm.sample(8, tune=0, chains=1, cores=1)
        
    ### In reality I want to load the template model from a pickle object
    N = 64
    X_new = np.random.rand(N, 2)
    y_new = np.random.rand(N)
    
    if True:
        with model:
            pm.set_data({'X':X_new})
            pm.set_data({'y':y_new})
            #This line fails
            trace = pm.sample(8, tune=0, chains=1, cores=1)
    
    else:
        with model:
            trace = pm.sample(8, tune=0, chains=1, cores=1)
            
            pm.set_data({'X':X_new})
            pm.set_data({'y':y_new})
            ppc = pm.sample_posterior_predictive(trace)
        print( ppc['obs'].shape )
    
if __name__== "__main__":
    main()