from torch import nn from torch.nn.utils.rnn import PackedSequence class LstmFlatteningResult(nn.LSTM): def forward(self, input, *fargs, **fkwargs): output, (hidden, cell) = nn.LSTM.forward(self, input, *fargs, **fkwargs) return output, hidden, cell class LstmFlatteningResultWithSeqLength(nn.Module): def __init__(self, input_size, hidden_size, layers, bidirect, dropout, batch_first): super().__init__() self.batch_first = batch_first self.inner_model = nn.LSTM( input_size=input_size, hidden_size=hidden_size, num_layers=layers, bidirectional=bidirect, dropout=dropout, batch_first=batch_first, ) def forward(self, input: PackedSequence, hx=None): output, (hidden, cell) = self.inner_model.forward(input, hx) return output, hidden, cell class LstmFlatteningResultWithoutSeqLength(nn.Module): def __init__(self, input_size, hidden_size, layers, bidirect, dropout, batch_first): super().__init__() self.batch_first = batch_first self.inner_model = nn.LSTM( input_size=input_size, hidden_size=hidden_size, num_layers=layers, bidirectional=bidirect, dropout=dropout, batch_first=batch_first, ) def forward(self, input, hx=None): output, (hidden, cell) = self.inner_model.forward(input, hx) return output, hidden, cell