Forecasting electricity prices with Amazon Chronos

Chronos is a foundational model for zero-shot probabilistic forecasting of univariate time series [1]. The model converts a time series into a sequence of tokens through scaling and quantization. The scaling procedure divides the time series by its mean absolute value, while the quantization process maps the scaled time series values to a discrete set of tokens using uniform binning.

The tokenized time series is then used by a large language model (LLM). The LLM takes as input a sequence of tokens and returns the predicted next token. Subsequent future tokens are generated in an autoregressive manner by extending the initial input sequence with the previously generated tokens and feeding it back to the model. The generated tokens are then converted back to time series values by inverting the quantization and scaling transformations.

Chronos was trained using the T5 model architecture [2], even though it is compatible with any LLM. The training was performed in a self-supervised manner by minimizing the cross-entropy loss between the actual and predicted distributions of the next token, as it is standard when training LLMs. The data used for training included both real time series from publicly available datasets, as well as synthetic time series generated using different methods.

In this post, we demonstrate how to use Chronos for one-step-ahead forecasting. We will use the US average electricity price monthly time series from November 1978 to July 2024, which we will download from the FRED database, and generate one-month-ahead forecasts from August 2014 to July 2024. We will use expanding context windows, that is on each month we will provide Chronos all the data up to that month, and generate the forecast for the next month.

We will compare Chronos' forecasts to the rolling forecasts of a SARIMA model which is re-trained each month on the same data that was provided to Chronos as context. We will find that Chronos and the SARIMA model have comparable performance.

Code

We start by installing and importing all the dependencies.

pip install git+https://github.com/amazon-science/chronos-forecasting.git fredapi pmdarima
import warnings
import transformers
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
from chronos import ChronosPipeline
from pmdarima.arima import auto_arima
from statsmodels.tsa.statespace.sarimax import SARIMAX
from tqdm import tqdm
from fredapi import Fred
from sklearn.metrics import mean_absolute_percentage_error, mean_absolute_error, root_mean_squared_error

Next, we download the time series from the FRED database. We use the Python API for FRED for downloading the data.

Tip

If you don’t have a FRED API key, you can request one for free at this link.

# set up the FRED API
fred = Fred(api_key_file="api_key.txt")

# define the time series ID
series = "APU000072610"

# download the time series
data = fred.get_series(series).rename(series).ffill()

The time series includes 549 monthly observations from November 1978 to July 2024. The time series had one missing value in September 1985, which we forward filled with the previous value.

US average electricity price from November 1978 to July 2024

US average electricity price from November 1978 to July 2024.

We generate the forecasts over a 10-year period (120 months) from August 2014 to July 2024.

# date of first forecast
start_date = "2014-08-01"

# date of last forecast
end_date = "2024-07-01"

SARIMA

We use the pmdarima library for finding the best order of the SARIMA model using the data up to July 2014.

# find the best order of the SARIMA model
best_sarima_model = auto_arima(
    y=data[data.index < start_date],
    start_p=0,
    start_q=0,
    start_P=0,
    start_Q=0,
    m=12,
    seasonal=True,
)
SARIMA estimation results.

SARIMA estimation results.

For each month in the considered time window, we train the SARIMA model with the identified best order on all the data up to that month, and generate the forecast for the next month.

# create a list for storing the forecasts
sarima_forecasts = []

# loop across the dates
for t in tqdm(range(data.index.get_loc(start_date), data.index.get_loc(end_date) + 1)):

    # extract the training data
    context = data.iloc[:t]

    # train the model
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        sarima_model = SARIMAX(
            endog=context,
            order=best_sarima_model.order,
            seasonal_order=best_sarima_model.seasonal_order,
            trend="c" if best_sarima_model.with_intercept else None,
        ).fit(disp=0)

    # generate the one-step-ahead forecast
    sarima_forecast = sarima_model.get_forecast(steps=1)

    # save the forecast
    sarima_forecasts.append({
        "date": data.index[t],
        "actual": data.values[t],
        "mean": sarima_forecast.predicted_mean.item(),
        "std": sarima_forecast.var_pred_mean.item() ** 0.5,
    })

