Implementing BART with CAR prior

I’m trying to implement a BART model with a CAR prior as a spatial model. I’ve filtered a larger dataset to ~1600 records to test it. I’m working off this example Conditional Autoregressive (CAR) Models for Spatial Data — PyMC example gallery and running in a clean environment on an HPC node. The final model will ideally also have a hierarchical structure because I have two levels of spatial data. I’m testing that part and posted a question in response to an existing topic here.

My model works with the aspatial model (runtime ~6 minutes)

with pm.Model() as independent_model:
    sigma = pm.HalfCauchy("sigma", beta=10)
    beta0 = pm.Normal("beta0", mu=0.0, tau=1.0e-5)
    beta1 = pm.Normal("beta1", mu=0.0, tau=1.0e-5)
    # variance parameter of the independent random effect
    tau_ind = pm.Gamma("tau_ind", alpha=3.2761, beta=1.81)

    # independent random effect
    theta = pm.Normal("theta", mu=0, tau=tau_ind)
    # mean of the likelihood
    mu = pm.Deterministic("mu", pt.exp(beta0 + beta1 * X + theta))
    # likelihood of the observed data
    y_i = pm.Gamma("y_i", mu=mu, sigma=sigma, observed=y)

    # saving the residual between the observation and the response for the CBG
    res = pm.Deterministic("res", y - y_i)

    # sampling the model
    independent_idata = pm.sample(2000, tune=2000)

It seems to run okay with the spatial/CAR model with no BART (runtime ~ 2 hours with some updates to the priors and checking that data is scaled well)

n_obs = epa.shape[0]

with pm.Model() as car_model:
    sigma = pm.HalfCauchy("sigma", beta=10)
    beta0 = pm.Normal("beta0", mu=0.0, sigma=10)
    beta1 = pm.Normal("beta1", mu=0.0, sigma=10)
    # variance parameter of the independent random effect
    tau_ind = pm.Gamma("tau_ind", alpha=3.2761, beta=1.81)
    # variance parameter of the spatially dependent random effects
    tau_spat = pm.Gamma("tau_spat", alpha=10.0, beta=1.0)
    # prior for alpha
    alpha = pm.Beta("alpha", alpha=2, beta=2)

    # area-specific model parameters
    # independent random effect
    theta = pm.Normal("theta", mu=0, tau=tau_ind)
    # spatially dependent random effect
    phi = pm.CAR("phi", mu=np.zeros(n_obs), tau=tau_spat, alpha=alpha, W=adj_matrix)
    # mean of the likelihood
    mu = pm.Deterministic("mu", pt.exp(beta0 + beta1 * X + theta + phi))

    # likelihood of the observed data
    y_i = pm.Gamma("y_i", mu=mu, sigma=sigma, observed=y)

    # sampling the model
    car_idata = pm.sample(200, tune=200)

However, I don’t quite have the BART implementation working. I’ve tried two sampling approaches. First, using nutpie giving

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: native lowering)
<numba.core.base.OverloadSelector object at 0x147b9f48a8a0>, (csr_matrix,)
During: lowering "$454load_global.65 = global(sparse_constant: <Compressed Sparse Row sparse matrix of dtype 'float64'
	with 8754 stored elements and shape (1633, 1633)>

.My understanding of the error is that numba doesn’t always play well with sparse matrices. The matrix isn’t huge in this example, but it is much larger for the full dataset.

Second, using the standard sampling gives me:

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/pymc/sampling/mcmc.py:713, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
    710         auto_nuts_init = False
    712 initial_points = None
--> 713 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
    715 if nuts_sampler != "pymc":
    716     if not isinstance(step, NUTS):

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/pymc/sampling/mcmc.py:237, in assign_step_methods(model, step, methods, step_kwargs)
    229         selected = max(
    230             methods_list,
    231             key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence(  # type: ignore
    232                 var, has_gradient
    233             ),
    234         )
    235         selected_steps.setdefault(selected, []).append(var)
--> 237 return instantiate_steppers(model, steps, selected_steps, step_kwargs)

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/pymc/sampling/mcmc.py:138, in instantiate_steppers(model, steps, selected_steps, step_kwargs)
    136         args = step_kwargs.get(name, {})
    137         used_keys.add(name)
--> 138         step = step_class(vars=vars, model=model, **args)
    139         steps.append(step)
    141 unused_args = set(step_kwargs).difference(used_keys)

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/pymc_bart/pgbart.py:156, in PGBART.__init__(self, vars, num_particles, batch, model)
    153 self.leaves_shape = self.shape if not self.bart.separate_trees else 1
    155 if self.bart.split_prior.size == 0:
