So I implemented this in PyTensor using scan like this
import os
os.environ['OMP_NUM_THREADS'] = '1'
import pymc as pm
import pytensor as at
def X_n_maker(d_factor, size):
x_size = at.zeros(shape=(at.shape(size),at.shape(size)))
Xn_1 = at.extra_ops.fill_diagonal_offset(x_size, d_factor, -1)
Xn_2a = at.extra_ops.fill_diagonal_offset(x_size, 1-2.*d_factor, 0)
Xn_2a1 = at.set_subtensor(Xn_2a[0,0],1-d_factor)
Xn_2 = at.set_subtensor(Xn_2a1[-1,-1],1-d_factor)
Xn_3 = at.extra_ops.fill_diagonal_offset(x_size, d_factor, 1)
return Xn_1 + Xn_2 + Xn_3
def function_inner_loop(dt, x, ds, param):
# Function 1
d_factor = param[0]
A_n = X_n_maker(-d_factor, x)
B_n = X_n_maker(d_factor, x)
x_1 = at.linalg.solve(A_n, at.dot(B_n, x), assume_a='sym')
x_2 = at.set_subtensor(x_1[0], x_1[1] / (param[1]*ds + 1))
x_3 = at.set_subtensor(x_2[-1], x_2[-2] / (param[2]*ds -1))
# Function 2
x_next = function2(x_3, dt, params)
return x_next
def function_outer_loop(dt, x, ds, param):
result_inner, _ = pytensor.scan(fn=function_inner_loop,
sequences=[dt],
outputs_info=[x_0],
non_sequences=[ds, param])
x_0 = x_0.dimshuffle('x',0)
x_calc = at.concatenate([x_0, result_inner], axis=0)
return x_calc
### Part of main code
result, _ = pytensor.scan(fn=function_outer_loop, sequences=[samples], outputs_info=[at.zeros(shape=(at.shape(time), at.shape(x_0)], non_sequences=[ds, param])
x_result = at.concatenate([x_0, result], axis=0)
Here, x is an array with length 200 and params are pymc random variables and sample
is an array with length 4, so that I can loop over all four samples.
This implementation runs fine, but is quite slow: When I run the same code, but skip everything in # Function 1
, a chain of 100 samples takes about 1 sec., whereas, if I include Function 1, it took about 10 min. and all the CPUs are maxed out.
This reminded me of a post from @aseyboldt from a while ago.
I have therefore started using os.environ['OMP_NUM_THREADS'] = '1'
as suggested here.
This brings it down to 2 min. run time for 100 samples, but still the single CPU that is used is maxed out (if I used ‘5’, then all 5 are maxed out). I feel like there may be something else wrong with this code or the way that pytensor parallelizes at.linalg.solve
and at.dot
.
Any help is highly appreciated!