Recommendations for implementing differentiable sampling (reparameterization trick) in PyTensor?

Hi everyone,

I am new to PyTensor/ PyMC, so apologies in advance if this is a basic question.

Background: I am trying to port a simulation-based optimization workflow from TensorFlow to PyTensor.

The Problem: I need to learn the parameters of a distribution by minimizing a discrepancy loss (e.g., MMD) via SGD. This requires computing gradients through the sampling step.

  • In TensorFlow Probability, e.g., tfd.Normal(...).sample() automatically applies the reparameterization trick, allowing gradients to flow back to the parameters.

  • In PyMC/PyTensor, my understanding is that creating a random variable breaks the gradient graph.

My Workflow: Below I pasted some code example showing how I would implement it in TensorFlow. General idea:

  1. Define a parametric target distribution (parameters to be learned).

  2. Sample from a ground-truth distribution.

  3. Compute a discrepancy loss between the target samples and true samples.

  4. Use mini-batch SGD to update the target parameters.

My Question: What is the recommended pattern to achieve “differentiable sampling” in PyTensor/PyMC? Can I use pm.Normal (or .dist) objects directly for this, or do I need to manually implement the reparameterization trick for every distribution using raw PyTensor ops?

Any advice would be greatly appreciated!

import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt

from tqdm import tqdm
from elicito.losses import MMD2

tfd = tfp.distributions

epochs=100
B=2**7  # batch size
S=200   # sample size

mmd2 = MMD2(kernel="energy")
optimizer = tf.keras.optimizers.Adam(learning_rate=0.05)
# distribution that we want to recover
true_dist = tfd.Normal(loc=2.5, scale=1.0).sample([B, S])

class Model(tf.Module):
    def __init__(self):
        # parameters
        self.mu = tf.Variable(1.0, name="mu", trainable=True)
        self.log_sigma = tf.Variable(tf.math.log(2.7), name="log_sigma", trainable=True)
    def __call__(self):
        # parametric distribution
        return tfd.Normal(loc=self.mu, scale=tf.exp(self.log_sigma)).sample([B, S])

get_dist_samples = Model()

# training loop
res = dict(mu=[], sigma=[], loss=[])
for epoch in tqdm(range(epochs)):
    with tf.GradientTape() as tape:
        # generate samples
        dist_samples = get_dist_samples()
        # compute discrepancy loss
        loss = mmd2(true_dist, dist_samples)
        # extract trainable variables
        trainable_vars = get_dist_samples.trainable_variables
        # compute gradients of loss wrt parameters
        gradients = tape.gradient(loss, trainable_vars)
        # update parameters
        optimizer.apply_gradients(zip(gradients, trainable_vars))
        # store results
        res["mu"].append(trainable_vars[0].numpy())
        res["sigma"].append(tf.exp(trainable_vars[1].numpy()))
        res["loss"].append(loss)

plt.plot(res["mu"], "-", label="mu")
plt.plot(res["sigma"], "-", label="sigma")
plt.legend()

plt.plot(res["loss"], "-")
1 Like

For now you need to use the parametrization trick manually, you may want to weigh in on Stochastic gradients in pytensor · pymc-devs/pytensor · Discussion #1424 · GitHub

1 Like

Great, thanks for the quick reply and for pointing me to the discussion thread. I will then move any follow-up ideas over to that thread instead.

In terms of how to actually do it, you just need to write your model in the re-parameterized way. For example:

import pytensor.tensor as pt
import pymc as pm

mu, sigma = pt.tensor('mu', shape=()), pt.tensor('sigma', shape=())
z = pm.Normal.dist(0, 1, shape=())
x = mu + sigma * z
dx_dtheta = pt.grad(x, [mu, sigma])

sample_fn = pm.compile([mu, sigma], dx_dtheta)
sample_fn(mu=1.1, sigma=0.8)
[array(1.), array(-1.31518985)]

We did something like this in this VI refactor PR