--> 156     self.alpha_vec = np.ones(self.X.shape[1])
    157 else:
    158     self.alpha_vec = self.bart.split_prior

IndexError: tuple index out of range

This is the model with BART:

with pm.Model() as bart_car_model:
    # beta, variance parameter of the independent random effect. beta0 incorporated with theta
    w = pmb.BART("w", X=X, Y=y, m=10, shape=(3, n_obs))
    # variance parameter of the spatially dependent random effects
    tau_spat = pm.Gamma("tau_spat", alpha=1.0, beta=1.0)

    # prior for alpha
    alpha = pm.Beta("alpha", alpha=2, beta=2)
    # area-specific model parameters
    # spatially dependent random effect
    phi = pm.CAR("phi", mu=np.zeros(n_obs), tau=tau_spat, alpha=alpha, W=adj_matrix)
    # mean of the likelihood
    mu = pm.Deterministic("mu", pt.exp(w[0] + w[1] * X + phi))

    # likelihood of the observed data
    y_i = pm.Gamma("y_i", mu=mu, sigma=w[2], observed=y)
    
    car_idata_bart = pm.sample()

@aloctavodia

Hi @jfhawkin,

X is expected to have 2 dimensions, with one covariate/feature per column. If your X is 1D then you will need to do something like X[:,None], or np.atleast_2d(X).T

In earlier versions, we used to catch this earlier and return a nicer message, but it seems we removed that check at some point. I will add it again.

Thanks, @aloctavodia.

I updated to use
X = df.TotPop.values[:,None] / 1000
which is (1633,1)

I still get the error

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: native lowering)
<numba.core.base.OverloadSelector object at 0x151e90760350>, (csr_matrix,)
During: lowering "$446load_global.65 = global(sparse_constant: <Compressed Sparse Row sparse matrix of dtype ‘float64’
with 8754 stored elements and shape (1633, 1633)>

The issue still looks to me like it’s with the spatial weight matrix, adj_matrix. This is stored as a sparse matrix and causes a matrix algebra error with numba.

Reviewing the examples, it looks like I shouldn’t be multiplying w[1] by X in sm.Deterministic. It’s already multiplied in the pmb.BART() statement. Is that right? I tried an adjusted model that only applies BART to the mean parameter for the X.

with pm.Model() as bart_car_model:
    # beta, variance parameter of the independent random effect. beta0 incorporated with theta
    w = pmb.BART("w", X=X, Y=y)
    sigma = pm.HalfNormal("sigma", 5)
    # variance parameter of the spatially dependent random effects
    tau_spat = pm.Gamma("tau_spat", alpha=1.0, beta=1.0)

    # prior for alpha
    alpha = pm.Beta("alpha", alpha=2, beta=2)

    # area-specific model parameters
    # spatially dependent random effect
    phi = pm.CAR("phi", mu=np.zeros(n_obs), tau=tau_spat, alpha=alpha, W=adj_matrix)

    # mean of the likelihood
    mu = pm.Deterministic("mu", pt.exp(w + phi))

    # likelihood of the observed data
    y_i = pm.Gamma("y_i", mu=mu, sigma=sigma, observed=y)

Quick update:

Everything runs if I change my weight matrix to non-sparse. The issue is that it’ll be too large if I use this approach for the full dataset. My subset is 1633x1633, but the full dataset is about 50 times larger in each dimension (expanding from 1 state to 50).

The model is also estimating it will take about 12 hours to run with this smaller subset and simple model using nutpie.

You are right you don’t want to multiply w times X. Could you share the entire traceback?

1 Like

The full traceback using the sparse matrix is below. I’d also tried running the non-BART model using car_idata = pm.sample(200, tune=200, nuts_sampler=“nutpie”, chains=4, blas_cores=16) to see if I could get a speedup. This syntax doesn’t work for the version with BART because it isn’t continuous (same reason you can’t run BART in Stan, I assume). Looking at the source code for nutpie, am I correct that it doesn’t have a multi-threading option? i.e., bart_car_idata = nutpie.sample(compiled_model, , chains=4, blas_cores=16). I figure that would help with the RAM issue from a non-sparse matrix because it would partition the matrix over multiple threads. Maybe there’s a numba workaround for the sparse matrix with a decorator?

