Hi all,
I am trying NN with Flax, referring to the following article.
The calculation works, but the parameter values do not change at all.
If anyone knows the cause, could you please tell me the solution?
I would be thankful if you could help me with this question.
import matplotlib.pyplot as plt
import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.graph import Apply, Op
from pytensor.link.jax.dispatch import jax_funcify
import jax
import jax.numpy as jnp
import pymc as pm
import pymc.sampling.jax
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.core import freeze
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = np.float32(train_ds['image']) / 255.
test_ds['image'] = np.float32(test_ds['image']) / 255.
return train_ds, test_ds
train, test = get_datasets()
train["image"] = train["image"][:1_000]
train["label"] = train["label"][:1_000]
train["image"].shape
plt.imshow(train["image"][0])
plt.title(train["label"][0]);
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
# Convolution layer
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
# Convolution layer
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
# Dense layer
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
# Output layer
x = nn.Dense(features=10)(x)
return x
cnn = CNN()
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))
n_params = 0
for layer in params["params"].values():
n_params += layer["kernel"].size + layer.get("bias", np.array(())).size
n_params
cnn.apply(params, train["image"][0:1])
treedef = jax.tree_util.tree_structure(params)
def cnn_op_jax(flat_params, images):
unflat_params = jax.tree_util.tree_unflatten(treedef, flat_params)
return cnn.apply(unflat_params, images)
jitted_cnn_op_jax = jax.jit(cnn_op_jax)
class CNNOp(Op):
def make_node(self, *inputs):
# Convert our inputs to symbolic variables
inputs = [pt.as_tensor_variable(inp) for inp in inputs]
# Assume the output to always be a float64 matrix
outputs = [pt.matrix(dtype="float64")]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
*flat_params, images = inputs
result = jitted_cnn_op_jax(flat_params, images)
outputs[0][0] = np.asarray(result, dtype="float64")
def grad(self, inputs, output_gradients):
raise NotImplementedError("PyTensor gradient of CNNOp not implemented")
@jax_funcify.register(CNNOp)
def cnn_op_jax_funcify(op, **kwargs):
def perform(*inputs):
*flat_params, images = inputs
return cnn_op_jax(flat_params, images)
return perform
cnn_op = CNNOp()
with pm.Model() as model:
images = pm.Data("images", train["image"], mutable=True)
#print(images.shape.eval()[0])
weights_prior = []
for layer_name, layer in params["params"].items():
for layer_weights_name, layer_weights in sorted(layer.items()):
prior_name = f"{layer_name}_{layer_weights_name}"
layer_weights_prior = pm.Normal(prior_name, 0, 1, shape=layer_weights.shape)
weights_prior.append(layer_weights_prior)
logitp_classes = cnn_op(*weights_prior, images)
logitp_classes = pt.specify_shape(logitp_classes, (images.shape.eval()[0], 10))
label = pm.Categorical("label", logit_p=logitp_classes, observed=train["label"])
pm.model_to_graphviz(model)
with model:
idata = pm.sampling.jax.sample_numpyro_nuts(draws=100, tune=100, chains=1) # pm.sample(500, tune=500, chains=3, cores=2, nuts_sampler="numpyro")
print(idata.posterior.Dense_0_bias)
array([[[-0.7735232 , -0.12957669, -0.61163926, ..., 0.8304822 ,
-0.25725549, -0.42333245],
[-0.7735232 , -0.12957669, -0.61163926, ..., 0.8304822 ,
-0.25725549, -0.42333245],
[-0.7735232 , -0.12957669, -0.61163926, ..., 0.8304822 ,
-0.25725549, -0.42333245],
...,
[-0.7735232 , -0.12957669, -0.61163926, ..., 0.8304822 ,
-0.25725549, -0.42333245],
[-0.7735232 , -0.12957669, -0.61163926, ..., 0.8304822 ,
-0.25725549, -0.42333245],
[-0.7735232 , -0.12957669, -0.61163926, ..., 0.8304822 ,
-0.25725549, -0.42333245]]])