1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import pytest 17import numpy as np 18 19import mindspore 20from mindspore import Tensor 21import mindspore.nn as nn 22import mindspore.context as context 23from mindspore.ops import composite as C 24 25 26class NetBatchDot(nn.Cell): 27 def __init__(self, axes): 28 super(NetBatchDot, self).__init__() 29 self.axes = axes 30 31 def construct(self, x, y): 32 return C.batch_dot(x, y, self.axes) 33 34 35# Implementation with numpy in tensorflow 36def _reference_batch_dot(x, y, axes): 37 if isinstance(axes, int): 38 axes = [axes, axes] 39 elif isinstance(axes, tuple): 40 axes = list(axes) 41 if axes is None: 42 if y.ndim == 2: 43 axes = [x.ndim - 1, y.ndim - 1] 44 else: 45 axes = [x.ndim - 1, y.ndim - 2] 46 if axes[0] < 0: 47 axes[0] += x.ndim 48 if axes[1] < 0: 49 axes[1] += y.ndim 50 result = [] 51 axes = [axes[0] - 1, axes[1] - 1] 52 for xi, yi in zip(x, y): 53 result.append(np.tensordot(xi, yi, axes)) 54 result = np.array(result) 55 if result.ndim == 1: 56 result = np.expand_dims(result, -1) 57 return result 58 59 60@pytest.mark.level0 61@pytest.mark.platform_x86_cpu 62@pytest.mark.env_onecard 63def test_batch_dot_fp32(): 64 context.set_context(mode=context.GRAPH_MODE, device_target="CPU") 65 np.random.seed(12876) 66 67 # case 1 68 shape_x1 = (3, 12, 5, 2, 3) 69 shape_x2 = (3, 1, 7, 3, 2) 70 axes = (-1, -2) 71 x1 = np.ones(shape=shape_x1).astype(np.float32) 72 x2 = np.ones(shape=shape_x2).astype(np.float32) 73 x1_tensor = Tensor(x1, dtype=mindspore.float32) 74 x2_tensor = Tensor(x2, dtype=mindspore.float32) 75 76 network = NetBatchDot(axes) 77 ms_result_np = network(x1_tensor, x2_tensor).asnumpy() 78 tf_result = _reference_batch_dot(x1, x2, axes) 79 80 assert np.allclose(ms_result_np, tf_result) 81 82 # case 2 83 shape_x1 = (4, 3, 7, 5) 84 shape_x2 = (4, 1, 7, 1) 85 axes = 2 86 x1 = np.random.random(shape_x1).astype(np.float32) 87 x2 = np.random.random(shape_x2).astype(np.float32) 88 x1_tensor = Tensor(x1, dtype=mindspore.float32) 89 x2_tensor = Tensor(x2, dtype=mindspore.float32) 90 91 network = NetBatchDot(axes) 92 ms_result_np = network(x1_tensor, x2_tensor).asnumpy() 93 tf_result = _reference_batch_dot(x1, x2, axes) 94 95 assert np.allclose(ms_result_np, tf_result) 96 97 # case 3 98 shape_x1 = (18, 3, 5, 7) 99 shape_x2 = (18, 1, 3, 7) 100 axes = -1 101 x1 = np.random.random(shape_x1).astype(np.float32) 102 x2 = np.random.random(shape_x2).astype(np.float32) 103 x1_tensor = Tensor(x1, dtype=mindspore.float32) 104 x2_tensor = Tensor(x2, dtype=mindspore.float32) 105 106 network = NetBatchDot(axes) 107 ms_result_np = network(x1_tensor, x2_tensor).asnumpy() 108 tf_result = _reference_batch_dot(x1, x2, axes) 109 110 assert np.allclose(ms_result_np, tf_result) 111 112 # case 4 113 shape_x1 = (2, 11, 3, 9) 114 shape_x2 = (2, 7, 9, 3) 115 axes = None 116 x1 = np.random.random(shape_x1).astype(np.float32) 117 x2 = np.random.random(shape_x2).astype(np.float32) 118 x1_tensor = Tensor(x1, dtype=mindspore.float32) 119 x2_tensor = Tensor(x2, dtype=mindspore.float32) 120 121 network = NetBatchDot(axes) 122 ms_result_np = network(x1_tensor, x2_tensor).asnumpy() 123 tf_result = _reference_batch_dot(x1, x2, axes) 124 125 assert np.allclose(ms_result_np, tf_result) 126 127 # case 5 128 shape_x1 = (7, 5) 129 shape_x2 = (7, 5) 130 axes = None 131 x1 = np.random.random(shape_x1).astype(np.float32) 132 x2 = np.random.random(shape_x2).astype(np.float32) 133 x1_tensor = Tensor(x1, dtype=mindspore.float32) 134 x2_tensor = Tensor(x2, dtype=mindspore.float32) 135 136 network = NetBatchDot(axes) 137 ms_result_np = network(x1_tensor, x2_tensor).asnumpy() 138 tf_result = _reference_batch_dot(x1, x2, axes) 139 140 assert np.allclose(ms_result_np, tf_result) 141 142 # case 6 143 shape_x1 = (7, 3, 5) 144 shape_x2 = (7, 5) 145 axes = None 146 x1 = np.random.random(shape_x1).astype(np.float32) 147 x2 = np.random.random(shape_x2).astype(np.float32) 148 x1_tensor = Tensor(x1, dtype=mindspore.float32) 149 x2_tensor = Tensor(x2, dtype=mindspore.float32) 150 151 network = NetBatchDot(axes) 152 ms_result_np = network(x1_tensor, x2_tensor).asnumpy() 153 tf_result = _reference_batch_dot(x1, x2, axes) 154 155 assert np.allclose(ms_result_np, tf_result) 156 157 # case 7 158 shape_x1 = (7, 5) 159 shape_x2 = (7, 5, 3) 160 axes = None 161 x1 = np.random.random(shape_x1).astype(np.float32) 162 x2 = np.random.random(shape_x2).astype(np.float32) 163 x1_tensor = Tensor(x1, dtype=mindspore.float32) 164 x2_tensor = Tensor(x2, dtype=mindspore.float32) 165 166 network = NetBatchDot(axes) 167 ms_result_np = network(x1_tensor, x2_tensor).asnumpy() 168 tf_result = _reference_batch_dot(x1, x2, axes) 169 170 assert np.allclose(ms_result_np, tf_result) 171 172 # case 8 173 shape_x1 = (39, 6) 174 shape_x2 = (39, 6) 175 axes = -1 176 x1 = np.random.random(shape_x1).astype(np.float32) 177 x2 = np.random.random(shape_x2).astype(np.float32) 178 x1_tensor = Tensor(x1, dtype=mindspore.float32) 179 x2_tensor = Tensor(x2, dtype=mindspore.float32) 180 181 network = NetBatchDot(axes) 182 ms_result_np = network(x1_tensor, x2_tensor).asnumpy() 183 tf_result = _reference_batch_dot(x1, x2, axes) 184 185 assert np.allclose(ms_result_np, tf_result) 186 187 # case 9 188 shape_x1 = (21, 2, 3) 189 shape_x2 = (21, 3, 2) 190 axes = (-1, -2) 191 x1 = np.ones(shape=shape_x1).astype(np.float32) 192 x2 = np.ones(shape=shape_x2).astype(np.float32) 193 x1_tensor = Tensor(x1, dtype=mindspore.float32) 194 x2_tensor = Tensor(x2, dtype=mindspore.float32) 195 196 network = NetBatchDot(axes) 197 ms_result_np = network(x1_tensor, x2_tensor).asnumpy() 198 tf_result = _reference_batch_dot(x1, x2, axes) 199 assert np.allclose(ms_result_np, tf_result) 200 201 # case 10 202 shape_x1 = (4, 3, 2, 1, 7, 5) 203 shape_x2 = (4, 5, 7, 1) 204 axes = -2 205 x1 = np.ones(shape=shape_x1).astype(np.float32) 206 x2 = np.ones(shape=shape_x2).astype(np.float32) 207 x1_tensor = Tensor(x1, dtype=mindspore.float32) 208 x2_tensor = Tensor(x2, dtype=mindspore.float32) 209 210 network = NetBatchDot(axes) 211 ms_result_np = network(x1_tensor, x2_tensor).asnumpy() 212 tf_result = _reference_batch_dot(x1, x2, axes) 213 assert np.allclose(ms_result_np, tf_result) 214