Hey! I’m trying to generate a jaxified log probability for the constrained variables, along with a forward transform (I may try to jaxify the backward and jacobian if it is easy!) So far I have the following for an eight schools problem:
treatment_effects = np.array([28, 8, -3, 7, -1, 1, 18, 12], dtype=np.float32)
treatment_stddevs = np.array(
[15, 10, 16, 11, 9, 11, 10, 18], dtype=np.float32)
with pm.Model() as model:
avg_effect = pm.Normal('avg_effect', 0., 10.)
avg_stddev = pm.HalfNormal('avg_stddev', 10.)
school_effects = pm.Normal('school_effects', shape=8)
pm.Normal('observed',
avg_effect + avg_stddev * school_effects,
treatment_stddevs,
observed=treatment_effects)
logp = pm_jax.get_jaxified_logp(
model=pm.model.transform.conditioning.remove_value_transforms(model))
So far so good! I can evaluate and take gradients for this logp
!
def get_forward_transform(model):
def identity(x):
return x
transforms = {k: identity if v is None else v.backward for k, v in model.rvs_to_transforms.items()}
def forward_transform(pt):
return [transforms[k](v) for k, v in zip(model.free_RVs, pt)]
return forward_transform
jax_transform = pm_jax.get_jaxified_graph(inputs=model.free_RVs,
outputs=forward_transform(model.free_RVs))
this throws
ValueError: Graph contains shared RandomType variables which cannot be safely replaced
any tips on computing this value?