# cast the forecasts to data frame
sarima_forecasts = pd.DataFrame(sarima_forecasts)
sarima_forecasts.shape
(120, 4)
sarima_forecasts.head()
First 3 rows of SARIMA forecasts
sarima_forecasts.tail()
Last 3 rows of SARIMA forecastsSARIMA forecasts from August 2014 to July 202.

SARIMA forecasts from August 2014 to July 2024.

We find that the SARIMA model achieves an RMSE of 0.001364 and a MAE of 0.001067.

# calculate the error metrics
sarima_metrics = pd.DataFrame(
    columns=["Metric", "Value"],
    data=[
        {"Metric": "RMSE", "Value": root_mean_squared_error(y_true=sarima_forecasts["actual"], y_pred=sarima_forecasts["mean"])},
        {"Metric": "MAE", "Value": mean_absolute_error(y_true=sarima_forecasts["actual"], y_pred=sarima_forecasts["mean"])},
    ]
).set_index("Metric")
SARIMA forecast errors from August 2014 to July 202.

SARIMA forecast errors from August 2014 to July 2024.

Chronos

We use the t5-large version of Chronos, which includes approximately 710 million parameters.

# instantiate the model
chronos_model = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-large",
    device_map="cuda",
    torch_dtype=torch.bfloat16,
)

For each month in the considered time window, we use as context window all the data up to that month, and generate 100 samples from the predicted distribution for the next month. We use the mean of the distribution as point forecast, as in the SARIMA model.

Note

Note that, as Chronos is a generative model, different random seeds and different numbers of samples result in slightly different forecasts.

# create a list for storing the forecasts
chronos_forecasts = []

# loop across the dates
for t in tqdm(range(data.index.get_loc(start_date), data.index.get_loc(end_date) + 1)):

    # extract the context window
    context = data.iloc[:t]

    # generate the one-step-ahead forecast
    transformers.set_seed(42)
    chronos_forecast = chronos_model.predict(
        context=torch.from_numpy(context.values),
        prediction_length=1,
        num_samples=100
    ).detach().cpu().numpy().flatten()

    # save the forecast
    chronos_forecasts.append({
        "date": data.index[t],
        "actual": data.values[t],
        "mean": np.mean(chronos_forecast),
        "std": np.std(chronos_forecast, ddof=1),
    })

# cast the forecasts to data frame
chronos_forecasts = pd.DataFrame(chronos_forecasts)
chronos_forecasts.shape
(120, 4)
chronos_forecasts.head()
First 3 rows of Chronos forecasts
chronos_forecasts.tail()
Last 3 rows of Chronos forecastsChronos forecasts from August 2014 to July 202.

Chronos forecasts from August 2014 to July 2024.

We find that Chronos achieves an RMSE of 0.001443 and a MAE of 0.001105.

# calculate the error metrics
chronos_metrics = pd.DataFrame(
    columns=["Metric", "Value"],
    data=[
        {"Metric": "RMSE", "Value": root_mean_squared_error(y_true=chronos_forecasts["actual"], y_pred=chronos_forecasts["mean"])},
        {"Metric": "MAE", "Value": mean_absolute_error(y_true=chronos_forecasts["actual"], y_pred=chronos_forecasts["mean"])},
    ]
).set_index("Metric")
Chronos forecast errors from August 2014 to July 202.

Chronos forecast errors from August 2014 to July 2024.

Tip

A Python notebook with the full code is available in our GitHub repository.

References

[1] Ansari, A.F., Stella, L., Turkmen, C., Zhang, X., Mercado, P., Shen, H., Shchur, O., Rangapuram, S.S., Arango, S.P., Kapoor, S. and Zschiegner, J., (2024). Chronos: Learning the language of time series. arXiv preprint, doi: 10.48550/arXiv.2403.07815.

[2] Raffel, C., Shazeer, N., Roberts, A., Lee, K., Narang, S., Matena, M., Zhou, Y., Li, W. and Liu, P.J., (2020). Exploring the limits of transfer learning with a unified text-to-text transformer. Journal of machine learning research, 21(140), pp.1-67.