The error above seems to indicate that Softmax is applied on the transformed RV of the Dirichlet distribution. However, the transformation currently used is the aeppl.transforms.Simplex which does not explicitly use the Softmax function:
class Simplex(RVTransform):
name = "simplex"
def forward(self, value, *inputs):
log_value = at.log(value)
shift = at.sum(log_value, -1, keepdims=True) / value.shape[-1]
return log_value[..., :-1] - shift
def backward(self, value, *inputs):
value = at.concatenate([value, -at.sum(value, -1, keepdims=True)], axis=-1)
exp_value_max = at.exp(value - at.max(value, -1, keepdims=True))
return exp_value_max / at.sum(exp_value_max, -1, keepdims=True)
def log_jac_det(self, value, *inputs):
N = value.shape[-1] + 1
sum_value = at.sum(value, -1, keepdims=True)
value_sum_expanded = value + sum_value
value_sum_expanded = at.concatenate(
[value_sum_expanded, at.zeros(sum_value.shape)], -1
)
This file has been truncated. show original
1 Like