As the title says. Below is a minimal example. This code breaks with the error below even if dummy is never used in logp_fn. The code works if we remove dummy from the parameters and function signature, or if we remove the second dimension in the size of dummy, i.e., (5,) instead of (5,5).
Error:
ValueError: Size length is incompatible with batched dimensions of parameter 2 dummy:
len(size) = 1, len(batched dims dummy) = 2. Size length must be 0 or >= 2
Code:
from typing import Any
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
import pytensor.tensor as ptt
def logp_fn(x: Any,
mu: Any,
sigma: Any,
dummy: Any):
return pm.logp(pm.Normal.dist(mu=mu, sigma=sigma), x).sum()
if __name__ == "__main__":
evidence = np.random.normal(loc=3, size=1000)
with pm.Model() as model:
logp_params = [
ptt.ones(1) * 3,
ptt.ones(1),
pm.Normal("dummy", mu=0, sigma=1, size=(5,5))
]
latent = pm.DensityDist("custom", *logp_params, logp=logp_fn, size=(1,))
obs = pm.Normal("observation", mu=latent, sigma=1, observed=evidence)
idata = pm.sample(1000, init="jitter+adapt_diag", tune=1000, chains=2, random_seed=0, cores=2)
az.plot_trace(idata)
plt.show()