Most (all?) arviz plotting functions let you pass an axis or collection of axes via an ax argument. You can instantiate an axis, pass it in, then arbitrarily update it. For example:
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoLocator
fig, ax = plt.subplots()
az.plot_posterior(idata.prior, var_names=['mu'], ax=ax);
[spine.set_visible(True) for spine in ax.spines.values()]
ax.yaxis.set_major_locator(AutoLocator())
plt.show()
