Hi folks,
Recently I began getting an error when using the blackjax NUTS sampler. For example, when executing pm.sample
as follows:
idata=pm.sample(draws=2000, nuts_sampler='blackjax', step=pm.NUTS(target_accept=0.95), progressbar=False)
I get the following error:
TypeError: clip() got an unexpected keyword argument 'max'
The error stems from the jax numpy library from what I can tell as shown below:
Here are my dependencies in case it helps:
`Python 3.10.7`
['blackjax==1.2.1',
'numpy==1.23.4',
'scipy==1.9.1',
'pandas==1.5.3',
'jax==0.4.23',
'matplotlib==3.8.3',
'scikit-learn==1.2.2',
'pymc==5.10.2',
'arviz==0.16.1',
'altair==5.3.0',
'tensorflow==2.12.0']
I was not able to find any other information on this error in PYMC-related posts—and it just started happening a few weeks ago.
Any pointers in the right direction would be much appreciated. Thanks very much for all the great work on PYMC.