Pymc State Space .1.4 -- XlaRuntimeError Incorrect output dtype for return value #0: Expected: int64, Actual: int32

Hey State Space Team:

I recently updated to .1.4 to check out the regression component. My original model, running on an earlier version would sample fine, but now not so much.

I get the following error when I try to sample:

XlaRuntimeError: FAILED_PRECONDITION: Buffer Definition Event: Error dispatching computation: %sCpuCallback error: Traceback (most recent call last):
  File "C:\Users\(my user)\anaconda3\envs\fs_forecast_env3\Lib\site-packages\jax\_src\interpreters\mlir.py", line 2867, in _wrapped_callback
**RuntimeError: Incorrect output dtype for return value #0: Expected: int64, Actual: int32**

I’m running windows environment with jupyter notebooks. To resolve this issue, I created a new environment, and began with Anaconda 3.11, pymc experimental (pip) .1.4, and then the m2w64 tool chain and other dependencies.

I ask jax to run in x64.

import jax
jax.config.update("jax_enable_x64", True)

I verified my data and exogenous data are float with a .dtype check. And this is how I’m calling the sampler:

ss_mod.build_statespace_graph(train_data['Standardized_Sales'].to_frame(), mode='JAX')

with struct_model:
    trace = pm.sample(
        nuts_sampler='numpyro', 
        target_accept=0.95, 
        progressbar=True, 
        draws=1000, 
        tune=2000, 
        chains=2,
        return_inferencedata=True, 
        # idata_kwargs={"log_likelihood": True}
    )

Also, when I try to run it without exogenous data, I’m getting the same exception. Is there an obvious reason why this could be happening? Thank you for your help!

-Roy

CC @jessegrabowski

Can you post a minimum example that reproduces the error?

Without knowing anything else, I’d make sure that the exogenous data you’re passing is all cast to float, even if it’s an indicator. Jax is pretty fussy about datatypes, so it’s less error prone to work all in float.

Thank you for the help!

I can include the model, the full error message, requirements.txt, and some data soon. However, I turned off exogenous data and I’m still getting the error with my time series data - and I wasn’t getting this error with my time series data before.

I think it has something to do with my environment build. I might try rolling Jax back to an earlier version, maybe, in addition to doing a float cast on my time series data.

Thanks,
Roy

1 Like

I saw you already checked the datatypes, sorry for not reading your OP carefully enough. If you can post a simple example (just using artificial data) that generates the error would help a lot. It’s definitely not necessary to show absolutely everything. Indeed, the simpler the better – that will help to isolate the bug (if there is one)

Hey Jesse:

I’m sure my code isn’t as minimal as it could be - but this should recreate the error. I also included a copy of the error message, time series data (sales data), exog data (temps), and
my environment.txt (.yml).

Been a busy week; I wish I could have gotten this to you sooner. Thank you for taking a look!

-Roy

import os
os.environ["JAX_ENABLE_X64"] = "True"
os.environ["JAX_PLATFORM_NAME"] = "cpu"
import jax
jax.config.update("jax_enable_x64", True)

import numpyro
numpyro.set_host_device_count(8)

import arviz as az
import matplotlib.pyplot as plt
%matplotlib qt
import pandas as pd
import pymc as pm
from pymc_experimental.statespace import structural as st
import pytensor.tensor as pt
import numpy as np
from sklearn.preprocessing import QuantileTransformer
import seaborn as sns
from scipy import stats
import statsmodels.api as sm

# Load your data
file_path = 'sales_data.xlsx'
sales_data = pd.read_excel(file_path, engine='openpyxl')
sales_data['Time'] = pd.to_datetime(sales_data['Time'], format='%m/%d/%Y %I:%M %p')
sales_data.set_index('Time', inplace=True)
sales_data = sales_data.asfreq('D')

#Normalize and then standarddize the data.
# Step 1: Apply the Box-Cox Transformation
# Save the NaN positions
nan_positions = sales_data['Sales'].isna()

