Hello!
I am running into some issues when running a model on pymc v5
We are attempting to generate a model and are first testing it on synthetic data in order to be aware of the bias and convergence. I have some questions:
- Is it possible to get, much like with model.initial_values() any other chain element? In the same manner (dictionary)?
- If the argument “initval” is passed within the model, then both the tuning steps AND the draws are initialized in those values, right?
- Is there any way to have something like discard_tuned_samples=False with ‘’‘pm.sampling_jax.sample_numpyro_nuts’‘’?
- If the sampling goes away from the given true values as initialization steps, would that always be an indication of poor modelling, or am I missing something? Since the priors were defined as informative around the true values, I am leaning towards something being amiss.
- In this case, the model has a Dirichlet prior. When this happens, the initial_values report a dirichlet_simplex__, and that is what is needed in order to run ```model.compile_logp’‘’. How can I go from the variable’s concentration values to the simplex that is needed? (for example, in any other of the chains’ values, as asked in 1)
- I am analyzing the logprobability in different stages.
On one hand, I have the inference data sample_stats.lp, and I can get the mean value along all sample chains.
Then, I have the initial value (which is in fact the true value with which the data was generated) and I get the logp using ‘’‘model.compile_logp’‘’
Finally, I use ‘’‘pm.find_MAP(model)’‘’ in order to have another dictionary and use those values with ‘’‘model.compile_logp’‘’.
I have two issues in this case: the initial values have lower logp values than the chains’ means. Then, the MAP values have the lowest of them all. Also, the reported logp while running find_MAP differs from the one reported from compile_logp()(MAP_point).
From this, I then have two questions:
6a) To ensure I am reading them right: the logp is maximised, meaning, desired to be as large and positive as possible?
6b) Also, are all the logp calculations offered the same, except for some constant value or are they obtained in different manners?
Thanks in advance!
Versions:
numpy==1.24.1
pymc==5.0.1 (I also use pymc.sampling_jax)
arviz==0.14.0
graphviz==0.20.1