Hi there
I have a model that calculates a correlation coefficient r that looks like this (all code can also be seen in the filebayesian_correlation.py (3.2 KB) attached):
def covariance(sigma, rho):
C = T.fill_diagonal(T.alloc(rho, 2, 2), 1.)
S = T.diag(sigma)
M = S.dot(C).dot(S) #
M1 = T.nlinalg.matrix_inverse(T.nlinalg.matrix_dot(S, C, S))
return M1
def create_correlation_model(data):
with pm.Model() as model:
# priors
mu = pm.Normal(‘mu’, mu=0., sd=1, shape=2,
testval=np.mean(data, axis=1))
sigma = pm.HalfCauchy(‘sigma’, beta=5, shape=2,
testval=np.std(data, axis=1))
rho = pm.Uniform(‘r’, lower=-1, upper=1,
testval=0,
transform=None)
cov = pm.Deterministic(‘cov’, covariance(sigma, rho))
mult_norm = pm.MvNormal(‘mult_norm’, mu=mu,
tau=cov, observed=data.T)
return modelmy_model = create_correlation_model(data)
I want to sample from the prior.
The aim is to calculate a Bayes Factor comparing posterior and prior to knowing the level of confidence (i.e. that the null finding is not due to a lack of power).
But I get an error when using the function sample_prior_predictive:
with my_model:
my_model_prior_trace = pm.sample_prior_predictive (nsamples_corr)
The error is: “ValueError: Input array needs to be 2 dimensional but received a 3d-array.”
Any idea what is going on here?
It looks like a simple shape problem, but I cannot figure it out, as the normal sample function works fine:
with my_model:
my_model_trace = pm.sample(nsamples_corr, tune=ntune_corr,chains=nchains, step=pm.Metropolis(), random_seed=21412, progressbar=True))
Thank you for your support!
P.S. Here the example input data:
data = np.array([[-0.47 , 0.48 , 0.37 , 0.63 , 0.22 ,
0.03 , 0.21 , 0.4 , -0.35 , 0.5 ,
-0.54 , 0.64 , 0.2 , 0.27 , 0.26 ,
1.01 , -0.29 , 0.51 , 0.51 , 0.7 ,
0.61 , 0.2 , 0.51 , 1.31 , 0.82 ,
-0.11 , 0.91 , 0.66 , 0.72 , 0.74 ,
-0.92 , 0.1 , 1.16 , -0.84 , 0.57 ,
0.66 ],
[-0.5710566 , 0.13970771, -0.78708349, -0.72088874, 0.66054004,
-0.58559405, 0.45714096, 0.31919394, -0.31985194, -1.27500595,
-0.4261229 , -0.49993372, -0.30406762, -0.90723603, 0.28160052,
-0.50950595, -0.14980235, 0.32254719, -1.05328163, -0.23188237,
-1.0048891 , -1.10937546, 0.06748289, -0.35484265, -0.77177524,
-0.0073387 , -0.84602639, 0.28435654, 0.17244235, -0.24576276,
-0.5445294 , -0.15638686, -0.56165015, -0.10054537, 0.0668456 ,
0.284587 ]])