Simple generative model, but divergent

Hi all,
A little way in to my pymc3 journey but having trouble with, what seems to be a simple model. Thinking generatively, I can construct test data which replicates the problem (i.e., it’s me not the data!). What’s strange is the prior predictive checks look good, but I have lots of divergences (even after reducing step size, and giving good testval hints). I’ve already centered the problem data. I guess I am in the funnel of doom?

Appreciate any help :slightly_smiling_face:

Generate test data:

    n_samples = 1000
    # Goals of the inference are W, x, y
    # The observed data is A1 to A4
    W = np.random.normal(4, 0.05, n_samples)
    # A random splitting
    x = np.random.normal(0.5, 0.01, n_samples)
    # Derived quantities - Level 1
    B1 = W * (1-x)
    B2 = W * x
    # Two further random splittings
    y = np.random.normal(0.5, 0.01, (n_samples, 2))
    # Now generate what will be the observations
    A1 = B1 * (1-y[:,0])
    A2 = B1 * y[:,0]
    A3 = B2 * (1-y[:,1])
    A4 = B2 * y[:,1]
    # and collect into a dataframe
    df_test = pd.DataFrame({
        'A1': A1,
        'A2': A2,
        'A3': A3,
        'A4': A4})

The model:

    with pm.Model() as test_model:
        W = pm.Normal("W", 4.0, 1.0, testval=4.0)
        x = pm.Normal("x", 0.05, 0.005, testval=0.05)
        σ_B = pm.HalfNormal("σ_B",0.01, shape=2, testval=0.05)
        B1 = pm.Normal("B1", W*x, σ_B[0], testval=2.0)
        B2 = pm.Normal("B2", W*(1-x), σ_B[1], testval=2.0)
        y = pm.Normal("y", 0.5, 0.01, shape=2, testval=0.5)
        σ_a = pm.HalfNormal("σ_a", 0.05, shape=4, testval = 0.05)
        A1 = pm.Normal("A1", B1*y[0], σ_a[0], observed=df_test["A1"].values)
        A2 = pm.Normal("A2", B1*(1-y[0]), σ_a[1], observed=df_test["A2"].values)    
        A3 = pm.Normal("A3", B2*y[1], σ_a[2], observed=df_test["A3"].values)
        A4 = pm.Normal("A4", B2*(1-y[1]), σ_a[3], observed=df_test["A4"].values)


    with test_model:
        # Only use one core to avoid Windows pipe error which will crash the kernel
        test_trace = pm.sample(1000, tune=1000, cores=1, target_accept=.95)