Optimization suggestion for Hierarchical Model using NUTS on CPU/GPU

Hi there,

I am recently using PyMC3. It runs very slows on GPU and CPU.

My model is like the following:

import numpy as np
import pymc3 as pm
import theano.tensor as tt
import theano

with pm.Model() as model_generator:
    sigma_error_prec = pm.Uniform('sigma_error_prec', 0, 100)
    
    B = pm.Normal("B", 0, 100, shape=(n, ))
    log_Sig = pm.Uniform('log_Sig', -8, 8, shape=(n, ))
    SQ = tt.diag(tt.sqrt(tt.exp(log_Sig)))
    
    Func_Covm = tt.dot(tt.dot(SQ, mFunc_t), SQ)
    Struct_Convm = tt.dot(tt.dot(SQ, Struct_t), SQ)
    
    L_fc_vec = tt.reshape(tt.slinalg.cholesky(tt.squeeze(Func_Covm)).T[np.triu_indices(n)], (n*(n+1)//2, ))
    L_st_vec = tt.reshape(tt.slinalg.cholesky(tt.squeeze(Struct_Convm)).T[np.triu_indices(n)], (n*(n+1)//2, ))
    Struct_vec = tt.reshape(Struct_t[np.triu_indices(n)], (n*(n+1)//2, ))
    
    lambdaw = pm.Beta('lambdaw', alpha=1, beta=1, shape=(n*(n+1)//2, ))
    Kf = pm.Beta('Kf', alpha=1, beta=1, shape=(n*(n+1)//2, ))
    rhonn = Kf*( (1-lambdaw)*L_fc_vec + lambdaw*L_st_vec ) + \
        (1-Kf)*( (1-Struct_vec*lambdaw)*L_fc_vec + Struct_vec*lambdaw*L_st_vec )

    Cov_temp = tt.triu(tt.ones((n,n)))
    Cov_temp = tt.set_subtensor(Cov_temp[np.triu_indices(n)], rhonn)
    Cov_mat_v = tt.dot(Cov_temp.T, Cov_temp)

    d = tt.sqrt(tt.diagonal(Cov_mat_v))
    rho = (Cov_mat_v.T/d).T/d
    rhoNew = pm.Deterministic("rhoNew", rho[np.triu_indices(n,1)])

    D = pm.MvNormal('D', mu=tt.zeros(n), cov=Cov_mat_v, shape = (n, ))
    phi_s = pm.Uniform("phi_s", 0, 20, shape = (n, ))
    spat_prec = pm.Uniform("spat_prec", 0, 100, shape = (n, ))

    # 14 ROIs
    # ROI 1
    phi_s1 = pm.Uniform('phi_s1', 0, 20)
    spat_prec1 = pm.Uniform('spat_prec1', 0, 100)
    H_temp1 = tt.sqr(spat_prec1)*tt.exp(-phi_s1*Dist[0])
    H1 = pm.MvNormal('H1', mu=tt.zeros(m), cov=H_temp1, shape = (m, ))

    # ROI 2
    phi_s2 = pm.Uniform('phi_s2', 0, 20)
    spat_prec2 = pm.Uniform('spat_prec2', 0, 100)
    H_temp2 = tt.sqr(spat_prec2)*tt.exp(-phi_s2*Dist[1])
    H2 = pm.MvNormal('H2', mu=tt.zeros(m), cov=H_temp2, shape = (m, ))

    # ROI 3
    phi_s3 = pm.Uniform('phi_s3', 0, 20)
    spat_prec3 = pm.Uniform('spat_prec3', 0, 100)
    H_temp3 = tt.sqr(spat_prec3)*tt.exp(-phi_s3*Dist[2])
    H3 = pm.MvNormal('H3', mu=tt.zeros(m), cov=H_temp3, shape = (m, ))

    # ROI 4
    phi_s4 = pm.Uniform('phi_s4', 0, 20)
    spat_prec4 = pm.Uniform('spat_prec4', 0, 100)
    H_temp4 = tt.sqr(spat_prec4)*tt.exp(-phi_s4*Dist[3])
    H4 = pm.MvNormal('H4', mu=tt.zeros(m), cov=H_temp4, shape = (m, ))

    # ROI 5
    phi_s5 = pm.Uniform('phi_s5', 0, 20)
    spat_prec5 = pm.Uniform('spat_prec5', 0, 100)
    H_temp5 = tt.sqr(spat_prec5)*tt.exp(-phi_s5*Dist[4])
    H5 = pm.MvNormal('H5', mu=tt.zeros(m), cov=H_temp5, shape = (m, ))
    
    # ROI 6
    phi_s6 = pm.Uniform('phi_s6', 0, 20)
    spat_prec6 = pm.Uniform('spat_prec6', 0, 100)
    H_temp6 = tt.sqr(spat_prec6)*tt.exp(-phi_s6*Dist[5])
    H6 = pm.MvNormal('H6', mu=tt.zeros(m), cov=H_temp6, shape = (m, ))
    
    # ROI 7
    phi_s7 = pm.Uniform('phi_s7', 0, 20)
    spat_prec7 = pm.Uniform('spat_prec7', 0, 100)
    H_temp7 = tt.sqr(spat_prec7)*tt.exp(-phi_s7*Dist[6])
    H7 = pm.MvNormal('H7', mu=tt.zeros(m), cov=H_temp7, shape = (m, ))
    
    # ROI 8
    phi_s8 = pm.Uniform('phi_s8', 0, 20)
    spat_prec8 = pm.Uniform('spat_prec8', 0, 100)
    H_temp8 = tt.sqr(spat_prec8)*tt.exp(-phi_s8*Dist[7])
    H8 = pm.MvNormal('H8', mu=tt.zeros(m), cov=H_temp8, shape = (m, ))
    
    # ROI 9
    phi_s9 = pm.Uniform('phi_s9', 0, 20)
    spat_prec9 = pm.Uniform('spat_prec9', 0, 100)
    H_temp9 = tt.sqr(spat_prec9)*tt.exp(-phi_s9*Dist[8])
    H9 = pm.MvNormal('H9', mu=tt.zeros(m), cov=H_temp9, shape = (m, ))
    
    # ROI 10
    phi_s10 = pm.Uniform('phi_s10', 0, 20)
    spat_prec10 = pm.Uniform('spat_prec10', 0, 100)
    H_temp10 = tt.sqr(spat_prec10)*tt.exp(-phi_s10*Dist[9])
    H10 = pm.MvNormal('H10', mu=tt.zeros(m), cov=H_temp10, shape = (m, ))
    
    # ROI 11
    phi_s11 = pm.Uniform('phi_s11', 0, 20)
    spat_prec11 = pm.Uniform('spat_prec11', 0, 100)
    H_temp11 = tt.sqr(spat_prec11)*tt.exp(-phi_s11*Dist[10])
    H11 = pm.MvNormal('H11', mu=tt.zeros(m), cov=H_temp11, shape = (m, ))
    
    # ROI 12
    phi_s12 = pm.Uniform('phi_s12', 0, 20)
    spat_prec12 = pm.Uniform('spat_prec12', 0, 100)
    H_temp12 = tt.sqr(spat_prec12)*tt.exp(-phi_s12*Dist[11])
    H12 = pm.MvNormal('H12', mu=tt.zeros(m), cov=H_temp12, shape = (m, ))
    
    # ROI 13
    phi_s13 = pm.Uniform('phi_s13', 0, 20)
    spat_prec13 = pm.Uniform('spat_prec13', 0, 100)
    H_temp13 = tt.sqr(spat_prec13)*tt.exp(-phi_s13*Dist[12])
    H13 = pm.MvNormal('H13', mu=tt.zeros(m), cov=H_temp13, shape = (m, ))
    
    # ROI 14
    phi_s14 = pm.Uniform('phi_s14', 0, 20)
    spat_prec14 = pm.Uniform('spat_prec14', 0, 100)
    H_temp14 = tt.sqr(spat_prec14)*tt.exp(-phi_s14*Dist[13])
    H14 = pm.MvNormal('H14', mu=tt.zeros(m), cov=H_temp14, shape = (m, ))

    # temporal correlation through W
    # AR(1)
    muW1 = tt.stack(tt.mean(Y_t[:300,0]) - B[0], \
                    tt.mean(Y_t[300:600,0]) - B[1], \
                    tt.mean(Y_t[600:900,0]) - B[2], \
                    tt.mean(Y_t[900:1200,0]) - B[3], \
                    tt.mean(Y_t[1200:1500,0]) - B[4],\
                    tt.mean(Y_t[1500:1800,0]) - B[5],\
                    tt.mean(Y_t[1800:2100,0]) - B[6],\
                    tt.mean(Y_t[2100:2400,0]) - B[7],\
                    tt.mean(Y_t[2400:2700,0]) - B[8],\
                    tt.mean(Y_t[2700:3000,0]) - B[9],\
                    tt.mean(Y_t[3000:3300,0]) - B[10],\
                    tt.mean(Y_t[3300:3600,0]) - B[11],\
                    tt.mean(Y_t[3600:3900,0]) - B[12],\
                    tt.mean(Y_t[3900:4200,0]) - B[13])

    phi_T = pm.Uniform('phi_T', 0, 1, shape=(n, ))
    sigW_T = pm.Uniform('sigW_T', 0, 100, shape=(n, ))

    mean_overall = muW1/(1.0-phi_T)
    tau_overall = (1.0-tt.sqr(phi_T))/tt.sqr(sigW_T)

    W_T1 = pm.Normal('W_T1', mu = mean_overall[0], tau=tau_overall[0], shape=(k, ))
    W_T2 = pm.Normal('W_T2', mu = mean_overall[1], tau=tau_overall[1], shape=(k, ))
    W_T3 = pm.Normal('W_T3', mu = mean_overall[2], tau=tau_overall[2], shape=(k, ))
    W_T4 = pm.Normal('W_T4', mu = mean_overall[3], tau=tau_overall[3], shape=(k, ))
    W_T5 = pm.Normal('W_T5', mu = mean_overall[4], tau=tau_overall[4], shape=(k, ))
    W_T6 = pm.Normal('W_T6', mu = mean_overall[5], tau=tau_overall[5], shape=(k, ))
    W_T7 = pm.Normal('W_T7', mu = mean_overall[6], tau=tau_overall[6], shape=(k, ))
    W_T8 = pm.Normal('W_T8', mu = mean_overall[7], tau=tau_overall[7], shape=(k, ))
    W_T9 = pm.Normal('W_T9', mu = mean_overall[8], tau=tau_overall[8], shape=(k, ))
    W_T10 = pm.Normal('W_T10', mu = mean_overall[9], tau=tau_overall[9], shape=(k, ))
    W_T11 = pm.Normal('W_T11', mu = mean_overall[10], tau=tau_overall[10], shape=(k, ))
    W_T12 = pm.Normal('W_T12', mu = mean_overall[11], tau=tau_overall[11], shape=(k, ))
    W_T13 = pm.Normal('W_T13', mu = mean_overall[12], tau=tau_overall[12], shape=(k, ))
    W_T14 = pm.Normal('W_T14', mu = mean_overall[13], tau=tau_overall[13], shape=(k, ))

    one_m_vec = tt.ones((m, 1))
    one_k_vec = tt.ones((1, k))
    MU_all_1 = B[0] + D[0] + one_m_vec*tt.reshape(W_T1, (1, k)) + tt.reshape(H1, (m, 1))*one_k_vec
    MU_all_2 = B[1] + D[1] + one_m_vec*tt.reshape(W_T2, (1, k)) + tt.reshape(H2, (m, 1))*one_k_vec
    MU_all_3 = B[2] + D[2] + one_m_vec*tt.reshape(W_T3, (1, k)) + tt.reshape(H3, (m, 1))*one_k_vec
    MU_all_4 = B[3] + D[3] + one_m_vec*tt.reshape(W_T4, (1, k)) + tt.reshape(H4, (m, 1))*one_k_vec
    MU_all_5 = B[4] + D[4] + one_m_vec*tt.reshape(W_T5, (1, k)) + tt.reshape(H5, (m, 1))*one_k_vec
    MU_all_6 = B[5] + D[5] + one_m_vec*tt.reshape(W_T6, (1, k)) + tt.reshape(H6, (m, 1))*one_k_vec
    MU_all_7 = B[6] + D[6] + one_m_vec*tt.reshape(W_T7, (1, k)) + tt.reshape(H7, (m, 1))*one_k_vec
    MU_all_8 = B[7] + D[7] + one_m_vec*tt.reshape(W_T8, (1, k)) + tt.reshape(H8, (m, 1))*one_k_vec
    MU_all_9 = B[8] + D[8] + one_m_vec*tt.reshape(W_T9, (1, k)) + tt.reshape(H9, (m, 1))*one_k_vec
    MU_all_10 = B[9] + D[9] + one_m_vec*tt.reshape(W_T10, (1, k)) + tt.reshape(H10, (m, 1))*one_k_vec
    MU_all_11 = B[10] + D[10] + one_m_vec*tt.reshape(W_T11, (1, k)) + tt.reshape(H11, (m, 1))*one_k_vec
    MU_all_12 = B[11] + D[11] + one_m_vec*tt.reshape(W_T12, (1, k)) + tt.reshape(H12, (m, 1))*one_k_vec
    MU_all_13 = B[12] + D[12] + one_m_vec*tt.reshape(W_T13, (1, k)) + tt.reshape(H13, (m, 1))*one_k_vec
    MU_all_14 = B[13] + D[13] + one_m_vec*tt.reshape(W_T14, (1, k)) + tt.reshape(H14, (m, 1))*one_k_vec
    MU_all = tt.concatenate([MU_all_1, MU_all_2, MU_all_3, MU_all_4, \
        MU_all_5, MU_all_6, MU_all_7, MU_all_8, MU_all_9, \
        MU_all_10, MU_all_11, MU_all_12, MU_all_13, MU_all_14], axis = 0)

    Y1 = pm.Normal('Y1', mu = MU_all, sd = sigma_error_prec, observed = Y_t)

with model_generator:
    step = pm.NUTS()
    trace = pm.sample(3000, step = step, chains = 1)

It now runs ~80s/it. I have tried to set chains = 1 or apply GPU in this high-dimensional framework.

This seems to be quite a big model, I am not too surprise that it takes times to sample from.
For what is worth, I would quickly try with initialising with ADVI

trace = pm.sample(1000, tune=1000, init='advi')

If the result is still slow, then some careful reparameterization is needed (better prior, exploit the structure of the model and use some mathematical shortcut/approximation etc)

Thank you for your reply. I am wondering any suggestion of using GPU for my case in optimization?

I dont have much experience of using GPU, but I would expect to see some advantage compare to using CPU as there are quite a lot of large matrix operation happening in multinormal (Cholesky factor etc). You can profile your model to check whether the computation is performed in GPU and which part of the computation is slow: http://docs.pymc.io/notebooks/profiling.html

Junpeng,

What I found out is that my model using NUTS runs faster on CPU than that on GPU. Using chains = 1 is faster than that using chains=None as default setting. What is the right way to apply chains?

The current set up does not take the full advantage of GPU due to the limitation of theano - if you use GPU PyMC3 produce a lot of overhead copying data back and forth between GPU and CPU.

Preferably you should always use multiple chains, if you see the sampler hangs doing chains=4 etc, that usually means that some of the chains are in the region of the parameter space that is difficult to sample for whatever reason (which is an indication you should improve the model). So in general you should do something like trace = pm.sample(1000, tune=1000, cores=4, chains=4)

My model can run both for chains=1 and chains=4. In some Linux computer, it hangs doing for chains=4. What I have read from Github is related with joblib package. One last question, will chains = 4, cores = 4 will be faster?

PyMC3 will try to use the same number of cores to run the chains, but if you have more chains than CPU cores then it will run some of the chains sequantially which makes the run time 2x longer