Potential with array of logps

Thanks for taking a look. Here is my usecase:
I have a blackbox model (in pytorch) that accepts as input a D dim parameter values, a D dim observations and returns a scalar logp value. I also can get the gradients of the parameters.

I use pymc to set the priors for the D dim parameter variables, and want to sample posteriors based on the logp returned by blackbox. I have an aesara op that wraps this torch code, and uses perform/grad as normal (i.e. sets the scalar logp, grad etc. in outputs). This op is called from a pm.Potential function, followed by pm.Sample. All works as expected, similar to the sample here .

The blacbox model is vectorizable - i.e. it can also accept NxD parameters and NxD observations, and return N logp values. Now the logp values are calculated “independent” of each observation/parameter (lets just say that is the intended model design).

In pymc, I now have priors for NxD parameter variables. I have set the aesara inputs appropriately. Now if I sum the logps and return a scalar logp to use by pm.Potential, all works correctly (so I know the inputs are not an issue). But the logps sum is NOT what I want. I want to have N x D x numsamples posteriors generated independently for each llh, i.e. assuming N different logps and corresponding gradients.

Of course, I can call this instead in a loop N times - but that is pretty slow.

I see now that wrapping in pm.Deterministic of the aesara op vector output is the correct way (since pm.Potential will imply sum).

But I am stuck with making the aesara op output variable be a N dimensional vector (instead of a scalar) and pass it to Deterministic. In the JAX sample you provided as well, outputs is [at.dscalar()] - a sample for [at.dvector()] that is used by Deterministic will be useful.

I will wait for you to comment on the usecase before posting the code (it is difficult to extract and provide self-contained working code snippet).