Hi there,
I am recently using PyMC3. It runs very slows on GPU and CPU.
My model is like the following:
import numpy as np
import pymc3 as pm
import theano.tensor as tt
import theano
with pm.Model() as model_generator:
sigma_error_prec = pm.Uniform('sigma_error_prec', 0, 100)
B = pm.Normal("B", 0, 100, shape=(n, ))
log_Sig = pm.Uniform('log_Sig', -8, 8, shape=(n, ))
SQ = tt.diag(tt.sqrt(tt.exp(log_Sig)))
Func_Covm = tt.dot(tt.dot(SQ, mFunc_t), SQ)
Struct_Convm = tt.dot(tt.dot(SQ, Struct_t), SQ)
L_fc_vec = tt.reshape(tt.slinalg.cholesky(tt.squeeze(Func_Covm)).T[np.triu_indices(n)], (n*(n+1)//2, ))
L_st_vec = tt.reshape(tt.slinalg.cholesky(tt.squeeze(Struct_Convm)).T[np.triu_indices(n)], (n*(n+1)//2, ))
Struct_vec = tt.reshape(Struct_t[np.triu_indices(n)], (n*(n+1)//2, ))
lambdaw = pm.Beta('lambdaw', alpha=1, beta=1, shape=(n*(n+1)//2, ))
Kf = pm.Beta('Kf', alpha=1, beta=1, shape=(n*(n+1)//2, ))
rhonn = Kf*( (1-lambdaw)*L_fc_vec + lambdaw*L_st_vec ) + \
(1-Kf)*( (1-Struct_vec*lambdaw)*L_fc_vec + Struct_vec*lambdaw*L_st_vec )
Cov_temp = tt.triu(tt.ones((n,n)))
Cov_temp = tt.set_subtensor(Cov_temp[np.triu_indices(n)], rhonn)
Cov_mat_v = tt.dot(Cov_temp.T, Cov_temp)
d = tt.sqrt(tt.diagonal(Cov_mat_v))
rho = (Cov_mat_v.T/d).T/d
rhoNew = pm.Deterministic("rhoNew", rho[np.triu_indices(n,1)])
D = pm.MvNormal('D', mu=tt.zeros(n), cov=Cov_mat_v, shape = (n, ))
phi_s = pm.Uniform("phi_s", 0, 20, shape = (n, ))
spat_prec = pm.Uniform("spat_prec", 0, 100, shape = (n, ))
# 14 ROIs
# ROI 1
phi_s1 = pm.Uniform('phi_s1', 0, 20)
spat_prec1 = pm.Uniform('spat_prec1', 0, 100)
H_temp1 = tt.sqr(spat_prec1)*tt.exp(-phi_s1*Dist[0])
H1 = pm.MvNormal('H1', mu=tt.zeros(m), cov=H_temp1, shape = (m, ))
# ROI 2
phi_s2 = pm.Uniform('phi_s2', 0, 20)
spat_prec2 = pm.Uniform('spat_prec2', 0, 100)
H_temp2 = tt.sqr(spat_prec2)*tt.exp(-phi_s2*Dist[1])
H2 = pm.MvNormal('H2', mu=tt.zeros(m), cov=H_temp2, shape = (m, ))
# ROI 3
phi_s3 = pm.Uniform('phi_s3', 0, 20)
spat_prec3 = pm.Uniform('spat_prec3', 0, 100)
H_temp3 = tt.sqr(spat_prec3)*tt.exp(-phi_s3*Dist[2])
H3 = pm.MvNormal('H3', mu=tt.zeros(m), cov=H_temp3, shape = (m, ))
# ROI 4
phi_s4 = pm.Uniform('phi_s4', 0, 20)
spat_prec4 = pm.Uniform('spat_prec4', 0, 100)
H_temp4 = tt.sqr(spat_prec4)*tt.exp(-phi_s4*Dist[3])
H4 = pm.MvNormal('H4', mu=tt.zeros(m), cov=H_temp4, shape = (m, ))
# ROI 5
phi_s5 = pm.Uniform('phi_s5', 0, 20)
spat_prec5 = pm.Uniform('spat_prec5', 0, 100)
H_temp5 = tt.sqr(spat_prec5)*tt.exp(-phi_s5*Dist[4])
H5 = pm.MvNormal('H5', mu=tt.zeros(m), cov=H_temp5, shape = (m, ))
# ROI 6
phi_s6 = pm.Uniform('phi_s6', 0, 20)
spat_prec6 = pm.Uniform('spat_prec6', 0, 100)
H_temp6 = tt.sqr(spat_prec6)*tt.exp(-phi_s6*Dist[5])
H6 = pm.MvNormal('H6', mu=tt.zeros(m), cov=H_temp6, shape = (m, ))
# ROI 7
phi_s7 = pm.Uniform('phi_s7', 0, 20)
spat_prec7 = pm.Uniform('spat_prec7', 0, 100)
H_temp7 = tt.sqr(spat_prec7)*tt.exp(-phi_s7*Dist[6])
H7 = pm.MvNormal('H7', mu=tt.zeros(m), cov=H_temp7, shape = (m, ))
# ROI 8
phi_s8 = pm.Uniform('phi_s8', 0, 20)
spat_prec8 = pm.Uniform('spat_prec8', 0, 100)
H_temp8 = tt.sqr(spat_prec8)*tt.exp(-phi_s8*Dist[7])
H8 = pm.MvNormal('H8', mu=tt.zeros(m), cov=H_temp8, shape = (m, ))
# ROI 9
phi_s9 = pm.Uniform('phi_s9', 0, 20)
spat_prec9 = pm.Uniform('spat_prec9', 0, 100)
H_temp9 = tt.sqr(spat_prec9)*tt.exp(-phi_s9*Dist[8])
H9 = pm.MvNormal('H9', mu=tt.zeros(m), cov=H_temp9, shape = (m, ))
# ROI 10
phi_s10 = pm.Uniform('phi_s10', 0, 20)
spat_prec10 = pm.Uniform('spat_prec10', 0, 100)
H_temp10 = tt.sqr(spat_prec10)*tt.exp(-phi_s10*Dist[9])
H10 = pm.MvNormal('H10', mu=tt.zeros(m), cov=H_temp10, shape = (m, ))
# ROI 11
phi_s11 = pm.Uniform('phi_s11', 0, 20)
spat_prec11 = pm.Uniform('spat_prec11', 0, 100)
H_temp11 = tt.sqr(spat_prec11)*tt.exp(-phi_s11*Dist[10])
H11 = pm.MvNormal('H11', mu=tt.zeros(m), cov=H_temp11, shape = (m, ))
# ROI 12
phi_s12 = pm.Uniform('phi_s12', 0, 20)
spat_prec12 = pm.Uniform('spat_prec12', 0, 100)
H_temp12 = tt.sqr(spat_prec12)*tt.exp(-phi_s12*Dist[11])
H12 = pm.MvNormal('H12', mu=tt.zeros(m), cov=H_temp12, shape = (m, ))
# ROI 13
phi_s13 = pm.Uniform('phi_s13', 0, 20)
spat_prec13 = pm.Uniform('spat_prec13', 0, 100)
H_temp13 = tt.sqr(spat_prec13)*tt.exp(-phi_s13*Dist[12])
H13 = pm.MvNormal('H13', mu=tt.zeros(m), cov=H_temp13, shape = (m, ))
# ROI 14
phi_s14 = pm.Uniform('phi_s14', 0, 20)
spat_prec14 = pm.Uniform('spat_prec14', 0, 100)
H_temp14 = tt.sqr(spat_prec14)*tt.exp(-phi_s14*Dist[13])
H14 = pm.MvNormal('H14', mu=tt.zeros(m), cov=H_temp14, shape = (m, ))
# temporal correlation through W
# AR(1)
muW1 = tt.stack(tt.mean(Y_t[:300,0]) - B[0], \
tt.mean(Y_t[300:600,0]) - B[1], \
tt.mean(Y_t[600:900,0]) - B[2], \
tt.mean(Y_t[900:1200,0]) - B[3], \
tt.mean(Y_t[1200:1500,0]) - B[4],\
tt.mean(Y_t[1500:1800,0]) - B[5],\
tt.mean(Y_t[1800:2100,0]) - B[6],\
tt.mean(Y_t[2100:2400,0]) - B[7],\
tt.mean(Y_t[2400:2700,0]) - B[8],\
tt.mean(Y_t[2700:3000,0]) - B[9],\
tt.mean(Y_t[3000:3300,0]) - B[10],\
tt.mean(Y_t[3300:3600,0]) - B[11],\
tt.mean(Y_t[3600:3900,0]) - B[12],\
tt.mean(Y_t[3900:4200,0]) - B[13])
phi_T = pm.Uniform('phi_T', 0, 1, shape=(n, ))
sigW_T = pm.Uniform('sigW_T', 0, 100, shape=(n, ))
mean_overall = muW1/(1.0-phi_T)
tau_overall = (1.0-tt.sqr(phi_T))/tt.sqr(sigW_T)
W_T1 = pm.Normal('W_T1', mu = mean_overall[0], tau=tau_overall[0], shape=(k, ))
W_T2 = pm.Normal('W_T2', mu = mean_overall[1], tau=tau_overall[1], shape=(k, ))
W_T3 = pm.Normal('W_T3', mu = mean_overall[2], tau=tau_overall[2], shape=(k, ))
W_T4 = pm.Normal('W_T4', mu = mean_overall[3], tau=tau_overall[3], shape=(k, ))
W_T5 = pm.Normal('W_T5', mu = mean_overall[4], tau=tau_overall[4], shape=(k, ))
W_T6 = pm.Normal('W_T6', mu = mean_overall[5], tau=tau_overall[5], shape=(k, ))
W_T7 = pm.Normal('W_T7', mu = mean_overall[6], tau=tau_overall[6], shape=(k, ))
W_T8 = pm.Normal('W_T8', mu = mean_overall[7], tau=tau_overall[7], shape=(k, ))
W_T9 = pm.Normal('W_T9', mu = mean_overall[8], tau=tau_overall[8], shape=(k, ))
W_T10 = pm.Normal('W_T10', mu = mean_overall[9], tau=tau_overall[9], shape=(k, ))
W_T11 = pm.Normal('W_T11', mu = mean_overall[10], tau=tau_overall[10], shape=(k, ))
W_T12 = pm.Normal('W_T12', mu = mean_overall[11], tau=tau_overall[11], shape=(k, ))
W_T13 = pm.Normal('W_T13', mu = mean_overall[12], tau=tau_overall[12], shape=(k, ))
W_T14 = pm.Normal('W_T14', mu = mean_overall[13], tau=tau_overall[13], shape=(k, ))
one_m_vec = tt.ones((m, 1))
one_k_vec = tt.ones((1, k))
MU_all_1 = B[0] + D[0] + one_m_vec*tt.reshape(W_T1, (1, k)) + tt.reshape(H1, (m, 1))*one_k_vec
MU_all_2 = B[1] + D[1] + one_m_vec*tt.reshape(W_T2, (1, k)) + tt.reshape(H2, (m, 1))*one_k_vec
MU_all_3 = B[2] + D[2] + one_m_vec*tt.reshape(W_T3, (1, k)) + tt.reshape(H3, (m, 1))*one_k_vec
MU_all_4 = B[3] + D[3] + one_m_vec*tt.reshape(W_T4, (1, k)) + tt.reshape(H4, (m, 1))*one_k_vec
MU_all_5 = B[4] + D[4] + one_m_vec*tt.reshape(W_T5, (1, k)) + tt.reshape(H5, (m, 1))*one_k_vec
MU_all_6 = B[5] + D[5] + one_m_vec*tt.reshape(W_T6, (1, k)) + tt.reshape(H6, (m, 1))*one_k_vec
MU_all_7 = B[6] + D[6] + one_m_vec*tt.reshape(W_T7, (1, k)) + tt.reshape(H7, (m, 1))*one_k_vec
MU_all_8 = B[7] + D[7] + one_m_vec*tt.reshape(W_T8, (1, k)) + tt.reshape(H8, (m, 1))*one_k_vec
MU_all_9 = B[8] + D[8] + one_m_vec*tt.reshape(W_T9, (1, k)) + tt.reshape(H9, (m, 1))*one_k_vec
MU_all_10 = B[9] + D[9] + one_m_vec*tt.reshape(W_T10, (1, k)) + tt.reshape(H10, (m, 1))*one_k_vec
MU_all_11 = B[10] + D[10] + one_m_vec*tt.reshape(W_T11, (1, k)) + tt.reshape(H11, (m, 1))*one_k_vec
MU_all_12 = B[11] + D[11] + one_m_vec*tt.reshape(W_T12, (1, k)) + tt.reshape(H12, (m, 1))*one_k_vec
MU_all_13 = B[12] + D[12] + one_m_vec*tt.reshape(W_T13, (1, k)) + tt.reshape(H13, (m, 1))*one_k_vec
MU_all_14 = B[13] + D[13] + one_m_vec*tt.reshape(W_T14, (1, k)) + tt.reshape(H14, (m, 1))*one_k_vec
MU_all = tt.concatenate([MU_all_1, MU_all_2, MU_all_3, MU_all_4, \
MU_all_5, MU_all_6, MU_all_7, MU_all_8, MU_all_9, \
MU_all_10, MU_all_11, MU_all_12, MU_all_13, MU_all_14], axis = 0)
Y1 = pm.Normal('Y1', mu = MU_all, sd = sigma_error_prec, observed = Y_t)
with model_generator:
step = pm.NUTS()
trace = pm.sample(3000, step = step, chains = 1)
It now runs ~80s/it. I have tried to set chains = 1
or apply GPU in this high-dimensional framework.