Home
last modified time | relevance | path

Searched refs:rnn_type (Results 1 – 5 of 5) sorted by relevance

/external/pytorch/test/onnx/model_defs/
Dword_language_model.py16 rnn_type, argument
28 if rnn_type in ["LSTM", "GRU"]:
29 self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
32 nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type]
58 self.rnn_type = rnn_type
89 if self.rnn_type == "LSTM":
/external/pytorch/benchmarks/functional_autograd_benchmark/
Dtorchaudio_models.py185 rnn_type=nn.LSTM, argument
196 self.rnn = rnn_type(
264 rnn_type, argument
276 self.rnn_type = rnn_type
305 rnn_type=rnn_type,
314 rnn_type=rnn_type,
Daudio_text_models.py57 rnn_type=nn.LSTM,
/external/pytorch/test/quantization/core/
Dtest_quantized_module.py1843 for rnn_type in cell_dict.keys():
1847 if rnn_type == 'RNNReLU':
1849 elif rnn_type == 'RNNTanh':
1852 cell_dq = cell_dict[rnn_type](**kwargs)
1853 result = qfn_dict[rnn_type](x, state[rnn_type],
1856 result_module = cell_dq(x, state[rnn_type])
1861 self.check_eager_serialization(cell_dq, cell_dict[rnn_type](**kwargs), [x])
1906 for rnn_type in cell_dict.keys():
1908 if rnn_type == 'RNNReLU':
1910 elif rnn_type == 'RNNTanh':
[all …]
Dtest_quantized_op.py3455 …t_rnn_weights_and_bias(self, input_size, hidden_size, num_directions, per_channel_quant, rnn_type): argument
3457 hidden_mult = hidden_mult_map[rnn_type]
3488 for rnn_type in ['LSTM', 'GRU']:
3504 rnn_type)
3520 if rnn_type == 'LSTM':
3566 if rnn_type == 'GRU':
3622 for rnn_type in ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU']:
3635 input_size, hidden_size, 1, per_channel_quant, rnn_type)
3663 … result_ref = fn_dict[rnn_type](Xq.dequantize()[0], state[rnn_type], W_ref1, W_ref2, b1, b2)
3664 …result_dynamic = qfn_dict[rnn_type](Xq.dequantize()[0], state[rnn_type], packed_ih, packed_hh, b1,…