• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import numpy as np
17import pytest
18
19import mindspore.context as context
20import mindspore.nn as nn
21from mindspore import Tensor
22from mindspore.common import dtype as mstype
23from mindspore.ops import operations as P
24
25
26class BatchMatMulNet(nn.Cell):
27    def __init__(self, transpose_a=False, transpose_b=False):
28        super(BatchMatMulNet, self).__init__()
29        self.batch_matmul = P.BatchMatMul(transpose_a, transpose_b)
30
31    def construct(self, x, y):
32        return self.batch_matmul(x, y)
33
34
35def judge_result_correct(result, expect):
36    assert result.dtype == expect.dtype
37    assert result.shape == expect.shape
38    assert np.allclose(result, expect)
39
40
41@pytest.mark.level0
42@pytest.mark.platform_x86_cpu
43@pytest.mark.env_onecard
44def test_4d_no_transpose_vec():
45    input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape((2, 4, 1, 3)), mstype.float32)
46    input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape((2, 4, 3, 4)), mstype.float32)
47
48    context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
49    net = BatchMatMulNet()
50    output = net(input_x, input_y)
51    expect = np.array([[[[20, 23, 26, 29]],
52                        [[200, 212, 224, 236]],
53                        [[596, 617, 638, 659]],
54                        [[1208, 1238, 1268, 1298]]],
55                       [[[2036, 2075, 2114, 2153]],
56                        [[3080, 3128, 3176, 3224]],
57                        [[4340, 4397, 4454, 4511]],
58                        [[5816, 5882, 5948, 6014]]]], dtype=np.float32)
59    judge_result_correct(output.asnumpy(), expect)
60
61
62@pytest.mark.level0
63@pytest.mark.platform_x86_cpu
64@pytest.mark.env_onecard
65def test_4d_no_transpose():
66    input_x = Tensor(np.arange(2 * 3 * 2 * 3).reshape((2, 3, 2, 3)), mstype.float32)
67    input_y = Tensor(np.arange(2 * 3 * 3 * 4).reshape((2, 3, 3, 4)), mstype.float32)
68
69    context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
70    net = BatchMatMulNet()
71    output = net(input_x, input_y)
72    expect = np.array([[[[20., 23., 26., 29.],
73                         [56., 68., 80., 92.]],
74                        [[344., 365., 386., 407.],
75                         [488., 518., 548., 578.]],
76                        [[1100., 1139., 1178., 1217.],
77                         [1352., 1400., 1448., 1496.]]],
78                       [[[2288., 2345., 2402., 2459.],
79                         [2648., 2714., 2780., 2846.]],
80                        [[3908., 3983., 4058., 4133.],
81                         [4376., 4460., 4544., 4628.]],
82                        [[5960., 6053., 6146., 6239.],
83                         [6536., 6638., 6740., 6842.]]]], dtype=np.float32)
84    judge_result_correct(output.asnumpy(), expect)
85
86
87@pytest.mark.level0
88@pytest.mark.platform_x86_cpu
89@pytest.mark.env_onecard
90def test_4d_transpose_a():
91    input_x = Tensor(np.arange(2 * 3 * 3 * 2).reshape((2, 3, 3, 2)), mstype.float32)
92    input_y = Tensor(np.arange(2 * 3 * 3 * 4).reshape((2, 3, 3, 4)), mstype.float32)
93
94    context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
95    net = BatchMatMulNet(transpose_a=True)
96    output = net(input_x, input_y)
97    expect = np.array([[[[40., 46., 52., 58.],
98                         [52., 61., 70., 79.]],
99                        [[400., 424., 448., 472.],
100                         [448., 475., 502., 529.]],
101                        [[1192., 1234., 1276., 1318.],
102                         [1276., 1321., 1366., 1411.]]],
103                       [[[2416., 2476., 2536., 2596.],
104                         [2536., 2599., 2662., 2725.]],
105                        [[4072., 4150., 4228., 4306.],
106                         [4228., 4309., 4390., 4471.]],
107                        [[6160., 6256., 6352., 6448.],
108                         [6352., 6451., 6550., 6649.]]]], dtype=np.float32)
109    judge_result_correct(output.asnumpy(), expect)
110
111
112@pytest.mark.level0
113@pytest.mark.platform_x86_cpu
114@pytest.mark.env_onecard
115def test_4d_transpose_b():
116    input_x = Tensor(np.arange(2 * 3 * 2 * 3).reshape((2, 3, 2, 3)), mstype.float32)
117    input_y = Tensor(np.arange(2 * 3 * 4 * 3).reshape((2, 3, 4, 3)), mstype.float32)
118
119    context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
120    net = BatchMatMulNet(transpose_b=True)
121    output = net(input_x, input_y)
122    expect = np.array([[[[5.000e+00, 1.400e+01, 2.300e+01, 3.200e+01],
123                         [1.400e+01, 5.000e+01, 8.600e+01, 1.220e+02]],
124                        [[2.750e+02, 3.380e+02, 4.010e+02, 4.640e+02],
125                         [3.920e+02, 4.820e+02, 5.720e+02, 6.620e+02]],
126                        [[9.770e+02, 1.094e+03, 1.211e+03, 1.328e+03],
127                         [1.202e+03, 1.346e+03, 1.490e+03, 1.634e+03]]],
128                       [[[2.111e+03, 2.282e+03, 2.453e+03, 2.624e+03],
129                         [2.444e+03, 2.642e+03, 2.840e+03, 3.038e+03]],
130                        [[3.677e+03, 3.902e+03, 4.127e+03, 4.352e+03],
131                         [4.118e+03, 4.370e+03, 4.622e+03, 4.874e+03]],
132                        [[5.675e+03, 5.954e+03, 6.233e+03, 6.512e+03],
133                         [6.224e+03, 6.530e+03, 6.836e+03, 7.142e+03]]]], dtype=np.float32)
134    judge_result_correct(output.asnumpy(), expect)
135
136
137@pytest.mark.level0
138@pytest.mark.platform_x86_cpu
139@pytest.mark.env_onecard
140def test_4d_transpose_ab():
141    input_x = Tensor(np.arange(2 * 3 * 3 * 2).reshape((2, 3, 3, 2)), mstype.float16)
142    input_y = Tensor(np.arange(2 * 3 * 4 * 3).reshape((2, 3, 4, 3)), mstype.float16)
143
144    context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
145    net = BatchMatMulNet(transpose_a=True, transpose_b=True)
146    output = net(input_x, input_y)
147    expect = np.array([[[[10., 28., 46., 64.],
148                         [13., 40., 67., 94.]],
149                        [[316., 388., 460., 532.],
150                         [355., 436., 517., 598.]],
151                        [[1054., 1180., 1306., 1432.],
152                         [1129., 1264., 1399., 1534.]]],
153                       [[[2224., 2404., 2584., 2764.],
154                         [2335., 2524., 2713., 2902.]],
155                        [[3826., 4060., 4294., 4528.],
156                         [3973., 4216., 4459., 4702.]],
157                        [[5860., 6148., 6436., 6724.],
158                         [6043., 6340., 6637., 6934.]]]], np.float16)
159    judge_result_correct(output.asnumpy(), expect)
160