Reinforcement learning - help building a model

Hi all!
I have an aversive learning task, in which participant is watching two stimuli (one at a time). One is associated with shocks (30% of the times) and the other is neve paired with shock. We measure SCR as a way to assess if aversive learning had occurred.
Now, I try to build a Bayesian model to assess the learning rate (Rescorla Wagner’s alpha). To make sure the model is reasonable, I have created a simulated data based on the same assumptions.
Code of creating the simulated data:

shockVec = np.zeros(len(scrVec), dtype=np.int32) # vector to capture shock (1=yes, 0=no)
stimVec = np.zeros(len(scrVec), dtype=np.int32) # vector to capture stimulus (1=CS+, 2= CS-)
# build stim vector and shock vector based on real trial 
for i, cond in enumerate(scrTwo['Condition'].values):
    #print(i)
    if cond=='CSplusUS1':
        shockVec[i]= 1
        stimVec[i] = 1
    else:
        shockVec[i] = 0
        if cond=='CSminus1':
            stimVec[i] = 0
        else:
            stimVec[i] = 1
print(shockVec.shape)
print(stimVec.shape)

def simulateSCR(alpha, stimVec, shockVec, intercept, slope):
    scrSim = np.zeros(len(stimVec))
    scrCSp = 0.5
    scrCSm = 0.5
    # set intercept and slopes
    for i,(s,t) in enumerate(zip(stimVec,shockVec)):
       
        if s==1:      
            pe = shockVec[i] - scrCSp   # prediction error
            scrCSp = scrCSp + alpha*pe
            scrSim[i] = scrCSp
        if s==0:
            pe = shockVec[i] - scrCSm   # prediction error
            scrCSm = scrCSm + alpha*pe
            scrSim[i] = scrCSm
        # add intercept and slope
        scrSim[i] = scrSim[i] + np.random.normal(0,1) # add noise #
        
        scrSim[i] =  slope*scrSim[i]
        
    return scrSim

# generate 10 subjects with different alphas
alphalist = []
interceptList = []
slopeList = []
subjects = np.empty([30,10]) # create an empty matrix of trials X subjects
for i in np.arange(10):
    print(i)
    alpha = np.random.beta(a=1,b=1)
    intercept = np.random.normal(0,1)
    slope = np.random.normal(0,1)
    subjects[:,i] = simulateSCR(alpha, stimVec, shockVec, intercept, slope)
    alphalist.append(alpha)
    interceptList.append(intercept)
    slopeList.append(slope)

n_trials=30
trials, subj = np.meshgrid(range(n_trials), range(n_subj))
trials = tt.as_tensor_variable(trials.T)
subj = tt.as_tensor_variable(subj.T)

stim =np.reshape([stimVec]*n_subj, (n_subj,30)).T # transform to matrix trials x subject
shock = np.reshape([shockVec]*n_subj, (n_subj,30)).T
# turn to tensores
stim = tt.as_tensor_variable(stim)
shock = tt.as_tensor_variable(shock)

No, I know the alpha of each subjects of course.
Thanks to @Maria and @ricardoV94 (see here, and here) I was able to strat building the model.
Because SCR is a vector that includes responses for both reinforced and non reinforced, I have tweaked Ricardo’s code a bit.

The code of the model is as follows:

def update_Q(stim, shock,
             Qs,vec,
             alpha, n_subj):
    """
    This function updates the Q table according to the RL update rule.
    It will be called by theano.scan to do so recursevely, given the observed data and the alpha parameter
    This could have been replaced be the following lamba expression in the theano.scan fn argument:
        fn=lamba action, reward, Qs, alpha: tt.set_subtensor(Qs[action], Qs[action] + alpha * (reward - Qs[action]))
    """
     
    PE = shock - Qs[tt.arange(n_subj), stim]
    Qs = tt.set_subtensor(Qs[tt.arange(n_subj),stim], Qs[tt.arange(n_subj),stim] + alpha * PE)
    
    # in order to get a vector of expected outcome (dependent on the stimulus presentes [CS+, CS-] 
    # we us if statement (switch in theano)
    vec = tt.set_subtensor(vec[tt.arange(n_subj),0], (tt.switch(tt.eq(stim,1), 
                                                                Qs[tt.arange(n_subj),1], Qs[tt.arange(n_subj),0])))
    
    return Qs, vec

