Sure 
The code in present form works with adjacency lists instead of adjacency matrix and assumes the values are already topologically sorted.
So “cleaning up” the code would introduce running a toposort + topological reduction on the adjacency matrix and then converting from adjacency matrix to adjacency list form used presently. Probably also a class method to generate initial values as they already need to conform to the order just like for regular ordered transform, and this requires they be in same order as toposort.
from pymc.logprob.transforms import Transform
def get_minval(dtype):
return np.iinfo(dtype).min if np.issubdtype(dtype, np.integer) else np.finfo(dtype).min
class DagOrdered(Transform):
name = "dag_ordered"
def __init__(self, dag):
"""Create a DagOrdered object
Assumes values are in topological sort order
Parameters
----------
dag: adjacency lists for the dag as an n x k array with -1 standing in for None
"""
self.dag = dag
self.is_start = np.all(self.dag[...,:,:]==-1,axis=-1)
def backward(self, value, *inputs):
x = pt.zeros(self.dag.shape[:-2] + (self.dag.shape[-2]+1,))
x = pt.set_subtensor(x[...,-1],get_minval(value.dtype)) # Functional infinity, as real inf creates nan when multiplied with 0
# Indices to allow broadcasting the max over the last dimension
idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]])
idx = tuple( np.tile(i[:,None],self.dag.shape[-1]) for i in idx )
# Has to be done stepwise as next steps depend on previous values
for i in range(self.dag.shape[-2]):
ist = self.is_start[...,i]
mval = pt.max(x[(Ellipsis,) +idx + (self.dag[...,i,:],)],axis=-1)
x = pt.set_subtensor(x[...,i], ist*value[...,i] +
(1-ist)*(mval + pt.exp(value[...,i])))
return x[...,:-1]
def forward(self, value, *inputs):
y = pt.zeros(value.shape)
vx = pt.zeros( self.dag.shape[:-2] + (self.dag.shape[-2]+1,),dtype=value.dtype)
vx = pt.set_subtensor(vx[...,:-1],value)
vx = pt.set_subtensor(vx[...,-1], get_minval(value.dtype))
# Indices to allow broadcasting the max over the last dimension
idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]])
idx = tuple( np.tile(i[:,None,None],self.dag.shape[-2:]) for i in idx )
y = self.is_start*value + (1-self.is_start)*(pt.log(value -
pt.max(vx[(Ellipsis,) + idx + (self.dag[...,:],)],axis=-1)))
return y
def log_jac_det(self, value, *inputs):
return pt.sum(value*(1-self.is_start), axis=-1)
And as a current usage example:
# Basic test of DagOrdered transform
import numpy as np
dag = np.array([[[-1,-1,-1],[0,-1,-1],[0,-1,-1],[1,2,-1]],
[[-1,-1,-1],[0,-1,-1],[1,-1,-1],[2,-1,-1]]])
dord = DagOrdered(dag)
import pymc as pm
with pm.Model() as model:
x = pm.Normal('x',size=(2,4),transform=dord,initval=np.array([[0,0.1,0.2,0.3],[0,0.1,0.2,0.3]]))
idata = pm.sample()
idata.posterior.x.mean(['chain','draw'])