Getting "TypeError: Shapes must be 1D sequences of concrete values of integer type"

Hello!

I have developed the Bayesian Zero Adjusted Dirichlet Regression model (Tsagris and Stewart 2018) for my research using the PyMC3 version and fortunately estimated results with them and currently revising my manuscript. As the pymc has upgraded to version 4, there is an opportunity use the pymc.sampling.sample_blackjax_nuts (or pymc.sampling.sample_numpyro_nuts ) to boost estimation speed. I have updated my module to fit version 4 and success in estimating by using pymc.sample. Thanks to No JAX conversion for the given `Op`: Nonzero · Issue #6026 · pymc-devs/pymc · GitHub, I have solved the nonzero function issue. However, there is another error happened as follows:

Extracting posterior distributions for the model [m22f|y17]

Compiling...
Compilation time =  0:01:23.825692
Sampling...
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [16], in <cell line: 1>()
      7 rdseed = qr.get_data(data_type="uint16", array_length=4)
      8 with md:
----> 9     p3infdb[mdkey] = pymc.sampling_jax.sample_numpyro_nuts(
     10         draws=5500,
     11         # chains=4,
     12         # cores=4,
     13         tune=7500,
     14         target_accept=0.95,
     15         # max_treedepth=15,
     16         model=md,
     17         # random_seed=rdseed,
     18         # return_inferencedata=True,
     19         # pickle_backend="dill",
     20     )

File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/pymc/sampling_jax.py:519, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
    516 if chains > 1:
    517     map_seed = jax.random.split(map_seed, chains)
--> 519 pmap_numpyro.run(
    520     map_seed,
    521     init_params=init_params,
    522     extra_fields=(
    523         "num_steps",
    524         "potential_energy",
    525         "energy",
    526         "adapt_state.step_size",
    527         "accept_prob",
    528         "diverging",
    529     ),
    530 )
    532 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
    534 tic3 = datetime.now()

File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/numpyro/infer/mcmc.py:599, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    597     states, last_state = _laxmap(partial_map_fn, map_args)
    598 elif self.chain_method == "parallel":
--> 599     states, last_state = pmap(partial_map_fn)(map_args)
    600 else:
    601     assert self.chain_method == "vectorized"

    [... skipping hidden 17 frame]

File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    379 rng_key, init_state, init_params = init
    380 if init_state is None:
--> 381     init_state = self.sampler.init(
    382         rng_key,
    383         self.num_warmup,
    384         init_params,
    385         model_args=args,
    386         model_kwargs=kwargs,
    387     )
    388 sample_fn, postprocess_fn = self._get_cached_fns()
    389 diagnostics = (
    390     lambda x: self.sampler.get_diagnostics_str(x[0])
    391     if rng_key.ndim == 1
    392     else ""
    393 )  # noqa: E731

File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/numpyro/infer/hmc.py:746, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    726 hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    727     init_params,
    728     num_warmup=num_warmup,
   (...)
    743     rng_key=rng_key,
    744 )
    745 if rng_key.ndim == 1:
--> 746     init_state = hmc_init_fn(init_params, rng_key)
    747 else:
    748     # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
    749     # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
    750     # wa_steps because those variables do not depend on traced args: init_params, rng_key.
    751     init_state = vmap(hmc_init_fn)(init_params, rng_key)

File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/numpyro/infer/hmc.py:726, in HMC.init.<locals>.<lambda>(init_params, rng_key)
    723         dense_mass = [tuple(sorted(z))] if dense_mass else []
    724     assert isinstance(dense_mass, list)
--> 726 hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    727     init_params,
    728     num_warmup=num_warmup,
    729     step_size=self._step_size,
    730     num_steps=self._num_steps,
    731     inverse_mass_matrix=inverse_mass_matrix,
    732     adapt_step_size=self._adapt_step_size,
    733     adapt_mass_matrix=self._adapt_mass_matrix,
    734     dense_mass=dense_mass,
    735     target_accept_prob=self._target_accept_prob,
    736     trajectory_length=self._trajectory_length,
    737     max_tree_depth=self._max_tree_depth,
    738     find_heuristic_step_size=self._find_heuristic_step_size,
    739     forward_mode_differentiation=self._forward_mode_differentiation,
    740     regularize_mass_matrix=self._regularize_mass_matrix,
    741     model_args=model_args,
    742     model_kwargs=model_kwargs,
    743     rng_key=rng_key,
    744 )
    745 if rng_key.ndim == 1:
    746     init_state = hmc_init_fn(init_params, rng_key)

