Is there a C implementation of `erfcinv` (or `erfinv`) that we could add to Pytensor to allow loop fusion? Can we write one?

I’m using an erfcinv in a model and get a warning from initialize_fuseable_mappings to the effect that we can’t do loop fusion because there’s no C implementation available / used.

I’m out of my depth here, but a quick search suggests this function simply doesn’t exist in the standard C library and developers have to create their own. algorithm - Computing the inverse of the complementary error function, erfcinv(), in C - Stack Overflow. Relatedly, I see there’s a TODO in the docstring of the related Erfinv to go find such a C implementation.

My question: Is it worth us (me?) trying to implement such a function in Pytensor directly? Or is this liable to just lead to implementation pain?

I think I’ve found an C implementation of erfinv here based on the implementation in go - which could get us halfway there…

Additional reading:

It’s a bit rough but would be a great addition. Perhaps the c code for the gamma functions is a good template (and how it’s called from the related gamma Ops)? https://github.com/pymc-devs/pytensor/blob/main/pytensor/scalar/c_code/gamma.c

Okay cool - I’ll look into it a little more then :slight_smile:

Do you have any feel for the kind of speedup one might get from this? I know that’s fairly impossible to guess, but are we talking fractional percentage gains in speed, or multiples?

It can provide nice speedups but it’s very context dependent.

You can benchmark your d+logp function in JAX to get an idea of a lower bound (they also do loop fusion and are generally faster than PyTensor C)

You could introduce a rewrite that replaces erfcinv by the less stable erfinv(1-x) if that has C code to get a more direct answer.

Having said that I wouldn’t suggest doing a C implementation as an attempt to speedup your model as a user. Too much work for that.

Hmmm… food for thought, thanks!

I’ve not used JAX before and TBH am a little apprehensive of probably days of work to find & fix numerical issues etc (unfounded?)

But I do see JAX has an erfinv - and (without knowing anything about how it works), might it be possible in theory to JAXiify scipy.special.erfcinv too?

Those functions are automatically used by Pytensor jaxify machinery based on the names, there should be no need for you to do anything.

I mentioned it more as a benchmark to assess the max speedup you could hope for by doing a C impl.

Using numpyro/nutpie for sampling is likely to shift sampling speed much more than adding C code for this one Op, but that requires more care so I didn’t suggest that directly.

Scipy already has a C implementation of erfcinv here I think? I’m not sure what would be required to bring it into pytensor.

It seems like it should be possible to write a numba overload for it though, and use nutpie for a speedup. I’m surprised it’s not already overloaded in numba-scipy like most of the scipy.special module.

Indeed it looks like pytensor’s erfcinv already falls back onto the scipy implementation…

I’m all ears about a nutpie / JAX implementation that overloads the scipy one… but that’s a little above my pay grade!

Indeed it looks like pytensor’s erfcinv already falls back onto the scipy implementation…

Yes but that will run in Python mode and indeed break the loop fusion as the warning tells you.

Edited

1 Like

Scipy already has a C implementation of erfcinv here I think? I’m not sure what would be required to bring it into pytensor.

Just copy the files and make sure the Op loads the headers / references the function in the c_code method. It’s not too bad, but requires some trial and error.

Importantly that would be good for users in general, but I doubt it will fix any bottleneck in a specific model (probabilistically)

Okay, that sounds like something a rank C/C++ amateur (like yours truly) could at least attempt :wink:

I’ll look into it in a week or so

1 Like

In terms of standard C libraries, “erfcinv” is actually in the GSL, it’s just slightly hidden: one has to use the inverse CDF gsl_cdf_gaussian_Pinv.

Otherwise, less standard but still common, you can find it directly in Boost for a C++ code:
https://live.boost.org/doc/libs/1_67_0/libs/math/doc/html/math_toolkit/sf_erf/error_inv.html

Thanks for sharing. I don’t know how easy it is to include GSL stuff for our Windows users, so my inclination would be to go with the scipy/cephes implementation.

Moreover, GSL is GPL licensed which is incompatible with BSD/Apache v2.