I’d debug this by drilling down into each component function and make sure it’s differentiable with respect to all it’s (non-data) inputs. For example, to test the probs_languages_to_probs_scenes_multiple_participants, I’d do something like this:
interpretation_weights = pt.tensor('interpretation_weights', shape=(1, None, None, None), dtype='float64')
word_order_weights = pt.tensor('word_order_weights', shape=(1, 1, 1, 1), dtype='float64')
signals = pt.tensor('signals', shape=(None, None, None), dtype='int64')
scenes_trials = pt.tensor('scenes_trials', shape=(1, None, None, None), dtype='int64')
word_orders = pt.tensor('word_orders', shape=(None, None), dtype='int64')
softmax_alphas = pt.tensor('softmax_alphas', shape=(None, None), dtype='float64')
output = probs_languages_to_probs_scenes_multiple_participants(interpretation_weights,
word_order_weights,
signals,
scenes_trials,
word_orders,
softmax_alphas)
That is, create symbolic tensor variables for each input (with the correct expected shapes), then test pytensor.grad(output.sum(), input) for each input. If you hit a NotImplemented, you will get more information about what is breaking the flow of gradients in the model. At that point I’d drill down again and test every individual computation inside the offending function to see exactly which operation doesn’t have a gradient.