Implementing Crank-Nicolson Method to Solve 1D Heat Diffusion

Hi all,

for my current project, I am trying to simulate the 1D heat diffusion equation numerically. Previously I have done this using the Crank-Nicholson method in python via (minimal code example):

import numpy as np

d_factor = Diffusion * dt / (2*ds * ds) 

A_n = np.diagflat([-d_factor for b in range(steps-1)], -1) +\
      np.diagflat([1.+d_factor]+[1.+2.*d_factor for b in range(steps-2)]+[1+d_factor]) +\
      np.diagflat([-d_factor for b in range(steps-1)], 1)

B_n = np.diagflat([d_factor for b in range(steps-1)], -1) +\
      np.diagflat([1.-d_factor]+[1.-2.*d_factor for b in range(steps-2)]+[1.-d_factor]) +\
      np.diagflat([d_factor for b in range(steps-1)], 1)

N = N0

for k in range(1,len(t)-1):

        N_new = np.linalg.solve(A_n,B_n.dot(N))
        
         ###  Boundary Conditions
        N_new[0] = N0
        N_new[-1] = Nd

        N = N_new

where dt and ds are the step size in time and space, Diffusion is a variable and N0 is the initial state.

Is there a smart way that I can ‘translate’ this code into PyTensor to be able to use it in pm.model ? The issue is that Diffusion is a random variable and so A_n and B_n will need to be calcuated for each sampling step.
Any help would be highly appreciated.

Hi,

You want to use pytensor.scan, which is a differentiable loop. Docs here, some discussion and example using a scan together with a CustomDist (which you will want to do) here.

I don’t think this example needs a CustomDist, it’s just a deterministic Scan no?

Depends on the error structure of the model, you’re right. Can just start with a scan.

Hi,

could you provide a short code example of what you mean by that?

So I implemented this in PyTensor using scan like this

import os
os.environ['OMP_NUM_THREADS'] = '1'
import pymc as pm
import pytensor as at

def X_n_maker(d_factor, size):
    x_size = at.zeros(shape=(at.shape(size),at.shape(size)))
    Xn_1 = at.extra_ops.fill_diagonal_offset(x_size, d_factor, -1)
    
    Xn_2a = at.extra_ops.fill_diagonal_offset(x_size, 1-2.*d_factor, 0)
    Xn_2a1 = at.set_subtensor(Xn_2a[0,0],1-d_factor)
    Xn_2 = at.set_subtensor(Xn_2a1[-1,-1],1-d_factor)
    
    Xn_3 = at.extra_ops.fill_diagonal_offset(x_size, d_factor, 1)

    return Xn_1 + Xn_2 + Xn_3


def function_inner_loop(dt, x, ds, param):
        # Function 1
        d_factor = param[0]
        A_n = X_n_maker(-d_factor, x)
        B_n = X_n_maker(d_factor, x)

        x_1 = at.linalg.solve(A_n, at.dot(B_n, x), assume_a='sym')
        x_2 = at.set_subtensor(x_1[0], x_1[1] / (param[1]*ds + 1))
        x_3 = at.set_subtensor(x_2[-1], x_2[-2] / (param[2]*ds -1))       
        
        # Function 2
        x_next = function2(x_3, dt, params)
        
        return x_next


def function_outer_loop(dt, x, ds, param):
    
        result_inner, _ = pytensor.scan(fn=function_inner_loop,
                                                sequences=[dt],
                                                outputs_info=[x_0],
                                                non_sequences=[ds, param])    

        x_0 = x_0.dimshuffle('x',0)
        x_calc = at.concatenate([x_0, result_inner], axis=0)

        return x_calc 


### Part of main code
result, _ = pytensor.scan(fn=function_outer_loop, sequences=[samples], outputs_info=[at.zeros(shape=(at.shape(time), at.shape(x_0)], non_sequences=[ds, param])    


x_result = at.concatenate([x_0, result], axis=0)

Here, x is an array with length 200 and params are pymc random variables and sample is an array with length 4, so that I can loop over all four samples.

This implementation runs fine, but is quite slow: When I run the same code, but skip everything in # Function 1 , a chain of 100 samples takes about 1 sec., whereas, if I include Function 1, it took about 10 min. and all the CPUs are maxed out.
This reminded me of a post from @aseyboldt from a while ago.
I have therefore started using os.environ['OMP_NUM_THREADS'] = '1' as suggested here.

This brings it down to 2 min. run time for 100 samples, but still the single CPU that is used is maxed out (if I used ‘5’, then all 5 are maxed out). I feel like there may be something else wrong with this code or the way that pytensor parallelizes at.linalg.solve and at.dot.

Any help is highly appreciated!

What is function2? And how are you sampling?

Function 2 is solving the differential equation \frac{d x}{dt}=param[3]^2\cdot x.

At the moment I am using pm.Metropolis to sample.