How to use Neural Networks in PyMC

Hi all,

I am trying NN with Flax, referring to the following article.

The calculation works, but the parameter values do not change at all.
If anyone knows the cause, could you please tell me the solution?

I would be thankful if you could help me with this question.

import matplotlib.pyplot as plt
import numpy as np

import pytensor
import pytensor.tensor as pt
from pytensor.graph import Apply, Op
from pytensor.link.jax.dispatch import jax_funcify

import jax
import jax.numpy as jnp

import pymc as pm
import pymc.sampling.jax

import tensorflow_datasets as tfds

from flax import linen as nn
from flax.core import freeze

def get_datasets():
    """Load MNIST train and test datasets into memory."""

    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    train_ds['image'] = np.float32(train_ds['image']) / 255.
    test_ds['image'] = np.float32(test_ds['image']) / 255.
    return train_ds, test_ds

train, test = get_datasets()
train["image"] = train["image"][:1_000]
train["label"] = train["label"][:1_000]
train["image"].shape
plt.imshow(train["image"][0])
plt.title(train["label"][0]);

class CNN(nn.Module):
    """A simple CNN model."""

    @nn.compact
    def __call__(self, x):
        # Convolution layer
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

        # Convolution layer
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

        # Dense layer
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)

        # Output layer
        x = nn.Dense(features=10)(x)

        return x

cnn = CNN()

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))

n_params = 0
for layer in params["params"].values():
    n_params += layer["kernel"].size + layer.get("bias", np.array(())).size
n_params
cnn.apply(params, train["image"][0:1])
treedef = jax.tree_util.tree_structure(params)

def cnn_op_jax(flat_params, images):
    unflat_params = jax.tree_util.tree_unflatten(treedef, flat_params)
    return cnn.apply(unflat_params, images)

jitted_cnn_op_jax = jax.jit(cnn_op_jax)

class CNNOp(Op):
    def make_node(self, *inputs):
        # Convert our inputs to symbolic variables
        inputs = [pt.as_tensor_variable(inp) for inp in inputs]
        # Assume the output to always be a float64 matrix
        outputs = [pt.matrix(dtype="float64")]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        *flat_params, images = inputs
        result = jitted_cnn_op_jax(flat_params, images)
        outputs[0][0] = np.asarray(result, dtype="float64")

    def grad(self, inputs, output_gradients):
        raise NotImplementedError("PyTensor gradient of CNNOp not implemented")

@jax_funcify.register(CNNOp)
def cnn_op_jax_funcify(op, **kwargs):
    def perform(*inputs):
        *flat_params, images = inputs
        return cnn_op_jax(flat_params, images)
    return perform

cnn_op = CNNOp()

with pm.Model() as model:
    images = pm.Data("images", train["image"], mutable=True)
    #print(images.shape.eval()[0])

    weights_prior = []
    for layer_name, layer in params["params"].items():
        for layer_weights_name, layer_weights in sorted(layer.items()):
            prior_name = f"{layer_name}_{layer_weights_name}"
            layer_weights_prior = pm.Normal(prior_name, 0, 1, shape=layer_weights.shape)
            weights_prior.append(layer_weights_prior)

    logitp_classes = cnn_op(*weights_prior, images)
    logitp_classes = pt.specify_shape(logitp_classes, (images.shape.eval()[0], 10))
    label = pm.Categorical("label", logit_p=logitp_classes, observed=train["label"])

pm.model_to_graphviz(model)
with model:
    idata = pm.sampling.jax.sample_numpyro_nuts(draws=100, tune=100, chains=1) # pm.sample(500, tune=500, chains=3, cores=2, nuts_sampler="numpyro")

print(idata.posterior.Dense_0_bias)
array([[[-0.7735232 , -0.12957669, -0.61163926, ...,  0.8304822 ,
         -0.25725549, -0.42333245],
        [-0.7735232 , -0.12957669, -0.61163926, ...,  0.8304822 ,
         -0.25725549, -0.42333245],
        [-0.7735232 , -0.12957669, -0.61163926, ...,  0.8304822 ,
         -0.25725549, -0.42333245],
        ...,
        [-0.7735232 , -0.12957669, -0.61163926, ...,  0.8304822 ,
         -0.25725549, -0.42333245],
        [-0.7735232 , -0.12957669, -0.61163926, ...,  0.8304822 ,
         -0.25725549, -0.42333245],
        [-0.7735232 , -0.12957669, -0.61163926, ...,  0.8304822 ,
         -0.25725549, -0.42333245]]])

