• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test lstm """
16import pytest
17
18import mindspore.context as context
19from mindspore import nn
20from ..ut_filter import run_on_gpu
21from ....ops_common import convert
22
23
24class LstmTestNet(nn.Cell):
25    """ LstmTestNet definition """
26
27    def __init__(self, input_size, hidden_size, num_layers, has_bias, batch_first, bidirectional):
28        super(LstmTestNet, self).__init__()
29        self.lstm = nn.LSTM(input_size=input_size,
30                            hidden_size=hidden_size,
31                            num_layers=num_layers,
32                            has_bias=has_bias,
33                            batch_first=batch_first,
34                            bidirectional=bidirectional,
35                            dropout=0.0)
36
37    def construct(self, inp, h0, c0):
38        return self.lstm(inp, (h0, c0))
39
40
41test_case_cell_ops = [
42    ('lstm1_with_bias', {
43        'cell': LstmTestNet(10, 12, 2, has_bias=True, batch_first=False, bidirectional=False),
44        'input_shape': [[5, 3, 10], [2, 3, 12], [2, 3, 12]],
45        'output_shape': [[5, 3, 12], [2, 3, 12], [2, 3, 12]]}),
46    ('lstm2_without_bias', {
47        'cell': LstmTestNet(10, 12, 2, has_bias=False, batch_first=False, bidirectional=False),
48        'input_shape': [[5, 3, 10], [2, 3, 12], [2, 3, 12]],
49        'output_shape': [[5, 3, 12], [2, 3, 12], [2, 3, 12]]}),
50    ('lstm3_with_bias_bidirectional', {
51        'cell': LstmTestNet(10, 12, 2, has_bias=True, batch_first=False, bidirectional=True),
52        'input_shape': [[5, 3, 10], [4, 3, 12], [4, 3, 12]],
53        'output_shape': [[5, 3, 24], [4, 3, 12], [4, 3, 12]]}),
54    ('lstm4_without_bias_bidirectional', {
55        'cell': LstmTestNet(10, 12, 2, has_bias=False, batch_first=False, bidirectional=True),
56        'input_shape': [[5, 3, 10], [4, 3, 12], [4, 3, 12]],
57        'output_shape': [[5, 3, 24], [4, 3, 12], [4, 3, 12]]}),
58    ('lstm5_with_bias_batch_first', {
59        'cell': LstmTestNet(10, 12, 2, has_bias=True, batch_first=True, bidirectional=False),
60        'input_shape': [[3, 5, 10], [2, 3, 12], [2, 3, 12]],
61        'output_shape': [[3, 5, 12], [2, 3, 12], [2, 3, 12]]}),
62    ('lstm6_without_bias_batch_first', {
63        'cell': LstmTestNet(10, 12, 2, has_bias=False, batch_first=True, bidirectional=False),
64        'input_shape': [[3, 5, 10], [2, 3, 12], [2, 3, 12]],
65        'output_shape': [[3, 5, 12], [2, 3, 12], [2, 3, 12]]}),
66    ('lstm7_with_bias_bidirectional_batch_first', {
67        'cell': LstmTestNet(10, 12, 2, has_bias=True, batch_first=True, bidirectional=True),
68        'input_shape': [[3, 5, 10], [4, 3, 12], [4, 3, 12]],
69        'output_shape': [[3, 5, 24], [4, 3, 12], [4, 3, 12]]}),
70    ('lstm8_without_bias_bidirectional_batch_first', {
71        'cell': LstmTestNet(10, 12, 2, has_bias=False, batch_first=True, bidirectional=True),
72        'input_shape': [[3, 5, 10], [4, 3, 12], [4, 3, 12]],
73        'output_shape': [[3, 5, 24], [4, 3, 12], [4, 3, 12]]}),
74]
75
76
77# use -k to select certain testcast
78# pytest  tests/python/ops/test_lstm.py::test_compile -k lstm_with_bias
79
80@pytest.mark.parametrize('args', test_case_cell_ops, ids=lambda x: x[0])
81def test_compile(args):
82    config = args[1]
83    shapes = config['input_shape']
84    net = config['cell']
85    net.set_train()
86    inputs = [convert(shp) for shp in shapes]
87    out = net(*inputs)
88    print(f"out: {out}")
89
90
91@run_on_gpu
92@pytest.mark.parametrize('args', test_case_cell_ops, ids=lambda x: x[0])
93def test_execute(args):
94    """ test_execute """
95    config = args[1]
96    shapes = config['input_shape']
97    net = config['cell']
98    net.set_train()
99    inputs = [convert(shp) for shp in shapes]
100    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
101    # pylint: disable=unused-variable
102    ret, (hn, cn) = net(*inputs)
103    print(f'result: {shapes[0]} --> {ret.asnumpy().shape}, expected: {config["output_shape"][0]}')
104