1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18import mindspore.context as context 19from mindspore import Tensor 20from mindspore.nn import Cell 21import mindspore.ops.operations as P 22 23class Net(Cell): 24 def __init__(self): 25 super(Net, self).__init__() 26 self.matmul = P.MatMul(transpose_a=False, transpose_b=False) 27 28 def construct(self, x, y): 29 return self.matmul(x, y) 30 31class Net1(Cell): 32 def __init__(self): 33 super(Net1, self).__init__() 34 self.bmm = P.BatchMatMul(transpose_a=False, transpose_b=False) 35 36 def construct(self, x, y): 37 return self.bmm(x, y) 38 39def get_output(i0, i1, net_cls, enable_graph_kernel=False): 40 context.set_context(enable_graph_kernel=enable_graph_kernel) 41 net = net_cls() 42 output = net(i0, i1) 43 return output 44 45def test_matmul(): 46 i0 = Tensor(np.random.normal(1, 0.01, [96, 1]).astype(np.float32)) 47 i1 = Tensor(np.random.normal(1, 0.01, [1, 128]).astype(np.float32)) 48 expect = get_output(i0, i1, Net, False) 49 output = get_output(i0, i1, Net, True) 50 expect_np = expect.asnumpy().copy() 51 output_np = output.asnumpy().copy() 52 assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7) 53 54def test_batchmatmul(): 55 i0 = Tensor(np.random.normal(1, 0.01, [16, 96, 1]).astype(np.float32)) 56 i1 = Tensor(np.random.normal(1, 0.01, [16, 1, 128]).astype(np.float32)) 57 expect = get_output(i0, i1, Net1, False) 58 output = get_output(i0, i1, Net1, True) 59 expect_np = expect.asnumpy().copy() 60 output_np = output.asnumpy().copy() 61 assert np.allclose(expect_np, output_np, 6.e-4, 6.e-4) 62 63@pytest.mark.level1 64@pytest.mark.platform_arm_ascend_training 65@pytest.mark.platform_x86_ascend_training 66@pytest.mark.env_onecard 67def test_matmul_ascend(): 68 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 69 test_matmul() 70 71@pytest.mark.level1 72@pytest.mark.platform_arm_ascend_training 73@pytest.mark.platform_x86_ascend_training 74@pytest.mark.env_onecard 75def test_batchmatmul_ascend(): 76 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 77 test_batchmatmul() 78 79@pytest.mark.level1 80@pytest.mark.platform_x86_gpu_training 81@pytest.mark.env_onecard 82def test_matmul_gpu(): 83 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 84 test_matmul() 85 86@pytest.mark.level1 87@pytest.mark.platform_x86_gpu_training 88@pytest.mark.env_onecard 89def test_batchmatmul_gpu(): 90 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 91 test_batchmatmul() 92