Linear regression using DensityDist

I’m playing around with the linear regression example in the pymc3 getting started documentaion to test out DensityDist. In order to test DensityDist I have swapped out the original normal distribution with DensityDist and my own implementation of the normal distribution. See the code below. The problem is that arviz.plot_trace fails when I try to plot the trace after sampling, see stack trace below.

import pymc3 as pm
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc3 as pm
# import aesara.tensor as tt
import theano.tensor as tt

# %%
# %config InlineBackend.figure_format = 'retina'
# Initialize random number generator
RANDOM_SEED = 8927
np.random.seed(RANDOM_SEED)
az.style.use("arviz-darkgrid")


# %% Generate data
# True parameter values
alpha, sigma = 1, 1
beta = [1, 2.5]

# Size of dataset
size = 100

# Predictor variable
X1 = np.random.randn(size)
X2 = np.random.randn(size) * 0.2

# Simulate outcome variable
Y = alpha + beta[0] * X1 + beta[1] * X2 + np.random.randn(size) * sigma

# %% Plot the data
fig, axes = plt.subplots(1, 2, sharex=True, figsize=(10, 4))
axes[0].scatter(X1, Y, alpha=0.6)
axes[1].scatter(X2, Y, alpha=0.6)
axes[0].set_ylabel("Y")
axes[0].set_xlabel("X1")
axes[1].set_xlabel("X2")

# %% Do inference with DensityDist
densitydist_basic_model = pm.Model()


def logp(value, mu_logp, sigma_logp):
    return (-1 / sigma_logp * (value - mu_logp) ** 2 + tt.log(1 / sigma_logp / np.pi / 2.0)) / 2.0


with densitydist_basic_model:
    # Priors for unknown model parameters
    alpha = pm.Normal("alpha", mu=0, sigma=10)
    beta = pm.Normal("beta", mu=0, sigma=10, shape=2)
    sigma = pm.HalfNormal("sigma", sigma=1)
    mu = alpha + beta[0] * X1 + beta[1] * X2
    Y_obs = pm.DensityDist("Y_obs", logp, observed=dict(value=Y, mu_logp=mu, sigma_logp=sigma))

with densitydist_basic_model:
    trace = pm.sample(500, return_inferencedata=False)
    az.plot_trace(trace)

The stack trace:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, beta, alpha]
Sampling 4 chains for 1_000 tune and 500 draw iterations (4_000 + 2_000 draws total) took 1 seconds.
Traceback (most recent call last):
  File "/home/gw/.pyenv/versions/anaconda3-2021.05/envs/pymc/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3457, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-16-e7a04ef7daec>", line 17, in <module>
    trace = pm.sample(500, return_inferencedata=False)
  File "/home/gw/.pyenv/versions/anaconda3-2021.05/envs/pymc/lib/python3.9/site-packages/pymc3/sampling.py", line 639, in sample
    idata = arviz.from_pymc3(trace, **ikwargs)
  File "/home/gw/.pyenv/versions/anaconda3-2021.05/envs/pymc/lib/python3.9/site-packages/arviz/data/io_pymc3_3x.py", line 580, in from_pymc3
    return PyMC3Converter(
  File "/home/gw/.pyenv/versions/anaconda3-2021.05/envs/pymc/lib/python3.9/site-packages/arviz/data/io_pymc3_3x.py", line 181, in __init__
    self.observations, self.multi_observations = self.find_observations()
  File "/home/gw/.pyenv/versions/anaconda3-2021.05/envs/pymc/lib/python3.9/site-packages/arviz/data/io_pymc3_3x.py", line 194, in find_observations
    multi_observations[key] = val.eval() if hasattr(val, "eval") else val
  File "/home/gw/.pyenv/versions/anaconda3-2021.05/envs/pymc/lib/python3.9/site-packages/theano/graph/basic.py", line 554, in eval
    self._fn_cache[inputs] = theano.function(inputs, self)
  File "/home/gw/.pyenv/versions/anaconda3-2021.05/envs/pymc/lib/python3.9/site-packages/theano/compile/function/__init__.py", line 337, in function
    fn = pfunc(
  File "/home/gw/.pyenv/versions/anaconda3-2021.05/envs/pymc/lib/python3.9/site-packages/theano/compile/function/pfunc.py", line 524, in pfunc
    return orig_function(
  File "/home/gw/.pyenv/versions/anaconda3-2021.05/envs/pymc/lib/python3.9/site-packages/theano/compile/function/types.py", line 1970, in orig_function
    m = Maker(
  File "/home/gw/.pyenv/versions/anaconda3-2021.05/envs/pymc/lib/python3.9/site-packages/theano/compile/function/types.py", line 1584, in __init__
    fgraph, additional_outputs = std_fgraph(inputs, outputs, accept_inplace)
  File "/home/gw/.pyenv/versions/anaconda3-2021.05/envs/pymc/lib/python3.9/site-packages/theano/compile/function/types.py", line 188, in std_fgraph
    fgraph = FunctionGraph(orig_inputs, orig_outputs, update_mapping=update_mapping)
  File "/home/gw/.pyenv/versions/anaconda3-2021.05/envs/pymc/lib/python3.9/site-packages/theano/graph/fg.py", line 162, in __init__
    self.import_var(output, reason="init")
  File "/home/gw/.pyenv/versions/anaconda3-2021.05/envs/pymc/lib/python3.9/site-packages/theano/graph/fg.py", line 330, in import_var
    self.import_node(var.owner, reason=reason)
  File "/home/gw/.pyenv/versions/anaconda3-2021.05/envs/pymc/lib/python3.9/site-packages/theano/graph/fg.py", line 383, in import_node
    raise MissingInputError(error_msg, variable=var)
theano.graph.fg.MissingInputError: Input 0 of the graph (indices start from 0), used to compute Subtensor{int64}(beta, Constant{1}), was not provided and not given a value. Use the Theano flag exception_verbosity='high', for more information on this error.

@ricardoV94 Could you take a look at this?

I am not familiar with arviz where the problem seems to originate. I suggest you try to update arviz to the latest version if you are not already using it, and if that does not fix it, to set return_inferencedata=True when you do pm.sample.

I get the same problem with return_inferencedata=True and it’s the latest version 0.11.4

It is a know problem, the issue is that you are passing random variables in the observed dictionary so when trying to convert to inferencedata and populate the observed_data group you get nonsensical results which result in an error.

You need to do return_inferencedata=True, idata_kwargs=dict(density_dist_obs=False) to tell ArviZ to not interpret the observed argument of the density dist as observed_data (always also a good idea to update to latest arviz as Ricardo suggests, I think this has been around for a few versions but I don’t really remember)

2 Likes