Hi everyone,
I am new to PyTensor/ PyMC, so apologies in advance if this is a basic question.
Background: I am trying to port a simulation-based optimization workflow from TensorFlow to PyTensor.
The Problem: I need to learn the parameters of a distribution by minimizing a discrepancy loss (e.g., MMD) via SGD. This requires computing gradients through the sampling step.
-
In TensorFlow Probability, e.g.,
tfd.Normal(...).sample()automatically applies the reparameterization trick, allowing gradients to flow back to the parameters. -
In PyMC/PyTensor, my understanding is that creating a random variable breaks the gradient graph.
My Workflow: Below I pasted some code example showing how I would implement it in TensorFlow. General idea:
-
Define a parametric target distribution (parameters to be learned).
-
Sample from a ground-truth distribution.
-
Compute a discrepancy loss between the target samples and true samples.
-
Use mini-batch SGD to update the target parameters.
My Question: What is the recommended pattern to achieve “differentiable sampling” in PyTensor/PyMC? Can I use pm.Normal (or .dist) objects directly for this, or do I need to manually implement the reparameterization trick for every distribution using raw PyTensor ops?
Any advice would be greatly appreciated!
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
from tqdm import tqdm
from elicito.losses import MMD2
tfd = tfp.distributions
epochs=100
B=2**7 # batch size
S=200 # sample size
mmd2 = MMD2(kernel="energy")
optimizer = tf.keras.optimizers.Adam(learning_rate=0.05)
# distribution that we want to recover
true_dist = tfd.Normal(loc=2.5, scale=1.0).sample([B, S])
class Model(tf.Module):
def __init__(self):
# parameters
self.mu = tf.Variable(1.0, name="mu", trainable=True)
self.log_sigma = tf.Variable(tf.math.log(2.7), name="log_sigma", trainable=True)
def __call__(self):
# parametric distribution
return tfd.Normal(loc=self.mu, scale=tf.exp(self.log_sigma)).sample([B, S])
get_dist_samples = Model()
# training loop
res = dict(mu=[], sigma=[], loss=[])
for epoch in tqdm(range(epochs)):
with tf.GradientTape() as tape:
# generate samples
dist_samples = get_dist_samples()
# compute discrepancy loss
loss = mmd2(true_dist, dist_samples)
# extract trainable variables
trainable_vars = get_dist_samples.trainable_variables
# compute gradients of loss wrt parameters
gradients = tape.gradient(loss, trainable_vars)
# update parameters
optimizer.apply_gradients(zip(gradients, trainable_vars))
# store results
res["mu"].append(trainable_vars[0].numpy())
res["sigma"].append(tf.exp(trainable_vars[1].numpy()))
res["loss"].append(loss)
plt.plot(res["mu"], "-", label="mu")
plt.plot(res["sigma"], "-", label="sigma")
plt.legend()
plt.plot(res["loss"], "-")