I restarted jupyter kernel and then ran the notebook again. That process was not “hanging” - the sampling progress bar was not refreshing. The step completed in 29 minutes.
As you suggested I ran the notebook in colab today and the step completed in about 29 minutes - also, the sampling progress bar refreshed to show the progress.
You can try a different compute backend by adding compile_kwargs={'mode':'JAX'} or {'mode':'NUMBA'} to the pm.sample_posterior_predictive call. Since the model is just a big matmul, if you have a GPU you might be able to get some lift in jax mode. Numba should be about the same since it’s all BLAS calls anyway.
You can also use a coarser grid – the notebook uses new_lonlat = elevations[["y", "x"]].to_numpy() which is probably huge – you could reduce the number of points you’re sampling.
JAX didn’t help much - perhaps because I didn’t install JAX with cuda –> This will be another effort navigating requirements for Windows 11 (pro) - I don’t think this possible in Windows. Coarser grid did help. NUMBA didn’t help either (as I had expected).
So, best option for windows 11 is use coarser grid.