If I understand correctly, all of the existing PyTensor backends work with NumPy arrays seamlessly (although in JAX, this doesn’t permit controlling the placement of these arrays in GPU memory).
It’s not the case for PyTorch, which can work with NumPy arrays only after they have been converted via torch.from_numpy()
.
@twiecki @ricardoV94 Should we deal with this by creating a GraphRewriter
that scans the graph for TensorConstant
s with NumPy data and wraps it with torch.from_numpy()
?