Hi Developers,
I have a pymc code that involves sorting of model parameters. I am currently using sort
function of pytensor which is working absolutely fine with default pymc sampler. Unfortunately, when I am using numpyro
as a nuts_sampler
, I am getting a compilation error as jax
can not compile sort
. Is there any way to use sort
when I am using numpyro
? Kindly give me some suggestions.
Thank you,
Soumya.
Are you running the latest version of PyMC? If not, try updating. It’s linked with newer PyTensor that should have Sort in JAX backend
Thank you so much @ricardoV94. I have updated to version 5.16.2. Unfortunately this brought a new problem. DensityDist
is giving me some error. The same code was compiling fine in my older version 5.9.0. The error message is following:
ValueError: Could not broadcast dimensions. Incompatible shapes were [(ScalarConstant(ScalarType(int64), data=499500),), (ScalarConstant(ScalarType(int64), data=1000),)]
Can you kindly help me in this? I am providing the part of the code that is creating problem. Omega_lower_star
is generated from standard normal with shape=s
(s=499500
) and Omega_lower_star
is generated from standard half normal with shape=p
(p=1000
). ll
is numpy
two dimensional integer array with dimension (p, p)
and y
is two dimensional numpy array of dimension (n, p)
. I tried the following to understand the problem: I removed Omega_lower_star
from the arguement and removed the part - (0.5) * at.sum(at.tensordot(y, at.concatenate((Omega_lower_star, Omega_diag))[ll], axes = 1)**2)
and there was no error but this do not solve my problem.
def my_logp(y, Omega_lower_star, Omega_diag):
return (n * at.sum(at.log(Omega_diag))) - (0.5) * at.sum(at.tensordot(y, at.concatenate((Omega_lower_star, Omega_diag))[ll], axes = 1)**2)
lik = pm.DensityDist("lik", Omega_lower_star, Omega_diag, logp = my_logp, observed = y)
Can you give me some guidance how to update the code according to newer version of pymc?
Thank you,
Soumya.
Will need to see more of your code to advise how to fix the problem. Ideally a fully reproducible snippet
Hi @ricardoV94,
Thank you so much for helping me out. Here is the reproducible code.
import numpy as np
p = 10
n = 10
y = np.random.multivariate_normal(mean=np.zeros(p), cov=np.eye(p), size=n)
# pymc code
import pymc as pm
import pytensor
from pytensor import tensor as at
import jax
import numpyro
n_param = int(p*(p-1)/2)
indc = at.as_tensor_variable(np.array(range(n_param)))
with pm.Model() as l1_ball_model:
# Diagonal elements of Omega
Omega_diag_z = pm.HalfNormal("Omega_diag_z", sigma=1, shape=p)
Omega_diag = Omega_diag_z * 2
# Lower triangular elements of Omega (off-diagonal elements)
Omega_lower = pm.Normal("Omega_lower", mu=0, sigma=1, shape=n_param)
# Slab precision parameters
#lambda_ = pm.Gamma("lambda_", alpha=1e-4, beta=1e-8, shape=K*(K-1)//2)
#lambda_z = pm.HalfNormal("lambda_z", sigma = 1, shape=n_param)
#lambda_ = lambda_z * 2
r = pm.HalfCauchy("r", beta = 1)
#Omega_lower = Omega_lower_z * 1
abs_x = at.abs(Omega_lower)
sorted_abs_x = (abs_x[at.argsort(abs_x)])[::-1]
mu = at.cumsum(sorted_abs_x) - r
KK = indc[(sorted_abs_x - (mu/(indc + 1))) > 0]
KK = KK[-1]
#KK = at.switch(at.gt(KK.shape[0], 0), KK[0], n_param)
Omega_lower_star = at.maximum(abs_x - (mu[KK]/(KK+1)), 0) * at.sign(Omega_lower)
def my_logp(y, Omega_lower_star, Omega_diag):
return (n * at.sum(at.log(Omega_diag))) - (0.5) * at.sum(at.tensordot(y, at.concatenate((Omega_lower_star, Omega_diag))[ll], axes = 1)**2)
lik = pm.DensityDist("lik", Omega_lower_star, Omega_diag, logp = my_logp, observed = y)
with l1_ball_model:
l1_ball_sample2 = pm.sample(tune = 100, draws = 100, chains = 1, nuts_sampler="numpyro")
Thank you,
Soumya Sahu.
Hi @ricardoV94,
Can you kindly update me on this? Let me know if the code is not reproducible.
Thank you,
Soumya.