Mixing pytensor.dot and dimensions

I see it now. I need to include the “clusters” dimension in my \alpha (intercept), but to make it work with the \mu, I need to calculate its mean. Possibly still some tweaking but this is essentially what I was hoping to achieve:

example_data = pd.DataFrame(data = {'BMI': np.array([13.43666, 16.061496, 22.998563, 14.499094, 18.248859, 13.811637, 15.061559, 19.873758, 15.436535, 13.186676]),
                                    'WTGAIN': np.array([20, 39, 99, 31, 47, 99, 99, 1, 34, 29]),
                                    'M_Ht_In': np.array([12, 12.8, 12.2, 12.2, 12.8, 13.8, 12.4, 13, 12.8, 13.4]),
                                    'CLUSTERS': np.array([3, 3, 2, 3, 3, 2, 2, 3, 3, 3]),
                                    'DBWT': np.array([2610, 3190, 3232, 2410, 2780, 3033, 2495, 3518, 3381, 2693])})

cluster_index, cluster_name = pd.factorize(example_data['CLUSTERS'])

with pm.Model() as Hier: 
    # add coordinates
    Hier.add_coord("num_cols", ['BMI','WTGAIN','M_Ht_In'])
    Hier.add_coord("obs", example_data.index)
    Hier.add_coord("clusters", ["cluster_3", "cluster_2"] ) 
    
    # add data containers
    X = pm.MutableData("X", example_data[['BMI','WTGAIN','M_Ht_In']].values)
    y = pm.MutableData("y", example_data['DBWT'].values)
    
    # hyper-prior
    αμ = pm.Normal("αμ", mu=0., sigma=3.,)
    ασ = pm.HalfNormal("ασ", sigma=3.,)
    
    βμ = pm.Normal("βμ", mu=0., sigma=3.,)
    βσ = pm.HalfNormal("βσ", sigma=3.,)
    
    # error
    ϵ = pm.Normal("ϵ", mu=0, sigma=3.)
        
    # using the hyperpior 'δ' 
    α = pm.Normal("α", mu=αμ, sigma=ασ, dims="clusters")
    β = pm.Normal("β", mu=βμ, sigma=βσ, dims= ("clusters", "num_cols"))
    
    # likelihood 
    μ = (X[cluster_index, :] * β[cluster_index, :]).sum(axis=1) + α[cluster_index].mean()
    # μ.shape.eval(), X.shape.eval(), β.shape.eval(), α.shape.eval()
    
    # response
    yhat = pm.Normal("yhat", mu=μ, sigma=ϵ, observed=y)
    
    # sample
    trace = pm.sample(
        250,
        tune=50,
        chains=4,
        return_inferencedata=True,
        idata_kwargs={'log_likelihood':True}
    )
2 Likes