Any way to leverage multiple GPU while compile?

Hello,

Nutpie is excellent work and significantly accelerates our code via GPU. However, currently, it only leverages 1 GPU on Azure while compiling. Is there any way for the code to use all the GPU detected?

Here are the logs:

2025-03-27 23:37:02|INFO|hbmmm_model:1355|sample_model(): Default backend: gpu
2025-03-27 23:37:02|INFO|hbmmm_model:1357|sample_model(): Available devices: [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)]
DEBUG:2025-03-27 23:38:00,167:jax._src.dispatch:184: Finished tracing + transforming convert_element_type for pjit in 0.000442028 sec
DEBUG:2025-03-27 23:38:00,192:jax._src.interpreters.pxla:1911: Compiling convert_element_type with global shapes and types [ShapedArray(uint8[25254])]. Argument mapping: (UnspecifiedValue,).
DEBUG:2025-03-27 23:38:00,254:jax._src.dispatch:184: Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.061057568 sec
DEBUG:2025-03-27 23:38:00,254:jax._src.compiler:167: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CudaDevice(id=0)]]
DEBUG:2025-03-27 23:38:00,255:jax._src.compiler:260: get_compile_options XLA-AutoFDO profile: using XLA-AutoFDO profile version -1

I was wondering any opportunities here by passing several params? Specifically, where can I change these params?

get_compile_options: num_replicas=1 num_partitions=1

Is it here?

compiled_model = nutpie.compile_pymc_model(self.model, **nutpie_args)
idata = nutpie.sample(compiled_model, **kwargs)

Thanks for your help! I also post it in the nupie git.