The symbolic structure for doing recursive computation is pytensor.scan
. You can read about its usages here. You just have to change how you’re thinking about the problem. First, you will need to store the whole population of nodes in a single flat vector. At each step, you have to compute the parent locations, mutate the parents, then store the children. Here’s a binary tree to illustrate what I mean:
import numpy as np
n_levels = 7
# This will hold the whole population of nodes
binary_tree = np.zeros(2 ** n_levels - 1)
for level in range(1, n_levels):
parent_start = 2 ** (level - 1) - 1
parent_end = 2 ** level - 1
parent_idx = np.arange(parent_start, parent_end)
child_start = 2 ** level -1
child_stop = 2 ** (level + 1) - 1
child_idx = np.arange(child_start, child_stop)
# Here the mutation is just to add one to the left node and subtract one from the right node
binary_tree[child_idx] = binary_tree[parent_idx].repeat(2) + np.tile(np.array([1, -1]), len(parent_idx))
Visualize the results:
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
geneology = np.r_[-1, np.arange(2 ** (n_levels-1) - 1).repeat(2)]
edges = pd.DataFrame(np.c_[geneology, np.arange(2 ** n_levels - 1)], columns=['parent', 'child'])
edges['weight'] = binary_tree
G = nx.from_pandas_edgelist(edges.iloc[1:], 'parent', 'child', edge_attr='weight', create_using=nx.DiGraph)
pos = nx.drawing.nx_pydot.graphviz_layout(G, prog='dot', root=0)
fig, ax = plt.subplots(figsize=(14,4))
nx.draw_networkx(G, pos, node_color=edges.weight, labels=edges.weight, ax=ax, )
Looks right to me, so we can go ahead and implement this in pytensor using a scan and random variables:
import pytensor
import pytensor.tensor as pt
import pymc as pm
def grow_tree(level, tree, sigma):
parent_start = 2 ** (level - 1) - 1
parent_end = 2 ** level - 1
parent_idx = pt.arange(parent_start, parent_end)
child_start = 2 ** level -1
child_stop = 2 ** (level + 1) - 1
child_idx = pt.arange(child_start, child_stop)
# Shape of children is (2, n_parents) because each parent has 2 offspring
children = pm.Normal.dist(mu=tree[parent_idx],
sigma=sigma,
size=(2, parent_idx.shape[0]))
# Need to transpose children because ravel operates row-major. Children is currently shape
# (child, parent), which will ravel to [child 1 of parent 1, child1 of parent 2, child 1 of parent 3,
# ...], but we want [child 1 of parent 1, child 2 of parent 1, child 1 of parent 2, ...]
tree = pt.set_subtensor(tree[child_idx], children.T.ravel())
return tree
n_levels = 7
init_tree = pt.zeros(2 ** n_levels - 1)
init_tree = pt.set_subtensor(init_tree[0], pm.Normal.dist(0, 1))
res, _ = pytensor.scan(grow_tree,
sequences=[pt.arange(1, n_levels)],
outputs_info=[init_tree],
non_sequences=[1])
final_tree = res[-1]
lowest_level = final_tree[2**(n_levels - 1) - 1:2 ** n_levels - 1]
You can do lowest_level.eval()
to check what you get out. If you want to use this inside a pymc model, you will have some additional challenges, because you are missing some boilerplate related to scan and CustomDistributions, see here for guidance. More importantly, I don’t think PyMC can infer the logp of models with slicing – @ricardoV94 ? Luckily your model is just a really strange random walk with drift, so it should be possible to work out the logp yourself and implement it.