Source code for torchid.ss.dt.simulator
import torch
import torch.nn as nn
from typing import List
[docs]class StateSpaceSimulator(nn.Module):
r""" Discrete-time state-space simulator.
Args:
f_xu (nn.Module): The neural state-space model.
batch_first (bool): If True, first dimension is batch.
Inputs: x_0, u
* **x_0**: tensor of shape :math:`(N, n_{x})` containing the
initial hidden state for each element in the batch.
Defaults to zeros if (h_0, c_0) is not provided.
* **input**: tensor of shape :math:`(L, N, n_{u})` when ``batch_first=False`` or
:math:`(N, L, n_{x})` when ``batch_first=True`` containing the input sequence
Outputs: x
* **x**: tensor of shape :math:`(L, N, n_{x})` corresponding to
the simulated state sequence.
Examples::
>>> ss_model = NeuralStateSpaceModel(n_x=3, n_u=2)
>>> nn_solution = StateSpaceSimulator(ss_model)
>>> x0 = torch.randn(64, 3)
>>> u = torch.randn(100, 64, 2)
>>> x = nn_solution(x0, u)
>>> print(x.size())
torch.Size([100, 64, 3])
"""
def __init__(self, f_xu, g_x=None, batch_first=False):
super().__init__()
self.f_xu = f_xu
self.g_x = g_x
self.batch_first = batch_first
def simulate_state(self, x_0, u):
x: List[torch.Tensor] = []
x_step = x_0
dim_time = 1 if self.batch_first else 0
for u_step in u.split(1, dim=dim_time): # split along the time axis
u_step = u_step.squeeze(dim_time)
x += [x_step]
dx = self.f_xu(x_step, u_step)
x_step = x_step + dx
x = torch.stack(x, dim_time)
return x
def forward(self, x_0, u, return_x=False):
x = self.simulate_state(x_0, u)
if self.g_x is not None:
y = self.g_x(x)
else:
y = x
if not return_x:
return y
else:
return y, x