# Drop NaN values before applying Box-Cox
sales_data_nonan = sales_data.dropna()

# Apply Box-Cox transformation
sales_data_nonan['Transformed_Sales'], lambda_value = stats.boxcox(sales_data_nonan['Sales'])

# Add the NaNs back in their original positions
sales_data['Transformed_Sales'] = np.nan
sales_data.loc[sales_data_nonan.index, 'Transformed_Sales'] = sales_data_nonan['Transformed_Sales']

# Step 2: Standardize the Box-Cox Transformed Data
mean_sales = sales_data['Transformed_Sales'].mean()
std_sales = sales_data['Transformed_Sales'].std()

sales_data['Standardized_Sales'] = (sales_data['Transformed_Sales'] - mean_sales) / std_sales

# Step 3: Update frequency for sales data
sales_data.index.freq = pd.infer_freq(sales_data.index)

# Step 4: Visualize the Distribution with KDE
plt.figure(figsize=(12, 6))
sns.histplot(sales_data['Standardized_Sales'].dropna(), kde=True, color='skyblue', bins=30)
plt.title('Histogram with KDE (Standardized Sales Data)')
plt.xlabel('Standardized Sales')
plt.ylabel('Count')
plt.show()

# Optionally, you can also print out the first few rows to ensure the transformations worked as expected
print(sales_data[['Sales', 'Transformed_Sales', 'Standardized_Sales']].head())

split_point = int(len(sales_data) * 0.9)

# Create the train_data set with 90% of the standardized data
train_data = sales_data.iloc[:split_point]

# Optionally, you can visualize the size of the train_data
print(f"Length of full data: {len(sales_data)}")
print(f"Length of train data: {len(train_data)}")
print(train_data.head())

# Temperature Data - exogenous variable
# Read the CSV file
temperature_data = pd.read_csv('temps.csv', parse_dates=['DATE'], index_col='DATE')

# Select the TMAX column
temperature_data = temperature_data[['TMAX']]

# Ensure your temperature data covers at least the span of your train_data
start_date = train_data.index[0]
end_date = train_data.index[-1]

# Filter the temperature data to match the training period exactly
temp_train_data = temperature_data.loc[start_date:end_date].copy()

# Standardize the TMAX data
temp_train_data['TMAX'] = (temp_train_data['TMAX'] - temp_train_data['TMAX'].mean()) / temp_train_data['TMAX'].std()

# Filter temperature data to include only dates present in sales data
temp_train_data = temp_train_data.loc[train_data.index]

# Ensure no missing values remain in either dataset
train_data = train_data.dropna()
temp_train_data = temp_train_data.dropna()

# Filter temperature data to include only dates present in train_data
temp_train_data = temp_train_data.loc[temp_train_data.index.isin(train_data.index)]

# Update frequency for temperature data
temp_train_data.index.freq = pd.infer_freq(temp_train_data.index)

# Rename columns for compatibility
temp_train_data.index.name = 'time'
X = temp_train_data.rename(columns={'TMAX': 'temperature'})

# Infer and set the frequency for the index of X
X.index.freq = pd.infer_freq(X.index)

# Display the aligned temperature training data
display(X.head())
display(X.tail())

# Verify alignment
print(temp_train_data.index.isin(train_data.index).all())  # Should be True
print(X.shape, train_data.shape)  # Shapes should match in rows

# Define the structural components

# Define the structural components
level_trend = st.LevelTrendComponent(order=1, innovations_order=[0])
weekly_seasonality = st.TimeSeasonality(season_length=7, state_names=['Sun', 'Mon', 'Tues', 'Wed', 'Thu', 'Fri', 'Sat'])
quarterly_seasonality = st.FrequencySeasonality(season_length=365, n=2)
exog = st.RegressionComponent(name='temperature', state_names=['temperature'], innovations=False)

