Restrict sampler to only draw from the prior for a particular variable

Hi, I’m looking at a difference-in-difference problem with a linear regression model:

Y \sim N(\beta_0 + \beta_1 a + \beta_2 t + \beta_3 a t, \sigma)

where t is a binary variable such that t=0 indicates the pre-treatment measuring time, and t=1 indicates post-treatment.

And a is a binary variable where a=0 indicates no treatment/assigned to the no treatment group, and a=1 indicates the example comes from the treatment group.

I want to look at what happens when the parallel trends assumption is violated. I’m modeling this as

E^{A=1}[Y_1 − Y_0|X, do(A = 0)] = E[Y_1 − Y_0|X, A = 0] + \xi

so \xi is our departure from parallel trends. If \xi \sim \delta(0), then parallel trends hold. In my model \xi \sim N(0, \sigma_{\xi}).

Then \beta_3 = \psi + \xi where \psi is the ATT.

Now the model is underdetermined, because \xi is not identifiable (we don’t observe any samples assigned to the treatment group but not treated).

I would like to sample the posterior of \psi, with the restriction that samples of \xi always come from the prior. So the distribution of \xi is not updated (because the data can’t inform this parameter). How can I do this with pymc?

I found several related discussions including: Prevent prior from updating? and Two-stage Bayesian regression enforcing a fixed distribution (not Just Hierarchical regression), but the first no longer works with v5 (as stated in the comments - it didn’t work with aesara either). And using the solution from the second gives me:

ValueError                                Traceback (most recent call last)
<ipython-input-17-1fc8630abef7> in <module>
      3 
      4 for xi_variance in xi_variances:
----> 5     results.append(sample_from_model(df, xi_variance))

<ipython-input-16-f21cb08fa7c0> in sample_from_model(df, xi_variance)
     82     with model:
     83         fixed_step = FixedDistSample([_xi], {'loc': 0, 'scale': xi_variance})
---> 84         return pm.sample(step=fixed_step)

~/.local/lib/python3.8/site-packages/pymc/sampling/mcmc.py in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
    651 
    652     initial_points = None
--> 653     step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
    654 
    655     if nuts_sampler != "pymc":

~/.local/lib/python3.8/site-packages/pymc/sampling/mcmc.py in assign_step_methods(model, step, methods, step_kwargs)
    209     methods_list: List[Type[BlockedStep]] = list(methods or pm.STEP_METHODS)
    210     selected_steps: Dict[Type[BlockedStep], List] = {}
--> 211     model_logp = model.logp()
    212 
    213     for var in model.value_vars:

~/.local/lib/python3.8/site-packages/pymc/model.py in logp(self, vars, jacobian, sum)
    716         rv_logps: List[TensorVariable] = []
    717         if rvs:
--> 718             rv_logps = transformed_conditional_logp(
    719                 rvs=rvs,
    720                 rvs_to_values=self.rvs_to_values,

~/.local/lib/python3.8/site-packages/pymc/logprob/basic.py in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
    629     rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logp_terms_list)
    630     if rvs_in_logp_expressions:
--> 631         raise ValueError(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions)
    632 
    633     return logp_terms_list

ValueError: Random variables detected in the logp graph: {normal_rv{0, (0, 0), floatX, False}.out}.
This can happen when DensityDist logp or Interval transform functions reference nonlocal variables,
or when not all rvs have a corresponding value variable.

Here is the relevant code:

def sample_from_model(df, xi_variance):
   	 
	def outcome(t, control_intercept, treat_intercept_delta, trend, att, group, treated, xi):
    	return control_intercept + (treat_intercept_delta * group) + (t * trend) + ((att + xi) * treated)
    
	with pm.Model() as model:
    	# data
    	t = pm.MutableData("t", df["period"].values, dims="obs_idx")
    	treated = pm.MutableData("treated", df["treated"].values, dims="obs_idx")
    	group = pm.MutableData("locale", df["locale"].values, dims="obs_idx")
    	# priors
    	_control_intercept = pm.HalfNormal("control_intercept", 5)
    	_treat_intercept_delta = pm.Normal("treat_intercept_delta", 0, 1)
    	_trend = pm.Normal("trend", 0, 1)
    	_att = pm.Normal("att", 0, 1)
    	_xi = pm.Deterministic('xi', srng.normal(0, np.sqrt(xi_variance)))

    	sigma = pm.HalfNormal("sigma", 1)
   	 
   	 
    	# expectation
    	mu = pm.Deterministic(
        	"mu",
        	outcome(t, _control_intercept, _treat_intercept_delta, _trend, _att, group, treated, _xi),
        	dims="obs_idx",
    	)
    	# likelihood
    	pm.Normal("obs", mu, sigma, observed=df["sales"].values, dims="obs_idx")


	class NormalProposal:
    	def __init__(self, loc, scale):
        	self.loc = loc
        	self.scale = scale

    	def __call__(self, rng=None, size=()):
        	if rng is None:
            	rng = np.random
        	return rng.normal(self.loc, scale=self.scale, size=size)


	class FixedDistSample(ArrayStepShared):
    	"""Return sample from a fixed proposal distribution.
    	"""

    	name = "fixed_dist_sample"

    	generates_stats = False

    	def __init__(self, vars, proposal_kwarg_dict, model=None):
        	model = pm.modelcontext(model)
        	initial_values = model.initial_point()

        	vars = [model.rvs_to_values.get(var, var) for var in vars]
        	vars = pm.inputvars(vars)
        	initial_values_shape = [initial_values[v.name].shape for v in vars]
        	self.ndim = int(sum(np.prod(ivs) for ivs in initial_values_shape))
        	self.proposal_dist = NormalProposal(**proposal_kwarg_dict)

        	shared = pm.make_shared_replacements(initial_values, vars, model)
        	super().__init__(vars, shared)

    	def astep(self, q0: RaveledVars) -> RaveledVars:
        	point_Amap_info = q0.point_map_info
        	q0 = q0.data
        	q = self.proposal_dist(size=self.ndim)

        	return RaveledVars(q, point_map_info)

	with model:
    	fixed_step = FixedDistSample([_xi], {'loc': 0, 'scale': xi_variance})
    	return pm.sample(step=fixed_step)

I’m new to pymc, so I’m not super familiar with its internal workings. How can I get it to do what I want?

Replace _xi by a pm.Normal. It still needs to be a proper model variable for the rest to work.

Small note: you seem to be mixing variance with standard deviation in your custom step sampler

Oh no, I had two notebooks open last night and was looking at the wrong one when making some of those edits.

For anyone else who finds this, I did have to make a change to the step sampler. The astep function should be implemented as follows:

    	def astep(self, q0: RaveledVars) -> RaveledVars:
        	point_Amap_info = q0.point_map_info
        	q0 = q0.data
        	q = self.proposal_dist(size=self.ndim)

        	return RaveledVars(q, point_map_info), []

So it should return a tuple of the RaveledVars and an empty list. The list is expected by the parent class ArrayStep and should contain a dictionary of “stats” related to the step (for example related to tuning). Here, we don’t have anything that we want to keep track of, so it’s left empty