1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18 19import mindspore.context as context 20import mindspore.nn as nn 21from mindspore import Tensor 22from mindspore.common import dtype as mstype 23from mindspore.ops import operations as P 24 25 26class MatMulNet(nn.Cell): 27 def __init__(self, transpose_a=False, transpose_b=False): 28 super(MatMulNet, self).__init__() 29 self.matmul = P.MatMul(transpose_a, transpose_b) 30 31 def construct(self, x, y): 32 return self.matmul(x, y) 33 34def judge_result_correct(result, expect): 35 assert result.dtype == expect.dtype 36 assert result.shape == expect.shape 37 assert np.allclose(result, expect) 38 39@pytest.mark.level0 40@pytest.mark.platform_x86_cpu 41@pytest.mark.env_onecard 42def test_matmul_no_transpose_vec(): 43 input_x = Tensor(np.arange(1 * 3).reshape((1, 3)), mstype.float32) 44 input_y = Tensor(np.arange(3 * 5).reshape((3, 5)), mstype.float32) 45 46 context.set_context(mode=context.GRAPH_MODE, device_target='CPU') 47 net = MatMulNet() 48 output = net(input_x, input_y) 49 expect = np.array([[25., 28., 31., 34., 37.]], dtype=np.float32) 50 judge_result_correct(output.asnumpy(), expect) 51 52 53@pytest.mark.level0 54@pytest.mark.platform_x86_cpu 55@pytest.mark.env_onecard 56def test_matmul_no_transpose(): 57 input_x = Tensor(np.arange(4 * 3).reshape((4, 3)), mstype.float32) 58 input_y = Tensor(np.arange(3 * 5).reshape((3, 5)), mstype.float32) 59 60 context.set_context(mode=context.GRAPH_MODE, device_target='CPU') 61 net = MatMulNet() 62 output = net(input_x, input_y) 63 expect = np.array([[25., 28., 31., 34., 37.], 64 [70., 82., 94., 106., 118.], 65 [115., 136., 157., 178., 199.], 66 [160., 190., 220., 250., 280.]], dtype=np.float32) 67 judge_result_correct(output.asnumpy(), expect) 68 69 70@pytest.mark.level0 71@pytest.mark.platform_x86_cpu 72@pytest.mark.env_onecard 73def test_matmul_transpose_a(): 74 input_x = Tensor(np.arange(3 * 2).reshape((3, 2)), mstype.float32) 75 input_y = Tensor(np.arange(3 * 4).reshape((3, 4)), mstype.float32) 76 77 context.set_context(mode=context.GRAPH_MODE, device_target="CPU") 78 net = MatMulNet(transpose_a=True) 79 output = net(input_x, input_y) 80 expect = np.array([[40., 46., 52., 58.], 81 [52., 61., 70., 79.]], dtype=np.float32) 82 judge_result_correct(output.asnumpy(), expect) 83 84 85@pytest.mark.level0 86@pytest.mark.platform_x86_cpu 87@pytest.mark.env_onecard 88def test_matmul_transpose_b(): 89 input_x = Tensor(np.arange(2 * 3).reshape((2, 3)), mstype.float32) 90 input_y = Tensor(np.arange(5 * 3).reshape((5, 3)), mstype.float32) 91 92 context.set_context(mode=context.GRAPH_MODE, device_target="CPU") 93 net = MatMulNet(transpose_b=True) 94 output = net(input_x, input_y) 95 expect = np.array([[5., 14., 23., 32., 41.], 96 [14., 50., 86., 122., 158.]], dtype=np.float32) 97 judge_result_correct(output.asnumpy(), expect) 98 99 100@pytest.mark.level0 101@pytest.mark.platform_x86_cpu 102@pytest.mark.env_onecard 103def test_matmul_transpose_ab(): 104 input_x = Tensor(np.arange(3 * 5).reshape((3, 5)), mstype.float16) 105 input_y = Tensor(np.arange(4 * 3).reshape((4, 3)), mstype.float16) 106 107 context.set_context(mode=context.GRAPH_MODE, device_target="CPU") 108 net = MatMulNet(transpose_a=True, transpose_b=True) 109 output = net(input_x, input_y) 110 expect = np.array([[25., 70., 115., 160.], 111 [28., 82., 136., 190.], 112 [31., 94., 157., 220.], 113 [34., 106., 178., 250.], 114 [37., 118., 199., 280.]], dtype=np.float16) 115 judge_result_correct(output.asnumpy(), expect) 116