File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/numpyro/infer/hmc.py:322, in hmc.<locals>.init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, num_steps, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, regularize_mass_matrix, model_args, model_kwargs, rng_key)
    320 r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
    321 vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
--> 322 vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
    323 energy = vv_state.potential_energy + kinetic_fn(
    324     wa_state.inverse_mass_matrix, vv_state.r
    325 )
    326 zero_int = jnp.array(0, dtype=jnp.result_type(int))

File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/numpyro/infer/hmc_util.py:278, in velocity_verlet.<locals>.init_fn(z, r, potential_energy, z_grad)
    270 """
    271 :param z: Position of the particle.
    272 :param r: Momentum of the particle.
   (...)
    275 :return: initial state for the integrator.
    276 """
    277 if potential_energy is None or z_grad is None:
--> 278     potential_energy, z_grad = _value_and_grad(
    279         potential_fn, z, forward_mode_differentiation
    280     )
    281 return IntegratorState(z, r, potential_energy, z_grad)

File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/numpyro/infer/hmc_util.py:246, in _value_and_grad(f, x, forward_mode_differentiation)
    244     return f(x), jacfwd(f)(x)
    245 else:
--> 246     return value_and_grad(f)(x)

    [... skipping hidden 8 frame]

File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/pymc/sampling_jax.py:109, in get_jaxified_logp.<locals>.logp_fn_wrap(x)
    108 def logp_fn_wrap(x):
--> 109     return logp_fn(*x)[0]

File /var/folders/z5/4j3nhw7s3w5dg0yd99hjcdh80000gn/T/tmplkadkqrc:1592, in jax_funcified_fgraph(_interval_19, _interval_18, _interval_17, _interval_16, _interval_15, _interval_14, _log_36, _log_35, _log_34, _log_33, _log_32, _log_31, _log_30, _log_29, _log_28, _log_27, _log_26, _log_25, _log_24, _log_23, _log_22, _log_21, _log_20, _log_19, _log_18, _log_17, _log_16, _log_15, _log_14, _log_13, _log_12, _log_11, _log_10, _log_9, _log_8, _log_7, _log_6, _log_5, _log_4, _log_3, _log_2, _log_1, _log_, _83, _82, _84, _81, _85, _80, _77, _76, _78, _75, _79, _74, _71, _70, _72, _69, _73, _68, _65, _64, _66, _63, _67, _62, _59, _58, _60, _57, _61, _56, _53, _52, _54, _51, _55, _50, _interval_12, _interval_10, _interval_8, _interval_6, _interval_4, _interval_2, _interval_, _43, _44, _45, _46, _47, _48, _37, _38, _39, _40, _41, _42, _31, _32, _33, _34, _35, _36, _25, _26, _27, _28, _29, _30, _19, _20, _21, _22, _23, _24, _13, _14, _15, _16, _17, _18, _7, _8, _9, _10, _11, _12, _49, _interval_13, _interval_11, _interval_9, _interval_7, _interval_5, _interval_3, _interval_1, _, _1, _2, _3, _4, _5, _6)
   1590 auto_619461 = elemwise9(auto_619460, auto_619401)
   1591 # Alloc(TensorConstant{0.0}, Elemwise{add,no_inplace}.0, TensorConstant{7})
-> 1592 auto_619462 = alloc1(auto_23635, auto_619461, auto_642007)
   1593 # AdvancedIncSubtensor{inplace=True,  set_instead_of_inc=True}(Alloc.0, Elemwise{mul,no_inplace}.0, TensorConstant{[ True  Tr..lse False]})
   1594 auto_672825 = advancedincsubtensor(auto_619462, auto_619344, auto_641748)

