NotImplementedError: No JAX conversion for the given `Op`: AllocDiag{offset=0, axis1=0, axis2=1}

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc3 as pm
import pymc3.sampling_jax
import theano
import seaborn as sns

sns.set_style(
style=‘darkgrid’,
rc={‘axes.facecolor’: ‘.9’, ‘grid.color’: ‘.8’}
)
sns.set_palette(palette=‘deep’)
sns_c = sns.color_palette(palette=‘deep’)

plt.rcParams[‘figure.figsize’] = [12, 6]
plt.rcParams[‘figure.dpi’] = 100

print(f"Running on PyMC3 v{pm.version}")

dri7 = pd.read_csv(’./dri7_train_data.csv’)

def generate_data(n):
“”"Generate sample data.

"""
# Define "time" variable.
t = dri7['time']
data_df = pd.DataFrame({'t' : t})
# Add components:
data_df['v_error'] = dri7['v_error']
return data_df.eval('y = v_error')

n=len(dri7)
data_df = generate_data(n=n)

Plot.

fig, ax = plt.subplots()
sns.lineplot(x=‘t’, y=‘y’, data=data_df, color=sns_c[0], label=‘y’, ax=ax)
ax.legend(loc=‘center left’, bbox_to_anchor=(1, 0.5))
ax.set(title=‘Sample Data’, xlabel=‘t’, ylabel=’’);

x = data_df[‘t’].values.reshape(n, 1)
y = data_df[‘y’].values.reshape(n, 1)

prop_train = 0.7
n_train = round(prop_train * n)

x_train = x[:n_train]
y_train = y[:n_train]

x_test = x[n_train:]
y_test = y[n_train:]

Plot.

fig, ax = plt.subplots()
sns.lineplot(x=x_train.flatten(), y=y_train.flatten(), color=sns_c[0], label=‘y_train’, ax=ax)
sns.lineplot(x=x_test.flatten(), y=y_test.flatten(), color=sns_c[1], label=‘y_test’, ax=ax)
ax.axvline(x=x_train.flatten()[-1], color=sns_c[7], linestyle=’–’, label=‘train-test-split’)
ax.legend(loc=‘center left’, bbox_to_anchor=(1, 0.5))
ax.set(title=‘y train-test split ‘, xlabel=‘t’, ylabel=’’);

with pm.Model() as model:
l_s = pm.Gamma(name=‘l_s’, alpha=2.0, beta=1.0)
gp = pm.gp.Marginal(cov_func=pm.gp.cov.ExpQuad(1, ls=l_s))

#Noise
sigma = pm.HalfNormal(name='sigma', sigma= 3)

#likelihood
y_pred = gp.marginal_likelihood('y_pred', X=x_train, y=y_train.flatten(), noise=sigma)

#sample
#trace = pm.sample(draws=2000, chains=2, tune=500, cores = 8)
trace = pm.sampling_jax.sample_numpyro_nuts(2000, tune=2000, chains=2)

NotImplementedError Traceback (most recent call last)
in
11 #sample
12 #trace = pm.sample(draws=2000, chains=2, tune=500, cores = 8)
—> 13 trace = pm.sampling_jax.sample_numpyro_nuts(2000, tune=2000, chains=2)

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/pymc3/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, model, progress_bar, keep_untransformed)
135
136 fgraph = theano.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
→ 137 fns = jax_funcify(fgraph)
138 logp_fn_jax = fns[0]
139

~/miniconda3/envs/pymc3/lib/python3.7/functools.py in wrapper(*args, **kw)
838 ‘1 positional argument’)
839
→ 840 return dispatch(args[0].class)(*args, **kw)
841
842 funcname = getattr(func, ‘name’, ‘singledispatch function’)

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in jax_funcify_FunctionGraph(fgraph)
674
675 out_nodes = [r.owner for r in fgraph.outputs if r.owner is not None]
→ 676 jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
677
678 return jax_funcs

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in (.0)
674
675 out_nodes = [r.owner for r in fgraph.outputs if r.owner is not None]
→ 676 jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
677
678 return jax_funcs

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
153 # This input is the output of another node, so we need to
154 # generate a JAX-able function for its subgraph
→ 155 input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
156
157 if i.owner.nout > 1:

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
153 # This input is the output of another node, so we need to
154 # generate a JAX-able function for its subgraph
→ 155 input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
156
157 if i.owner.nout > 1:

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
153 # This input is the output of another node, so we need to
154 # generate a JAX-able function for its subgraph
→ 155 input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
156
157 if i.owner.nout > 1:

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
153 # This input is the output of another node, so we need to
154 # generate a JAX-able function for its subgraph
→ 155 input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
156
157 if i.owner.nout > 1:

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
153 # This input is the output of another node, so we need to
154 # generate a JAX-able function for its subgraph
→ 155 input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
156
157 if i.owner.nout > 1:

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
153 # This input is the output of another node, so we need to
154 # generate a JAX-able function for its subgraph
→ 155 input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
156
157 if i.owner.nout > 1:

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
153 # This input is the output of another node, so we need to
154 # generate a JAX-able function for its subgraph
→ 155 input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
156
157 if i.owner.nout > 1:

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
153 # This input is the output of another node, so we need to
154 # generate a JAX-able function for its subgraph
→ 155 input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
156
157 if i.owner.nout > 1:

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
153 # This input is the output of another node, so we need to
154 # generate a JAX-able function for its subgraph
→ 155 input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
156
157 if i.owner.nout > 1:

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
153 # This input is the output of another node, so we need to
154 # generate a JAX-able function for its subgraph
→ 155 input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
156
157 if i.owner.nout > 1:

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
153 # This input is the output of another node, so we need to
154 # generate a JAX-able function for its subgraph
→ 155 input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
156
157 if i.owner.nout > 1:

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
153 # This input is the output of another node, so we need to
154 # generate a JAX-able function for its subgraph
→ 155 input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
156
157 if i.owner.nout > 1:

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
119 return memo[out_node]
120
→ 121 jax_return_func = jax_funcify(out_node.op)
122
123 # We create a list of JAX-able functions that produce the values of each

~/miniconda3/envs/pymc3/lib/python3.7/functools.py in wrapper(*args, **kw)
838 ‘1 positional argument’)
839
→ 840 return dispatch(args[0].class)(*args, **kw)
841
842 funcname = getattr(func, ‘name’, ‘singledispatch function’)

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in jax_funcify(op)
195 def jax_funcify(op):
196 “”“Create a JAX “perform” function for a Theano Variable and its Op.”""
→ 197 raise NotImplementedError(f"No JAX conversion for the given Op: {op}")
198
199

NotImplementedError: No JAX conversion for the given Op: AllocDiag{offset=0, axis1=0, axis2=1}

Versions and main components
Theano version: v1.1.2
Python version: 3.7.10
pymc3 version: 3.11.2
Operating system: Ubuntu 20.04

Can anyone please help with my error faced?
Thank you very much.

I think this is added in the most recent aesara version, so you can try updating to PyMC3 v4 (the github main branch).

Hi @twiecki ,

when I try update with the main branch, nothing happen after the commands as shown in the figure.

Besides that, I tried the example from Pymc3 example.
https://docs.pymc.io/pymc-examples/examples/samplers/GLM-hierarchical-jax.html

However, I get the following error when I run sampling with Jax.
TypeError: Argument ‘None’ of type ‘<class ‘NoneType’>’ is not a valid JAX type

Can you please try whether you able to run the example from your side?

Thank you very much.

Best Regards,
GMCobraz

Hi @twiecki ,

Good day.
I manage to solve both issue.
I wait for a while and updated to Pymc3 v4, although it is still unstable now.

regarding the problem I mentioned in example, actually is my settings issue. If we install numpyro v0.6.0 and jax v0.2.10, the exmaple will run with no problem.

I will still wait for official release of Pymc3 v4.
Thank you very much.

1 Like