How to set up a pymc environment on google cloud compute platform?

Hello,

I just took a new job which uses google platform cloud compute. Has anyone set up a pymc3/4 environment in the google platform before? I’m the only data scientist interested in Bayesian processes so I don’t have anyone at work who has the knowledge to help with this.

I have an extensive budget so compute cores and GPUs are not a problem as of now. I would appreciate the help.

Its pretty easy to set up on GCE–I use it regularly in my work. Following the Linux setup guide should get you there. If it doesn’t, post the details here and I can try to help you out.

It took longer for my new employer to get me notebook access than I expected but it worked great. Thanks you.

Well, I thought it worked great. It seems I’m getting the following error:

---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
/tmp/ipykernel_2309/275736203.py in <module>
      1 import pymc3 as pm
----> 2 import pymc3.sampling_jax
      3 
      4 import aesara
      5 from aesara import tensor as at

/opt/conda/lib/python3.7/site-packages/pymc3/sampling_jax.py in <module>
     16 import theano.graph.fg
     17 
---> 18 from theano.link.jax.jax_dispatch import jax_funcify
     19 
     20 import pymc3 as pm

/opt/conda/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in <module>
     84 # Older versions < 0.2.0 do not have this flag so we don't need to set it.
     85 try:
---> 86     jax.config.disable_omnistaging()
     87 except AttributeError:
     88     pass

/opt/conda/lib/python3.7/site-packages/jax/_src/config.py in disable_omnistaging(self)
    183   def disable_omnistaging(self):
    184     raise Exception(
--> 185       "Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: "
    186       "see https://github.com/google/jax/blob/main/design_notes/omnistaging.md.")
    187 

Exception: Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: see https://github.com/google/jax/blob/main/design_notes/omnistaging.md.

That stated, I getting confused on which installation instructions to use. The time series data I’ll be working is pretty big, which is why I was interested in trying the jax sampler. Will you point me to the right installation instructions to use?

Disregard. It is up and working.

Hello @fonnesbeck.

In regards to google cloud platform. What do I have to do to use a GPU instance once spun up? Is there a special install of PyMC and Jax? Do I need to convert data to aesera tensors?

This is what I currently have in my instance.

Glad you seem to have gotten it working. To use the GPU, you just have to ensure that jax and numpyro are installed, and to use sample_numpyro_nuts instead of sample to run the model.

1 Like