How to improve the efficiency of loop, and find a root

Hi, I want to indentify some parameters in a physical model based on multi-normal distribution.

This is a model including 2 input {s, z} and 3 output {Fn, Fdp, Mr}, whose length is both n_length. It’s complicated to calculate the output based on RVs for each input. So I use a for loop to calculate the output for each input, though it’s not necessary to use for loop for the current model.

For the current model, a for loop with numpyro sampling took a long time to run (12min). I tried scan, but it took more time (30min) since the jax backend is not working.

I would ask if there is some way to improve the efficiency of the for loop?

Moreover, I want to find a root (such as scipy.root) based on an output in the future, such as calculate z_pred based on Fn_pred and the measured Fn is equal. Then the input set is {s, Fn}, and the output set is {z, Fdp, Mr}.

So if there is a function like scipy.root() or some method to estimate the root in this model?

This is the current model with for loop

theta_1 = np.arccos(1 - z / r)
observed = np.array([Fn, Fdp, Mr]).T
with pm.Model() as model:
    # set parameters to indentify (RVs)
    Ks = pm.Uniform("Ks", lower=5e5, upper=5e7)
    n0 = pm.Uniform("n0", 0.2, 0.9)
    n1 = pm.Uniform("n1", 0.0, 0.9)
    phi = pm.Uniform("phi", 25, 45)
    phi_rad = phi / 180 * np.pi
    k = pm.Uniform("k", lower=0.005, upper=0.05)
    c1 = pm.Uniform("c1", lower=0.3, upper=0.7)
    c2 = pm.Uniform("c2", lower=-0.4, upper=0.1)
    sigma_Fn = pm.Uniform("sigma_Fn", lower=1e0, upper=1e4)
    sigma_Fdp = pm.Uniform("sigma_Fdp", lower=1e0, upper=1e4)
    sigma_Mr = pm.Uniform("sigma_Mr", lower=1e-1, upper=1e4)

    # calculate the output(n*3) based on the physical model
    # s, theta_1 are known, which are all n_length-dimensional vectors
    # other parameters are constants
    F_pred=[]
    for i in range(len(s)):
        s_i = s[i]
        theta_1_i = theta_1[i]
        thetam_i = (c1 + c2 * s_i) * theta_1[i]
        n = n0 + n1 * s_i
        Tan_phi = pm.math.tan(phi_rad)
        sigma_m = (
            Ks
            * at.power(pm.math.cos(thetam_i) - pm.math.cos(theta_1_i), n)
            * at.power(r, n)
        )
        j_m = r* (
            theta_1_i
            - thetam_i
            - (1 - s_i) * (pm.math.sin(theta_1_i) - pm.math.sin(thetam_i))
        )
        tau_m = sigma_m * Tan_phi * (1 - pm.math.exp(-j_m / k))
        A = (pm.math.cos(thetam_i) - 1) / thetam_i + (
            pm.math.cos(thetam_i) - pm.math.cos(theta_1_i)
        ) / (theta_1_i - thetam_i)
        B = pm.math.sin(thetam_i) / thetam_i + (
            pm.math.sin(thetam_i) - pm.math.sin(theta_1_i)
        ) / (theta_1_i - thetam_i)
        C = theta_1_i / 2
        X = r * b * sigma_m
        Y = r * b * tau_m
        Fn_pred_i = A * X + B * Y
        Fdp_pred_i = A * Y - B * X
        Mr_pred_i = r * C * Y
        F_pred.append([Fn_pred_i, Fdp_pred_i, Mr_pred_i])

    _mu = at.stack(F_pred)
    mu = pm.Deterministic("mu", _mu)
    cov = at.diag([sigma_Fn, sigma_Fdp, sigma_Mr])
    likelihood = pm.MvNormal("likelihood", cov=cov, mu=mu, observed=observed)
    trace=sample_numpyro_nuts(draws=10000, tune=2000, chains=4, target_accept=0.9)

The scan loop is like this

s = at.as_tensor_variable(s)
theta_1 = at.as_tensor_variable(theta_1)

