Hi,
Thanks for the reply. You are correct. I forgot to mention that using NUTS for my model takes an extremely long time but it’s okay to use metropolis hasting.
Here is the code.
def update_decisions(acti_normalized, noise_d_seq, sign_seq, bias, weight,learning_rate,bias_strength):
# Calculate sign of decision
decision=pt.sum(acti_normalized*weight,axis=1) - bias_strength * bias + noise_d_seq
decision=decision/pt.abs(decision)
# Update bias and weight
bias= bias_update*decision+(1-bias_update)*bias
delta=acti_normalized*learning_rate*sign_seq*Amax
weight=weight+(w_max-weight)*pt.switch(delta>=0,delta,0)+(weight-w_min)*pt.switch(delta<0,delta,0)
return bias, weight, decision
BIP=pm.Model()
lower=[0,0,0,0,0]
upper=[0.01,3,3,0.5,0.5]
with BIP:
test=pm.Uniform('test', lower=lower, upper=upper)
noise_d_seq = pt.as_tensor_variable(np.random.normal(0, 1, size=(nTrials,nAverage)))*test[3]
noise_r_seq = pt.as_tensor_variable(np.random.normal(0, 1, size=(nTrials,nAverage,nUnits)))*test[4]
acti_withnoise = acti_seq +noise_r_seq
acti_normalized =((1-pt.exp(-test[2]*pt.switch(acti_withnoise>0,acti_withnoise,0)))/(1+pt.exp(-test[2]*pt.switch(acti_withnoise>0,acti_withnoise,0))))*Amax
results, _ = pytensor.scan(fn=update_decisions,
sequences=[acti_normalized, noise_d_seq, sign_seq],
outputs_info=[bias,weight, None],
non_sequences=[test[0],test[1]])
correctness=(results[2].sum(axis=1)/nAverage)*sign_seq
block_correctness_prob=pt.clip((pt.stack([correctness[IL].sum(axis=1),correctness[IN].sum(axis=1),correctness[IH].sum(axis=1)
,correctness[CL].sum(axis=1),correctness[CN].sum(axis=1),correctness[CH].sum(axis=1)],axis=0)+50)/100,0.001,0.999)
nCorrect=pm.Binomial('nCorrect', n=trials_p_block, p=block_correctness_prob, observed=block_correct)