File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:271, in jax_funcify_Alloc.<locals>.alloc(x, *shape)
    270 def alloc(x, *shape):
--> 271     res = jnp.broadcast_to(x, shape)
    272     return res

File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/jax/_src/numpy/util.py:375, in _broadcast_to(arr, shape)
    373 if not isinstance(shape, tuple) and np.ndim(shape) == 0:
    374   shape = (shape,)
--> 375 shape = core.canonicalize_shape(shape)  # check that shape is concrete
    376 arr_shape = np.shape(arr)
    377 if core.symbolic_equal_shape(arr_shape, shape):

File ~/miniconda3/envs/pymc/lib/python3.9/site-packages/jax/core.py:1790, in canonicalize_shape(shape, context)
   1788 except TypeError:
   1789   pass
-> 1790 raise _invalid_shape_error(shape, context)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>, DeviceArray(7, dtype=int64)).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

I have tried a solution where suggested at Pm.sampling_jax to sample a MvNormal(), and still, the error message occurred. Is there any idea what could be my problem?

My module for this is as follows:

#!/usr/bin/env python3
# Copyright 2022 The ZADR developer (Lee, Meongsu)
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the license.
#
#  ....
#  Module: PyMC model for ZADR
#  ....
#  Creator: Meongsu Lee, Ph.D.
#  Date:    09/29/2020
#  Update:  08/02/2022
#

import pandas as pd
import numpy as np
import aesara
import aesara.tensor as at
import pymc as pm
from typing import Union
from itertools import product
from aesara.tensor.subtensor import set_subtensor
from zadr.distributions.zad_distribution import ZeroAdjDirichlet
from zadr.models.hyper_means import hyper_mean
from zadr.models.hyper_precs import hyper_prec


