Transformed variables causing errors with NUTS but not with SVGD

#1

Hi,

I am trying to use the following linear transformation to some parameters in my model.

class Linear(pm.distributions.transforms.ElemwiseTransform):
    name = "linear"

    def __init__(self, alpha, beta):
        super(pm.distributions.transforms.ElemwiseTransform, self).__init__()
        self.alpha = alpha
        self.beta = beta

    def forward(self, x):
        return self.alpha * x + self.beta

    def forward_val(self, x, point=None):
        return self.alpha * x + self.beta

    def backward(self, x):
        return (1 / self.alpha) * (x - self.beta)

    def jacobian_det(self, x):
        return -tt.log(self.alpha)

When I ran SVGD with this transformation it worked fine but however if I try to sample using NUTS it is throwing the following error. What am I doing wrong?

ValueError                                Traceback (most recent call last)
<ipython-input-7-8b609db3e25e> in <module>()
      2 # optmzr = pm.SVGD(n_particles=4, jitter=0.1, temperature=1, model=vep_mdl.model, start=params_init)
      3 # post_approx = optmzr.fit(niters)
----> 4 trace = pm.sample(draws=500, start=params_init, tune=500, model=vep_mdl.model, chains=1)# nuts_kwargs={'target_accept':0.95, 'max_treedepth':15})

~/anaconda3/envs/python3.6/lib/python3.6/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, nuts_kwargs, step_kwargs, progressbar, model, random_seed, live_plot, discard_tuned_samples, live_plot_kwargs, compute_convergence_checks, use_mmap, **kwargs)
    457                 _log.info('Sequential sampling ({} chains in 1 job)'.format(chains))
    458                 _print_step_hierarchy(step)
--> 459                 trace = _sample_many(**sample_args)
    460 
    461         discard = tune if discard_tuned_samples else 0

~/anaconda3/envs/python3.6/lib/python3.6/site-packages/pymc3/sampling.py in _sample_many(draws, chain, chains, start, random_seed, step, **kwargs)
    503     for i in range(chains):
    504         trace = _sample(draws=draws, chain=chain + i, start=start[i],
--> 505                         step=step, random_seed=random_seed[i], **kwargs)
    506         if trace is None:
    507             if len(traces) == 0:

~/anaconda3/envs/python3.6/lib/python3.6/site-packages/pymc3/sampling.py in _sample(chain, progressbar, random_seed, start, draws, step, trace, tune, model, live_plot, live_plot_kwargs, **kwargs)
    547     try:
    548         strace = None
--> 549         for it, strace in enumerate(sampling):
    550             if live_plot:
    551                 if live_plot_kwargs is None:

~/anaconda3/envs/python3.6/lib/python3.6/site-packages/tqdm/_tqdm.py in __iter__(self)
   1020                 """), fp_write=getattr(self.fp, 'write', sys.stderr.write))
   1021 
-> 1022             for obj in iterable:
   1023                 yield obj
   1024                 # Update and possibly print the progressbar.

~/anaconda3/envs/python3.6/lib/python3.6/site-packages/pymc3/sampling.py in _iter_sample(draws, step, start, trace, chain, tune, model, random_seed)
    643                 step = stop_tuning(step)
    644             if step.generates_stats:
--> 645                 point, states = step.step(point)
    646                 if strace.supports_sampler_stats:
    647                     strace.record(point, states)

~/anaconda3/envs/python3.6/lib/python3.6/site-packages/pymc3/step_methods/arraystep.py in step(self, point)
    242     def step(self, point):
    243         self._logp_dlogp_func.set_extra_values(point)
--> 244         array = self._logp_dlogp_func.dict_to_array(point)
    245 
    246         if self.generates_stats:

~/anaconda3/envs/python3.6/lib/python3.6/site-packages/pymc3/model.py in dict_to_array(self, point)
    493         array = np.empty(self.size, dtype=self.dtype)
    494         for varmap in self._ordering.vmap:
--> 495             array[varmap.slc] = point[varmap.var].ravel().astype(self.dtype)
    496         return array
    497 

ValueError: setting an array element with a sequence.

0 Likes