Thanks for your reply! ![]()
There is no missing value in my inputs. For the ease of reading, I only include first 5 elements if its length is greater than 5. Below are the inputs I checked:
> combined_v.eval()
array([[0.5],
[0.5],
[0.5],
[0.5],
[0.5]])
> vS.eval()
array([[0.5, 0.5],
[0.5, 0.5],
[0.5, 0.5],
[0.5, 0.5],
[0.5, 0.5]])
> stim_matrix[:5].eval()
array([[1, 1, 1, 1, 1],
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[0, 0, 0, 0, 0]])
> shock_matrix[:5].eval()
array([[1, 1, 1, 1, 1],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[0, 0, 0, 0, 0]])
> no_subjects
5
> est_scr_matrix[:5]
array([[-1.9896787 , -2.27735786, -2.07463294, -2.20171174, -2.19734734],
[-0.49669095, -0.20901179, -0.41173671, -0.28465791, -0.28902232],
[-0.75986582, -0.31556975, -0.64359855, -0.4463102 , -0.4535504 ],
[-1.81928559, -2.1817793 , -1.90742927, -2.06732156, -2.06118915],
[-0.23351609, -0.10245383, -0.17987487, -0.12300563, -0.12449423]])
>
I forget to include one information in my original post: these error message appear around 33%-45% of the progress bar .
33.65% [2692/8000 1:47:22<3:31:43 Sampling 4 chains, 0 divergences]
Edit:
Instead of the prior in the original post, I’ve tried to fit with the simplest prior I could think of(more details in below code block). I got the same error messages.
beta0 = pm.Normal('beta0',0,1)
beta1 = pm.Normal('beta1',0,1)
lr = pm.Beta('lr', 1,1, shape=no_subjects)
I’ve also checked my update function in below setting:
# test the functionalities of theano.scan and theano.function with fn = update_RW
lr_ = np.random.uniform(low=0.0, high=1.0, size=no_subjects)
lr = tt.vector('lr')
vS = 0.5 * tt.ones((no_subjects,2), dtype='float64')
combined_v = 0.5 * tt.ones((no_subjects,1), dtype='float64')
outputs, updates = theano.scan(
fn=update_RW,
sequences=[stim_matrix, shock_matrix],
outputs_info=[vS, combined_v],
non_sequences=[lr, no_subjects])
op = theano.function(
inputs=[lr],
outputs= outputs
)
oval= op(lr_)
> oval[1].reshape(est_scr_matrix.shape)[:5]
array([[0.82372604, 0.94848156, 0.86056746, 0.91567671, 0.91378403],
[0.17627396, 0.05151844, 0.13943254, 0.08432329, 0.08621597],
[0.29040291, 0.09772858, 0.23998221, 0.15442574, 0.15756555],
[0.74983301, 0.90703277, 0.78805758, 0.85739679, 0.8547374 ],
[0.06214502, 0.0053083 , 0.03888286, 0.01422083, 0.01486639]])