I am trying to run a model with an embedded optimizer, but I’m running into (what seems like) a bug where, occasionally, the wrong data is stored in the inference data. My model looks like this:
x = pt.dscalar('x')
y = pt.dscalar('y')
eq1 = x ** 2 + y - 1
eq2 = x - y ** 2 + 1
system = pt.stack([eq1 ,eq2])
jac = pt.stack(pytensor.gradient.jacobian(system, [x, y]))
jac_inv = pt.linalg.solve(jac, pt.identity_like(jac))
f = pytensor.compile.builders.OpFromGraph([x, y], [system])
f_jac_inv = pytensor.compile.builders.OpFromGraph([x, y], [jac_inv])
with pm.Model(coords={'optim_step':np.arange(100), 'params':['x', 'y']}) as m:
x_val = pm.Normal('x', mu=0, sigma=20)
y_val = pm.Normal('y', mu=0, sigma=20)
root, converged, step_size, n_steps = pytensor_root(f, f_jac_inv, [x_val, y_val])
root = pm.Deterministic('root', root, dims=['optim_step', 'params'])
success = pm.Deterministic('success', converged, dims=['optim_step'])
n_steps = pm.Deterministic('n_steps', n_steps, dims=['optim_step'])
step_size = pm.Deterministic('step_size', step_size, dims=['optim_step'])
idata = pm.sample_prior_predictive()
pytensor_root
is a global newton algorithm, here’s the full code that can be copy+pasted to run:
Summary
import pytensor.tensor as pt
import pytensor
import pymc as pm
import numpy as np
import functools as ft
def _newton_step(X, F, J_inv, step_size, tol):
F_X = F(*X)
J_inv_X = J_inv(*X)
new_X = X - step_size * J_inv_X @ F_X
return (X, new_X, F_X)
def no_op(X, F, J_inv, step_size, tol):
return (X, X, X)
def compute_norms(X, new_X, F_X):
norm_X = pt.linalg.norm(X, ord=1)
norm_new_X = pt.linalg.norm(new_X, ord=1)
norm_root = pt.linalg.norm(F_X, ord=1)
norm_step = pt.linalg.norm(new_X - X, ord=1)
return norm_X, norm_new_X, norm_root, norm_step
def _check_convergence(norm_step, norm_root, tol):
new_converged = pt.or_(pt.lt(norm_step, tol), pt.lt(norm_root, tol))
return new_converged
def check_convergence(norm_step, norm_root, converged, tol):
return pytensor.ifelse.ifelse(converged,
np.array(True),
_check_convergence(norm_step, norm_root, tol))
def check_stepsize(norm_X, norm_new_X, step_size, initial_step_size):
is_decreasing = pt.le(norm_new_X, norm_X)
return pytensor.ifelse.ifelse(is_decreasing,
(is_decreasing, initial_step_size),
(is_decreasing, step_size * 0.5))
def backtrack_if_not_decreasing(is_decreasing, X, new_X):
return pytensor.ifelse.ifelse(is_decreasing, new_X, X)
def scan_body(X, converged, step_size, n_steps, tol, F, J_inv, initial_step_size):
out = pytensor.ifelse.ifelse(converged,
no_op(X, F, J_inv, step_size, tol),
_newton_step(X, F, J_inv, step_size, tol))
X, new_X, F_X = [out[i] for i in range(3)]
norm_X, norm_new_X, norm_root, norm_step = compute_norms(X, new_X, F_X)
is_converged = check_convergence(norm_step, norm_root, converged, tol)
is_decreasing, new_step_size = check_stepsize(norm_X, norm_new_X, step_size, initial_step_size)
new_X = backtrack_if_not_decreasing(is_decreasing, X, new_X)
new_n_steps = n_steps + (1 - is_converged)
return new_X, is_converged, new_step_size, new_n_steps
def pytensor_root(f, f_jac_inv, x0, step_size=1, max_iter=100, tol=1e-8):
root_func = ft.partial(scan_body, F=f, J_inv=f_jac_inv, initial_step_size=np.float64(step_size))
converged = np.array(False)
step_size = np.float64(step_size)
n_steps = 0
outputs, updates = pytensor.scan(root_func,
outputs_info=[[x_val, y_val], converged, step_size, n_steps],
non_sequences=[tol],
n_steps=max_iter,
strict=True)
root, converged, step_size, n_steps = outputs
return root, converged, step_size, n_steps
x = pt.dscalar('x')
y = pt.dscalar('y')
eq1 = x ** 2 + y - 1
eq2 = x - y ** 2 + 1
system = pt.stack([eq1 ,eq2])
jac = pt.stack(pytensor.gradient.jacobian(system, [x, y]))
jac_inv = pt.linalg.solve(jac, pt.identity_like(jac))
f = pytensor.compile.builders.OpFromGraph([x, y], [system])
f_jac_inv = pytensor.compile.builders.OpFromGraph([x, y], [jac_inv])
with pm.Model(coords={'optim_step':np.arange(100), 'params':['x', 'y']}) as m:
x_val = pm.Normal('x', mu=0, sigma=20)
y_val = pm.Normal('y', mu=0, sigma=20)
root, converged, step_size, n_steps = pytensor_root(f, f_jac_inv, [x_val, y_val])
root = pm.Deterministic('root', root, dims=['optim_step', 'params'])
success = pm.Deterministic('success', converged, dims=['optim_step'])
n_steps = pm.Deterministic('n_steps', n_steps, dims=['optim_step'])
step_size = pm.Deterministic('step_size', step_size, dims=['optim_step'])
idata = pm.sample_prior_predictive()
The reason why I’m thinking it’s a bug is because when I inspect the idata, I’ll have draws that are reported as nothing but the initial value. For example:
array([[40.95065948, 0.09735453],
[40.95065948, 0.09735453],
[40.95065948, 0.09735453],
[40.95065948, 0.09735453],
[40.95065948, 0.09735453],
[40.95065948, 0.09735453],
#...
But if I take this x0 and plug it into pytensor_root
manually, the algorithm correctly converges to a root:
x0 = [40.95065948, 0.09735453]
root, converged, step_size, n_steps = pytensor_root(f, f_jac_inv, x0)
root.eval()
>>Out: array([[-11.17560205, 5.79732092],
>>> [ -5.52029685, 2.50879961],
>>> [ -2.7690399 , 0.90183281],
>>> [ -1.53747434, 0.15292639],
>>> [ -1.53747434, 0.15292639],
>>> [ -1.53747434, 0.15292639],
So I’m not really sure what to make of all this. Obviously this is far from a MRE, but I’m hoping someone else has seen something like this before?