with pm.Model() as mB:
    

    alpha = pm.Beta('alpha', 1,1, shape=n_subj)
    beta = pm.Normal('beta',0, 1, shape=n_subj)
    eps = pm.HalfNormal('eps', 10)
    
    #betas = tt.tile(tt.repeat(beta,1), n_trials).reshape([n_trials, n_subj])    # Q_sub.shape
    
    Qs = 0.5 * tt.ones((n_subj,2), dtype='float64') # set values for boths stimuli (CS+, CS-)
    vec0 = 0.5 * tt.ones((n_subj,1), dtype='float64') # vector to save the relevant stimulus's expactation
    
    [Qs,vec], updates = theano.scan(
        fn=update_Q,
        sequences=[stim, shock],
        outputs_info=[Qs, vec0],
        non_sequences=[alpha, n_subj])
   
    vec = tt.concatenate([[vec0], vec[:-1]], axis=0) # add first value, remove last
    
    vec_ = vec[trials,subj,0]* beta
    # change to vector
    vec_reshape = vec_.T.reshape([n_trials * n_subj])#vec_.T.reshape([30*n_subj])
    scrVec = subjects.T.flatten()
    scrs = pm.Normal('scrs',vec_reshape, eps, observed=scrVec) 
    
    trB = pm.sample(target_accept=.9, chains=4, cores=8, return_inferencedata=True)

Unfortunately, the model is far from recovering the real alpha of each participant. I suspect I’ve messed up with the shapes but was unable to figure it out (theano is hard…).

Any help would be much appreciated.

Thanks!

I would suggest you check if the logp of your model matches a pure python/numpy implementation.

Is the normal likelihood appropriate for your data?

Perhaps 30 trials is way too little to inform your posterior?

1 Like

Thanks for the reply.
I was actually wondering whether I need to test logp (like your code) instead of estimating the SCR response directly. Although, I’m having a bit of trouble revising your code to multiple subjects (in theano).

Perhaps it’s better to start with a single subject

1 Like

Thanks for the pointers!
I have used the following code, using MLE to recover the parameters. Found out that it can do so relatively accurately only when using the student-t distribution.

def llik_td(x, *args):
    # Extract the arguments as they are passed by scipy.optimize.minimize
    alpha, beta = x
    stim, shock, scr  = args
    
    scrSim = np.zeros(len(stim))
    scrCSp = 0.5
    scrCSm = 0.5
    # set intercept and slopes
    for i,(s,t) in enumerate(zip(stim,shock)):
       
        if s==1:      
            pe = shockVec[i] - scrCSp   # prediction error
            scrCSp = scrCSp + alpha*pe
            scrSim[i] = scrCSp
        if s==0:
            pe = shockVec[i] - scrCSm   # prediction error
            scrCSm = scrCSm + alpha*pe
            scrSim[i] = scrCSm
        # add intercept and slope
        scrSim[i] = scrSim[i] 
        
        scrSim[i] =  beta*scrSim[i]
   
    scrPred =  slope * scrSim
    # Calculate the log-likelihood for normal distribution
    LL = np.sum(scipy.stats.t.logpdf(scr, scrPred))
    # Calculate the negative log-likelihood
    neg_LL = -1*LL
    return neg_LL 

estLog = []
for i in np.arange(n_subj):
    x0 = [alphalist[i], slopeList[i]]
    estLog.append(scipy.optimize.minimize(llik_td, x0, args=( stimVec,shockVec, subjects[:,i]), method='L-BFGS-B'))
    print(estLog[i].x)