# Combine components into a single model
combined_model = level_trend + weekly_seasonality + quarterly_seasonality 
# + exog
# Build the state-space model
ss_mod = combined_model.build()

#attempt to cast as float to resolve the jax issue
train_data['Standardized_Sales'].astype(float).to_frame()

# Define the PyMC model and assign priors
with pm.Model(coords=ss_mod.coords) as struct_model:
  
    P0_diag = pm.Gamma("P0_diag", alpha=2, beta=1, dims=['state'])
    P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=['state', 'state_aux'])

    initial_trend = pm.Normal("initial_trend", mu=[0], sigma=[0.3], dims=["trend_state"])
    # sigma_trend = pm.HalfNormal("sigma_trend", sigma=0.5, dims=["trend_shock"])  # Applied to the level only

    Seasonal_coefs = pm.Normal('Seasonal[s=7]_coefs', mu=0, sigma=.5, dims=['Seasonal[s=7]_state'])
    sigma_Seasonal = pm.Gamma('sigma_Seasonal[s=7]', alpha=.05, beta=10)

    Frequency_coefs = pm.HalfNormal('Frequency[s=365, n=2]', sigma=0.05, dims=['Frequency[s=365, n=2]_state'])
    sigma_Frequency = pm.HalfNormal('sigma_Frequency[s=365, n=2]', sigma=0.05)

    # data_exog = pm.Data('data_temperature', X.values, dims=['time', 'exog_state'])
    # beta_exog = pm.Normal('beta_temperature', mu=0, sigma=3, dims=['exog_state'])

    # Build the statespace graph with the exogenous data
    ss_mod.build_statespace_graph(train_data['Standardized_Sales'], mode='JAX')

    # Sampling from the posterior
with struct_model:
    trace = pm.sample(
        nuts_sampler='numpyro', 
        target_accept=0.95, 
        progressbar=True, 
        draws=1000, 
        tune=2000, 
        chains=2,
        return_inferencedata=True, 
        # idata_kwargs={"log_likelihood": True}
    )

    # Save the results
    # trace.to_netcdf('mcmc_results.nc')


error message.txt (7.1 KB)

temps.csv (18.3 KB)

sales_data.csv (27.0 KB)

environment.txt (3.3 KB)

I ran your script and cannot reproduce the error (that is, the model samples for me). Small changes I made:

  • I got rid of all the jax flag setting and replaced it with just jax.config.update('jax_platform_name', 'cpu')
  • I cast the temperature data to float before normalizing it
  • I used nutpie to sample instead of numpyro:
from pymc.model.transform.optimization import freeze_dims_and_data
with freeze_dims_and_data(struct_model):
    trace = pm.sample(
        nuts_sampler='nutpie', 
        nuts_sampler_kwargs={'backend':'jax', 'gradient_backend':'jax'},
        draws=500, 
        tune=1000, 
        chains=6,
    )

I had to freeze the model to include the exogenous component; it worked as expected otherwise. I think this is actually a bug, and I’ll look into it.

Quick update:

I did the things you recommended and it bypassed my issues with Jax easily. Thank you! Actually, nutpie runs faster, so that’s a nice benefit too.

By the way, the work you’re doing here is amazing. I can’t believe I’m doing (trying to at least) time series forecasting with state space.

Thanks,
Roy

2 Likes

Hi, Jesse, I believe this is still a bug for Windows user. I have very similar errors and which also suggest “Expected: int64, Actual: int32”. While my friends could run it in Mac without any problems in exact same code. This is fixed by changing nuts_sampler=“nutpie” instead of “numpyro”. It seems the Windows user has some problems with the PYMC Default NUTS which is the JAX-based.

What model are you running?

Be aware that windows jax support is entirely community supported, the jax team doesn’t actively work on it. The officialy recommended way to run jax on windows is via WLS2

We’ve also seen integer type issues in Windows in the past, because the default integer type for numpy arrays was int32 instead of int64 on Linux/Macs