The model was just an example, the priors are likely terrible for inference. You should check model convergence first or start from a simpler model. All draws being the same from numpyro usually means you had 100% divergences.

@ricardoV94

Thank you for your reply!
I have tried simpler data and models and was able to build a model successfully.

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
from pymc.sampling.jax import sample_numpyro_nuts, sample_blackjax_nuts
import pytensor
import pytensor.tensor as pt
from pytensor.graph import Apply, Op
from pytensor.link.jax.dispatch import jax_funcify
import seaborn as sns
import jax
import jax.numpy as jnp
from flax import linen as nn

from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import scale

%config InlineBackend.figure_format = 'retina'
floatX = pytensor.config.floatX
RANDOM_SEED = 9927
rng = np.random.default_rng(RANDOM_SEED)
az.style.use("arviz-darkgrid")
X, Y = make_moons(noise=0.2, random_state=0, n_samples=1000)
X = scale(X)
X = X.astype(floatX)
Y = Y.astype(floatX)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.5)

fig, ax = plt.subplots()
ax.scatter(X[Y == 0, 0], X[Y == 0, 1], color="C0", label="Class 0")
ax.scatter(X[Y == 1, 0], X[Y == 1, 1], color="C1", label="Class 1")
sns.despine()
ax.legend()
ax.set(xlabel="X", ylabel="Y", title="Toy binary classification data set");

class MLP(nn.Module):
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=5, use_bias=False)(x)
        x = nn.tanh(x)
        
        x = nn.Dense(features=5, use_bias=False)(x)
        x = nn.tanh(x)
        
        x = nn.Dense(features=1, use_bias=False)(x)
        x = nn.sigmoid(x)
        
        return x

mlp = MLP()

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

init_params = mlp.init(rng, jnp.zeros((X_train.shape)))

treedef = jax.tree_util.tree_structure(init_params)
from pytensor.graph.basic import Apply, Variable


def mlp_op_jax(flat_params, x):
    unflat_params = jax.tree_util.tree_unflatten(treedef, flat_params)
    out = mlp.apply(unflat_params, x)
    return out

jitted_mlp_op_jax = jax.jit(mlp_op_jax)

class MLPOp(Op):
    def make_node(self, *inputs):
        inputs = [pt.as_tensor_variable(inp) for inp in inputs]
        outputs = [pt.matrix(dtype="float64")]
        return Apply(self, inputs, outputs)
    
    def perform(self, node, inputs, outputs):
        *flat_params, x = inputs
        result = jitted_mlp_op_jax(flat_params, x)
        outputs[0][0] = np.asarray(result, dtype="float64")
        
    def grad(self, inputs, output_gradients):
        raise NotImplementedError("Not Implemented")
    
@jax_funcify.register(MLPOp)
def mlp_op_jax_funcify(op, **kwargs):
    def perform(*inputs):
        *flat_params, x = inputs
        return mlp_op_jax(flat_params, x)
    return perform

mlp_op = MLPOp()    
with pm.Model() as model:
    weight1 = pm.Normal("Dense_0_kernel", mu=0, sigma=1, shape=(2, 5))
    weight2 = pm.Normal("Dense_1_kernel", mu=0, sigma=1, shape=(5, 5))
    weight3 = pm.Normal("Dense_2_kernel", mu=0, sigma=1, shape=(5, 1))
    
    weight_prior = [weight1, weight2, weight3]
    
    act_out = mlp_op(*weight_prior, X_train)
    act_out = pt.specify_shape(act_out, (len(X_train), 1))
    
    out = pm.Bernoulli(
        "out",
        act_out,
        observed=Y_train.reshape(len(Y_train), 1),
    )
    
with model:
    idata = sample_numpyro_nuts(draws=1000, tune=500, chains=3)
