Dynamic shaping, "round" function, JAX, and a "few" more questions

@ricardoV94, @jessegrabowski, Thank you a lot for all the information provided in your replies above. It is extremely useful to get so many insights into how PyTensor and PyMC work. I will need some time to digest everything. I am still trying to get my model working, I think I am very close. I have only been using PyTensor/PyMC/JAX for a few weeks, so it is a lot to assimilate in a short time! I am still facing a few challenges. Any comments or help would be greatly appreciated:

1- I ran into exactly the same issue as the one reported in this thread (How can I output a gradient in vector format in Op.grad instance?) with my custom PyTensor Op ((also wrapping a function in JAX)). If we go back to this “simple” example given by @HajimeKawahara, the issue comes from the fact that the parameter phase is the same scalar used for all the elements of the vector x. In other words, PyTensor doesn’t know that the parameter phase needs to be vectorized. That is why it expects a zero dimensional gradient for the gradient of the Op function with respect to this variable. I solved this in my case by adding inputs = list(pt.broadcast_arrays(*inputs)) at the beginning of the function make_node … I don’t know if there is a recommended strategy to deal with this case. I guess it is only an implementation issue in this specific example, since this problem doesn’t occur for the rest of PyTensor/PyMC. But perhaps it would be useful to complete the PyMC examples by considering this edge case, particularly for (How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs). This is not relevant for (https://www.pymc.io/projects/examples/en/latest/case_studies/wrapping_jax_function.html) since the cost function to minimize is already provided reduced via the sum to a function Potential.

2- Thanks @jessegrabowski. I get the expected result by taking the sum of the function, i.e. the “vectorized” gradient of the function with respect to its input parameters as indicated on this page (Derivatives in PyTensor — PyTensor dev documentation). I don’t find it very intuitive to have to take the sum, I still don’t really understand the logic … but now it works!

3- Is it possible to take a CustomDist of a CustomDist. Let’s take the simplified example below:

def _custom_dist1(mu, sigma, size,):
    _dist = pm.CustomDist.dist(mu, sigma,
        random=_myrndfunc, logp=_mylogp, logcdf=_mylogcdf,
        size=size,)
    return _dist

def _round_dist(mu, sigma, size,):
    _dist = pt.round(_custom_dist1(mu, sigma, size=size))
    return _dist

def _custom_dist2(mu, sigma, size,):
    _dist = pm.CustomDist.dist(mu, sigma,
        dist=_round_dist,
        size=size)
    return _dist

I get the following error when I try to “instantiate” _custom_dist2: “Model variables cannot be created in the dist function. Use the .dist API”.

4- This may be a recurring question, sorry about that: is it possible to get the logp from the PyMC model (via model.logp or model.compile_logp) which accepts the un-transformed parameters of the model as inputs? Or, how can we recover from PyMC the functions associated with these transformations? (It would be great to be able to overlap the model logp curve over the histogram of the observed data (again, I suppose that’s a fairly classic question!)).

5-My Mixture model is made up of a finite/limited number of components weighted by a truncated Poisson Law (~convolution of a Normal Law with a Poisson dist.). However, the random process generating the observed data is not “truncated”. In other words, there may be data to fit which is not “covered” by any component of the Mixture and which is therefore seen as outliers by the model. Even a few points can have a dramatic effect and “pull the whole distribution to the right” and induce a significant bias in the model parameters to be inferred. I was wondering if there was a common strategy for dealing with this scenario. We can of course increase the number of components, but at the expense of a heavier/slower model, and not necessarily a more accurate one. Another option, provided PyMC allows it, would be to trim “in real time” at each iteration of the MCMC procedure the data located too “far” from the last component of the Mixture.

Thank you again for your time and help.