Hello,
I am trying to understand the behaviour of a Deterministic Variable when using sample_posterior_predictive.
In the following example, the variable out is determinstically calculated as the sum of in_1 and in_2, both being observed RV.
I am a little bit surprised that the output of sample_posterior_predictive for in_1 and in_2 are not equal to the observed values while the output for out is equal to the sum of the observations for in_1 and in_2.
Understand deterministic behaviour
# pymc3 library
import pymc3 as pm
import theano.tensor as tt
# standard libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from matplotlib.backends.backend_pdf import PdfPages
env: MKL_THREADING_LAYER=GNU
def Model_1(meas_in_1, meas_in_2, meas_out):
with pm.Model() as Assy_model_1:
mu_in_1 = pm.Normal('mu_in_1', -5., 5.)
sigma_in_1 = pm.HalfCauchy('sd_in_1', 5.)
mu_in_2 = pm.Normal('mu_in_2', -5., 5.)
sigma_in_2 = pm.HalfCauchy('sd__in_2', 5.)
in_1 = pm.Normal('in_1', mu_in_1, sigma_in_1, observed=meas_in_1)
in_2 = pm.Normal('in_2', mu_in_2, sigma_in_2, observed=meas_in_2)
out_diff = in_1 + in_2
out = pm.Deterministic('out', out_diff)
return(Assy_model_1)
meas_in_1 = 2 + np.random.uniform(0., 1., size=100)
meas_in_2 = 5 + np.random.uniform(0., 1., size=100)
meas_out = meas_in_1 + meas_in_2
_, ax = plt.subplots()
ax.plot(meas_in_1, 'bo')
ax.plot(meas_in_2, 'rx')
ax.plot(meas_out, 'g.')
[<matplotlib.lines.Line2D at 0xbca6ac8>]
Model = Model_1(meas_in_1, meas_in_2, meas_out)
with Model:
trace = pm.sample()
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [sd__in_2, mu_in_2, sd_in_1, mu_in_1]
Sampling 2 chains: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 2000/2000 [00:09<00:00, 207.30draws/s]
The acceptance probability does not match the target. It is 0.8803458879453975, but should be close to 0.8. Try to increase the number of tuning steps.
df_trace= pm.trace_to_dataframe(trace)
df_trace.head()
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
mu_in_1 | mu_in_2 | sd_in_1 | sd__in_2 | out__0 | out__1 | out__2 | out__3 | out__4 | out__5 | ... | out__90 | out__91 | out__92 | out__93 | out__94 | out__95 | out__96 | out__97 | out__98 | out__99 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 2.526969 | 5.478829 | 0.267955 | 0.273721 | 8.153856 | 8.14412 | 7.441531 | 7.739894 | 7.865402 | 7.746085 | ... | 7.946815 | 7.636004 | 7.322069 | 8.434482 | 7.9841 | 7.772285 | 7.913376 | 7.681261 | 7.54143 | 7.359818 |
1 | 2.518501 | 5.469738 | 0.266575 | 0.337548 | 8.153856 | 8.14412 | 7.441531 | 7.739894 | 7.865402 | 7.746085 | ... | 7.946815 | 7.636004 | 7.322069 | 8.434482 | 7.9841 | 7.772285 | 7.913376 | 7.681261 | 7.54143 | 7.359818 |
2 | 2.568457 | 5.445675 | 0.277477 | 0.251927 | 8.153856 | 8.14412 | 7.441531 | 7.739894 | 7.865402 | 7.746085 | ... | 7.946815 | 7.636004 | 7.322069 | 8.434482 | 7.9841 | 7.772285 | 7.913376 | 7.681261 | 7.54143 | 7.359818 |
3 | 2.547500 | 5.495479 | 0.275151 | 0.290518 | 8.153856 | 8.14412 | 7.441531 | 7.739894 | 7.865402 | 7.746085 | ... | 7.946815 | 7.636004 | 7.322069 | 8.434482 | 7.9841 | 7.772285 | 7.913376 | 7.681261 | 7.54143 | 7.359818 |
4 | 2.540248 | 5.434334 | 0.270955 | 0.311199 | 8.153856 | 8.14412 | 7.441531 | 7.739894 | 7.865402 | 7.746085 | ... | 7.946815 | 7.636004 | 7.322069 | 8.434482 | 7.9841 | 7.772285 | 7.913376 | 7.681261 | 7.54143 | 7.359818 |
5 rows Γ 104 columns
ppc = pm.sample_posterior_predictive(model=Model, trace=trace, vars=Model.deterministics + Model.basic_RVs)
ppc.keys()
100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1000/1000 [00:01<00:00, 880.97it/s]
dict_keys(['sd_in_1', 'sd__in_2', 'out', 'mu_in_1', 'sd_in_1_log__', 'mu_in_2', 'sd__in_2_log__', 'in_1', 'in_2'])
ppc['in_1']
array([[2.3046228 , 2.51313731, 2.63050388, ..., 2.12266168, 2.51652621,
2.23607623],
[2.81470368, 2.107067 , 2.17810206, ..., 2.72700908, 2.37048934,
2.62947594],
[2.29541819, 2.15975066, 2.23415919, ..., 2.43940786, 3.0174447 ,
2.77409926],
...,
[2.44175511, 2.30148973, 2.76123438, ..., 3.04548749, 2.59122029,
2.56332023],
[2.23192466, 2.61834592, 2.16864666, ..., 2.37697546, 2.22281658,
2.64598264],
[2.66647161, 2.5922423 , 1.85761683, ..., 2.53917296, 2.26989612,
2.28435433]])
ppc['in_1'] + ppc['in_2']
array([[7.67497426, 8.02139751, 8.10059517, ..., 7.75776431, 7.62839001,
7.77139687],
[8.51122304, 7.25616768, 7.91580245, ..., 8.29146957, 7.88926232,
8.22490669],
[7.45079876, 7.73870336, 7.35301301, ..., 7.87123766, 8.13027904,
8.21990877],
...,
[7.85121949, 7.37005349, 8.31343194, ..., 8.59174903, 7.76359936,
7.81540863],
[7.44235084, 7.6824003 , 8.05634915, ..., 7.53345239, 7.91250437,
8.85873824],
[8.12081644, 8.14599994, 7.42489894, ..., 8.11697222, 8.26146199,
7.58653474]])
ppc['out']
array([[8.15385585, 8.1441204 , 7.4415312 , ..., 7.68126088, 7.54142967,
7.35981789],
[8.15385585, 8.1441204 , 7.4415312 , ..., 7.68126088, 7.54142967,
7.35981789],
[8.15385585, 8.1441204 , 7.4415312 , ..., 7.68126088, 7.54142967,
7.35981789],
...,
[8.15385585, 8.1441204 , 7.4415312 , ..., 7.68126088, 7.54142967,
7.35981789],
[8.15385585, 8.1441204 , 7.4415312 , ..., 7.68126088, 7.54142967,
7.35981789],
[8.15385585, 8.1441204 , 7.4415312 , ..., 7.68126088, 7.54142967,
7.35981789]])
meas_in_1 + meas_in_2
array([8.15385585, 8.1441204 , 7.4415312 , 7.73989375, 7.8654016 ,
7.74608471, 8.48040112, 8.81317155, 7.77495918, 8.49819 ,
8.18345198, 8.21189372, 7.98008477, 8.29767972, 8.67260615,
8.12818269, 8.63885564, 8.03442404, 8.11057144, 8.26208294,
8.36144186, 8.71137184, 8.32990704, 7.96227166, 8.10491104,
8.07091849, 7.20086461, 8.28202971, 8.25363121, 7.66939645,
8.08572065, 7.97524454, 8.39970787, 8.95612655, 8.03883754,
7.86347917, 8.33435643, 7.94535083, 7.40095602, 8.05540814,
7.80407147, 7.99487119, 8.01619831, 8.41293399, 7.69228935,
7.91484023, 8.18598319, 7.96839924, 7.54092267, 7.95161861,
8.362282 , 8.0168221 , 7.9195567 , 7.57040455, 7.68608696,
7.20945103, 8.19924339, 7.94047761, 7.24024496, 7.90641995,
7.77363673, 7.7329568 , 7.31899988, 8.17895453, 8.02107085,
7.9520007 , 7.97263951, 8.08014528, 8.64726758, 7.4888866 ,
7.87502672, 8.1399327 , 8.37915103, 8.24066968, 8.48276241,
8.08999262, 7.99412051, 7.41187795, 8.02883382, 8.3499598 ,
8.52884763, 7.87123229, 8.1308587 , 7.90066182, 8.0634844 ,
7.60986783, 7.16097218, 8.30135231, 8.26288221, 8.53414287,
7.94681548, 7.63600402, 7.32206889, 8.43448178, 7.98409968,
7.77228459, 7.91337593, 7.68126088, 7.54142967, 7.35981789])