Problem with sampling - pickling

Hi,

I have been revisiting a model that worked a couple of months ago, but since updating to PyMC v3.11 I am getting a Recursion Error when trying to sample.

The model looks like this:

def pow_law_mod(x,p1,p2,p3):
    return p1*tt.pow(x,-p2)+p3

def likelihood(model, theta1, theta2, theta3, nu):
    '''
    model:    - function
                functional form of curve
    theta1-3: - float
                parameters
    nu:       - int 
                constant
    '''
    def logp_(value):
        mod_ = model(value[0],theta1,theta2, theta3) 
        return (-nu*((value[1]/mod_)+tt.log(mod_)+(2/nu-1)*tt.log(value[1]))).sum()
    return logp_

with pm.Model() as model:
    
    #prior
    slope = pm.Normal('slope', 1, 1 )
    amp = pm.HalfNormal('amp', 1 )
    const = pm.Normal('const', 1e3, 100)
    
    #likelihood
    like = pm.DensityDist('like', likelihood(pow_law_mod, amp, slope, const,2), observed=[x,test_data])

and I get a very long error message that looks like the following (I have truncated):

---------------------------------------------------------------------------
RecursionError                            Traceback (most recent call last)
<ipython-input-7-af36fa849c26> in <module>
      2 
      3     
----> 4     trace = pm.sample(1000, tune = 2500, target_accept=0.85)
      5 
      6     

~/opt/anaconda3/lib/python3.8/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
    556         _print_step_hierarchy(step)
    557         try:
--> 558             trace = _mp_sample(**sample_args, **parallel_args)
    559         except pickle.PickleError:
    560             _log.warning("Could not pickle model, sampling singlethreaded.")

~/opt/anaconda3/lib/python3.8/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
   1458         traces.append(strace)
   1459 
-> 1460     sampler = ps.ParallelSampler(
   1461         draws,
   1462         tune,

~/opt/anaconda3/lib/python3.8/site-packages/pymc3/parallel_sampling.py in __init__(self, draws, tune, chains, cores, seeds, start_points, step_method, start_chain_num, progressbar, mp_ctx, pickle_backend)
    421         if mp_ctx.get_start_method() != "fork":
    422             if pickle_backend == "pickle":
--> 423                 step_method_pickled = pickle.dumps(step_method, protocol=-1)
    424             elif pickle_backend == "dill":
    425                 try:

~/opt/anaconda3/lib/python3.8/site-packages/pymc3/distributions/distribution.py in __getstate__(self)
    559                 ) from err
    560             else:
--> 561                 raise err
    562         vals = self.__dict__.copy()
    563         vals["logp"] = logp

~/opt/anaconda3/lib/python3.8/site-packages/pymc3/distributions/distribution.py in __getstate__(self)
    552         # Fix https://github.com/pymc-devs/pymc3/issues/3844
    553         try:
--> 554             logp = dill.dumps(self.logp)
    555         except RecursionError as err:
    556             if type(self.logp) == types.MethodType:

~/opt/anaconda3/lib/python3.8/site-packages/dill/_dill.py in dumps(obj, protocol, byref, fmode, recurse, **kwds)
    271     """pickle an object to a string"""
    272     file = StringIO()
--> 273     dump(obj, file, protocol, byref, fmode, recurse, **kwds)#, strictio)
    274     return file.getvalue()
    275 

~/opt/anaconda3/lib/python3.8/site-packages/dill/_dill.py in dump(obj, file, protocol, byref, fmode, recurse, **kwds)
    265     _kwds = kwds.copy()
    266     _kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse))
--> 267     Pickler(file, protocol, **_kwds).dump(obj)
    268     return
    269 

~/opt/anaconda3/lib/python3.8/site-packages/dill/_dill.py in dump(self, obj)
    452             raise PicklingError(msg)
    453         else:
--> 454             StockPickler.dump(self, obj)
    455         stack.clear()  # clear record of 'recursion-sensitive' pickled objects
    456         return

~/opt/anaconda3/lib/python3.8/pickle.py in dump(self, obj)
    485         if self.proto >= 4:
    486             self.framer.start_framing()
--> 487         self.save(obj)
    488         self.write(STOP)
    489         self.framer.end_framing()

~/opt/anaconda3/lib/python3.8/pickle.py in save(self, obj, save_persistent_id)
    558             f = self.dispatch.get(t)
    559             if f is not None:
--> 560                 f(self, obj)  # Call unbound method with explicit self
    561                 return
    562 

~/opt/anaconda3/lib/python3.8/site-packages/dill/_dill.py in save_function(pickler, obj)
   1442             if _memo: pickler._recurse = False
   1443             fkwdefaults = getattr(obj, '__kwdefaults__', None)
-> 1444             pickler.save_reduce(_create_function, (obj.__code__,
   1445                                 globs, obj.__name__,
   1446                                 obj.__defaults__, obj.__closure__,

~/opt/anaconda3/lib/python3.8/pickle.py in save_reduce(self, func, args, state, listitems, dictitems, state_setter, obj)
    690         else:
    691             save(func)
--> 692             save(args)
    693             write(REDUCE)
    694 

~/opt/anaconda3/lib/python3.8/pickle.py in save(self, obj, save_persistent_id)
    558             f = self.dispatch.get(t)
    559             if f is not None:
--> 560                 f(self, obj)  # Call unbound method with explicit self
    561                 return
    562 

~/opt/anaconda3/lib/python3.8/pickle.py in save_tuple(self, obj)
    899         write(MARK)
    900         for element in obj:
--> 901             save(element)
    902 
    903         if id(obj) in memo:

~/opt/anaconda3/lib/python3.8/pickle.py in save(self, obj, save_persistent_id)
    558             f = self.dispatch.get(t)
    559             if f is not None:
--> 560                 f(self, obj)  # Call unbound method with explicit self
    561                 return
    562 

~/opt/anaconda3/lib/python3.8/pickle.py in save_tuple(self, obj)
    899         write(MARK)
    900         for element in obj:
--> 901             save(element)
    902 
    903         if id(obj) in memo:

~/opt/anaconda3/lib/python3.8/pickle.py in save(self, obj, save_persistent_id)
    558             f = self.dispatch.get(t)
    559             if f is not None:
--> 560                 f(self, obj)  # Call unbound method with explicit self
    561                 return
    562 

~/opt/anaconda3/lib/python3.8/site-packages/dill/_dill.py in save_cell(pickler, obj)
   1176     log.info("Ce: %s" % obj)
   1177     f = obj.cell_contents
-> 1178     pickler.save_reduce(_create_cell, (f,), obj=obj)
   1179     log.info("# Ce")
   1180     return
+
+ 
+
+

... last 132 frames repeated, from the frame below ...

~/opt/anaconda3/lib/python3.8/pickle.py in _batch_setitems(self, items)
    995                 for k, v in tmp:
    996                     save(k)
--> 997                     save(v)
    998                 write(SETITEMS)
    999             elif n:

RecursionError: maximum recursion depth exceeded while calling a Python object

Anybody have any ideas what might have happened? It may not be a PyMC thing but any help would be appreciated.

Thanks in advance.