Get predictions from a model using the MAP values instead of sampling the posterior

yes, I know. I should predict using the posterior, Why would we build a Bayesian model if we want a deterministic prediction??? :man_shrugging: :man_shrugging:. anyway here is the question:

I have a fairly complex structural model that I would like to get predictions from using MAP values, is there an easy way to do that?. One alternative is to replicate the model’s functional structure in a separate function. However, this would require updating/maintaining both places separately when doing model changes (DRY).

Additional context: why do I want to get MAP fixed - predictions?
I am creating a prediction model that will be utilized in an optimization model later. It is crucial that the predictions are convex in order to optimize them effectively. Fortunately, for most of the model, selecting the appropriate priors naturally produces the convexity condition. However, there are a few variables that may not exhibit this convexity for example when they are affected by other random variables with high variance.
I know that the MAP values comply with the convexity guideline. As the other sources of noise will have a fixed known variance when sampled from the MAP, my prediction will be convex and compliant.

reading the documentation of the sample_posterior_predictive method I found this

pymc.sample_posterior_predictive(trace, …)[source]

Generate posterior predictive samples from a model given a trace.
trace backend, list, xarray.Dataset, arviz.InferenceData, or MultiTrace

Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()), or xarray.Dataset (eg. InferenceData.posterior or InferenceData.prior)

so yes, basically you just need to use the MAP values when doing the prediction, changing the trace parameter from the posterior sample:

trace = pm.sample(draws=draws, ...)
post_idata = pm.sample_posterior_predictive(trace, ...)

to a wrapped map list [map]

map = pm.find_MAP()
post_idata = pm.sample_posterior_predictive([map], ...)

that was all