Implementing hierarchical priors encoded in tree structures

The symbolic structure for doing recursive computation is pytensor.scan. You can read about its usages here. You just have to change how you’re thinking about the problem. First, you will need to store the whole population of nodes in a single flat vector. At each step, you have to compute the parent locations, mutate the parents, then store the children. Here’s a binary tree to illustrate what I mean:

import numpy as np
n_levels = 7

# This will hold the whole population of nodes
binary_tree = np.zeros(2 ** n_levels - 1)

for level in range(1, n_levels):
    parent_start = 2 ** (level - 1) - 1
    parent_end = 2 ** level - 1
    parent_idx = np.arange(parent_start, parent_end)
    
    child_start = 2 ** level -1 
    child_stop = 2 ** (level + 1) - 1
    child_idx = np.arange(child_start, child_stop)
    
    # Here the mutation is just to add one to the left node and subtract one from the right node
    binary_tree[child_idx] = binary_tree[parent_idx].repeat(2) + np.tile(np.array([1, -1]), len(parent_idx))

Visualize the results:

import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt

geneology = np.r_[-1, np.arange(2 ** (n_levels-1) - 1).repeat(2)]
edges = pd.DataFrame(np.c_[geneology, np.arange(2 ** n_levels - 1)], columns=['parent', 'child'])
edges['weight'] = binary_tree

G = nx.from_pandas_edgelist(edges.iloc[1:], 'parent', 'child', edge_attr='weight', create_using=nx.DiGraph)
pos = nx.drawing.nx_pydot.graphviz_layout(G, prog='dot', root=0)

fig, ax = plt.subplots(figsize=(14,4))
nx.draw_networkx(G, pos, node_color=edges.weight, labels=edges.weight, ax=ax, )

Looks right to me, so we can go ahead and implement this in pytensor using a scan and random variables:

import pytensor
import pytensor.tensor as pt
import pymc as pm

def grow_tree(level, tree, sigma):
    parent_start = 2 ** (level - 1) - 1
    parent_end = 2 ** level - 1
    parent_idx = pt.arange(parent_start, parent_end)
    
    child_start = 2 ** level -1 
    child_stop = 2 ** (level + 1) - 1
    child_idx = pt.arange(child_start, child_stop)
    
    # Shape of children is (2, n_parents) because each parent has 2 offspring
    children = pm.Normal.dist(mu=tree[parent_idx], 
                                sigma=sigma, 
                                size=(2, parent_idx.shape[0]))
    
    # Need to transpose children because ravel operates row-major. Children is currently shape
    # (child, parent), which will ravel to [child 1 of parent 1, child1 of parent 2, child 1 of parent 3, 
    # ...], but we want [child 1 of parent 1, child 2 of parent 1, child 1 of parent 2, ...]
    tree = pt.set_subtensor(tree[child_idx], children.T.ravel())
    
    return tree

n_levels = 7
init_tree = pt.zeros(2 ** n_levels - 1)
init_tree = pt.set_subtensor(init_tree[0], pm.Normal.dist(0, 1))

res, _ = pytensor.scan(grow_tree,
                       sequences=[pt.arange(1, n_levels)],
                       outputs_info=[init_tree],
                       non_sequences=[1]) 

final_tree = res[-1]
lowest_level = final_tree[2**(n_levels - 1) - 1:2 ** n_levels - 1]

You can do lowest_level.eval() to check what you get out. If you want to use this inside a pymc model, you will have some additional challenges, because you are missing some boilerplate related to scan and CustomDistributions, see here for guidance. More importantly, I don’t think PyMC can infer the logp of models with slicing – @ricardoV94 ? Luckily your model is just a really strange random walk with drift, so it should be possible to work out the logp yourself and implement it.