Source code for torchid.dynonet.module.static

import torch
import torch.nn as nn


[docs]class MimoStaticNonLinearity(nn.Module): r"""Applies a Static MIMO non-linearity. The non-linearity is implemented as a feed-forward neural network. Args: in_channels (int): Number of input channels out_channels (int): Number of output channels n_hidden (int, optional): Number of nodes in the hidden layer. Default: 20 activation (str): Activation function. Either 'tanh', 'relu', or 'sigmoid'. Default: 'tanh' Shape: - Input: (..., in_channels) - Output: (..., out_channels) Examples:: >>> in_channels, out_channels = 2, 4 >>> F = MimoStaticNonLinearity(in_channels, out_channels) >>> batch_size, seq_len = 32, 100 >>> u_in = torch.ones((batch_size, seq_len, in_channels)) >>> y_out = F(u_in, y_0, u_0) # shape: (batch_size, seq_len, out_channels) """ def __init__(self, in_channels, out_channels, n_hidden=20, activation='tanh'): super(MimoStaticNonLinearity, self).__init__() activation_dict = {'tanh': nn.Tanh, 'relu': nn.ReLU, 'sigmoid': nn.Sigmoid} self.net = nn.Sequential( nn.Linear(in_channels, n_hidden), activation_dict[activation](), #nn.Tanh(), nn.Linear(n_hidden, out_channels) )
[docs] def forward(self, u_lin): y_nl = self.net(u_lin) return y_nl
[docs]class SisoStaticNonLinearity(MimoStaticNonLinearity): r"""Applies a Static SISO non-linearity. The non-linearity is implemented as a feed-forward neural network. Args: n_hidden (int, optional): Number of nodes in the hidden layer. Default: 20 activation (str): Activation function. Either 'tanh', 'relu', or 'sigmoid'. Default: 'tanh' s Shape: - Input: (..., in_channels) - Output: (..., out_channels) Examples:: >>> F = SisoStaticNonLinearity(n_hidden=20) >>> batch_size, seq_len = 32, 100 >>> u_in = torch.ones((batch_size, seq_len, in_channels)) >>> y_out = F(u_in, y_0, u_0) # shape: (batch_size, seq_len, out_channels) """ def __init__(self, n_hidden=20, activation='tanh'): super(SisoStaticNonLinearity, self).__init__(in_channels=1, out_channels=1, n_hidden=n_hidden, activation=activation)
[docs]class MimoChannelWiseNonLinearity(nn.Module): r"""Applies a Channel-wise non-linearity. The non-linearity is implemented as a set of feed-forward neural networks (each one operating on a different channel). Args: channels (int): Number of both input and output channels n_hidden (int, optional): Number of nodes in the hidden layer of each network. Default: 10 Shape: - Input: (..., channels) - Output: (..., channels) Examples:: >>> channels = 4 >>> F = MimoChannelWiseNonLinearity(channels) >>> batch_size, seq_len = 32, 100 >>> u_in = torch.ones((batch_size, seq_len, channels)) >>> y_out = F(u_in, y_0, u_0) # shape: (batch_size, seq_len, channels) """ def __init__(self, channels, n_hidden=10): super(MimoChannelWiseNonLinearity, self).__init__() self.net = nn.ModuleList() for channel_idx in range(channels): channel_net = nn.Sequential( nn.Linear(1, n_hidden), # 2 states, 1 input nn.ReLU(), nn.Linear(n_hidden, 1) ) self.net.append(channel_net)
[docs] def forward(self, u_lin): y_nl = [] for channel_idx, u_channel in enumerate(u_lin.split(1, dim=-1)): # split over the last dimension (input channel) y_nl_channel = self.net[channel_idx](u_channel) # Process blocks individually y_nl.append(y_nl_channel) y_nl = torch.cat(y_nl, -1) # concatenate all output channels return y_nl
if __name__ == '__main__': channels = 4 nn1 = MimoChannelWiseNonLinearity(channels) in_data = torch.randn(100, 10, channels) xx = net_out = nn1(in_data)