Constructing a jaxified transform function

Hey! I’m trying to generate a jaxified log probability for the constrained variables, along with a forward transform (I may try to jaxify the backward and jacobian if it is easy!) So far I have the following for an eight schools problem:

treatment_effects = np.array([28, 8, -3, 7, -1, 1, 18, 12], dtype=np.float32)
treatment_stddevs = np.array(
    [15, 10, 16, 11, 9, 11, 10, 18], dtype=np.float32)

with pm.Model() as model:
  avg_effect = pm.Normal('avg_effect', 0., 10.)
  avg_stddev = pm.HalfNormal('avg_stddev', 10.)
  school_effects = pm.Normal('school_effects', shape=8)
  pm.Normal('observed', 
            avg_effect + avg_stddev * school_effects, 
            treatment_stddevs,
            observed=treatment_effects)
            
logp = pm_jax.get_jaxified_logp(
    model=pm.model.transform.conditioning.remove_value_transforms(model))

So far so good! I can evaluate and take gradients for this logp!

def get_forward_transform(model):
  def identity(x):
    return x
  
  transforms = {k: identity if v is None else v.backward for k, v in model.rvs_to_transforms.items()}
  
  def forward_transform(pt):
    return [transforms[k](v) for k, v in zip(model.free_RVs, pt)]

  return forward_transform

jax_transform = pm_jax.get_jaxified_graph(inputs=model.free_RVs, 
    outputs=forward_transform(model.free_RVs))

this throws

ValueError: Graph contains shared RandomType variables which cannot be safely replaced

any tips on computing this value?

Update: got it thanks to @ricardoV94 Needed to use the value_vars from the unconditioned model.

def get_jaxified_outputs(model):
  uc_model = pm.model.transform.conditioning.remove_value_transforms(model)

  logp = pm_jax.get_jaxified_logp(model=pm.model.transform.conditioning.remove_value_transforms(model))

  def identity(x):
    return x
  
  # We have different ideas of forward and backward!
  fwd_transforms = {k.name: identity if v is None else v.backward for k, v in model.rvs_to_transforms.items()}
  
  def forward_transform(pt):
    return [fwd_transforms[k.name](v) for k, v in zip(uc_model.value_vars, pt)]

  # We have different ideas of forward and backward!
  bwd_transforms = {k.name: identity if v is None else v.forward for k, v in model.rvs_to_transforms.items()}

  def backward_transform(pt):
    return [bwd_transforms[k.name](v) for k, v in zip(uc_model.value_vars, pt)]

  def none(x):
    return 0.

  ildjs = {k.name: none if v is None else v.log_jac_det for k, v in model.rvs_to_transforms.items()}
  
  def ildj(pt):
    tot = 0.
    return -pm.math.log(pm.math.sum([ildjs[k.name](v) for k, v in zip(uc_model.value_vars, pt)]))
  
  fwd = pm_jax.get_jaxified_graph(inputs=uc_model.value_vars, outputs=forward_transform(uc_model.value_vars))
  bwd = pm_jax.get_jaxified_graph(inputs=uc_model.value_vars, outputs=backward_transform(uc_model.value_vars))
  ildj = pm_jax.get_jaxified_graph(inputs=uc_model.value_vars, outputs=[ildj(uc_model.value_vars)])
  def ildj_wrap(args):
    return ildj(*args)[0]
  return logp, fwd, bwd, ildj_wrap

  
logp, fwd, bwd, ildj = get_jaxified_outputs(model)