Constructing a jaxified transform function

Hey! I’m trying to generate a jaxified log probability for the constrained variables, along with a forward transform (I may try to jaxify the backward and jacobian if it is easy!) So far I have the following for an eight schools problem:

``````treatment_effects = np.array([28, 8, -3, 7, -1, 1, 18, 12], dtype=np.float32)
treatment_stddevs = np.array(
[15, 10, 16, 11, 9, 11, 10, 18], dtype=np.float32)

with pm.Model() as model:
avg_effect = pm.Normal('avg_effect', 0., 10.)
avg_stddev = pm.HalfNormal('avg_stddev', 10.)
school_effects = pm.Normal('school_effects', shape=8)
pm.Normal('observed',
avg_effect + avg_stddev * school_effects,
treatment_stddevs,
observed=treatment_effects)

logp = pm_jax.get_jaxified_logp(
model=pm.model.transform.conditioning.remove_value_transforms(model))
``````

So far so good! I can evaluate and take gradients for this `logp`!

``````def get_forward_transform(model):
def identity(x):
return x

transforms = {k: identity if v is None else v.backward for k, v in model.rvs_to_transforms.items()}

def forward_transform(pt):
return [transforms[k](v) for k, v in zip(model.free_RVs, pt)]

return forward_transform

jax_transform = pm_jax.get_jaxified_graph(inputs=model.free_RVs,
outputs=forward_transform(model.free_RVs))
``````

this throws

``````ValueError: Graph contains shared RandomType variables which cannot be safely replaced
``````

any tips on computing this value?

Update: got it thanks to @ricardoV94 Needed to use the `value_vars` from the unconditioned model.

``````def get_jaxified_outputs(model):
uc_model = pm.model.transform.conditioning.remove_value_transforms(model)

logp = pm_jax.get_jaxified_logp(model=pm.model.transform.conditioning.remove_value_transforms(model))

def identity(x):
return x

# We have different ideas of forward and backward!
fwd_transforms = {k.name: identity if v is None else v.backward for k, v in model.rvs_to_transforms.items()}

def forward_transform(pt):
return [fwd_transforms[k.name](v) for k, v in zip(uc_model.value_vars, pt)]

# We have different ideas of forward and backward!
bwd_transforms = {k.name: identity if v is None else v.forward for k, v in model.rvs_to_transforms.items()}

def backward_transform(pt):
return [bwd_transforms[k.name](v) for k, v in zip(uc_model.value_vars, pt)]

def none(x):
return 0.

ildjs = {k.name: none if v is None else v.log_jac_det for k, v in model.rvs_to_transforms.items()}

def ildj(pt):
tot = 0.
return -pm.math.log(pm.math.sum([ildjs[k.name](v) for k, v in zip(uc_model.value_vars, pt)]))

fwd = pm_jax.get_jaxified_graph(inputs=uc_model.value_vars, outputs=forward_transform(uc_model.value_vars))
bwd = pm_jax.get_jaxified_graph(inputs=uc_model.value_vars, outputs=backward_transform(uc_model.value_vars))
ildj = pm_jax.get_jaxified_graph(inputs=uc_model.value_vars, outputs=[ildj(uc_model.value_vars)])
def ildj_wrap(args):
return ildj(*args)[0]
return logp, fwd, bwd, ildj_wrap

logp, fwd, bwd, ildj = get_jaxified_outputs(model)
``````