JAX Sampling Error with TruncatedNormal Distribution
|
|
6
|
458
|
July 27, 2023
|
Dynamic shaping, "round" function, JAX, and a "few" more questions
|
|
22
|
552
|
July 15, 2023
|
Numpyro JAX sampling very slow
|
|
1
|
245
|
July 9, 2023
|
Jax on Apple Silicon GPU
|
|
2
|
460
|
June 13, 2023
|
NameError: unbound axis name raised during transformation of variables after sample_numpyro_nuts
|
|
6
|
476
|
June 13, 2023
|
Wild results for masked multinomial with jax sampler(s)
|
|
2
|
149
|
May 7, 2023
|
Pymc3 on GPU using jax
|
|
2
|
501
|
April 20, 2023
|
Out of Memory when using pm.sampling.jax.sample_blackjax_nuts
|
|
2
|
305
|
March 23, 2023
|
How to write a PyTensor Op to wrap Jax ODEs with multiple input parameters
|
|
2
|
244
|
March 16, 2023
|
How can I output a gradient in vector format in Op.grad instance?
|
|
7
|
446
|
January 14, 2023
|
How to use JAX ODEs and Neural Networks in PyMC
|
|
0
|
183
|
January 4, 2023
|
Numpyro Convergence Diagnostics
|
|
1
|
238
|
December 30, 2022
|
Out of memory when "transforming variables" in Numpyro & JAX
|
|
10
|
719
|
December 12, 2022
|
Sample_numpyro_nuts hangs when parallelizing over datasets with multiprocessing
|
|
0
|
301
|
November 9, 2022
|
JAX Sampling Error when same model had previously worked without it
|
|
1
|
183
|
October 25, 2022
|
Slow inference for numpyro sampling on Colab GPU
|
|
3
|
458
|
October 5, 2022
|
Var_names not working with sample_numpyro_nuts
|
|
5
|
285
|
September 12, 2022
|
Pm.sampling_jax to sample a MvNormal()
|
|
4
|
438
|
August 3, 2022
|
Open all NUTS kwargs for sampling with Numpyro
|
|
5
|
368
|
August 1, 2022
|
Use of numpyro/Jax with pymc-dev
|
|
5
|
1193
|
July 28, 2022
|
Implementation of generalized linear mixed in the form X\beta + Zb
|
|
3
|
289
|
July 27, 2022
|
Gradients from external model in likelihood
|
|
3
|
300
|
July 1, 2022
|
How to debug a model that is not sampling?
|
|
0
|
324
|
June 30, 2022
|
JAX Numpyro backend "IndentationError: unexpected indent"
|
|
16
|
806
|
June 29, 2022
|