If you like your model but are concerned about it being slow, try the following:
trace = pm.sample(target_accept=.9, nuts_sampler = 'nutpie',nuts_sampler_kwargs={'backend':'jax')
You’ll need to pip install jax and nutpie but this should give you a very large speedup.