Pytensor-Jax does not support slicing arrays with a dynamic slice length

Consider this simple pytensor function, which should fetch data from a shared variable using an index range specified by input integer values. Note that ‘train_data’ and ‘input_pp’ are shared variables.
The purpose is to avoid data transfers, as the shared variables reside on the GPU
and this function efficiently fetches a subset of the data from ‘train_data’ → ‘input_pp’

inp=T.tensor4()
start_index = T.lscalar()
end_index = T.lscalar()
updates = collections.OrderedDict()
updates[input_pp] = inp
fn = pytensor.function(inputs=[start_index,end_index], updates=updates,
givens={inp: train_data[start_index:end_index]})

This causes the error. Not sure if this is a Jax issue or Pytensor issue.

If not fixable, perhaps someone can suggest a different way to accomplish this.

It’s a JAX limitation for Jitted functions, see the alert in this section: 🔪 JAX - The Sharp Bits 🔪 — JAX documentation

It should work if the slice length is fixed. You can change the start index, but the stop must always be the same distance from the start.

Whether JAX understands this is the case from an equivalent generated PyTensor graph I don’t know.

Yes, it is always a fixed batch size. Perhaps if I only pass the start index, and calculate the end index by adding a constant, that might fix it.

I tried using [start_index : start_index+BATCHSIZE], and it still gives the error. Any ideas?

I looked deeper and it appears that in:
pytensor/link/jax/dispatch/subtensor.py
the function
subtensor_assert_indices_jax_compatible()
wants ALL slice parameters to be constant, not just the start index.
I wonder if that could be fixed?

def subtensor_assert_indices_jax_compatible(node, idx_list):
from pytensor.graph.basic import Constant
from pytensor.tensor.var import TensorVariable

ilist = indices_from_subtensor(node.inputs[1:], idx_list)
for idx in ilist:
    if isinstance(idx, TensorVariable):
        if idx.type.dtype == "bool":
            raise NotImplementedError(BOOLEAN_MASK_ERROR)
    elif isinstance(idx, slice):
        for slice_arg in (idx.start, idx.stop, idx.step):
            if slice_arg is not None and not isinstance(slice_arg, Constant):

Yes we should remove the check if we can demonstrate we can still generate a valid JAX graph from the PyTensor one.

This issue arised recently in another PR, where we found the naive checks were too strict because sometimes JAX was still happy to compile a given function: Add JAX support for `pt.tri` by jessegrabowski · Pull Request #302 · pymc-devs/pytensor · GitHub

I removed the check, and my first attempt seemed to fail. Not sure if this has
to do with the dynamic start index, but I gog the error:

fpbn failed Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int32)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(int32)>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
Apply node that caused the error: DeepCopyOp(TensorConstant{0})

Seems like we need special logic in how we convert Subtensor operations to use that specific slice operator when possible. Can you open an issue on the pytensor GitHub repository linking to this thread?

Yes, I’ll do that, thanks

Here is the bug report

1 Like

I wrote this simple JAX-capable Op to do dynamic indexing on the
first dimension of a tansor, with fixed block size.

dynamic_idx.py (2.8 KB)

1 Like

Looks neat. Why do you need the reshape at the end?

Would be nice to figure out when we have that type of graph in PyTensor, in which case we could introduce that Op in the JAX backend, so users don’t have to implement it / know about it.

Would you be interested in opening a PR in PyTensor? Even if it’s just a draft and you abandon it it can be very useful for devs.

By the way, thanks for poking around the JAX backend and reporting bugs / issues. It really helps with moving it from “experimental” to “fully supported” :slight_smile:

The reshape is there because I found it easier to first convert to a matrix (2 dim) before I use jax.lax.dynamic_slice() , then later reshape to the original tensor shape. I suppose I could use jax.lax.dynamic_slice() on the tensor, but I would need to keep track of all the dimensions, and I thought this was easier. I think dynamic indexing is a necessary thing to do whenever processing data in mini-batches. I’m surprised it has not been an issue yet. Maybe most of the NN people have moved to Tensorflow, but I thought Theano was much nicer. #:^) . OK, I can open a pull request (PR). I have to read up on doing that…

1 Like

Hi there (and sorry for the 9 month bump…),

I was just wondering whether it’s still impossible to a dynamic slicing within jitted functions? Is there any chance that dynamic_update_slice may be used instead (although I don’t know how)?

Cheers,
Vian

Freezing the model prior to sampling with JAX can fix some of the problems: pymc.model.transform.optimization.freeze_dims_and_data — PyMC dev documentation