Now I’m not sure how to implement a similar method into PyMC

That line looks weird. Looking at scipy.stats.t — SciPy v1.7.1 Manual, it seems you are passing scr as x and scr_Pred as the degrees of freedom of the StudentT.

If you go back to the Normal model and increase the number of trials, do inferences get closer to the true parameters?

Ok. I have doubled the number of simulated trials. Using the stats.norm… results are not near true values.
If I add some noise to the alpha and beta values (the ones I input to the LL function) - I’m unable to recover the true parameters whether using t or normal.

That could suggest you have some parameter redundancy (or code bug), although 60 trials is not really that much. If you try something like 500 trials for one participant do things get any better?

Nope.
Using stats.norm.logpdf doesn’t improve. Even when I’m around 500 trials.
I guess code error is the reasonable option. I just can’t really find it yet…

Yeah that is the most likely.

My debug strategy would probably be to remove or set some parameters to the fixed known values until you are sure every component is working. You can also compute the log-likelihood by hand for a very small dataset to compare with the python output.

1 Like

OK.
Found the problem in the code. Very silly (put embraced face here:)

I have accidentally fed the LL function with the shock vector instead of the stim and vise verse. Once this was corrected, 30 trials are enough to recover alpha relatively accurate.

Now, back to how to implement this in pymc

OK.
I have played with the number of trials. Things seem much better now. In 60+ trials the recovery is relatively good. Using the following pymc model

with pm.Model() as mB:
    
   # betaHyper= pm.Normal('betaH', 0, 1)
    alpha = pm.Beta('alpha', 1,1, shape=n_subj)
    beta = pm.Normal('beta',0, 1, shape=n_subj)
    eps = pm.HalfNormal('eps', 5)
    nu = pm.Gamma('nu',2,0.1)
    #betas = tt.tile(tt.repeat(beta,1), n_trials).reshape([n_trials, n_subj])    # Q_sub.shape
    
    Qs = 0.5 * tt.ones((n_subj,2), dtype='float64') # set values for boths stimuli (CS+, CS-)
    vec = 0.5 * tt.ones((n_subj,1), dtype='float64') # vector to save the relevant stimulus's expactation
    
    [Qs,vec], updates = theano.scan(
        fn=update_Q,
        sequences=[stim, shock],
        outputs_info=[Qs, vec],
        non_sequences=[alpha, n_subj])
   
    #vec = tt.concatenate([[vec0], vec[:-1]], axis=0) # add first value, remove last
    
    vec_ = vec[trials, subj,0] * beta[subj]
    # change to vector
    #vec_reshape = vec_.T.reshape([n_trials * n_subj])#vec_.T.reshape([30*n_subj])
    #scrVec = subjects.T.flatten()
    scrs = pm.Normal('scrs', vec_, eps, observed=subjects) 
    
    trB = pm.sample(target_accept=.9, chains=4, cores=10, return_inferencedata=True)

And the same update function that was presented earlier.
Now I’ll try to play with the hierarchy which might help with the accuracy.
Thanks for the help.
If anyone sees any issues with the current model - please let me know

2 Likes

Just FYI if people are interested.
I’ve uploaded the simulation script + different models to recover the alpha parameter to this github repository.
https://github.com/orduek/aversive_learning_simulation

4 Likes

I’m trying to run the same model with Pymc v4 and JAX as the backend.
I have made small changes in the original code (to remove trial with shock administered), but this works perfectly well with Pymc3.
This is the current code I try to run:

