Source code for torchid.datasets
import torch
from torch.utils.data import Dataset, TensorDataset
[docs]class SubsequenceDataset(Dataset):
r"""A dataset returning sub-sequences extracted from longer sequences.
Args:
*tensors (Tensor): tensors that have the same size of the first dimension.
Examples:
>>> u = torch.randn(1000, 2) # 2 inputs
>>> y = torch.randn(1000, 3) # 3 outputs
>>> train_dataset = SubsequenceDataset(u, y, subseq_len=100)
"""
def __init__(self, *tensors, subseq_len):
self.tensors = tensors
self.subseq_len = subseq_len
self.length = self.tensors[0].shape[0]
def __len__(self):
return self.length - self.subseq_len + 1
def __getitem__(self, idx):
subsequences = [tensor[idx:idx+self.subseq_len] for tensor in self.tensors]
return subsequences