Recent error from Jax's (numpy) clip method

Hi folks,

Recently I began getting an error when using the blackjax NUTS sampler. For example, when executing pm.sample as follows:

idata=pm.sample(draws=2000, nuts_sampler='blackjax', step=pm.NUTS(target_accept=0.95), progressbar=False)

I get the following error:

TypeError: clip() got an unexpected keyword argument 'max'

The error stems from the jax numpy library from what I can tell as shown below:

Here are my dependencies in case it helps:

`Python 3.10.7`

['blackjax==1.2.1',
 'numpy==1.23.4',
 'scipy==1.9.1',
 'pandas==1.5.3',
 'jax==0.4.23',
 'matplotlib==3.8.3',
 'scikit-learn==1.2.2',
 'pymc==5.10.2',
 'arviz==0.16.1',
 'altair==5.3.0',
 'tensorflow==2.12.0']

I was not able to find any other information on this error in PYMC-related posts—and it just started happening a few weeks ago.

Any pointers in the right direction would be much appreciated. Thanks very much for all the great work on PYMC.

You need to update blackjax/jax to the most recent versions. There was a breaking change in JAX that made blackjax start to fail