with pm.Model() as m5:
    
    # α
    phi = pm.Uniform("phi", lower=0.0, upper=1.0)
    kappa_log = pm.Exponential("kappa_log", lam=1.5)
    kappa = pm.Deterministic("kappa", tt.exp(kappa_log))
    alpha = pm.Beta("alpha", alpha=phi * kappa, beta=(1.0 - phi) * kappa, shape=n_subj)
    
    # β (reparametarization)
    beta_h = pm.Normal('beta_h', 0,1, shape=n_subj)
    beta_sd = pm.HalfNormal('beta_sd', 5)
    beta = pm.Deterministic('beta',0 + beta_h*beta_sd)
       
    eps = pm.HalfNormal('eps', 5)
    
    Qs = 0.5 * tt.ones((n_subj,2), dtype='float64') # set values for boths stimuli (CS+, CS-)
    vec = 0.5 * tt.ones((n_subj,1), dtype='float64') # vector to save the relevant stimulus's expactation
    
    [Qs,vec, pe], updates = aesara.scan(
        fn=update_Q,
        sequences=[stim, shock],
        outputs_info=[Qs, vec, None],
        non_sequences=[alpha, n_subj])
   
     
    vec_ = vec[trials, subj,0] * beta[subj]
    # add matrix of expected values (trials X subjects)
    ev = pm.Deterministic('expected_value', vec_)
    # add PE
    pe = pm.Deterministic('pe', pe)
    
    # transform to vector
    v = tt.reshape(vec_.T, n_subj*n_trials, ndim=1)
    # clean shocks
    vec_clean = v[shockVec==0]
       
    # likelihood function
    scrs = pm.Normal('scrs', mu = vec_clean, sigma = eps, observed=scrClean) 
    trH_phi = pm.sampling_jax.sample_numpyro_nuts(target_accept=.95, chains=4,  draws = 5000)

And this is the error I receive:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/aesara/tensor/type.py:276, in TensorType.dtype_specs(self)
    275 try:
--> 276     return self.dtype_specs_map[self.dtype]
    277 except KeyError:

KeyError: 'object'

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Input In [76], in <cell line: 1>()
      1 with m5:
----> 2     trH_phi = pm.sampling_jax.sample_numpyro_nuts(target_accept=.95, chains=4,  draws = 5000)

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:474, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, idata_kwargs, nuts_kwargs)
    465 print("Compiling...", file=sys.stdout)
    467 init_params = _get_batched_jittered_initial_points(
    468     model=model,
    469     chains=chains,
    470     initvals=initvals,
    471     random_seed=random_seed,
    472 )
--> 474 logp_fn = get_jaxified_logp(model, negative_logp=False)
    476 if nuts_kwargs is None:
    477     nuts_kwargs = {}

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:106, in get_jaxified_logp(model, negative_logp)
    104 if not negative_logp:
    105     model_logpt = -model_logpt
--> 106 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt])
    108 def logp_fn_wrap(x):
    109     return logp_fn(*x)[0]

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:81, in get_jaxified_graph(inputs, outputs)
     75 def get_jaxified_graph(
     76     inputs: Optional[List[TensorVariable]] = None,
     77     outputs: Optional[List[TensorVariable]] = None,
     78 ) -> List[TensorVariable]:
     79     """Compile an Aesara graph into an optimized JAX function"""
---> 81     graph = _replace_shared_variables(outputs)
     83     fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
     84     # We need to add a Supervisor to the fgraph to be able to run the
     85     # JAX sequential optimizer without warnings. We made sure there
     86     # are no mutable input variables, so we only need to check for
     87     # "destroyers". This should be automatically handled by Aesara
     88     # once https://github.com/aesara-devs/aesara/issues/637 is fixed.

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:69, in _replace_shared_variables(graph)
     63 if any(hasattr(var, "default_update") for var in shared_variables):
     64     raise ValueError(
     65         "Graph contains shared variables with default_update which cannot "
     66         "be safely replaced."
     67     )
---> 69 replacements = {var: at.constant(var.get_value(borrow=True)) for var in shared_variables}
     71 new_graph = clone_replace(graph, replace=replacements)
     72 return new_graph

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:69, in <dictcomp>(.0)
     63 if any(hasattr(var, "default_update") for var in shared_variables):
     64     raise ValueError(
     65         "Graph contains shared variables with default_update which cannot "
     66         "be safely replaced."
     67     )