def zadr_pymcmd(
    x: Union[np.ndarray, dict], y: np.ndarray, hpr_mu: dict, hpr_tau: dict, **kwarg
) -> pm.Model:
    """
    zadr_pymcmd:
    pymc model for ZADR

    Parameters
    ----------
    x: np.ndarray or dict
       independent variables (or feature variables)
    y: np.array
       dependent (or target) variable
    hpr_mu: dict
       dictionary of mean hyper priors
    hpr_tau: dict
       dictionary of precision hyper priors
    """

    M = y.shape[1]  # Total number of compositions
    m = M - 1  # the dimensionality of the simplex
    z = kwarg.get("latent", None)  # Get the latent covariates
    fixed = kwarg.get("fixed", None)  # Get the fixed covariates

    if isinstance(x, dict):
        p = x[list(x.keys())[0]].shape[1]
    else:
        p = x.shape[1]  # Number of explanatory variables

    with pm.Model() as ret_model:

        # Covariates
        if isinstance(x, dict):
            cap_xs = [
                pm.MutableData("X" + "_" + key[3:], db.copy()) for key, db in x.items()
            ]
        else:
            cap_x = pm.MutableData("X", x.copy())

        # Dependent variables
        cap_y = pm.MutableData("Y", y.copy())

        # Latent variables
        if z is not None:
            cap_z = pm.MutableData("Z", z.copy())
        else:
            pass

        # Fixed variables
        if fixed is not None:
            cap_fx = pm.MutableData("FX", fixed.copy())
        else:
            pass

        # rloc_nz = at.nonzero(at.all(at.neq(cap_y, 0.0), axis=-1))
        # rloc_z = at.nonzero(at.any(at.eq(cap_y, 0.0), axis=-1))
        rloc_nz = at.all(at.neq(cap_y, 0.0), axis=-1)
        rloc_z = at.any(at.eq(cap_y, 0.0), axis=-1)

        # Separate zero, non-zero components
        y2 = cap_y[rloc_z].copy()

        if isinstance(x, dict):
            x1 = at.stack([cap_x[rloc_nz].copy() for cap_x in cap_xs], axis=-1)
            x2 = at.stack([cap_x[rloc_z].copy() for cap_x in cap_xs], axis=-1)
            w1 = x1[:, 0, 0].reshape((-1, 1)).copy()
            w2 = x2[:, 0, 0].reshape((-1, 1)).copy()

        else:
            x1 = cap_x[rloc_nz].copy()
            x2 = cap_x[rloc_z].copy()
            w1 = x1[:, 0].reshape((-1, 1)).copy()
            w2 = x2[:, 0].reshape((-1, 1)).copy()

        # Make ln_y1, ln_y2
        ly2 = at.log(y2)

        # Hyper priors
        # --------------------------------------
        #
        τβs = hyper_prec("β", hpr_tau["β"])
        τγs = hyper_prec("γ", hpr_tau["γ"])

        μβs = hyper_mean("β", hpr_mu["β"])
        μγs = hyper_mean("γ", hpr_mu["γ"])

        # Precision candidates of ζ
        if z is not None:
            Σ = hyper_prec("ζ", hpr_tau["ζ"])
        else:
            pass

        if fixed is not None:
            # τηs = 1 / pm.Exponential(
            #     "τηs", lam=1e-3, shape=(𝔽𝕏.shape[1].eval(),)
            # )
            # 1
            # / pm.Gamma(
            #     "τη" + "{0}".format(chr(0x2080 + i)),
            #     alpha=0.1,
            #     beta=0.01,
            #   )

            τηs = at.stack(
                [
                    1
                    / pm.Uniform(
                        "τη" + "{0}".format(chr(0x2080 + i)),
                        lower=0,
                        upper=1000,
                    )
                    for i in range(cap_fx.get_value().shape[1])
                ]
            )

        # Priors, βⱼ = (β₀ⱼ,β₁ⱼ,...,β_{mj})
        # (11/21/20) Comment if..else statements
        β = at.stack(
            [
                at.stack(
                    [
                        pm.Normal(
                            "β"
                            + "{0}".format(chr(0x2080 + i))
                            + "{0}".format(chr(0x2080 + j)),
                            mu=μβs[i, j],
                            tau=τβs[i, j],
                        )
                        for j in range(m)
                    ]
                )
                for i in range(p)
            ]
        )
        β = at.concatenate([β, at.zeros((p, 1))], axis=1)

        γ = at.stack(
            [
                pm.Normal(
                    "γ" + "{0}".format(chr(0x2080 + i)),
                    mu=μγs[i],
                    tau=τγs[i],
                )
                for i in range(hpr_mu["γ"][0][0])
            ]
        )

        # ζ: coefficients of latent factor
        if z is not None and fixed is None:
            ζ = (
                pm.MvNormal(
                    "ζ",
                    mu=at.zeros((cap_z.get_value().shape[1],)),
                    chol=Σ[0],
                    shape=(cap_z.get_value().shape[1], cap_z.get_value().shape[1]),
                    # shape=(ℤ.shape[1].eval(),),
                )
                if hpr_tau["ζ"][0] == "LKJ"
                else pm.MvNormal(
                    "ζ",
                    mu=at.zeros((cap_z.get_value().shape[1],)),
                    cov=at.diag(at.pow(Σ, 2)),
                    shape=(cap_z.get_value().shape[1], cap_z.get_value().shape[1]),
                    # shape=(ℤ.shape[1].eval(),),
                )
            )

            λ = at.dot(cap_z, at.diag(ζ).reshape((-1, 1)))
            # λ = at.dot(ℤ, ζ.reshape((-1, 1)))

            # Parameters α₋₀ and α₀ with latent factors
            if isinstance(x, dict):
                nrtn_nz = at.concatenate(
                    [
                        at.tensordot(x1[:, :, i], β[:, i], axes=[[1], [0]]).reshape(
                            (-1, 1)
                        )
                        for i in range(M)
                    ],
                    axis=1,
                )
                lc_nz = at.exp(nrtn_nz + λ[rloc_nz])
                α_neg0 = lc_nz / at.sum(lc_nz, axis=1).reshape((-1, 1))

                nrtn_z = at.concatenate(
                    [
                        at.tensordot(x2[:, :, i], β[:, i], axes=[[1], [0]]).reshape(
                            (-1, 1)
                        )
                        for i in range(M)
                    ],
                    axis=1,
                )
                lc_z = at.exp(nrtn_z + λ[rloc_z])
                alpha_z = at.where(at.isinf(ly2), 0.0, lc_z)
                α_0 = alpha_z / at.sum(alpha_z, axis=1).reshape((-1, 1))
            else:
                α_neg0 = at.exp(at.dot(x1, β) + λ[rloc_nz]) / at.sum(
                    at.exp(at.dot(x1, β) + λ[rloc_nz]), axis=1
                ).reshape((-1, 1))
                lc_z = at.exp(at.dot(x2, β) + λ[rloc_z])
                alpha_z = at.where(at.isinf(ly2), 0.0, lc_z)
                α_0 = alpha_z / at.sum(alpha_z, axis=1).reshape((-1, 1))
        elif fixed is not None and z is None:
            μηs = at.stack(
                [
                    pm.Uniform(
                        "μη" + "{0}".format(chr(0x2080 + i)),
                        lower=-1.0,
                        upper=1.0,
                    )
                    for i in range(cap_fx.get_value().shape[1])
                ]
            )
            η = at.stack(
                [
                    pm.Normal(
                        "η" + "{0}".format(chr(0x2080 + i)),
                        mu=μηs[i],
                        tau=τηs[i],
                    )
                    for i in range(cap_fx.get_value().shape[1])
                ]
            )
            ο = at.dot(cap_fx, η.reshape((-1, 1)))

            # Parameters α₋₀ and α₀ with fixed factors
            if isinstance(x, dict):
                nrtn_nz = at.concatenate(
                    [
                        at.tensordot(x1[:, :, i], β[:, i], axes=[[1], [0]]).reshape(
                            (-1, 1)
                        )
                        for i in range(M)
                    ],
                    axis=1,
                )
                lc_nz = at.exp(nrtn_nz + ο[rloc_nz])
                α_neg0 = lc_nz / at.sum(lc_nz, axis=1).reshape((-1, 1))

                nrtn_z = at.concatenate(
                    [
                        at.tensordot(x2[:, :, i], β[:, i], axes=[[1], [0]]).reshape(
                            (-1, 1)
                        )
                        for i in range(M)
                    ],
                    axis=1,
                )
                lc_z = at.exp(nrtn_z + ο[rloc_z])
                alpha_z = at.where(at.isinf(ly2), 0.0, lc_z)
                α_0 = alpha_z / at.sum(alpha_z, axis=1).reshape((-1, 1))
            else:
                α_neg0 = at.exp(at.dot(x1, β) + ο[rloc_nz]) / at.sum(
                    at.exp(at.dot(x1, β) + ο[rloc_nz]), axis=1
                ).reshape((-1, 1))
                lc_z = at.exp(at.dot(x2, β) + ο[rloc_z])
                alpha_z = at.where(at.isinf(ly2), 0.0, lc_z)
                α_0 = alpha_z / at.sum(alpha_z, axis=1).reshape((-1, 1))
        elif z is not None and fixed is not None:
            ζ = (
                pm.MvNormal(
                    "ζ",
                    mu=at.zeros((cap_z.get_value().shape[1],)),
                    chol=Σ[0],
                    shape=(cap_z.get_value().shape[1], cap_z.get_value().shape[1]),
                    # shape=(ℤ.shape[1].eval(),),
                )
                if hpr_tau["ζ"][0] == "LKJ"
                else pm.MvNormal(
                    "ζ",
                    mu=at.zeros((cap_z.get_value().shape[1],)),
                    cov=at.diag(at.pow(Σ, 2)),
                    shape=(cap_z.get_value().shape[1], cap_z.get_value().shape[1]),
                    # shape=(ℤ.shape[1].eval(),),
                )
            )

            λ = at.dot(cap_z, at.diag(ζ).reshape((-1, 1)))
            μη = pm.Uniform("μη", lower=-1.0, upper=1.0)
            η = at.stack(
                [
                    pm.Normal("η" + "{0}".format(chr(0x2080 + i)), mu=μη, tau=τηs[i])
                    for i in range(cap_fx.get_value().shape[1])
                ]
            )
            ο = at.dot(cap_fx, η.reshape((-1, 1)))
            # Parameters α₋₀ and α₀ with fixed factors
            if isinstance(x, dict):
                nrtn_nz = at.concatenate(
                    [
                        at.tensordot(x1[:, :, i], β[:, i], axes=[[1], [0]]).reshape(
                            (-1, 1)
                        )
                        for i in range(M)
                    ],
                    axis=1,
                )
                lc_nz = at.exp(nrtn_nz + λ[rloc_nz] + ο[rloc_nz])
                α_neg0 = lc_nz / at.sum(lc_nz, axis=1).reshape((-1, 1))

                nrtn_z = at.concatenate(
                    [
                        at.tensordot(x2[:, :, i], β[:, i], axes=[[1], [0]]).reshape(
                            (-1, 1)
                        )
                        for i in range(M)
                    ],
                    axis=1,
                )
                lc_z = at.exp(nrtn_z + λ[rloc_nz] + ο[rloc_z])
                alpha_z = at.where(at.isinf(ly2), 0.0, lc_z)
                α_0 = alpha_z / at.sum(alpha_z, axis=1).reshape((-1, 1))
            else:
                α_neg0 = at.exp(at.dot(x1, β) + λ[rloc_nz] + ο[rloc_nz]) / at.sum(
                    at.exp(at.dot(x1, β) + λ[rloc_nz] + ο[rloc_nz]), axis=1
                ).reshape((-1, 1))
                lc_z = at.exp(at.dot(x2, β) + λ[rloc_z] + ο[rloc_z])
                alpha_z = at.where(at.isinf(ly2), 0.0, lc_z)
                α_0 = alpha_z / at.sum(alpha_z, axis=1).reshape((-1, 1))

        else:
            # Parameters α₋₀ and α₀ without latent factors
            if isinstance(x, dict):
                nrtn_nz = at.concatenate(
                    [
                        at.tensordot(x1[:, :, i], β[:, i], axes=[[1], [0]]).reshape(
                            (-1, 1)
                        )
                        for i in range(M)
                    ],
                    axis=1,
                )
                lc_nz = at.exp(nrtn_nz)
                α_neg0 = lc_nz / at.sum(lc_nz, axis=1).reshape((-1, 1))
                nrtn_z = at.concatenate(
                    [
                        at.tensordot(x2[:, :, i], β[:, i], axes=[[1], [0]]).reshape(
                            (-1, 1)
                        )
                        for i in range(M)
                    ],
                    axis=1,
                )

                lc_z = at.exp(nrtn_z)
                alpha_z = at.where(at.isinf(ly2), 0.0, lc_z)
                α_0 = alpha_z / at.sum(alpha_z, axis=1).reshape((-1, 1))

            else:
                α_neg0 = at.exp(at.dot(x1, β)) / at.sum(
                    at.exp(at.dot(x1, β)), axis=1
                ).reshape((-1, 1))

                lc_z = at.exp(at.dot(x2, β))

                alpha_z = at.where(at.isinf(ly2), 0.0, lc_z)
                α_0 = alpha_z / at.sum(alpha_z, axis=1).reshape((-1, 1))

        # Parameters ϕ
        ϕ_neg0 = at.exp(at.dot(w1, γ.reshape((-1, 1))))
        ϕ_0 = at.exp(at.dot(w2, γ.reshape((-1, 1))))

        # Grouped concentration parameters
        a_0 = α_0 * ϕ_0
        a_neg0 = α_neg0 * ϕ_neg0

        # Unified concentration parameters
        a = at.zeros((a_0.shape[0] + a_neg0.shape[0], a_0.shape[1]))
        a = set_subtensor(a[rloc_nz], a_neg0)
        a = set_subtensor(a[rloc_z], a_0)

        # Likelihood
        # {"nz": rloc_nz, "z": rloc_z}
        obs = ZeroAdjDirichlet("obs", a, observed=cap_y)

    return ret_model

Sorry for the long lines of the module. The model and distribution are still under development so it could be messy.

1 Like