1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18 19import mindspore.context as context 20import mindspore.nn as nn 21from mindspore import Tensor 22from mindspore.common import dtype as mstype 23from mindspore.ops import operations as P 24 25 26class BatchMatMulNet(nn.Cell): 27 def __init__(self, transpose_a=False, transpose_b=False): 28 super(BatchMatMulNet, self).__init__() 29 self.batch_matmul = P.BatchMatMul(transpose_a, transpose_b) 30 31 def construct(self, x, y): 32 return self.batch_matmul(x, y) 33 34 35def judge_result_correct(result, expect): 36 assert result.dtype == expect.dtype 37 assert result.shape == expect.shape 38 assert np.allclose(result, expect) 39 40 41@pytest.mark.level0 42@pytest.mark.platform_x86_cpu 43@pytest.mark.env_onecard 44def test_4d_no_transpose_vec(): 45 input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape((2, 4, 1, 3)), mstype.float32) 46 input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape((2, 4, 3, 4)), mstype.float32) 47 48 context.set_context(mode=context.GRAPH_MODE, device_target='CPU') 49 net = BatchMatMulNet() 50 output = net(input_x, input_y) 51 expect = np.array([[[[20, 23, 26, 29]], 52 [[200, 212, 224, 236]], 53 [[596, 617, 638, 659]], 54 [[1208, 1238, 1268, 1298]]], 55 [[[2036, 2075, 2114, 2153]], 56 [[3080, 3128, 3176, 3224]], 57 [[4340, 4397, 4454, 4511]], 58 [[5816, 5882, 5948, 6014]]]], dtype=np.float32) 59 judge_result_correct(output.asnumpy(), expect) 60 61 62@pytest.mark.level0 63@pytest.mark.platform_x86_cpu 64@pytest.mark.env_onecard 65def test_4d_no_transpose(): 66 input_x = Tensor(np.arange(2 * 3 * 2 * 3).reshape((2, 3, 2, 3)), mstype.float32) 67 input_y = Tensor(np.arange(2 * 3 * 3 * 4).reshape((2, 3, 3, 4)), mstype.float32) 68 69 context.set_context(mode=context.GRAPH_MODE, device_target='CPU') 70 net = BatchMatMulNet() 71 output = net(input_x, input_y) 72 expect = np.array([[[[20., 23., 26., 29.], 73 [56., 68., 80., 92.]], 74 [[344., 365., 386., 407.], 75 [488., 518., 548., 578.]], 76 [[1100., 1139., 1178., 1217.], 77 [1352., 1400., 1448., 1496.]]], 78 [[[2288., 2345., 2402., 2459.], 79 [2648., 2714., 2780., 2846.]], 80 [[3908., 3983., 4058., 4133.], 81 [4376., 4460., 4544., 4628.]], 82 [[5960., 6053., 6146., 6239.], 83 [6536., 6638., 6740., 6842.]]]], dtype=np.float32) 84 judge_result_correct(output.asnumpy(), expect) 85 86 87@pytest.mark.level0 88@pytest.mark.platform_x86_cpu 89@pytest.mark.env_onecard 90def test_4d_transpose_a(): 91 input_x = Tensor(np.arange(2 * 3 * 3 * 2).reshape((2, 3, 3, 2)), mstype.float32) 92 input_y = Tensor(np.arange(2 * 3 * 3 * 4).reshape((2, 3, 3, 4)), mstype.float32) 93 94 context.set_context(mode=context.GRAPH_MODE, device_target="CPU") 95 net = BatchMatMulNet(transpose_a=True) 96 output = net(input_x, input_y) 97 expect = np.array([[[[40., 46., 52., 58.], 98 [52., 61., 70., 79.]], 99 [[400., 424., 448., 472.], 100 [448., 475., 502., 529.]], 101 [[1192., 1234., 1276., 1318.], 102 [1276., 1321., 1366., 1411.]]], 103 [[[2416., 2476., 2536., 2596.], 104 [2536., 2599., 2662., 2725.]], 105 [[4072., 4150., 4228., 4306.], 106 [4228., 4309., 4390., 4471.]], 107 [[6160., 6256., 6352., 6448.], 108 [6352., 6451., 6550., 6649.]]]], dtype=np.float32) 109 judge_result_correct(output.asnumpy(), expect) 110 111 112@pytest.mark.level0 113@pytest.mark.platform_x86_cpu 114@pytest.mark.env_onecard 115def test_4d_transpose_b(): 116 input_x = Tensor(np.arange(2 * 3 * 2 * 3).reshape((2, 3, 2, 3)), mstype.float32) 117 input_y = Tensor(np.arange(2 * 3 * 4 * 3).reshape((2, 3, 4, 3)), mstype.float32) 118 119 context.set_context(mode=context.GRAPH_MODE, device_target="CPU") 120 net = BatchMatMulNet(transpose_b=True) 121 output = net(input_x, input_y) 122 expect = np.array([[[[5.000e+00, 1.400e+01, 2.300e+01, 3.200e+01], 123 [1.400e+01, 5.000e+01, 8.600e+01, 1.220e+02]], 124 [[2.750e+02, 3.380e+02, 4.010e+02, 4.640e+02], 125 [3.920e+02, 4.820e+02, 5.720e+02, 6.620e+02]], 126 [[9.770e+02, 1.094e+03, 1.211e+03, 1.328e+03], 127 [1.202e+03, 1.346e+03, 1.490e+03, 1.634e+03]]], 128 [[[2.111e+03, 2.282e+03, 2.453e+03, 2.624e+03], 129 [2.444e+03, 2.642e+03, 2.840e+03, 3.038e+03]], 130 [[3.677e+03, 3.902e+03, 4.127e+03, 4.352e+03], 131 [4.118e+03, 4.370e+03, 4.622e+03, 4.874e+03]], 132 [[5.675e+03, 5.954e+03, 6.233e+03, 6.512e+03], 133 [6.224e+03, 6.530e+03, 6.836e+03, 7.142e+03]]]], dtype=np.float32) 134 judge_result_correct(output.asnumpy(), expect) 135 136 137@pytest.mark.level0 138@pytest.mark.platform_x86_cpu 139@pytest.mark.env_onecard 140def test_4d_transpose_ab(): 141 input_x = Tensor(np.arange(2 * 3 * 3 * 2).reshape((2, 3, 3, 2)), mstype.float16) 142 input_y = Tensor(np.arange(2 * 3 * 4 * 3).reshape((2, 3, 4, 3)), mstype.float16) 143 144 context.set_context(mode=context.GRAPH_MODE, device_target="CPU") 145 net = BatchMatMulNet(transpose_a=True, transpose_b=True) 146 output = net(input_x, input_y) 147 expect = np.array([[[[10., 28., 46., 64.], 148 [13., 40., 67., 94.]], 149 [[316., 388., 460., 532.], 150 [355., 436., 517., 598.]], 151 [[1054., 1180., 1306., 1432.], 152 [1129., 1264., 1399., 1534.]]], 153 [[[2224., 2404., 2584., 2764.], 154 [2335., 2524., 2713., 2902.]], 155 [[3826., 4060., 4294., 4528.], 156 [3973., 4216., 4459., 4702.]], 157 [[5860., 6148., 6436., 6724.], 158 [6043., 6340., 6637., 6934.]]]], np.float16) 159 judge_result_correct(output.asnumpy(), expect) 160