---> 69 replacements = {var: at.constant(var.get_value(borrow=True)) for var in shared_variables}
     71 new_graph = clone_replace(graph, replace=replacements)
     72 return new_graph

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/aesara/tensor/basic.py:219, in constant(x, name, ndim, dtype)
    213             raise ValueError(
    214                 f"ndarray could not be cast to constant with {int(ndim)} dimensions"
    215             )
    217     assert x_.ndim == ndim
--> 219 ttype = TensorType(dtype=x_.dtype, shape=x_.shape)
    221 return TensorConstant(ttype, x_, name=name)

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/aesara/tensor/type.py:107, in TensorType.__init__(self, dtype, shape, name, broadcastable)
    104         return s
    106 self.shape = tuple(parse_bcast_and_shape(s) for s in shape)
--> 107 self.dtype_specs()  # error checking is done there
    108 self.name = name
    109 self.numpy_dtype = np.dtype(self.dtype)

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/aesara/tensor/type.py:278, in TensorType.dtype_specs(self)
    276     return self.dtype_specs_map[self.dtype]
    277 except KeyError:
--> 278     raise TypeError(
    279         f"Unsupported dtype for {self.__class__.__name__}: {self.dtype}"
    280     )

TypeError: Unsupported dtype for TensorType: object

Any ideas?

Just to make sure, you are using aesara.tensor and not theano.tensor right?

Yes. I have just changed the
import theano.tensor as tt
to:
import aesara.tensor as at

For some reason, after restarting the kernel, I get a different error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [36], in <cell line: 1>()
      1 with m5:
----> 2     trH_phi = pm.sampling_jax.sample_numpyro_nuts(target_accept=.95, chains=4,  draws = 5000)

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:506, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, idata_kwargs, nuts_kwargs)
    503 if chains > 1:
    504     map_seed = jax.random.split(map_seed, chains)
--> 506 pmap_numpyro.run(
    507     map_seed,
    508     init_params=init_params,
    509     extra_fields=(
    510         "num_steps",
    511         "potential_energy",
    512         "energy",
    513         "adapt_state.step_size",
    514         "accept_prob",
    515         "diverging",
    516     ),
    517 )
    519 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
    521 tic3 = datetime.now()

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/mcmc.py:599, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    597     states, last_state = _laxmap(partial_map_fn, map_args)
    598 elif self.chain_method == "parallel":
--> 599     states, last_state = pmap(partial_map_fn)(map_args)
    600 else:
    601     assert self.chain_method == "vectorized"

    [... skipping hidden 13 frame]

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    379 rng_key, init_state, init_params = init
    380 if init_state is None:
--> 381     init_state = self.sampler.init(
    382         rng_key,
    383         self.num_warmup,
    384         init_params,
    385         model_args=args,
    386         model_kwargs=kwargs,
    387     )
    388 sample_fn, postprocess_fn = self._get_cached_fns()
    389 diagnostics = (
    390     lambda x: self.sampler.get_diagnostics_str(x[0])
    391     if rng_key.ndim == 1
    392     else ""
    393 )  # noqa: E731

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc.py:746, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    726 hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    727     init_params,
    728     num_warmup=num_warmup,
   (...)
    743     rng_key=rng_key,
    744 )
    745 if rng_key.ndim == 1:
--> 746     init_state = hmc_init_fn(init_params, rng_key)
    747 else:
    748     # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
    749     # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
    750     # wa_steps because those variables do not depend on traced args: init_params, rng_key.
    751     init_state = vmap(hmc_init_fn)(init_params, rng_key)

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc.py:726, in HMC.init.<locals>.<lambda>(init_params, rng_key)
    723         dense_mass = [tuple(sorted(z))] if dense_mass else []
    724     assert isinstance(dense_mass, list)
