I am trying to use the following linear transformation to some parameters in my model.
class Linear(pm.distributions.transforms.ElemwiseTransform):
name = "linear"
def __init__(self, alpha, beta):
super(pm.distributions.transforms.ElemwiseTransform, self).__init__()
self.alpha = alpha
self.beta = beta
def forward(self, x):
return self.alpha * x + self.beta
def forward_val(self, x, point=None):
return self.alpha * x + self.beta
def backward(self, x):
return (1 / self.alpha) * (x - self.beta)
def jacobian_det(self, x):
return -tt.log(self.alpha)
When I ran SVGD with this transformation it worked fine but however if I try to sample using NUTS it is throwing the following error. What am I doing wrong?
ValueError Traceback (most recent call last)
<ipython-input-7-8b609db3e25e> in <module>()
2 # optmzr = pm.SVGD(n_particles=4, jitter=0.1, temperature=1, model=vep_mdl.model, start=params_init)
3 # post_approx = optmzr.fit(niters)
----> 4 trace = pm.sample(draws=500, start=params_init, tune=500, model=vep_mdl.model, chains=1)# nuts_kwargs={'target_accept':0.95, 'max_treedepth':15})
~/anaconda3/envs/python3.6/lib/python3.6/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, nuts_kwargs, step_kwargs, progressbar, model, random_seed, live_plot, discard_tuned_samples, live_plot_kwargs, compute_convergence_checks, use_mmap, **kwargs)
457 _log.info('Sequential sampling ({} chains in 1 job)'.format(chains))
458 _print_step_hierarchy(step)
--> 459 trace = _sample_many(**sample_args)
461 discard = tune if discard_tuned_samples else 0
~/anaconda3/envs/python3.6/lib/python3.6/site-packages/pymc3/sampling.py in _sample_many(draws, chain, chains, start, random_seed, step, **kwargs)
503 for i in range(chains):
504 trace = _sample(draws=draws, chain=chain + i, start=start[i],
--> 505 step=step, random_seed=random_seed[i], **kwargs)
506 if trace is None:
507 if len(traces) == 0:
~/anaconda3/envs/python3.6/lib/python3.6/site-packages/pymc3/sampling.py in _sample(chain, progressbar, random_seed, start, draws, step, trace, tune, model, live_plot, live_plot_kwargs, **kwargs)
547 try:
548 strace = None
--> 549 for it, strace in enumerate(sampling):
550 if live_plot:
551 if live_plot_kwargs is None:
~/anaconda3/envs/python3.6/lib/python3.6/site-packages/tqdm/_tqdm.py in __iter__(self)
1020 """), fp_write=getattr(self.fp, 'write', sys.stderr.write))
-> 1022 for obj in iterable:
1023 yield obj
1024 # Update and possibly print the progressbar.
~/anaconda3/envs/python3.6/lib/python3.6/site-packages/pymc3/sampling.py in _iter_sample(draws, step, start, trace, chain, tune, model, random_seed)
643 step = stop_tuning(step)
644 if step.generates_stats:
--> 645 point, states = step.step(point)
646 if strace.supports_sampler_stats:
647 strace.record(point, states)
~/anaconda3/envs/python3.6/lib/python3.6/site-packages/pymc3/step_methods/arraystep.py in step(self, point)
242 def step(self, point):
243 self._logp_dlogp_func.set_extra_values(point)
--> 244 array = self._logp_dlogp_func.dict_to_array(point)
246 if self.generates_stats:
~/anaconda3/envs/python3.6/lib/python3.6/site-packages/pymc3/model.py in dict_to_array(self, point)
493 array = np.empty(self.size, dtype=self.dtype)
494 for varmap in self._ordering.vmap:
--> 495 array[varmap.slc] = point[varmap.var].ravel().astype(self.dtype)
496 return array
ValueError: setting an array element with a sequence.