1import numpy as np 2 3import mindspore.context as context 4import mindspore.nn as nn 5from mindspore import Tensor 6 7context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 8 9 10class Net(nn.Cell): 11 def __init__(self, transpose_x1, transpose_x2): 12 super(Net, self).__init__() 13 self.matmul = nn.MatMul(transpose_x1, transpose_x2) 14 15 def construct(self, x1, x2): 16 return self.matmul(x1, x2) 17 18 19def test_x1_2D_x2_3D(): 20 x1 = np.random.randn(16, 64).astype(np.float32) 21 x2 = np.random.randn(32, 64, 20).astype(np.float32) 22 transpose_x1 = False 23 transpose_x2 = False 24 net = Net(transpose_x1, transpose_x2) 25 output = net(Tensor(x1), Tensor(x2)) 26 assert output.shape == (32, 16, 20) 27 28 29def test_x1_4D_x2_3D_transpose_x2_True(): 30 x1 = np.random.randn(3, 2, 3, 4).astype(np.float32) 31 x2 = np.random.randn(1, 5, 4).astype(np.float32) 32 transpose_x1 = False 33 transpose_x2 = True 34 net = Net(transpose_x1, transpose_x2) 35 output = net(Tensor(x1), Tensor(x2)) 36 assert output.shape == (3, 2, 3, 5) 37 38 39def test_x1_3D_transpose_x1_True_x2_2D(): 40 x1 = np.random.randn(2, 3, 4).astype(np.float32) 41 x2 = np.random.randn(3, 4).astype(np.float32) 42 transpose_x1 = True 43 transpose_x2 = False 44 net = Net(transpose_x1, transpose_x2) 45 output = net(Tensor(x1), Tensor(x2)) 46 assert output.shape == (2, 4, 4) 47 48 49def test_x1_3D_transpose_x1_True_x2_3D_transpose_x2_True(): 50 x1 = np.random.randn(2, 5, 6).astype(np.float32) 51 x2 = np.random.randn(2, 4, 5).astype(np.float32) 52 transpose_x1 = True 53 transpose_x2 = True 54 net = Net(transpose_x1, transpose_x2) 55 output = net(Tensor(x1), Tensor(x2)) 56 assert output.shape == (2, 6, 4) 57 58def test_x1_1D_x2_1D(): 59 x1 = np.random.randn(4).astype(np.float32) 60 x2 = np.random.randn(4).astype(np.float32) 61 transpose_x1 = False 62 transpose_x2 = False 63 net = Net(transpose_x1, transpose_x2) 64 output = net(Tensor(x1), Tensor(x2)) 65 assert output.shape == () 66 67def test_x1_1D_x2_3D(): 68 x1 = np.random.randn(4).astype(np.float32) 69 x2 = np.random.randn(2, 4, 5).astype(np.float32) 70 transpose_x1 = False 71 transpose_x2 = False 72 net = Net(transpose_x1, transpose_x2) 73 output = net(Tensor(x1), Tensor(x2)) 74 assert output.shape == (2, 5) 75 76 77def test_x1_3D_x2_1D(): 78 x1 = np.random.randn(2, 4, 5).astype(np.float32) 79 x2 = np.random.randn(5).astype(np.float32) 80 transpose_x1 = False 81 transpose_x2 = False 82 net = Net(transpose_x1, transpose_x2) 83 output = net(Tensor(x1), Tensor(x2)) 84 assert output.shape == (2, 4) 85 86 87def test_x1_1D_transpose_x1_True_x2_3D(): 88 x1 = np.random.randn(4).astype(np.float32) 89 x2 = np.random.randn(2, 4, 5).astype(np.float32) 90 transpose_x1 = True 91 transpose_x2 = False 92 net = Net(transpose_x1, transpose_x2) 93 output = net(Tensor(x1), Tensor(x2)) 94 assert output.shape == (2, 5) 95 96 97def test_x1_3D_x2_1D_transpose_x2_True(): 98 x1 = np.random.randn(2, 4, 5).astype(np.float32) 99 x2 = np.random.randn(5).astype(np.float32) 100 transpose_x1 = False 101 transpose_x2 = True 102 net = Net(transpose_x1, transpose_x2) 103 output = net(Tensor(x1), Tensor(x2)) 104 assert output.shape == (2, 4) 105