--> 726 hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    727     init_params,
    728     num_warmup=num_warmup,
    729     step_size=self._step_size,
    730     num_steps=self._num_steps,
    731     inverse_mass_matrix=inverse_mass_matrix,
    732     adapt_step_size=self._adapt_step_size,
    733     adapt_mass_matrix=self._adapt_mass_matrix,
    734     dense_mass=dense_mass,
    735     target_accept_prob=self._target_accept_prob,
    736     trajectory_length=self._trajectory_length,
    737     max_tree_depth=self._max_tree_depth,
    738     find_heuristic_step_size=self._find_heuristic_step_size,
    739     forward_mode_differentiation=self._forward_mode_differentiation,
    740     regularize_mass_matrix=self._regularize_mass_matrix,
    741     model_args=model_args,
    742     model_kwargs=model_kwargs,
    743     rng_key=rng_key,
    744 )
    745 if rng_key.ndim == 1:
    746     init_state = hmc_init_fn(init_params, rng_key)

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc.py:322, in hmc.<locals>.init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, num_steps, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, regularize_mass_matrix, model_args, model_kwargs, rng_key)
    320 r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
    321 vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
--> 322 vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
    323 energy = vv_state.potential_energy + kinetic_fn(
    324     wa_state.inverse_mass_matrix, vv_state.r
    325 )
    326 zero_int = jnp.array(0, dtype=jnp.result_type(int))

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc_util.py:278, in velocity_verlet.<locals>.init_fn(z, r, potential_energy, z_grad)
    270 """
    271 :param z: Position of the particle.
    272 :param r: Momentum of the particle.
   (...)
    275 :return: initial state for the integrator.
    276 """
    277 if potential_energy is None or z_grad is None:
--> 278     potential_energy, z_grad = _value_and_grad(
    279         potential_fn, z, forward_mode_differentiation
    280     )
    281 return IntegratorState(z, r, potential_energy, z_grad)

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc_util.py:246, in _value_and_grad(f, x, forward_mode_differentiation)
    244     return f(x), jacfwd(f)(x)
    245 else:
--> 246     return value_and_grad(f)(x)

    [... skipping hidden 7 frame]

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:109, in get_jaxified_logp.<locals>.logp_fn_wrap(x)
    108 def logp_fn_wrap(x):
--> 109     return logp_fn(*x)[0]

File /tmp/tmpwurs7sxl:44, in jax_funcified_fgraph(phi_interval_, kappa_log_log_, alpha_logodds_, beta_h, beta_sd_log_, eps_log_)
     42 auto_194865 = log(auto_194201)
     43 # forall_inplace,cpu,scan_fn}(TensorConstant{69}, TensorConstant{[[ True  T..se False]]}, TensorConstant{[[1 1 1 ...... 0 0 0]]}, TensorConstant{[[1 1 1 ...... 0 0 0]]}, IncSubtensor{InplaceSet;:int64:}.0, IncSubtensor{InplaceSet;:int64:}.0, Elemwise{sigmoid,no_inplace}.0)
---> 44 auto_198126, auto_198127 = scan(auto_191975, auto_194927, auto_193120, auto_192695, auto_197868, auto_197866, auto_194208)
     45 # Elemwise{mul,no_inplace}(beta_h, InplaceDimShuffle{x}.0)
     46 auto_194205 = elemwise(beta_h, auto_194204)

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:420, in jax_funcify_Scan.<locals>.scan(*outer_inputs)
    419 def scan(*outer_inputs):
--> 420     scan_args = ScanArgs(
    421         list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info
    422     )
    424     # `outer_inputs` is a list with the following composite form:
    425     # [n_steps]
    426     # + outer_in_seqs
   (...)
    431     # + outer_in_nit_sot
    432     # + outer_in_non_seqs
    433     n_steps = scan_args.n_steps

TypeError: __init__() missing 1 required positional argument: 'as_while'

Yes that’s unfortunately expected. Scan is not working in our JAX backend at the moment: JAX backend fails with simple Scan examples · Issue #924 · aesara-devs/aesara · GitHub

1 Like