Hi - I’m trying to use NUTS sampler which has a custom likelihood and grad function. I follow the step of Using a “black box” likelihood function — PyMC3 3.11.5 documentation, defining two tt.Op
objects for the loglikelihood and loglikegrad respectively.
Here is my code. Here self.likelihood
is the function cal_logp
, which involves a pytorch network. It takes an array with the shape of (10,) as input, and produces the logp
and logp_grad
at the same time.
class loglikelihood(tt.Op):
itypes = [tt.dvector]
otypes = [tt.dscalar]
def __init__(self, loglike, data, mcmc_para, z_list = torch.linspace(2., 4., 9)):
self.data = data
self.likelihood = loglike
self.z_list = z_list
self.mcmc_para = mcmc_para
self.logpgrad = Loglikegrad(self.likelihood, self.data, self.mcmc_para, self.z_list)
def perform(self, node, inputs, outputs):
(para_value, ) = inputs
logp, logp_grad = self.likelihood(para_value, self.data, self.z_list, self.mcmc_para)
outputs[0][0] = logp.detach().cpu().numpy()
def grad(self,inputs,g):
(para_value, ) = inputs
grad_value = self.logpgrad(para_value)
return [grad_value]
class Loglikegrad(tt.Op):
itypes = [tt.dvector]
otypes = [tt.dscalar]
def __init__(self, log_likelihood, data, mcmc_para, z_list):
self.likelihood = log_likelihood
self.data = data
self.mcmc_para = mcmc_para
self.z_list = z_list
def perform(self, node, inputs, outputs):
(para_value, ) = inputs
logp, logp_grad = self.likelihood(para_value, self.data, self.z_list, self.mcmc_para)
grad_value = logp_grad.detach().cpu().numpy()
outputs[0][0] = grad_value
logp1 = loglikelihood(cal_logp, test_data_global, mcmc_para, z_list = torch.linspace(2., 4., 9))
with pm.Model() as pymodel:
As = pm.Uniform('As',lower=0, upper = 1, testval=0.5)
ns = pm.Uniform('ns',lower=0, upper = 1, testval=0.5)
ln_sigT_kms_0 = pm.Uniform('ln_sigT_kms_0',lower=0, upper = 1,testval=0.5)
ln_sigT_kms_1 = pm.Uniform('ln_sigT_kms_1',lower=0, upper = 1,testval=0.5)
ln_gamma_0 = pm.Uniform('ln_gamma_0',lower=0, upper = 1,testval=0.5)
ln_gamma_1 = pm.Uniform('ln_gamma_1',lower=0, upper = 1,testval=0.5)
ln_tau_0 = pm.Uniform('ln_tau_0',lower=0, upper = 1,testval=0.5)
ln_tau_1 = pm.Uniform('ln_tau_1',lower=0, upper = 1,testval=0.5)
ln_kF_0 = pm.Uniform('ln_kF_0',lower=0, upper = 1,testval=0.5)
ln_kF_1 = pm.Uniform('ln_kF_1',lower=0, upper = 1,testval=0.5)
para_value = tt.as_tensor_variable([As, ns, ln_sigT_kms_0, ln_sigT_kms_1, ln_gamma_0,
ln_gamma_1, ln_tau_0, ln_tau_1, ln_kF_0, ln_kF_1])
para_value = para_value.reshape((10,))
pm.DensityDist('likelihood', lambda v: logp1(v), observed={'v': para_value})
trace = pm.sample(ndraws, tune=nburn, discard_tuned_samples=True)
I get the error of
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_44569/3281932712.py in <module>
24 pm.DensityDist('likelihood', lambda v: logp1(v), observed={'v': para_value})
25
---> 26 trace = pm.sample(ndraws, tune=nburn, discard_tuned_samples=True,chains=1)
~/.conda/envs/ML/lib/python3.8/site-packages/deprecat/classic.py in wrapper_function(wrapped_, instance_, args_, kwargs_)
213 else:
214 warnings.warn(message, category=category, stacklevel=_routine_stacklevel)
--> 215 return wrapped_(*args_, **kwargs_)
216
217 return wrapper_function(wrapped)
~/.conda/envs/ML/lib/python3.8/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, start, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
510 # By default, try to use NUTS
511 _log.info("Auto-assigning NUTS sampler...")
--> 512 start_, step = init_nuts(
513 init=init,
514 chains=chains,
~/.conda/envs/ML/lib/python3.8/site-packages/pymc3/sampling.py in init_nuts(init, chains, n_init, model, random_seed, progressbar, jitter_max_retries, **kwargs)
2204 raise ValueError(f"Unknown initializer: {init}.")
2205
-> 2206 step = pm.NUTS(potential=potential, model=model, **kwargs)
2207
2208 return start, step
~/.conda/envs/ML/lib/python3.8/site-packages/pymc3/step_methods/hmc/nuts.py in __init__(self, vars, max_treedepth, early_max_treedepth, **kwargs)
166 `pm.sample` to the desired number of tuning steps.
167 """
--> 168 super().__init__(vars, **kwargs)
169
170 self.max_treedepth = max_treedepth
~/.conda/envs/ML/lib/python3.8/site-packages/pymc3/step_methods/hmc/base_hmc.py in __init__(self, vars, scaling, step_scale, is_cov, model, blocked, potential, dtype, Emax, target_accept, gamma, k, t0, adapt_step_size, step_rand, **theano_kwargs)
86 vars = inputvars(vars)
87
---> 88 super().__init__(vars, blocked=blocked, model=model, dtype=dtype, **theano_kwargs)
89
90 self.adapt_step_size = adapt_step_size
~/.conda/envs/ML/lib/python3.8/site-packages/pymc3/step_methods/arraystep.py in __init__(self, vars, model, blocked, dtype, logp_dlogp_func, **theano_kwargs)
252
253 if logp_dlogp_func is None:
--> 254 func = model.logp_dlogp_function(vars, dtype=dtype, **theano_kwargs)
255 else:
256 func = logp_dlogp_func
~/.conda/envs/ML/lib/python3.8/site-packages/pymc3/model.py in logp_dlogp_function(self, grad_vars, tempered, **kwargs)
1002 varnames = [var.name for var in grad_vars]
1003 extra_vars = [var for var in self.free_RVs if var.name not in varnames]
-> 1004 return ValueGradFunction(costs, grad_vars, extra_vars, **kwargs)
1005
1006 @property
~/.conda/envs/ML/lib/python3.8/site-packages/pymc3/model.py in __init__(self, costs, grad_vars, extra_vars, dtype, casting, compute_grads, **kwargs)
689
690 if compute_grads:
--> 691 grad = tt.grad(self._cost_joined, self._vars_joined)
692 grad.name = "__grad"
693 outputs = [self._cost_joined, grad]
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
637 assert g.type.dtype in theano.tensor.float_dtypes
638
--> 639 rval = _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
640
641 for i in range(len(rval)):
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
1438 return grad_dict[var]
1439
-> 1440 rval = [access_grad_cache(elem) for elem in wrt]
1441
1442 return rval
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in <listcomp>(.0)
1438 return grad_dict[var]
1439
-> 1440 rval = [access_grad_cache(elem) for elem in wrt]
1441
1442 return rval
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in access_grad_cache(var)
1391 for idx in node_to_idx[node]:
1392
-> 1393 term = access_term_cache(node)[idx]
1394
1395 if not isinstance(term, Variable):
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in access_term_cache(node)
1059 inputs = node.inputs
1060
-> 1061 output_grads = [access_grad_cache(var) for var in node.outputs]
1062
1063 # list of bools indicating if each output is connected to the cost
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in <listcomp>(.0)
1059 inputs = node.inputs
1060
-> 1061 output_grads = [access_grad_cache(var) for var in node.outputs]
1062
1063 # list of bools indicating if each output is connected to the cost
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in access_grad_cache(var)
1391 for idx in node_to_idx[node]:
1392
-> 1393 term = access_term_cache(node)[idx]
1394
1395 if not isinstance(term, Variable):
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in access_term_cache(node)
1059 inputs = node.inputs
1060
-> 1061 output_grads = [access_grad_cache(var) for var in node.outputs]
1062
1063 # list of bools indicating if each output is connected to the cost
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in <listcomp>(.0)
1059 inputs = node.inputs
1060
-> 1061 output_grads = [access_grad_cache(var) for var in node.outputs]
1062
1063 # list of bools indicating if each output is connected to the cost
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in access_grad_cache(var)
1391 for idx in node_to_idx[node]:
1392
-> 1393 term = access_term_cache(node)[idx]
1394
1395 if not isinstance(term, Variable):
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in access_term_cache(node)
1059 inputs = node.inputs
1060
-> 1061 output_grads = [access_grad_cache(var) for var in node.outputs]
1062
1063 # list of bools indicating if each output is connected to the cost
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in <listcomp>(.0)
1059 inputs = node.inputs
1060
-> 1061 output_grads = [access_grad_cache(var) for var in node.outputs]
1062
1063 # list of bools indicating if each output is connected to the cost
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in access_grad_cache(var)
1391 for idx in node_to_idx[node]:
1392
-> 1393 term = access_term_cache(node)[idx]
1394
1395 if not isinstance(term, Variable):
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in access_term_cache(node)
1059 inputs = node.inputs
1060
-> 1061 output_grads = [access_grad_cache(var) for var in node.outputs]
1062
1063 # list of bools indicating if each output is connected to the cost
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in <listcomp>(.0)
1059 inputs = node.inputs
1060
-> 1061 output_grads = [access_grad_cache(var) for var in node.outputs]
1062
1063 # list of bools indicating if each output is connected to the cost
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in access_grad_cache(var)
1391 for idx in node_to_idx[node]:
1392
-> 1393 term = access_term_cache(node)[idx]
1394
1395 if not isinstance(term, Variable):
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in access_term_cache(node)
1059 inputs = node.inputs
1060
-> 1061 output_grads = [access_grad_cache(var) for var in node.outputs]
1062
1063 # list of bools indicating if each output is connected to the cost
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in <listcomp>(.0)
1059 inputs = node.inputs
1060
-> 1061 output_grads = [access_grad_cache(var) for var in node.outputs]
1062
1063 # list of bools indicating if each output is connected to the cost
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in access_grad_cache(var)
1391 for idx in node_to_idx[node]:
1392
-> 1393 term = access_term_cache(node)[idx]
1394
1395 if not isinstance(term, Variable):
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in access_term_cache(node)
1059 inputs = node.inputs
1060
-> 1061 output_grads = [access_grad_cache(var) for var in node.outputs]
1062
1063 # list of bools indicating if each output is connected to the cost
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in <listcomp>(.0)
1059 inputs = node.inputs
1060
-> 1061 output_grads = [access_grad_cache(var) for var in node.outputs]
1062
1063 # list of bools indicating if each output is connected to the cost
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in access_grad_cache(var)
1391 for idx in node_to_idx[node]:
1392
-> 1393 term = access_term_cache(node)[idx]
1394
1395 if not isinstance(term, Variable):
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in access_term_cache(node)
1059 inputs = node.inputs
1060
-> 1061 output_grads = [access_grad_cache(var) for var in node.outputs]
1062
1063 # list of bools indicating if each output is connected to the cost
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in <listcomp>(.0)
1059 inputs = node.inputs
1060
-> 1061 output_grads = [access_grad_cache(var) for var in node.outputs]
1062
1063 # list of bools indicating if each output is connected to the cost
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in access_grad_cache(var)
1391 for idx in node_to_idx[node]:
1392
-> 1393 term = access_term_cache(node)[idx]
1394
1395 if not isinstance(term, Variable):
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in access_term_cache(node)
1059 inputs = node.inputs
1060
-> 1061 output_grads = [access_grad_cache(var) for var in node.outputs]
1062
1063 # list of bools indicating if each output is connected to the cost
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in <listcomp>(.0)
1059 inputs = node.inputs
1060
-> 1061 output_grads = [access_grad_cache(var) for var in node.outputs]
1062
1063 # list of bools indicating if each output is connected to the cost
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in access_grad_cache(var)
1391 for idx in node_to_idx[node]:
1392
-> 1393 term = access_term_cache(node)[idx]
1394
1395 if not isinstance(term, Variable):
~/.conda/envs/ML/lib/python3.8/site-packages/theano/gradient.py in access_term_cache(node)
1218 )
1219
-> 1220 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
1221
1222 if input_grads is None:
~/.conda/envs/ML/lib/python3.8/site-packages/theano/graph/op.py in L_op(self, inputs, outputs, output_grads)
323
324 """
--> 325 return self.grad(inputs, output_grads)
326
327 def R_op(
/tmp/ipykernel_44569/3888861307.py in grad(self, inputs, g)
28 #print('finish log1')
29 #print('debug...grad:',log1(para_value))
---> 30 grad_value = self.logpgrad(para_value)
31 #print('debug...grad_value', grad_value)
32 return [grad_value]
~/.conda/envs/ML/lib/python3.8/site-packages/theano/graph/op.py in __call__(self, *inputs, **kwargs)
251
252 if config.compute_test_value != "off":
--> 253 compute_test_value(node)
254
255 if self.default_output is not None:
~/.conda/envs/ML/lib/python3.8/site-packages/theano/graph/op.py in compute_test_value(node)
137 # Add 'test_value' to output tag, so that downstream `Op`s can use
138 # these numerical values as test values
--> 139 output.tag.test_value = storage_map[output][0]
140
141
~/.conda/envs/ML/lib/python3.8/site-packages/theano/graph/utils.py in __setattr__(self, attr, obj)
264
265 if getattr(self, "attr", None) == attr:
--> 266 obj = self.attr_filter(obj)
267
268 return object.__setattr__(self, attr, obj)
~/.conda/envs/ML/lib/python3.8/site-packages/theano/tensor/type.py in filter(self, data, strict, allow_downcast)
179
180 if self.ndim != data.ndim:
--> 181 raise TypeError(
182 f"Wrong number of dimensions: expected {self.ndim},"
183 f" got {data.ndim} with shape {data.shape}."
TypeError: Wrong number of dimensions: expected 0, got 1 with shape (10,).