Aesara aeppl.logprob.ParameterValueError

Hi there,

I’m trying to update my script which uses 3.11.4 to 3.11.5 and ran into two issues:

ERROR (aesara.graph.opt): Optimization failure due to: transform_values
Optimization failure due to: transform_values
ERROR (aesara.graph.opt): node: truncated_normal_rv{0, (0, 0, 0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FA8557BA840>), TensorConstant{}, TensorConstant{11}, TensorConstant{2.0}, TensorConstant{2.0}, TensorConstant{1.0}, TensorConstant{3.0})
node: truncated_normal_rv{0, (0, 0, 0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FA8557BA840>), TensorConstant{}, TensorConstant{11}, TensorConstant{2.0}, TensorConstant{2.0}, TensorConstant{1.0}, TensorConstant{3.0})
ERROR (aesara.graph.opt): TRACEBACK:
TRACEBACK:
File “python3.9/site-packages/aesara/graph/op.py”, line 522, in rval
r = p(n, [x[0] for x in i], o)
File “python3.9/site-packages/aeppl/transforms.py”, line 48, in perform
raise NotImplementedError(
NotImplementedError: These Ops should be removed from graphs used for computation.

This seems just to be warning and the code proceeds. However, there is a random error happening with SMC (apparently never with NUTS):

log_thunk_trace: There was a problem executing an Op.
multiprocessing.pool.RemoteTraceback:
“”"
Traceback (most recent call last):
File “python3.9/multiprocessing/pool.py”, line 125, in worker
result = (True, func(*args, **kwds))
File “python3.9/multiprocessing/pool.py”, line 51, in starmapstar
return list(itertools.starmap(args[0], args[1]))
File “python3.9/site-packages/pymc/smc/sample_smc.py”, line 439, in _apply_args_and_kwargs
return fn(*args, **kwargs)
File “python3.9/site-packages/pymc/smc/sample_smc.py”, line 363, in _sample_smc_int
smc._initialize_kernel()
File “python3.9/site-packages/pymc/smc/smc.py”, line 214, in _initialize_kernel
initial_point, [self.model.varlogpt], self.variables, shared
File “python3.9/site-packages/pymc/model.py”, line 871, in varlogpt
return self.logpt(vars=self.free_RVs)
File “python3.9/site-packages/pymc/model.py”, line 763, in logpt
rv_logps = joint_logpt(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian)
File “python3.9/site-packages/pymc/distributions/logprob.py”, line 223, in joint_logpt
temp_logp_var_dict = factorized_joint_logprob(
File “python3.9/site-packages/aeppl/joint_logprob.py”, line 186, in factorized_joint_logprob
q_logprob_vars = _logprob(
File “python3.9/functools.py”, line 877, in wrapper
return dispatch(args[0].class)(*args, **kw)
File “python3.9/site-packages/pymc/distributions/distribution.py”, line 122, in logp
return class_logp(value, *dist_params)
File “python3.9/site-packages/pymc/distributions/continuous.py”, line 785, in logp
return check_parameters(logp, *bounds)
File “python3.9/site-packages/pymc/distributions/dist_math.py”, line 67, in check_parameters
return CheckParameterValue(msg)(logp, all_true_scalar)
File “python3.9/site-packages/aesara/graph/op.py”, line 297, in call
compute_test_value(node)
File “python3.9/site-packages/aesara/graph/op.py”, line 135, in compute_test_value
required = thunk()
File “python3.9/site-packages/aesara/link/c/op.py”, line 103, in rval
thunk()
File “python3.9/site-packages/aesara/link/c/basic.py”, line 1766, in call
raise exc_value.with_traceback(exc_trace)
aeppl.logprob.ParameterValueError
“”"

Interestingly, when I change one of the RVs from TruncatedNormal to Uniform, the error never happens it seems, while occurs randomly with TruncatedNormal. I still get the warning message in both cases, however.

Anyway, any help in guiding me through debugging this would be appreciated! I assume this is somehow related to the prior sampling with SMC? By the way, I see that p_acc_rate seems to be gone from the latter? I’ll try to extract a minimal example script but that’s not straightforward.

Cheers,
Vian

While trying to make a minimal script, I realized the issue comes from:
theano.config.compute_test_value = ‘ignore’

which I had left, with importing aesara as theano. Noe I don’t have the warning message or the random error.

1 Like