As a first project in PyMC3, I am starting from this example: A Hierarchical model for Rugby prediction — PyMC3 3.11.2 documentation In this moment, I am trying to re-write the likelihood using DensityDist()
, but I must be missing something. The model seems to work fine, but it fails at the very last step. This is the code to read the data:
import pandas as pd
import pymc3 as pm
df_all = pd.read_csv(pm.get_data("rugby.csv"), index_col=0)
df = df_all[["home_team", "away_team", "home_score", "away_score"]]
teams = df.home_team.unique()
teams = pd.DataFrame(teams, columns=["team"])
teams["i"] = teams.index
df = pd.merge(df, teams, left_on="home_team", right_on="team", how="left")
df = df.rename(columns={"i": "i_home"}).drop("team", 1)
df = pd.merge(df, teams, left_on="away_team", right_on="team", how="left")
df = df.rename(columns={"i": "i_away"}).drop("team", 1)
observed_home_goals = df.home_score.values
observed_away_goals = df.away_score.values
home_team = df.i_home.values
away_team = df.i_away.values
num_teams = len(df.i_home.drop_duplicates())
num_games = len(home_team)
This is the model:
with pm.Model() as model:
# global model parameters
home = pm.Flat("home")
sd_att = pm.HalfStudentT("sd_att", nu=3, sigma=2.5)
sd_def = pm.HalfStudentT("sd_def", nu=3, sigma=2.5)
# team-specific model parameters
atts_star = pm.Normal("atts_star", mu=0, sigma=sd_att, shape=num_teams-1)
atts_last = -tt.sum(atts_star, keepdims=True) + 1
atts = pm.Deterministic('atts', pm.math.concatenate([atts_star, atts_last], axis=0))
defs = pm.Normal("defs_star", mu=0, sigma=sd_def, shape=num_teams)
# likelihood of observed data
mu_home = tt.exp( home + atts[home_team] + defs[away_team] )
mu_away = tt.exp( atts[away_team] + defs[home_team] )
# Original likelihood
# home_points = pm.Poisson("home_points", mu=mu_home, observed=observed_home_goals)
# away_points = pm.Poisson("away_points", mu=mu_away, observed=observed_away_goals)
# Ind. Poisson
def logp(home_goal, away_goal, mu_home, mu_away):
base = home_goal * pm.math.log(mu_home) - mu_home + away_goal * pm.math.log(
mu_away) - mu_away - pm.distributions.dist_math.factln(home_goal) - pm.distributions.dist_math.factln(
away_goal)
return base.sum()
# Re-written likelihood
res = pm.DensityDist('res', logp=logp, observed={'home_goal':observed_home_goals, 'away_goal':observed_home_goals,
'mu_home': mu_home, 'mu_away':mu_away})
trace = pm.sample(4000, chains=3, tune=1000, return_inferencedata=True, target_accept=0.85, random_seed=123456)
Everything seems to work when I use the original likelihood (see the commented lines). However, when I use the DensityDist()
specification,the sampler seems to fail after drawing the samples. I get the following error:
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (3 chains in 4 jobs)
NUTS: [defs_star, atts_star, sd_def, sd_att, home]
█Sampling 3 chains for 1_000 tune and 4_000 draw iterations (3_000 + 12_000 draws total) took 10 seconds.
Traceback (most recent call last):
File "/home/non1987/anaconda3/envs/football_analytics/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3427, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-4-131353e97490>", line 55, in <module>
trace = pm.sample(4000, chains=3, tune=1000, return_inferencedata=True, target_accept=0.85, random_seed=123456)
File "/home/non1987/anaconda3/envs/football_analytics/lib/python3.8/site-packages/pymc3/sampling.py", line 639, in sample
idata = arviz.from_pymc3(trace, **ikwargs)
File "/home/non1987/anaconda3/envs/football_analytics/lib/python3.8/site-packages/arviz/data/io_pymc3.py", line 563, in from_pymc3
return PyMC3Converter(
File "/home/non1987/anaconda3/envs/football_analytics/lib/python3.8/site-packages/arviz/data/io_pymc3.py", line 171, in __init__
self.observations, self.multi_observations = self.find_observations()
File "/home/non1987/anaconda3/envs/football_analytics/lib/python3.8/site-packages/arviz/data/io_pymc3.py", line 184, in find_observations
multi_observations[key] = val.eval() if hasattr(val, "eval") else val
File "/home/non1987/anaconda3/envs/football_analytics/lib/python3.8/site-packages/theano/graph/basic.py", line 554, in eval
self._fn_cache[inputs] = theano.function(inputs, self)
File "/home/non1987/anaconda3/envs/football_analytics/lib/python3.8/site-packages/theano/compile/function/__init__.py", line 337, in function
fn = pfunc(
File "/home/non1987/anaconda3/envs/football_analytics/lib/python3.8/site-packages/theano/compile/function/pfunc.py", line 524, in pfunc
return orig_function(
File "/home/non1987/anaconda3/envs/football_analytics/lib/python3.8/site-packages/theano/compile/function/types.py", line 1970, in orig_function
m = Maker(
File "/home/non1987/anaconda3/envs/football_analytics/lib/python3.8/site-packages/theano/compile/function/types.py", line 1584, in __init__
fgraph, additional_outputs = std_fgraph(inputs, outputs, accept_inplace)
File "/home/non1987/anaconda3/envs/football_analytics/lib/python3.8/site-packages/theano/compile/function/types.py", line 188, in std_fgraph
fgraph = FunctionGraph(orig_inputs, orig_outputs, update_mapping=update_mapping)
File "/home/non1987/anaconda3/envs/football_analytics/lib/python3.8/site-packages/theano/graph/fg.py", line 162, in __init__
self.import_var(output, reason="init")
File "/home/non1987/anaconda3/envs/football_analytics/lib/python3.8/site-packages/theano/graph/fg.py", line 330, in import_var
self.import_node(var.owner, reason=reason)
File "/home/non1987/anaconda3/envs/football_analytics/lib/python3.8/site-packages/theano/graph/fg.py", line 383, in import_node
raise MissingInputError(error_msg, variable=var)
theano.graph.fg.MissingInputError: Input 0 of the graph (indices start from 0), used to compute AdvancedSubtensor1(defs_star, TensorConstant{[4 5 3 0 5..1 3 4 0 5]}), was not provided and not given a value. Use the Theano flag exception_verbosity='high', for more information on this error.
As far as I understand, the logp
function I wrote should match the original implementation exactly – I even included the normalization. I am having a hard time understanding the error and I would appreciate any guidance.