Consider this simple pytensor function, which should fetch data from a shared variable using an index range specified by input integer values. Note that ‘train_data’ and ‘input_pp’ are shared variables.
The purpose is to avoid data transfers, as the shared variables reside on the GPU
and this function efficiently fetches a subset of the data from ‘train_data’ → ‘input_pp’
I looked deeper and it appears that in:
pytensor/link/jax/dispatch/subtensor.py
the function
subtensor_assert_indices_jax_compatible()
wants ALL slice parameters to be constant, not just the start index.
I wonder if that could be fixed?
def subtensor_assert_indices_jax_compatible(node, idx_list):
from pytensor.graph.basic import Constant
from pytensor.tensor.var import TensorVariable
ilist = indices_from_subtensor(node.inputs[1:], idx_list)
for idx in ilist:
if isinstance(idx, TensorVariable):
if idx.type.dtype == "bool":
raise NotImplementedError(BOOLEAN_MASK_ERROR)
elif isinstance(idx, slice):
for slice_arg in (idx.start, idx.stop, idx.step):
if slice_arg is not None and not isinstance(slice_arg, Constant):
I removed the check, and my first attempt seemed to fail. Not sure if this has
to do with the dynamic start index, but I gog the error:
fpbn failed Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int32)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(int32)>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
Apply node that caused the error: DeepCopyOp(TensorConstant{0})
Seems like we need special logic in how we convert Subtensor operations to use that specific slice operator when possible. Can you open an issue on the pytensor GitHub repository linking to this thread?
Looks neat. Why do you need the reshape at the end?
Would be nice to figure out when we have that type of graph in PyTensor, in which case we could introduce that Op in the JAX backend, so users don’t have to implement it / know about it.
Would you be interested in opening a PR in PyTensor? Even if it’s just a draft and you abandon it it can be very useful for devs.
By the way, thanks for poking around the JAX backend and reporting bugs / issues. It really helps with moving it from “experimental” to “fully supported”
The reshape is there because I found it easier to first convert to a matrix (2 dim) before I use jax.lax.dynamic_slice() , then later reshape to the original tensor shape. I suppose I could use jax.lax.dynamic_slice() on the tensor, but I would need to keep track of all the dimensions, and I thought this was easier. I think dynamic indexing is a necessary thing to do whenever processing data in mini-batches. I’m surprised it has not been an issue yet. Maybe most of the NN people have moved to Tensorflow, but I thought Theano was much nicer. #:^) . OK, I can open a pull request (PR). I have to read up on doing that…
I was just wondering whether it’s still impossible to a dynamic slicing within jitted functions? Is there any chance that dynamic_update_slice may be used instead (although I don’t know how)?