/home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:381: UserWarning: Numba will use object mode to run AdvancedSetSubtensor's perform method
  warnings.warn(
/home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:381: UserWarning: Numba will use object mode to run SparseDot's perform method
  warnings.warn(
/home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:381: UserWarning: Numba will use object mode to run SparseDot's perform method
  warnings.warn(
/home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:381: UserWarning: Numba will use object mode to run Eigvalsh{lower=True}'s perform method
  warnings.warn(
/home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:381: UserWarning: Numba will use object mode to run StructuredDot's perform method
  warnings.warn(
---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
Cell In[33], line 1
----> 1 compiled_model = nutpie.compile_pymc_model(bart_car_model)
      2 bart_car_idata = nutpie.sample(compiled_model)

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/nutpie/compile_pymc.py:395, in compile_pymc_model(model, backend, gradient_backend, **kwargs)
    392     backend = "numba"
    394 if backend.lower() == "numba":
--> 395     return _compile_pymc_model_numba(model, **kwargs)
    396 elif backend.lower() == "jax":
    397     return _compile_pymc_model_jax(
    398         model, gradient_backend=gradient_backend, **kwargs
    399     )

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/nutpie/compile_pymc.py:207, in _compile_pymc_model_numba(model, **kwargs)
    200 with warnings.catch_warnings():
    201     warnings.filterwarnings(
    202         "ignore",
    203         message="Cannot cache compiled function .* as it uses dynamic globals",
    204         category=numba.NumbaWarning,
    205     )
--> 207     logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw)
    209 expand_shared_names = [var.name for var in expand_fn_pt.get_shared()]
    210 expand_numba_raw, c_sig_expand = _make_c_expand_func(
    211     n_dim, n_expanded, expand_fn, user_data, expand_shared_names, shared_data
    212 )

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/decorators.py:275, in cfunc.<locals>.wrapper(func)
    273 if cache:
    274     res.enable_caching()
--> 275 res.compile()
    276 return res

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/ccallback.py:68, in CFunc.compile(self)
     65 cres = self._cache.load_overload(self._sig,
     66                                  self._targetdescr.target_context)
     67 if cres is None:
---> 68     cres = self._compile_uncached()
     69     self._cache.save_overload(self._sig, cres)
     70 else:

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/ccallback.py:82, in CFunc._compile_uncached(self)
     79 sig = self._sig
     81 # Compile native function as well as cfunc wrapper
---> 82 return self._compiler.compile(sig.args, sig.return_type)

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/dispatcher.py:84, in _FunctionCompiler.compile(self, args, return_type)
     82     return retval
     83 else:
---> 84     raise retval

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/dispatcher.py:94, in _FunctionCompiler._compile_cached(self, args, return_type)
     91     pass
     93 try:
---> 94     retval = self._compile_core(args, return_type)
     95 except errors.TypingError as e:
     96     self._failed_cache[key] = e

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/dispatcher.py:107, in _FunctionCompiler._compile_core(self, args, return_type)
    104 flags = self._customize_flags(flags)
    106 impl = self._get_implementation(args, {})
--> 107 cres = compiler.compile_extra(self.targetdescr.typing_context,
    108                               self.targetdescr.target_context,
    109                               impl,
    110                               args=args, return_type=return_type,
    111                               flags=flags, locals=self.locals,
    112                               pipeline_class=self.pipeline_class)
    113 # Check typing error if object mode is used
    114 if cres.typing_error is not None and not flags.enable_pyobject:

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler.py:744, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    720 """Compiler entry point
    721 
    722 Parameter
   (...)
    740     compiler pipeline
    741 """
    742 pipeline = pipeline_class(typingctx, targetctx, library,
    743                           args, return_type, flags, locals)
--> 744 return pipeline.compile_extra(func)

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler.py:438, in CompilerBase.compile_extra(self, func)
    436 self.state.lifted = ()
    437 self.state.lifted_from = None
--> 438 return self._compile_bytecode()

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler.py:506, in CompilerBase._compile_bytecode(self)
    502 """
    503 Populate and run pipeline for bytecode input
    504 """
    505 assert self.state.func_ir is None
--> 506 return self._compile_core()

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler.py:485, in CompilerBase._compile_core(self)
    483         self.state.status.fail_reason = e
    484         if is_final_pipeline:
--> 485             raise e
    486 else:
    487     raise CompilerError("All available pipelines exhausted")

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler.py:472, in CompilerBase._compile_core(self)
    470 res = None
    471 try:
--> 472     pm.run(self.state)
    473     if self.state.cr is not None:
    474         break

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler_machinery.py:368, in PassManager.run(self, state)
    365 msg = "Failed in %s mode pipeline (step: %s)" % \
    366     (self.pipeline_name, pass_desc)
    367 patched_exception = self._patch_error(msg, e)
--> 368 raise patched_exception

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state)
    354 pass_inst = _pass_registry.get(pss).pass_inst
    355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
    357 else:
    358     raise BaseException("Legacy pass in use")

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
    309     mutated |= check(pss.run_initialization, internal_state)
    310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
    312 with SimpleTimer() as finalize_time:
    313     mutated |= check(pss.run_finalizer, internal_state)

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
    272 def check(func, compiler_state):
--> 273     mangled = func(compiler_state)
    274     if mangled not in (True, False):
    275         msg = ("CompilerPass implementations should return True/False. "
    276                "CompilerPass with name '%s' did not.")

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/typed_passes.py:112, in BaseTypeInference.run_pass(self, state)
    106 """
    107 Type inference and legalization
    108 """
    109 with fallback_context(state, 'Function "%s" failed type inference'
    110                       % (state.func_id.func_name,)):
    111     # Type inference
--> 112     typemap, return_type, calltypes, errs = type_inference_stage(
    113         state.typingctx,
    114         state.targetctx,
    115         state.func_ir,
    116         state.args,
    117         state.return_type,
    118         state.locals,
    119         raise_errors=self._raise_errors)
    120     state.typemap = typemap
    121     # save errors in case of partial typing

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/typed_passes.py:93, in type_inference_stage(typingctx, targetctx, interp, args, return_type, locals, raise_errors)
     91     infer.build_constraint()
     92     # return errors in case of partial typing
---> 93     errs = infer.propagate(raise_errors=raise_errors)
     94     typemap, restype, calltypes = infer.unify(raise_errors=raise_errors)
     96 return _TypingResults(typemap, restype, calltypes, errs)

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/typeinfer.py:1091, in TypeInferer.propagate(self, raise_errors)
   1088 force_lit_args = [e for e in errors
   1089                   if isinstance(e, ForceLiteralArg)]
   1090 if not force_lit_args:
-> 1091     raise errors[0]
   1092 else:
   1093     raise reduce(operator.or_, force_lit_args)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: native lowering)
<numba.core.base.OverloadSelector object at 0x15230dfa8860>, (csr_matrix,)
During: lowering "$446load_global.65 = global(sparse_constant: <Compressed Sparse Row sparse matrix of dtype 'float64'
	with 8754 stored elements and shape (1633, 1633)>
  Coords	Values
  (0, 26)	1.0
  (0, 512)	1.0
  (0, 513)	1.0
  (0, 1504)	1.0
  (0, 1547)	1.0
  (1, 200)	1.0
  (1, 483)	1.0
  (1, 536)	1.0
  (1, 1091)	1.0
  (1, 1108)	1.0
  (1, 1513)	1.0
  (1, 1518)	1.0
  (1, 1546)	1.0
  (1, 1551)	1.0
  (1, 1554)	1.0
  (2, 3)	1.0
  (2, 199)	1.0
  (2, 244)	1.0
  (2, 1622)	1.0
  (3, 2)	1.0
  (3, 11)	1.0
  (3, 199)	1.0
  (3, 249)	1.0
  (3, 1550)	1.0
  (3, 1622)	1.0
  :	:
  (1628, 1629)	1.0
  (1629, 18)	1.0
  (1629, 1456)	1.0
  (1629, 1458)	1.0
  (1629, 1487)	1.0
  (1629, 1627)	1.0
  (1629, 1628)	1.0
  (1630, 195)	1.0
  (1630, 198)	1.0
  (1630, 275)	1.0
  (1630, 276)	1.0
  (1630, 725)	1.0
  (1630, 1004)	1.0
  (1630, 1631)	1.0
  (1631, 193)	1.0
  (1631, 194)	1.0
  (1631, 195)	1.0
  (1631, 275)	1.0
  (1631, 590)	1.0
  (1631, 595)	1.0
  (1631, 596)	1.0
  (1631, 1630)	1.0
  (1632, 193)	1.0
  (1632, 196)	1.0
  (1632, 197)	1.0)" at /tmp/tmpgeij1h_b (25)
During: resolving callee type: type(CPUDispatcher(<function numba_funcified_fgraph at 0x1523059a1300>))
During: typing of call at /home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/nutpie/compile_pymc.py (558)

During: resolving callee type: type(CPUDispatcher(<function numba_funcified_fgraph at 0x1523059a1300>))
During: typing of call at /home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/nutpie/compile_pymc.py (558)


File ".conda/envs/pymc_env/lib/python3.12/site-packages/nutpie/compile_pymc.py", line 558:
        def extract_shared(x, user_data_):
            return inner(x)
            ^

Running with 10 warmup/10 samples to check runtime on the BART CAR model with the dense weight matrix using the PyMC sampler.

This runs in similar time with the sparse matrix (makes sense given PyMC documentation indicates it tries to make the matrix sparse). However, none of the other samplers works ‘nutpie’, ‘blackjax’, or ‘numpyro’. They give ValueError: Model can not be sampled with NUTS alone. Your model is probably not continuous. Running with nutpie.compile_pymc_model(bart_car_model) and the sparse matrix gives the error noted before. Running it with the dense matrix runs in 25 minutes and gives warnings from numba.

/home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:381: UserWarning: Numba will use object mode to run AdvancedSetSubtensor's perform method
  warnings.warn(
/home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:381: UserWarning: Numba will use object mode to run Eigvalsh{lower=True}'s perform method
  warnings.warn(

Sparse support is lacking a lot in the numba and JAX backends at the moment

1 Like

You might consider using an ICAR to model the spatial structure. It scales much better for large datasets. I did some benchmarking awhile back and the comparison was pretty striking.

image

The weight matrix will need to be dense. It works pretty well with numba/nutpie. Two drawbacks to be aware of:

  • Can’t estimate alpha. Assuming alpha = 1 gives you the faster logp evaluation because it eliminates the need to do any matrix algebra.
  • Prior predictive sampling will break. ICAR doesn’t have a random draw method. Presently, the only way to draw samples from it is to run MCMC.

I think the code should be:

  with pm.Model() as bart_car_model:
      # beta, variance parameter of the independent random effect. beta0 incorporated with theta
      w = pmb.BART("w", X=X, Y=y)
      sigma = pm.HalfNormal("sigma", 5)

      # variance parameter of the spatially dependent random effects
      tau_spat = pm.Gamma("tau_spat", alpha=1.0, beta=1.0)

      # area-specific model parameters
      # spatially dependent random effect
      phi = pm.ICAR("phi", W=adj_matrix)

      # mean of the likelihood
      mu = pm.Deterministic("mu", pt.exp(w + phi*(1/tau_spat)))
  
      # likelihood of the observed data
      y_i = pm.Gamma("y_i", mu=mu, sigma=sigma, observed=y)
2 Likes

Thanks! I’ll give this a try. I was hoping to test the alpha, but I can do that separately as a Moran I statistic. I’d read that ICAR was faster for large datasets.

1 Like

The ICAR specification is much faster! The only difficulty to scale it is that the class uses the non-sparse representation in the log likelihood. It looks pretty easy to change though (see my CHANGE ME comment below).

class ICAR(Continuous):
...
def logp(value, W, sigma, zero_sum_stdev):
        # convert adjacency matrix to edgelist representation
        # An edgelist is a pair of lists.
        # If node i and node j are connected then one list
        # will contain i and the other will contain j at the same
        # index value.
        # We only use the lower triangle here because adjacency
        # is a undirected connection.
        N = pt.shape(W)[-2]
        node1, node2 = pt.eq(pt.tril(W), 1).nonzero() # CHANGE ME: This is already constructing the pairwise representation. Would just need to switch to sparse and/or use a check on if the sparse matrix is passed (similar to CAR class)

        pairwise_difference = (-1 / (2 * sigma**2)) * pt.sum(pt.square(value[node1] - value[node2]))
        zero_sum = (
            -0.5 * pt.pow(pt.sum(value) / (zero_sum_stdev * N), 2)
            - pt.log(pt.sqrt(2.0 * np.pi))
            - pt.log(zero_sum_stdev * N)
        )

        return check_parameters(pairwise_difference + zero_sum, sigma > 0, msg="sigma > 0")
1 Like

This is a really nice suggestion. I’m going to look into adding support for sparse weight matrices.

Do you have any sense of whether passing a dense weight matrix is slowing down your sampler? Or is the problem more like the dense weight matrix is too big to hold in RAM?

I explored a sparse route a little bit when building the distribution. My hunch was that it wouldn’t matter much because once you generate edgelist representation, then you basically are in sparse territory. All the computation that occurs during sampling just uses the edgelist representation and the weight matrix is left behind. But I could be wrong in that so curious what you think.

Thanks for putting together the ICAR model. Whether you specify it as sparse or non-sparse doesn’t have any appreciable effect on runtime. It just affects the ability to run the various checks on shape, etc. Everything is eventually translated into a sparse representation.