Hi
My idea was to save a model object to a pickle file but with dummy data held in the model. So the dimensions of the data are correct but the number of rows is just 1.
Then later on I would load the pickled model, use pm.set_data() to input all of the training data, sample to generate the trace, then use pm.set_data() to input the training data and finally do posterior predictive sampling.
However when I do the sampling after inputting the real training data it fails with error message:
Only 8 samples in chain.
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Traceback (most recent call last):
File "model_template.py", line 45, in <module>
main()
File "model_template.py", line 31, in main
trace = pm.sample(8, tune=0, chains=1, cores=1)
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/pymc3/sampling.py", line 368, in sample
progressbar=progressbar, **kwargs)
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/pymc3/sampling.py", line 1579, in init_nuts
step = pm.NUTS(potential=potential, model=model, **kwargs)
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/pymc3/step_methods/hmc/nuts.py", line 148, in __init__
super().__init__(vars, **kwargs)
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/pymc3/step_methods/hmc/base_hmc.py", line 72, in __init__
super().__init__(vars, blocked=blocked, model=model, dtype=dtype, **theano_kwargs)
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/pymc3/step_methods/arraystep.py", line 228, in __init__
vars, dtype=dtype, **theano_kwargs)
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/pymc3/model.py", line 733, in logp_dlogp_function
return ValueGradFunction(self.logpt, grad_vars, extra_vars, **kwargs)
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/pymc3/model.py", line 466, in __init__
grad = tt.grad(self._cost_joined, self._vars_joined)
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 605, in grad
grad_dict, wrt, cost_name)
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1371, in _populate_grad_dict
rval = [access_grad_cache(elem) for elem in wrt]
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1371, in <listcomp>
rval = [access_grad_cache(elem) for elem in wrt]
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1326, in access_grad_cache
term = access_term_cache(node)[idx]
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1021, in access_term_cache
output_grads = [access_grad_cache(var) for var in node.outputs]
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1021, in <listcomp>
output_grads = [access_grad_cache(var) for var in node.outputs]
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1326, in access_grad_cache
term = access_term_cache(node)[idx]
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1021, in access_term_cache
output_grads = [access_grad_cache(var) for var in node.outputs]
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1021, in <listcomp>
output_grads = [access_grad_cache(var) for var in node.outputs]
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1326, in access_grad_cache
term = access_term_cache(node)[idx]
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1021, in access_term_cache
output_grads = [access_grad_cache(var) for var in node.outputs]
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1021, in <listcomp>
output_grads = [access_grad_cache(var) for var in node.outputs]
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1326, in access_grad_cache
term = access_term_cache(node)[idx]
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1021, in access_term_cache
output_grads = [access_grad_cache(var) for var in node.outputs]
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1021, in <listcomp>
output_grads = [access_grad_cache(var) for var in node.outputs]
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1326, in access_grad_cache
term = access_term_cache(node)[idx]
File "/home/n_kaimcaudle/miniconda3/lib/python3.6/site-packages/theano/gradient.py", line 1237, in access_term_cache
"of shape %s" % (node.op, t_shape, i, i_shape))
ValueError: Elemwise{sub,no_inplace}.grad returned object of shape (1,) as gradient term on input 0 of shape (64,)
I know from previous projects that I can sample with training data if I build the model using it originally and then change the number of rows in the testing data using pm.set_data() and pm.sample_posterior_predictive() works fine.
Having pm.set_data() before pm.sample() doesn’t seem to “reset” the theano graph to allow a different number of rows, but before pm.sample_posterior_predictive() it is fine.
My code below:
import pymc3 as pm
import numpy as np
def main():
X_pre = np.random.rand(1,2)
y_pre = np.random.rand(1)
with pm.Model() as model:
X = pm.Data('X', X_pre)
y = pm.Data('y', y_pre)
coef = pm.Normal('coef', 0, 1, shape=2)
intercept = pm.Normal('intercept', 0, 1)
mu = intercept + X.dot(coef)
obs = pm.Normal('obs', mu=mu, sigma=1, observed=y)
#Sampling here works if I uncomment out, as expected
#trace = pm.sample(8, tune=0, chains=1, cores=1)
### In reality I want to load the template model from a pickle object
N = 64
X_new = np.random.rand(N, 2)
y_new = np.random.rand(N)
if True:
with model:
pm.set_data({'X':X_new})
pm.set_data({'y':y_new})
#This line fails
trace = pm.sample(8, tune=0, chains=1, cores=1)
else:
with model:
trace = pm.sample(8, tune=0, chains=1, cores=1)
pm.set_data({'X':X_new})
pm.set_data({'y':y_new})
ppc = pm.sample_posterior_predictive(trace)
print( ppc['obs'].shape )
if __name__== "__main__":
main()