I have developed the Bayesian Zero Adjusted Dirichlet Regression model (Tsagris and Stewart 2018) for my research using the PyMC3
version and fortunately estimated results with them and currently revising my manuscript. As the pymc
has upgraded to version 4, there is an opportunity use the pymc.sampling.sample_blackjax_nuts
(or pymc.sampling.sample_numpyro_nuts
) to boost estimation speed. I have updated my module to fit version 4 and success in estimating by using pymc.sample
. Thanks to No JAX conversion for the given `Op`: Nonzero · Issue #6026 · pymc-devs/pymc · GitHub, I have solved the nonzero function issue. However, there is another error happened as follows:
Extracting posterior distributions for the model [m22f|y17]
Compilation time = 0:01:23.825692
TypeError Traceback (most recent call last)
Input In [16], in <cell line: 1>()
7 rdseed = qr.get_data(data_type="uint16", array_length=4)
8 with md:
----> 9 p3infdb[mdkey] = pymc.sampling_jax.sample_numpyro_nuts(
10 draws=5500,
11 # chains=4,
12 # cores=4,
13 tune=7500,
14 target_accept=0.95,
15 # max_treedepth=15,
16 model=md,
17 # random_seed=rdseed,
18 # return_inferencedata=True,
19 # pickle_backend="dill",
20 )
File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/pymc/sampling_jax.py:519, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
516 if chains > 1:
517 map_seed = jax.random.split(map_seed, chains)
--> 519 pmap_numpyro.run(
520 map_seed,
521 init_params=init_params,
522 extra_fields=(
523 "num_steps",
524 "potential_energy",
525 "energy",
526 "adapt_state.step_size",
527 "accept_prob",
528 "diverging",
529 ),
530 )
532 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
534 tic3 = datetime.now()
File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/numpyro/infer/mcmc.py:599, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
597 states, last_state = _laxmap(partial_map_fn, map_args)
598 elif self.chain_method == "parallel":
--> 599 states, last_state = pmap(partial_map_fn)(map_args)
600 else:
601 assert self.chain_method == "vectorized"
[... skipping hidden 17 frame]
File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
379 rng_key, init_state, init_params = init
380 if init_state is None:
--> 381 init_state = self.sampler.init(
382 rng_key,
383 self.num_warmup,
384 init_params,
385 model_args=args,
386 model_kwargs=kwargs,
387 )
388 sample_fn, postprocess_fn = self._get_cached_fns()
389 diagnostics = (
390 lambda x: self.sampler.get_diagnostics_str(x[0])
391 if rng_key.ndim == 1
392 else ""
393 ) # noqa: E731
File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/numpyro/infer/hmc.py:746, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
726 hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
727 init_params,
728 num_warmup=num_warmup,
743 rng_key=rng_key,
744 )
745 if rng_key.ndim == 1:
--> 746 init_state = hmc_init_fn(init_params, rng_key)
747 else:
748 # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
749 # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
750 # wa_steps because those variables do not depend on traced args: init_params, rng_key.
751 init_state = vmap(hmc_init_fn)(init_params, rng_key)
File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/numpyro/infer/hmc.py:726, in HMC.init.<locals>.<lambda>(init_params, rng_key)
723 dense_mass = [tuple(sorted(z))] if dense_mass else []
724 assert isinstance(dense_mass, list)
--> 726 hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
727 init_params,
728 num_warmup=num_warmup,
729 step_size=self._step_size,
730 num_steps=self._num_steps,
731 inverse_mass_matrix=inverse_mass_matrix,
732 adapt_step_size=self._adapt_step_size,
733 adapt_mass_matrix=self._adapt_mass_matrix,
734 dense_mass=dense_mass,
735 target_accept_prob=self._target_accept_prob,
736 trajectory_length=self._trajectory_length,
737 max_tree_depth=self._max_tree_depth,
738 find_heuristic_step_size=self._find_heuristic_step_size,
739 forward_mode_differentiation=self._forward_mode_differentiation,
740 regularize_mass_matrix=self._regularize_mass_matrix,
741 model_args=model_args,
742 model_kwargs=model_kwargs,
743 rng_key=rng_key,
744 )
745 if rng_key.ndim == 1:
746 init_state = hmc_init_fn(init_params, rng_key)
File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/numpyro/infer/hmc.py:322, in hmc.<locals>.init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, num_steps, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, regularize_mass_matrix, model_args, model_kwargs, rng_key)
320 r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
321 vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
--> 322 vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
323 energy = vv_state.potential_energy + kinetic_fn(
324 wa_state.inverse_mass_matrix, vv_state.r
325 )
326 zero_int = jnp.array(0, dtype=jnp.result_type(int))
File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/numpyro/infer/hmc_util.py:278, in velocity_verlet.<locals>.init_fn(z, r, potential_energy, z_grad)
270 """
271 :param z: Position of the particle.
272 :param r: Momentum of the particle.
275 :return: initial state for the integrator.
276 """
277 if potential_energy is None or z_grad is None:
--> 278 potential_energy, z_grad = _value_and_grad(
279 potential_fn, z, forward_mode_differentiation
280 )
281 return IntegratorState(z, r, potential_energy, z_grad)
File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/numpyro/infer/hmc_util.py:246, in _value_and_grad(f, x, forward_mode_differentiation)
244 return f(x), jacfwd(f)(x)
245 else:
--> 246 return value_and_grad(f)(x)
[... skipping hidden 8 frame]
File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/pymc/sampling_jax.py:109, in get_jaxified_logp.<locals>.logp_fn_wrap(x)
108 def logp_fn_wrap(x):
--> 109 return logp_fn(*x)[0]
File /var/folders/z5/4j3nhw7s3w5dg0yd99hjcdh80000gn/T/tmplkadkqrc:1592, in jax_funcified_fgraph(_interval_19, _interval_18, _interval_17, _interval_16, _interval_15, _interval_14, _log_36, _log_35, _log_34, _log_33, _log_32, _log_31, _log_30, _log_29, _log_28, _log_27, _log_26, _log_25, _log_24, _log_23, _log_22, _log_21, _log_20, _log_19, _log_18, _log_17, _log_16, _log_15, _log_14, _log_13, _log_12, _log_11, _log_10, _log_9, _log_8, _log_7, _log_6, _log_5, _log_4, _log_3, _log_2, _log_1, _log_, _83, _82, _84, _81, _85, _80, _77, _76, _78, _75, _79, _74, _71, _70, _72, _69, _73, _68, _65, _64, _66, _63, _67, _62, _59, _58, _60, _57, _61, _56, _53, _52, _54, _51, _55, _50, _interval_12, _interval_10, _interval_8, _interval_6, _interval_4, _interval_2, _interval_, _43, _44, _45, _46, _47, _48, _37, _38, _39, _40, _41, _42, _31, _32, _33, _34, _35, _36, _25, _26, _27, _28, _29, _30, _19, _20, _21, _22, _23, _24, _13, _14, _15, _16, _17, _18, _7, _8, _9, _10, _11, _12, _49, _interval_13, _interval_11, _interval_9, _interval_7, _interval_5, _interval_3, _interval_1, _, _1, _2, _3, _4, _5, _6)
1590 auto_619461 = elemwise9(auto_619460, auto_619401)
1591 # Alloc(TensorConstant{0.0}, Elemwise{add,no_inplace}.0, TensorConstant{7})
-> 1592 auto_619462 = alloc1(auto_23635, auto_619461, auto_642007)
1593 # AdvancedIncSubtensor{inplace=True, set_instead_of_inc=True}(Alloc.0, Elemwise{mul,no_inplace}.0, TensorConstant{[ True Tr..lse False]})
1594 auto_672825 = advancedincsubtensor(auto_619462, auto_619344, auto_641748)
File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:271, in jax_funcify_Alloc.<locals>.alloc(x, *shape)
270 def alloc(x, *shape):
--> 271 res = jnp.broadcast_to(x, shape)
272 return res
File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/jax/_src/numpy/util.py:375, in _broadcast_to(arr, shape)
373 if not isinstance(shape, tuple) and np.ndim(shape) == 0:
374 shape = (shape,)
--> 375 shape = core.canonicalize_shape(shape) # check that shape is concrete
376 arr_shape = np.shape(arr)
377 if core.symbolic_equal_shape(arr_shape, shape):
File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/jax/core.py:1790, in canonicalize_shape(shape, context)
1788 except TypeError:
1789 pass
-> 1790 raise _invalid_shape_error(shape, context)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>, DeviceArray(7, dtype=int64)).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
I have tried a solution where suggested at Pm.sampling_jax to sample a MvNormal(), and still, the error message occurred. Is there any idea what could be my problem?
My module for this is as follows:
#!/usr/bin/env python3
# Copyright 2022 The ZADR developer (Lee, Meongsu)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the license.
# ....
# Module: PyMC model for ZADR
# ....
# Creator: Meongsu Lee, Ph.D.
# Date: 09/29/2020
# Update: 08/02/2022
import pandas as pd
import numpy as np
import aesara
import aesara.tensor as at
import pymc as pm
from typing import Union
from itertools import product
from aesara.tensor.subtensor import set_subtensor
from zadr.distributions.zad_distribution import ZeroAdjDirichlet
from zadr.models.hyper_means import hyper_mean
from zadr.models.hyper_precs import hyper_prec
def zadr_pymcmd(
x: Union[np.ndarray, dict], y: np.ndarray, hpr_mu: dict, hpr_tau: dict, **kwarg
) -> pm.Model:
pymc model for ZADR
x: np.ndarray or dict
independent variables (or feature variables)
y: np.array
dependent (or target) variable
hpr_mu: dict
dictionary of mean hyper priors
hpr_tau: dict
dictionary of precision hyper priors
M = y.shape[1] # Total number of compositions
m = M - 1 # the dimensionality of the simplex
z = kwarg.get("latent", None) # Get the latent covariates
fixed = kwarg.get("fixed", None) # Get the fixed covariates
if isinstance(x, dict):
p = x[list(x.keys())[0]].shape[1]
p = x.shape[1] # Number of explanatory variables
with pm.Model() as ret_model:
# Covariates
if isinstance(x, dict):
cap_xs = [
pm.MutableData("X" + "_" + key[3:], db.copy()) for key, db in x.items()
cap_x = pm.MutableData("X", x.copy())
# Dependent variables
cap_y = pm.MutableData("Y", y.copy())
# Latent variables
if z is not None:
cap_z = pm.MutableData("Z", z.copy())
# Fixed variables
if fixed is not None:
cap_fx = pm.MutableData("FX", fixed.copy())
# rloc_nz = at.nonzero(at.all(at.neq(cap_y, 0.0), axis=-1))
# rloc_z = at.nonzero(at.any(at.eq(cap_y, 0.0), axis=-1))
rloc_nz = at.all(at.neq(cap_y, 0.0), axis=-1)
rloc_z = at.any(at.eq(cap_y, 0.0), axis=-1)
# Separate zero, non-zero components
y2 = cap_y[rloc_z].copy()
if isinstance(x, dict):
x1 = at.stack([cap_x[rloc_nz].copy() for cap_x in cap_xs], axis=-1)
x2 = at.stack([cap_x[rloc_z].copy() for cap_x in cap_xs], axis=-1)
w1 = x1[:, 0, 0].reshape((-1, 1)).copy()
w2 = x2[:, 0, 0].reshape((-1, 1)).copy()
x1 = cap_x[rloc_nz].copy()
x2 = cap_x[rloc_z].copy()
w1 = x1[:, 0].reshape((-1, 1)).copy()
w2 = x2[:, 0].reshape((-1, 1)).copy()
# Make ln_y1, ln_y2
ly2 = at.log(y2)
# Hyper priors
# --------------------------------------
τβs = hyper_prec("β", hpr_tau["β"])
τγs = hyper_prec("γ", hpr_tau["γ"])
μβs = hyper_mean("β", hpr_mu["β"])
μγs = hyper_mean("γ", hpr_mu["γ"])
# Precision candidates of ζ
if z is not None:
Σ = hyper_prec("ζ", hpr_tau["ζ"])
if fixed is not None:
# τηs = 1 / pm.Exponential(
# "τηs", lam=1e-3, shape=(𝔽𝕏.shape[1].eval(),)
# )
# 1
# / pm.Gamma(
# "τη" + "{0}".format(chr(0x2080 + i)),
# alpha=0.1,
# beta=0.01,
# )
τηs = at.stack(
/ pm.Uniform(
"τη" + "{0}".format(chr(0x2080 + i)),
for i in range(cap_fx.get_value().shape[1])
# Priors, βⱼ = (β₀ⱼ,β₁ⱼ,...,β_{mj})
# (11/21/20) Comment if..else statements
β = at.stack(
+ "{0}".format(chr(0x2080 + i))
+ "{0}".format(chr(0x2080 + j)),
mu=μβs[i, j],
tau=τβs[i, j],
for j in range(m)
for i in range(p)
β = at.concatenate([β, at.zeros((p, 1))], axis=1)
γ = at.stack(
"γ" + "{0}".format(chr(0x2080 + i)),
for i in range(hpr_mu["γ"][0][0])
# ζ: coefficients of latent factor
if z is not None and fixed is None:
ζ = (
shape=(cap_z.get_value().shape[1], cap_z.get_value().shape[1]),
# shape=(ℤ.shape[1].eval(),),
if hpr_tau["ζ"][0] == "LKJ"
else pm.MvNormal(
cov=at.diag(at.pow(Σ, 2)),
shape=(cap_z.get_value().shape[1], cap_z.get_value().shape[1]),
# shape=(ℤ.shape[1].eval(),),
λ = at.dot(cap_z, at.diag(ζ).reshape((-1, 1)))
# λ = at.dot(ℤ, ζ.reshape((-1, 1)))
# Parameters α₋₀ and α₀ with latent factors
if isinstance(x, dict):
nrtn_nz = at.concatenate(
at.tensordot(x1[:, :, i], β[:, i], axes=[[1], [0]]).reshape(
(-1, 1)
for i in range(M)
lc_nz = at.exp(nrtn_nz + λ[rloc_nz])
α_neg0 = lc_nz / at.sum(lc_nz, axis=1).reshape((-1, 1))
nrtn_z = at.concatenate(
at.tensordot(x2[:, :, i], β[:, i], axes=[[1], [0]]).reshape(
(-1, 1)
for i in range(M)
lc_z = at.exp(nrtn_z + λ[rloc_z])
alpha_z = at.where(at.isinf(ly2), 0.0, lc_z)
α_0 = alpha_z / at.sum(alpha_z, axis=1).reshape((-1, 1))
α_neg0 = at.exp(at.dot(x1, β) + λ[rloc_nz]) / at.sum(
at.exp(at.dot(x1, β) + λ[rloc_nz]), axis=1
).reshape((-1, 1))
lc_z = at.exp(at.dot(x2, β) + λ[rloc_z])
alpha_z = at.where(at.isinf(ly2), 0.0, lc_z)
α_0 = alpha_z / at.sum(alpha_z, axis=1).reshape((-1, 1))
elif fixed is not None and z is None:
μηs = at.stack(
"μη" + "{0}".format(chr(0x2080 + i)),
for i in range(cap_fx.get_value().shape[1])
η = at.stack(
"η" + "{0}".format(chr(0x2080 + i)),
for i in range(cap_fx.get_value().shape[1])
ο = at.dot(cap_fx, η.reshape((-1, 1)))
# Parameters α₋₀ and α₀ with fixed factors
if isinstance(x, dict):
nrtn_nz = at.concatenate(
at.tensordot(x1[:, :, i], β[:, i], axes=[[1], [0]]).reshape(
(-1, 1)
for i in range(M)
lc_nz = at.exp(nrtn_nz + ο[rloc_nz])
α_neg0 = lc_nz / at.sum(lc_nz, axis=1).reshape((-1, 1))
nrtn_z = at.concatenate(
at.tensordot(x2[:, :, i], β[:, i], axes=[[1], [0]]).reshape(
(-1, 1)
for i in range(M)
lc_z = at.exp(nrtn_z + ο[rloc_z])
alpha_z = at.where(at.isinf(ly2), 0.0, lc_z)
α_0 = alpha_z / at.sum(alpha_z, axis=1).reshape((-1, 1))
α_neg0 = at.exp(at.dot(x1, β) + ο[rloc_nz]) / at.sum(
at.exp(at.dot(x1, β) + ο[rloc_nz]), axis=1
).reshape((-1, 1))
lc_z = at.exp(at.dot(x2, β) + ο[rloc_z])
alpha_z = at.where(at.isinf(ly2), 0.0, lc_z)
α_0 = alpha_z / at.sum(alpha_z, axis=1).reshape((-1, 1))
elif z is not None and fixed is not None:
ζ = (
shape=(cap_z.get_value().shape[1], cap_z.get_value().shape[1]),
# shape=(ℤ.shape[1].eval(),),
if hpr_tau["ζ"][0] == "LKJ"
else pm.MvNormal(
cov=at.diag(at.pow(Σ, 2)),
shape=(cap_z.get_value().shape[1], cap_z.get_value().shape[1]),
# shape=(ℤ.shape[1].eval(),),
λ = at.dot(cap_z, at.diag(ζ).reshape((-1, 1)))
μη = pm.Uniform("μη", lower=-1.0, upper=1.0)
η = at.stack(
pm.Normal("η" + "{0}".format(chr(0x2080 + i)), mu=μη, tau=τηs[i])
for i in range(cap_fx.get_value().shape[1])
ο = at.dot(cap_fx, η.reshape((-1, 1)))
# Parameters α₋₀ and α₀ with fixed factors
if isinstance(x, dict):
nrtn_nz = at.concatenate(
at.tensordot(x1[:, :, i], β[:, i], axes=[[1], [0]]).reshape(
(-1, 1)
for i in range(M)
lc_nz = at.exp(nrtn_nz + λ[rloc_nz] + ο[rloc_nz])
α_neg0 = lc_nz / at.sum(lc_nz, axis=1).reshape((-1, 1))
nrtn_z = at.concatenate(
at.tensordot(x2[:, :, i], β[:, i], axes=[[1], [0]]).reshape(
(-1, 1)
for i in range(M)
lc_z = at.exp(nrtn_z + λ[rloc_nz] + ο[rloc_z])
alpha_z = at.where(at.isinf(ly2), 0.0, lc_z)
α_0 = alpha_z / at.sum(alpha_z, axis=1).reshape((-1, 1))
α_neg0 = at.exp(at.dot(x1, β) + λ[rloc_nz] + ο[rloc_nz]) / at.sum(
at.exp(at.dot(x1, β) + λ[rloc_nz] + ο[rloc_nz]), axis=1
).reshape((-1, 1))
lc_z = at.exp(at.dot(x2, β) + λ[rloc_z] + ο[rloc_z])
alpha_z = at.where(at.isinf(ly2), 0.0, lc_z)
α_0 = alpha_z / at.sum(alpha_z, axis=1).reshape((-1, 1))
# Parameters α₋₀ and α₀ without latent factors
if isinstance(x, dict):
nrtn_nz = at.concatenate(
at.tensordot(x1[:, :, i], β[:, i], axes=[[1], [0]]).reshape(
(-1, 1)
for i in range(M)
lc_nz = at.exp(nrtn_nz)
α_neg0 = lc_nz / at.sum(lc_nz, axis=1).reshape((-1, 1))
nrtn_z = at.concatenate(
at.tensordot(x2[:, :, i], β[:, i], axes=[[1], [0]]).reshape(
(-1, 1)
for i in range(M)
lc_z = at.exp(nrtn_z)
alpha_z = at.where(at.isinf(ly2), 0.0, lc_z)
α_0 = alpha_z / at.sum(alpha_z, axis=1).reshape((-1, 1))
α_neg0 = at.exp(at.dot(x1, β)) / at.sum(
at.exp(at.dot(x1, β)), axis=1
).reshape((-1, 1))
lc_z = at.exp(at.dot(x2, β))
alpha_z = at.where(at.isinf(ly2), 0.0, lc_z)
α_0 = alpha_z / at.sum(alpha_z, axis=1).reshape((-1, 1))
# Parameters ϕ
ϕ_neg0 = at.exp(at.dot(w1, γ.reshape((-1, 1))))
ϕ_0 = at.exp(at.dot(w2, γ.reshape((-1, 1))))
# Grouped concentration parameters
a_0 = α_0 * ϕ_0
a_neg0 = α_neg0 * ϕ_neg0
# Unified concentration parameters
a = at.zeros((a_0.shape[0] + a_neg0.shape[0], a_0.shape[1]))
a = set_subtensor(a[rloc_nz], a_neg0)
a = set_subtensor(a[rloc_z], a_0)
# Likelihood
# {"nz": rloc_nz, "z": rloc_z}
obs = ZeroAdjDirichlet("obs", a, observed=cap_y)
return ret_model
Sorry for the long lines of the module. The model and distribution are still under development so it could be messy.