My simple polynomial fit is too slow (<20 draws/s) . What can I do?

Hey,
I have a NxK data matrix Y. I have the following model I want to fit:
Y_{i,j} = \sum_{k=1}^{P} a_{i, k} \cdot j^{k}. That means that the polynomial grows/decreases with increasing column number. Basically I fit a polynomial at each row. Here I have the data generation code:
My general idea was to create a 3d array in which I would multiply the coefficients along the first and third dimension and then contract.

import numpy as np
from theano import tensor as T
import pymc3 as pm
import multiprocessing
np.random.seed(1234)

n_poly = 2
n_row = 17
n_col = 11

temp1 = np.linspace(1, n_poly, n_poly)
temp2 = np.repeat(temp1[:, np.newaxis], n_row, axis=1)
temp3 = np.repeat(temp2[:, :, np.newaxis],n_col, axis=2)
E = temp3.transpose(1, 2, 0)

temp1 = np.linspace(0, n_col -1, n_col)
temp2 = np.repeat(temp1[:, np.newaxis], n_row, axis=1).T
a = np.repeat(temp2[:, :, np.newaxis],n_poly, axis=2)

Z = a**E

coefs = np.random.normal(0, 0.3, size=(n_row, n_poly))
sigma = np.random.normal(0, 0.5,
                size=(n_row, n_col))

poly = np.einsum("ijk, ik -> ij", Z, coefs) + sigma

My pymc3 model looks like this:

poly_model = pm.Model()
with poly_model:
    poly_coef = pm.Normal("coefficients", 0, 0.3, shape = (n_row, n_poly))
    mean_mat = T.batched_tensordot(Z, poly_coef, axes=[[2],[1]])
    sig = pm.HalfNormal("sigma", 1, shape = (n_row, n_col))
    
    y_obs = pm.Normal("y_obs", mu = mean_mat, sd = sig, 
                      shape = (n_row, n_col), observed = poly)
    
    n_cores = multiprocessing.cpu_count()
    n_chains = n_cores
    n_samp = 100*(n_row*n_poly)
    trace = pm.sample(n_samp, tune=5000, chains=n_chains, cores = n_cores)

It’s horrible slow. I guess my contraction idea is not very smart. I therefore wanted to ask if there is a better way in fitting this model.

Cheers,
Mister- Knister

Ps: In case you are note convinced that theanos batched_tensordot and np.einsum deliver the same result:

n_row = 17
n_col = 31
n_poly = 10
a = np.random.normal(0, 1, (n_row, n_col, n_poly))
b = np.random.normal(10, 1, (n_row, n_poly))
m1 = np.linspace(1, n_poly, n_poly)
m2 = np.repeat(m1[:, np.newaxis], n_row, axis=1)
m3 = np.repeat(m2[:, :, np.newaxis],n_col, axis=2)
E = m3.transpose(1, 2, 0)

Z = a**E
R = np.einsum("ijk, ik -> ij", Z, b)

q = th.shared(Z)
w = th.shared(b)

ta = T.tensor3('a')
tb = T.matrix('b')
tc = T.batched_tensordot(ta, tb, axes=[[2],[1]])

f = T.function([ta, tb], tc)


np.allclose(f(Z, b), R)