• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import pytest
17import numpy as np
18
19import mindspore
20from mindspore import Tensor
21import mindspore.nn as nn
22import mindspore.context as context
23from mindspore.ops import composite as C
24
25
26class NetBatchDot(nn.Cell):
27    def __init__(self, axes):
28        super(NetBatchDot, self).__init__()
29        self.axes = axes
30
31    def construct(self, x, y):
32        return C.batch_dot(x, y, self.axes)
33
34
35# Implementation with numpy in tensorflow
36def _reference_batch_dot(x, y, axes):
37    if isinstance(axes, int):
38        axes = [axes, axes]
39    elif isinstance(axes, tuple):
40        axes = list(axes)
41    if axes is None:
42        if y.ndim == 2:
43            axes = [x.ndim - 1, y.ndim - 1]
44        else:
45            axes = [x.ndim - 1, y.ndim - 2]
46    if axes[0] < 0:
47        axes[0] += x.ndim
48    if axes[1] < 0:
49        axes[1] += y.ndim
50    result = []
51    axes = [axes[0] - 1, axes[1] - 1]
52    for xi, yi in zip(x, y):
53        result.append(np.tensordot(xi, yi, axes))
54    result = np.array(result)
55    if result.ndim == 1:
56        result = np.expand_dims(result, -1)
57    return result
58
59
60@pytest.mark.level0
61@pytest.mark.platform_x86_cpu
62@pytest.mark.env_onecard
63def test_batch_dot_fp32():
64    context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
65    np.random.seed(12876)
66
67    # case 1
68    shape_x1 = (3, 12, 5, 2, 3)
69    shape_x2 = (3, 1, 7, 3, 2)
70    axes = (-1, -2)
71    x1 = np.ones(shape=shape_x1).astype(np.float32)
72    x2 = np.ones(shape=shape_x2).astype(np.float32)
73    x1_tensor = Tensor(x1, dtype=mindspore.float32)
74    x2_tensor = Tensor(x2, dtype=mindspore.float32)
75
76    network = NetBatchDot(axes)
77    ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
78    tf_result = _reference_batch_dot(x1, x2, axes)
79
80    assert np.allclose(ms_result_np, tf_result)
81
82    # case 2
83    shape_x1 = (4, 3, 7, 5)
84    shape_x2 = (4, 1, 7, 1)
85    axes = 2
86    x1 = np.random.random(shape_x1).astype(np.float32)
87    x2 = np.random.random(shape_x2).astype(np.float32)
88    x1_tensor = Tensor(x1, dtype=mindspore.float32)
89    x2_tensor = Tensor(x2, dtype=mindspore.float32)
90
91    network = NetBatchDot(axes)
92    ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
93    tf_result = _reference_batch_dot(x1, x2, axes)
94
95    assert np.allclose(ms_result_np, tf_result)
96
97    # case 3
98    shape_x1 = (18, 3, 5, 7)
99    shape_x2 = (18, 1, 3, 7)
100    axes = -1
101    x1 = np.random.random(shape_x1).astype(np.float32)
102    x2 = np.random.random(shape_x2).astype(np.float32)
103    x1_tensor = Tensor(x1, dtype=mindspore.float32)
104    x2_tensor = Tensor(x2, dtype=mindspore.float32)
105
106    network = NetBatchDot(axes)
107    ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
108    tf_result = _reference_batch_dot(x1, x2, axes)
109
110    assert np.allclose(ms_result_np, tf_result)
111
112    # case 4
113    shape_x1 = (2, 11, 3, 9)
114    shape_x2 = (2, 7, 9, 3)
115    axes = None
116    x1 = np.random.random(shape_x1).astype(np.float32)
117    x2 = np.random.random(shape_x2).astype(np.float32)
118    x1_tensor = Tensor(x1, dtype=mindspore.float32)
119    x2_tensor = Tensor(x2, dtype=mindspore.float32)
120
121    network = NetBatchDot(axes)
122    ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
123    tf_result = _reference_batch_dot(x1, x2, axes)
124
125    assert np.allclose(ms_result_np, tf_result)
126
127    # case 5
128    shape_x1 = (7, 5)
129    shape_x2 = (7, 5)
130    axes = None
131    x1 = np.random.random(shape_x1).astype(np.float32)
132    x2 = np.random.random(shape_x2).astype(np.float32)
133    x1_tensor = Tensor(x1, dtype=mindspore.float32)
134    x2_tensor = Tensor(x2, dtype=mindspore.float32)
135
136    network = NetBatchDot(axes)
137    ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
138    tf_result = _reference_batch_dot(x1, x2, axes)
139
140    assert np.allclose(ms_result_np, tf_result)
141
142    # case 6
143    shape_x1 = (7, 3, 5)
144    shape_x2 = (7, 5)
145    axes = None
146    x1 = np.random.random(shape_x1).astype(np.float32)
147    x2 = np.random.random(shape_x2).astype(np.float32)
148    x1_tensor = Tensor(x1, dtype=mindspore.float32)
149    x2_tensor = Tensor(x2, dtype=mindspore.float32)
150
151    network = NetBatchDot(axes)
152    ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
153    tf_result = _reference_batch_dot(x1, x2, axes)
154
155    assert np.allclose(ms_result_np, tf_result)
156
157    # case 7
158    shape_x1 = (7, 5)
159    shape_x2 = (7, 5, 3)
160    axes = None
161    x1 = np.random.random(shape_x1).astype(np.float32)
162    x2 = np.random.random(shape_x2).astype(np.float32)
163    x1_tensor = Tensor(x1, dtype=mindspore.float32)
164    x2_tensor = Tensor(x2, dtype=mindspore.float32)
165
166    network = NetBatchDot(axes)
167    ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
168    tf_result = _reference_batch_dot(x1, x2, axes)
169
170    assert np.allclose(ms_result_np, tf_result)
171
172    # case 8
173    shape_x1 = (39, 6)
174    shape_x2 = (39, 6)
175    axes = -1
176    x1 = np.random.random(shape_x1).astype(np.float32)
177    x2 = np.random.random(shape_x2).astype(np.float32)
178    x1_tensor = Tensor(x1, dtype=mindspore.float32)
179    x2_tensor = Tensor(x2, dtype=mindspore.float32)
180
181    network = NetBatchDot(axes)
182    ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
183    tf_result = _reference_batch_dot(x1, x2, axes)
184
185    assert np.allclose(ms_result_np, tf_result)
186
187    # case 9
188    shape_x1 = (21, 2, 3)
189    shape_x2 = (21, 3, 2)
190    axes = (-1, -2)
191    x1 = np.ones(shape=shape_x1).astype(np.float32)
192    x2 = np.ones(shape=shape_x2).astype(np.float32)
193    x1_tensor = Tensor(x1, dtype=mindspore.float32)
194    x2_tensor = Tensor(x2, dtype=mindspore.float32)
195
196    network = NetBatchDot(axes)
197    ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
198    tf_result = _reference_batch_dot(x1, x2, axes)
199    assert np.allclose(ms_result_np, tf_result)
200
201    # case 10
202    shape_x1 = (4, 3, 2, 1, 7, 5)
203    shape_x2 = (4, 5, 7, 1)
204    axes = -2
205    x1 = np.ones(shape=shape_x1).astype(np.float32)
206    x2 = np.ones(shape=shape_x2).astype(np.float32)
207    x1_tensor = Tensor(x1, dtype=mindspore.float32)
208    x2_tensor = Tensor(x2, dtype=mindspore.float32)
209
210    network = NetBatchDot(axes)
211    ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
212    tf_result = _reference_batch_dot(x1, x2, axes)
213    assert np.allclose(ms_result_np, tf_result)
214