def inner_func(i, Ks, n0, n1, phi, k, c1, c2):
	s_i = s[i]
    theta_1_i = theta_1[i]
    thetam_i = (c1 + c2 * s_i) * theta_1[i]
    n = n0 + n1 * s_i
    Tan_phi = pm.math.tan(phi_rad)
    sigma_m = (
        Ks
        * at.power(pm.math.cos(thetam_i) - pm.math.cos(theta_1_i), n)
        * at.power(r, n)
    )
    j_m = r* (
        theta_1_i
        - thetam_i
        - (1 - s_i) * (pm.math.sin(theta_1_i) - pm.math.sin(thetam_i))
    )
    tau_m = sigma_m * Tan_phi * (1 - pm.math.exp(-j_m / k))
    A = (pm.math.cos(thetam_i) - 1) / thetam_i + (
        pm.math.cos(thetam_i) - pm.math.cos(theta_1_i)
    ) / (theta_1_i - thetam_i)
    B = pm.math.sin(thetam_i) / thetam_i + (
        pm.math.sin(thetam_i) - pm.math.sin(theta_1_i)
    ) / (theta_1_i - thetam_i)
    C = theta_1_i / 2
    X = r * b * sigma_m
    Y = r * b * tau_m
    Fn_pred_i = A * X + B * Y
    Fdp_pred_i = A * Y - B * X
    Mr_pred_i = r * C * Y
    return at.stack([Fn_pred_i, Fdp_pred_i, Mr_pred_i])

result, _ = aesara.scan(
    inner_func,
    sequences=at.arange(num_inter, dtype="int32"),
    non_sequences=[Ks, n0, n1, phi, k, c1, c2],
)

_mu = result

I don’t have a good answer for you but here are some thoughts:

  • AFAIK, there’s no out-of-the-box differentiable root finder for Aesara yet, but you can find one for the Jax / NumPyro ecosystem here: jax.scipy.optimize.minimize — JAX documentation

  • You could try wrapping a solver in an Aesara Op, but this would require getting the gradients of the solver’s operations.

  • Supposing your data is of modest dimension ( N \ge 50), 12 minutes for sampling this kind of model doesn’t sound too outrageous. These types of problems are pretty hard in general and I’m guessing the posterior is going to be rather complicated because of the nonlinearities in this model.

  • You may want to search the existing Discourse topics for threads related to “root”, “optimization”, “Newton” or related keywords. Here’s an example of one such thread: Defining grad() for custom Theano Op that solves nonlinear system of equations - #2 by BioGoertz.

4 Likes

I have an implementation of Newton’s Method in Aesara here that might be helpful. I’ve tried including it in PyMC models before, but it’s very slow. One thing I’ve been thinking about recently is that it might be wasteful to have every step of the root finding algorithm on the computational graph. If you can write down derivatives of the root with respect to parameters, you could just wrap scipy.optimize.minimize in an Op and provide your own gradients. Not a general solution, but would likely be much much faster than a scan-based optimizer like the one I linked.

In general though, Scans are a part of the library under very active development, and for now it is what it is. I’ve found it can be quite fast with scalar-valued inputs, but goes very slow on linear algebra operators. There are ways to optimize them, though. There’s a thread here were the Aesara devs walk me though optimization of a scan, it might be of interest.

You can also try using nutpie, which is a Numba-based NUTS sampler. It can offer significant speedups in certain cases, and can compile a PyMC model without any work on your end.

Finally, I second the (somewhat heterodox) opinion of @BioGoertz in the thread that @ckrapu just linked, that sometimes you should think about letting go of NUTS if you model is really too hard to get gradients for. I recently have had success with using emcee for a modest-dimensional model with a ton of loops, optimizations, and linear-approximation.

1 Like

Thanks for your suggest. I’m a bigger to PyMC.
I also found that the Metropolis sampling is usually much faster than NUTS, but the posterior is not as good as NUTS. The posterior may diverge or exit many noise.

Yes, definitely don’t use Metropolis if you can avoid it.

If you’re new to all this, check out this interactive gallery that shows how different algorithms produce samples. You will quickly understand why NUTS is king. But if you can’t get gradients…