with model:
    weight_prior = [weight1, weight2, weight3]
    
    act_out = mlp_op(*weight_prior, X_test)
    act_out = pt.specify_shape(act_out, (len(X_test), 1))
    
    out = pm.Bernoulli(
        "out_test",
        act_out
    )
with model:
    ppc = pm.sample_posterior_predictive(idata, var_names=["out_test"])
    
pred = ppc.posterior_predictive["out_test"].mean(("chain", "draw")).squeeze() > 0.5

fig, ax = plt.subplots()
ax.scatter(X_test[pred == 0, 0], X_test[pred == 0, 1], color="C0")
ax.scatter(X_test[pred == 1, 0], X_test[pred == 1, 1], color="C1")
sns.despine()
ax.set(title="Predicted labels in testing set", xlabel="X", ylabel="Y");

@ricardoV94

However, I do not know how to write code with gradient and cannot use any sampler other than Jax.
I am a beginner in Op with gradient and the following code returns “ValueError: MLPOp returned the wrong number of gradient terms”.

I would be thankful if you could help me with this question.

class MLP(nn.Module):
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=5, use_bias=False)(x)
        x = nn.tanh(x)
        
        x = nn.Dense(features=5, use_bias=False)(x)
        x = nn.tanh(x)
        
        x = nn.Dense(features=1, use_bias=False)(x)
        x = nn.sigmoid(x)
        
        return x

mlp = MLP()

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

init_params = mlp.init(rng, jnp.zeros((X_train.shape)))

treedef = jax.tree_util.tree_structure(init_params)
flat_params = jax.tree_util.tree_flatten(init_params)[0]
def mlp_op_jax(flat_params, x):
    unflat_params = jax.tree_util.tree_unflatten(treedef, flat_params)
    out = mlp.apply(unflat_params, x)
    return out

jitted_mlp_op_jax = jax.jit(mlp_op_jax)

mlp_op_jax(flat_params, X_train[:3, :])
def vjp_mlp_op_jax(flat_params, x, gz):
    _, vjp_fn = jax.vjp(mlp_op_jax, flat_params, x)
    vjp = vjp_fn(gz)[0]
    return vjp

jitted_vjp_mlp_op_jax = jax.jit(vjp_mlp_op_jax)

vjp_mlp_op_jax(flat_params, X_train[:3, :], jnp.ones((3, 1)))
class MLPOp(Op):
    def make_node(self, *inputs):
        inputs = [pt.as_tensor_variable(inp) for inp in inputs]
        outputs = [pt.matrix(dtype="float64")]
        return Apply(self, inputs, outputs)
    
    def perform(self, node, inputs, outputs):
        *flat_params, x = inputs
        result = jitted_mlp_op_jax(flat_params, x)
        outputs[0][0] = np.asarray(result, dtype="float64")
        
    def grad(self, inputs, output_gradients):
        *flat_params, x = inputs
        (gz,) = output_gradients
        
        result = vjp_mlp_op(flat_params, x, gz)
        return result
    
class VJPMLPOp(Op):
    def make_node(self, *inputs):
        flat_params, x, gz = inputs
        inputs = [pt.as_tensor_variable(inp) for inp in flat_params] + [pt.as_tensor_variable(x)] + [pt.as_tensor_variable(gz)]
        outputs = [param.type() for param in flat_params]
        return Apply(self, inputs, outputs)
    
    def perform(self, node, inputs, outputs):
        *flat_params, x, gz = inputs
        print(flat_params, x, gz)
        result = jitted_vjp_mlp_op_jax(flat_params, x, gz)
       
        for i in range(len(result)):
            outputs[i][0] = np.asarray(result[i], dtype="float64")       
    
@jax_funcify.register(MLPOp)
def mlp_op_jax_funcify(op, **kwargs):
    def perform(*inputs):
        *flat_params, x = inputs
        return mlp_op_jax(flat_params, x)
    return perform

mlp_op = MLPOp()   
vjp_mlp_op = VJPMLPOp() 
pytensor.gradient.verify_grad(mlp_op, (*flat_params, X_train), rng=np.random.default_rng())
ValueError: MLPOp returned the wrong number of gradient terms.

You can write Neural Networks directly in PyMC’s Pytensor backend. That way you don’t have to go through the troubles of linking to an external library

1 Like

Thank you for your reply! I will do it.