State Space Model
The base-class for time-series modeling with state-space models. Generates forecasts in the form of
torchcast.state_space.Predictions, which can be used for training
(log_prob()), evaluation
(to_dataframe()) or visualization
(plot()).
This class is abstract; see torchcast.kalman_filter.KalmanFilter for the go-to forecasting model.
- class torchcast.state_space.StateSpaceModel(processes: Sequence[torchcast.process.base.Process], measures: Optional[Sequence[str]] = None, measure_covariance: Optional[torchcast.covariance.base.Covariance] = None)
Bases:
torch.nn.modules.module.ModuleBase-class for any
torch.nn.Modulewhich generates predictions/forecasts using a state-space model.- Parameters
processes – A list of
Processmodules.measures – A list of strings specifying the names of the dimensions of the time-series being measured.
measure_covariance – A module created with
Covariance.from_measures(measures).
- fit(*args, tol: float = 0.0001, patience: int = 1, max_iter: int = 200, optimizer: Optional[torch.optim.optimizer.Optimizer] = None, verbose: int = 2, callbacks: Sequence[Callable] = (), loss_callback: Optional[Callable] = None, callable_kwargs: Optional[Dict[str, Callable]] = None, set_initial_values: bool = True, **kwargs)
A high-level interface for invoking the standard model-training boilerplate. This is helpful to common cases in which the number of parameters is moderate and the data fit in memory. For other cases you are encouraged to roll your own training loop.
- Parameters
args – A tensor containing the batch of time-series(es), see
StateSpaceModel.forward().tol – Stopping tolerance.
patience – Patience: if loss changes by less than
tolfor this many epochs, then training will be stopped.max_iter – The maximum number of iterations after which training will stop regardless of loss.
optimizer – The optimizer to use. Default is to create an instance of
torch.optim.LBFGSwith args(max_iter=10, line_search_fn='strong_wolfe', lr=.5).verbose – If True (default) will print the loss and epoch; for
torch.optim.LBFGSoptimizer (the default) this progress bar will tick within each epoch to track the calls to forward.callbacks – A list of functions that will be called at the end of each epoch, which take the current epoch’s loss value.
loss_callback – A callback that takes the loss and returns a modified loss, called before each call to backward(). This can be used for example to add regularization.
callable_kwargs – A dictionary where the keys are keyword-names and the values are no-argument functions that will be called each iteration to recompute the corresponding arguments.
set_initial_values – Default is to set
initial_meanto sensible value giveny. This helps speed up training if the data are not centered. Set toFalseif you’re resuming training from a previousfit()call.kwargs – Further keyword-arguments passed to
StateSpaceModel.forward().
- Returns
This
StateSpaceModelinstance.
- forward(*args, n_step: Union[int, float] = 1, start_offsets: Optional[Sequence] = None, out_timesteps: Optional[Union[int, float]] = None, initial_state: Union[torch.Tensor, Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]] = (None, None), every_step: bool = True, include_updates_in_output: bool = False, **kwargs) torchcast.state_space.predictions.Predictions
Generate n-step-ahead predictions from the model.
- Parameters
args – A (group X time X measures) tensor. Optional if
initial_stateis specified.n_step – What is the horizon for the predictions output for each timepoint? Defaults to one-step-ahead predictions (i.e. n_step=1).
start_offsets – If your model includes seasonal processes, then these needs to know the start-time for each group in
input. If you passeddt_unitwhen constructing those processes, then you should pass an array of datetimes here. Otherwise you can pass an array of integers. Or leaveNoneif there are no seasonal processes.out_timesteps – The number of timesteps to produce in the output. This is useful when passing a tensor of predictors that goes later in time than the input tensor – you can specify
out_timesteps=X.shape[1]to get forecasts into this later time horizon.initial_state – The initial prediction for the state of the system. XXX (single tensor for mean always supported. for kf child class, can pass (mean,cov) tuple. this is usually if you’re feeding from a previous output. for exp-smooth child class, latter isn’t supported.
every_step – By default,
n_stepahead predictions will be generated at every timestep. Ifevery_step=False, then these predictions will only be generated every n_step timesteps. For example, with hourly data,n_step=24andevery_step=True, each timepoint would be a forecast generated with data 24-hours in the past. But withevery_step=Falsethe first timestep would be 1-step-ahead, the 2nd would be 2-step-ahead, … the 23rd would be 24-step-ahead, the 24th would be 1-step-ahead, etc. The advantage toevery_step=Falseis speed: training data for long-range forecasts can be generated without requiring the model to produce and discard intermediate predictions every timestep.include_updates_in_output – If False, only the
n_stepahead predictions are included in the output. This means that we cannot use this output to generate theinitial_statefor subsequent forward-passes. Set to True to allow this – False by default to reduce memory.kwargs – Further arguments passed to the processes. For example, the
LinearModelexpects anXargument for predictors.
- Returns
A
Predictionsobject withPredictions.log_prob()andPredictions.to_dataframe()methods.
- simulate(out_timesteps: int, initial_state: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]] = (None, None), start_offsets: Optional[Sequence] = None, num_sims: Optional[int] = None, progress: bool = False, **kwargs)
Generate simulated state-trajectories from your model.
- Parameters
out_timesteps – The number of timesteps to generate in the output.
initial_state – The initial state of the system: a tuple of mean, cov.
start_offsets – If your model includes seasonal processes, then these needs to know the start-time for each group in
input. If you passeddt_unitwhen constructing those processes, then you should pass an array datetimes here. Otherwise you can pass an array of integers (or leave None if there are no seasonal processes).num_sims – The number of state-trajectories to simulate.
progress – Should a progress-bar be displayed? Requires tqdm.
kwargs – Further arguments passed to the processes.
- Returns
A
Simulationsobject with aSimulations.sample()method.
- class torchcast.state_space.Predictions(state_means: Sequence[torch.Tensor], state_covs: Sequence[torch.Tensor], R: Sequence[torch.Tensor], H: Sequence[torch.Tensor], model: Union[StateSpaceModel, dict], update_means: Optional[Sequence[torch.Tensor]] = None, update_covs: Optional[Sequence[torch.Tensor]] = None)
Bases:
torch.nn.modules.module.ModuleThe output of the
StateSpaceModelforward pass, containing the underlying state means and covariances, as well as the predicted observations and covariances.- get_state_at_times(times: Union[numpy.ndarray, numpy.datetime64], start_times: Optional[numpy.ndarray] = None, dt_unit: Optional[str] = None, type_: str = 'update') Tuple[torch.Tensor, torch.Tensor]
For each group, get the state (tuple of (mean, cov)) for a timepoint. This is often useful since predictions are right-aligned and padded, so that the final prediction for each group is arbitrarily padded and does not correspond to a timepoint of interest – e.g. for forecasting (i.e., calling
StateSpaceModel.forward(initial_state=get_state_at_times(...))).- Parameters
times – Either (a) indices corresponding to each group (e.g.
times[0]corresponds to the timestep to take for the 0th group,times[1]the timestep to take for the 1th group, etc.) or (b) ifstart_timesis passed, an array of datetimes. Will also support a single datetime.start_times – If
timesis an array of datetimes, must also passstart_datetimes, i.e. the datetimes at which each group started.dt_unit – If
timesis an array of datetimes, must also passdt_unit, i.e. anumpy.timedelta64that indicates how much time passes at each timestep. (times-start_times)/dt_unit should be an array of integers.type – What type of state? Since this method is typically used for getting an initial_state for another call to
StateSpaceModel.forward(), this should generally be ‘update’ (the default); other option is ‘prediction’.
- Returns
A tuple of state-means and state-covs, appropriate for forecasting by passing as initial_state for
StateSpaceModel.forward().
- classmethod observe(state_means: torch.Tensor, state_covs: torch.Tensor, R: torch.Tensor, H: torch.Tensor) Tuple[torch.Tensor, torch.Tensor]
Convert latent states into observed predictions (and their uncertainty).
- Parameters
state_means – The latent state means
state_covs – The latent state covs.
R – The measure-covariance matrices.
H – The measurement matrix.
- Returns
A tuple of means, covs.
- log_prob(obs: torch.Tensor) torch.Tensor
Compute the log-probability of data (e.g. data that was originally fed into the KalmanFilter).
- Parameters
obs – A Tensor that could be used in the KalmanFilter.forward pass.
- Returns
A tensor with one element for each group X timestep indicating the log-probability.
- to_dataframe(dataset: Union[torchcast.utils.data.TimeSeriesDataset, dict], type: str = 'predictions', group_colname: str = 'group', time_colname: str = 'time', multi: Optional[float] = 1.96) DataFrame
- Parameters
dataset – Either a
TimeSeriesDataset, or a dictionary with ‘start_times’, ‘group_names’, & ‘dt_unit’type – Either ‘predictions’ or ‘components’.
group_colname – Column-name for ‘group’
time_colname – Column-name for ‘time’
multi – Multiplier on std-dev for lower/upper CIs. Default 1.96.
- Returns
A pandas DataFrame with group, ‘time’, ‘measure’, ‘mean’, ‘lower’, ‘upper’. For
type='components'additionally includes: ‘process’ and ‘state_element’.
- static plot(df: DataFrame, group_colname: str = None, time_colname: str = None, max_num_groups: int = 1, split_dt: Optional[numpy.datetime64] = None, **kwargs) DataFrame
- Parameters
df – The output of
Predictions.to_dataframe().group_colname – The name of the group-column.
time_colname – The name of the time-column.
max_num_groups – Max. number of groups to plot; if the number of groups in the dataframe is greater than this, a random subset will be taken.
split_dt – If supplied, will draw a vertical line at this date (useful for showing pre/post validation).
kwargs – Further keyword arguments to pass to
plotnine.theme(e.g.figure_size=(x,y))
- Returns
A plot of the predicted and actual values.