Hi again! Thank you for getting back to me. Apologies again for bearing with me; currently trying to cram some knowledge of jax in tonight.
numpyro does in fact speed things up a lot! However, I now run into this error when I sample from the conditional posterior after using your reccomended state space set up (copied directly from your code)
idata_post = bvar_mod.sample_conditional_posterior(idata)
/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pytensor/link/jax/linker.py:28: UserWarning: The RandomType SharedVariables [RNG(<Generator(PCG64) at 0x7F80FABA3680>), RNG(<Generator(PCG64) at 0x7F8157418BA0>), RNG(<Generator(PCG64) at 0x7F8157419380>), RNG(<Generator(PCG64) at 0x7F815741AB20>), RNG(<Generator(PCG64) at 0x7F815741A340>), RNG(<Generator(PCG64) at 0x7F80FABA1A80>), RNG(<Generator(PCG64) at 0x7F80FABA34C0>), RNG(<Generator(PCG64) at 0x7F80FABA0900>), RNG(<Generator(PCG64) at 0x7F80FABA0AC0>), RNG(<Generator(PCG64) at 0x7F80FABA33E0>), RNG(<Generator(PCG64) at 0x7F8157419FC0>)] will not be used in the compiled JAX graph. Instead a copy will be used.
warnings.warn(
Sampling: [filtered_posterior, filtered_posterior_observed, predicted_posterior, predicted_posterior_observed, smoothed_posterior, smoothed_posterior_observed]
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=1/0)>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function jax_funcified_fgraph at /tmp/tmpgvhrcj7b:1 for jit. This concrete value was not available in Python because it depends on the value of the argument observed_state.
The error occurred while tracing the function jax_funcified_fgraph at /tmp/tmpgvhrcj7b:1 for jit. This concrete value was not available in Python because it depends on the value of the argument ar_lag.
The error occurred while tracing the function jax_funcified_fgraph at /tmp/tmpgvhrcj7b:1 for jit. This concrete value was not available in Python because it depends on the value of the argument observed_state_aux.
Apply node that caused the error: Scan{scan_fn&scan_fn, while_loop=False, inplace=none}(Cast{int32}.0, Subtensor{start:stop:step}.0, Subtensor{start:stop:step}.0, Subtensor{start:stop:step}.0, Subtensor{start:stop:step}.0, RNG(<Generator(PCG64) at 0x7F80F1F209E0>), RNG(<Generator(PCG64) at 0x7F80F1F20AC0>), Cast{int32}.0, Cast{int32}.0)
Toposort index: 389
Inputs types: [TensorType(int32, shape=()), TensorType(float64, shape=(None, 8)), TensorType(float64, shape=(None, 8, 8)), TensorType(float64, shape=(None, 4)), TensorType(float64, shape=(None, 4, 4)), RandomGeneratorType, RandomGeneratorType, TensorType(int32, shape=()), TensorType(int32, shape=())]
Inputs shapes: [(23, 4), 'No shapes', (), 'No shapes', 'No shapes', (), (), (), 'No shapes', 'No shapes', 'No shapes', 'No shapes', 'No shapes', 'No shapes', 'No shapes', 'No shapes']
Inputs strides: [(8, 184), 'No strides', (), 'No strides', 'No strides', (), (), (), 'No strides', 'No strides', 'No strides', 'No strides', 'No strides', 'No strides', 'No strides', 'No strides']
Inputs values: ['not shown', {'bit_generator': 1, 'state': {'state': -1875307623921742080, 'inc': -2117585609497974999}, 'has_uint32': 0, 'uinteger': 0, 'jax_state': array([3858338214, 3750760192], dtype=uint32)}, array(8), {'bit_generator': 1, 'state': {'state': 5892849408795900207, 'inc': 3467606628746858139}, 'has_uint32': 0, 'uinteger': 0, 'jax_state': array([1372035920, 3458627887], dtype=uint32)}, {'bit_generator': 1, 'state': {'state': -4804107393757949562, 'inc': 7567433675770801199}, 'has_uint32': 0, 'uinteger': 0, 'jax_state': array([3176423879, 1413140870], dtype=uint32)}, array(4), array(2), array(4), {'bit_generator': 1, 'state': {'state': 6642182048432215167, 'inc': -6602606547338022137}, 'has_uint32': 0, 'uinteger': 0, 'jax_state': array([1546503521, 2588365951], dtype=uint32)}, {'bit_generator': 1, 'state': {'state': -4968201944813567915, 'inc': -6434350206830792963}, 'has_uint32': 0, 'uinteger': 0, 'jax_state': array([3138217639, 1660649557], dtype=uint32)}, {'bit_generator': 1, 'state': {'state': 5623932624895122884, 'inc': 7668392534115438363}, 'has_uint32': 0, 'uinteger': 0, 'jax_state': array([1309423852, 3952778692], dtype=uint32)}, {'bit_generator': 1, 'state': {'state': 196117885448786772, 'inc': 9192249282833447751}, 'has_uint32': 0, 'uinteger': 0, 'jax_state': array([ 45662253, 2152108884], dtype=uint32)}, {'bit_generator': 1, 'state': {'state': 4985419079470314668, 'inc': -5619302220929038176}, 'has_uint32': 0, 'uinteger': 0, 'jax_state': array([1160758333, 675837100], dtype=uint32)}, {'bit_generator': 1, 'state': {'state': 5052423089649604090, 'inc': 7414153455853755774}, 'has_uint32': 0, 'uinteger': 0, 'jax_state': array([1176358919, 4186691066], dtype=uint32)}, {'bit_generator': 1, 'state': {'state': 2408003874362962264, 'inc': 126647561998871904}, 'has_uint32': 0, 'uinteger': 0, 'jax_state': array([ 560657091, 4247466328], dtype=uint32)}, {'bit_generator': 1, 'state': {'state': -2526832023300368508, 'inc': 5498351235899043001}, 'has_uint32': 0, 'uinteger': 0, 'jax_state': array([3706643369, 2618922884], dtype=uint32)}]
Outputs clients: [[output[0](filtered_posterior)], [output[1](filtered_posterior_observed)], [output[11](Scan{scan_fn&scan_fn, while_loop=False, inplace=none}.2)], [output[12](Scan{scan_fn&scan_fn, while_loop=False, inplace=none}.3)]]
HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pytensor/link/utils.py:196, in streamline.<locals>.streamline_default_f()
193 for thunk, node, old_storage in zip(
194 thunks, order, post_thunk_old_storage
195 ):
--> 196 thunk()
197 for old_s in old_storage:
File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/jax/_src/core.py:1702, in canonicalize_shape(shape, context)
1700 except TypeError:
1701 pass
-> 1702 raise _invalid_shape_error(shape, context)
from the description it looks like it’s telling me i have some maligned vectors out there…but I’m going to be honest I don’t know where to start here