Is there a way to pass a function to the logp function in CustomDist?

I have a custom distribution that has a neural network inside. In my logp function that I pass to the CustomDist, I would like to pass as a parameter a function that will be used as the activation function inside logp. I managed to do this by passing a number and mapping that number to a desired function inside logp. For instance, 0 for linear, 1 for relu so on, but maybe there’s maybe a better way to do that and pass the function directly.

Passing a function does not work because logp expects a tensor. I tried to encapsulate the function on a pytensor Op but that is also not working. I get: Cannot convert <coordination.common.activation_function.ActivationFunction object at 0x7fc5094fe9a0> to a tensor variable.

Is there a way to accomplish that or should I stick to my original, less readable, solution?