BNN with missing data imputation

I’m trying to classify a dataset where I have all the possible problems at once: my input data are uncertain, some are missing not-at-random and some are censored (with potentially different censoring thresholds for different points), and I also have relatively few labelled examples. The dataset is not enormous, but has several million rows, with of the order of 100 features (each with an uncertainty associated with it). As a result, I figured this might be a good excuse to play around with a BNN to classify them.

It seemed foolhardy to try and put all of these problems in in the first case, so I have started off trying to treat the uncertainty and imputation parts so far, with a subset of 13 features. However, the model doesn’t train - I get advi.hist = [inf inf inf inf ... ] and Average loss = nan.

I imagine I am probably doing something particularly stupid in my model definition - see below. I have standardised all the inputs and uncertainties, and am currently using only the labelled data. Any suggestions would be most welcome, and if any extra info is required I’ll try to update with that asap. Thanks!

def construct_nn_with_uncertainty(ann_input, ann_output, input_uncertainties):
    n_samples, n_features = ann_input.shape
    n_outputs = len(np.unique(ann_output))

    n_hidden = 5

    # Initialize random weights between each layer
    init_1 = rng.standard_normal(size=(n_features, n_hidden)).astype(floatX)
    init_2 = rng.standard_normal(size=(n_hidden, n_hidden)).astype(floatX)
    init_out = rng.standard_normal(size=n_hidden).astype(floatX)

    coords = {
        "hidden_layer_1": np.arange(n_hidden),
        "hidden_layer_2": np.arange(n_hidden),
        "train_cols": np.arange(n_features),
        "obs_id": np.arange(n_samples),
    }
    
    with pm.Model(coords=coords) as neural_network:
        # Latent true input variables - nan values will be imputed
        true_sigma = pm.Exponential('true_sigma', lam=1, 
                                    dims=("obs_id", "train_cols"),
                                    observed=input_uncertainties)
        true_inputs = pm.Normal('true_inputs', mu=0, sigma=true_sigma, 
                                dims=("obs_id", "train_cols"),
                                observed=ann_input)
        
        # Define priors for weights and biases
        weights_in_1 = pm.Normal('w_in_1', mu=0, sigma=1, initval=init_1, dims=("train_cols", "hidden_layer_1"))

        # Weights from 1st to 2nd layer
        weights_1_2 = pm.Normal(
            "w_1_2", 0, sigma=1, initval=init_2, dims=("hidden_layer_1", "hidden_layer_2")
        )
        
        weights_2_out = pm.Normal('w_2_out', mu=0, sigma=1, initval=init_out, dims="hidden_layer_2") 

        # Build neural-network using tanh activation function
        act_1 = pm.math.tanh(pm.math.dot(true_inputs, weights_in_1))  #ann_input
        act_2 = pm.math.tanh(pm.math.dot(act_1, weights_1_2))
        act_out = pm.math.softmax(pm.math.dot(act_2, weights_2_out))

        pm.Categorical('out', act_out, observed=ann_output,
                       total_size=n_outputs,
                       dims="obs_id",
                       )

    return neural_network