• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16import pytest
17
18import mindspore.context as context
19from mindspore import Tensor
20from mindspore.nn import Cell
21from mindspore.ops import composite as C
22from mindspore.ops.operations import Minimum
23
24context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
25grad = C.GradOperation(get_all=True, sens_param=True)
26
27
28class MinNetMe(Cell):
29    def __init__(self):
30        super(MinNetMe, self).__init__()
31        self.min = Minimum()
32
33    def construct(self, inputA, inputB):
34        x = self.min(inputA, inputB)
35        return x
36
37
38class GradWrap(Cell):
39    def __init__(self, network):
40        super(GradWrap, self).__init__()
41        self.network = network
42
43    def construct(self, inputA, inputB, sens):
44        gout = grad(self.network)(inputA, inputB, sens)
45        return gout
46
47
48def gen_data(inputA_np, inputB_np, grad_=None):
49    inputA_me = inputA_np
50    if isinstance(inputA_np, np.ndarray):
51        inputA_me = Tensor(inputA_me)
52
53    inputB_me = inputB_np
54    if isinstance(inputB_np, np.ndarray):
55        inputB_me = Tensor(inputB_np)
56
57    if grad_ is None:
58        grad_ = Tensor(grad_)
59
60    net_me = GradWrap(MinNetMe())
61    net_me.set_train()
62    output = net_me(inputA_me, inputB_me, Tensor(grad_))
63    return output
64
65
66@pytest.mark.level1
67@pytest.mark.platform_x86_cpu
68@pytest.mark.env_onecard
69def test_min_tensor_grad_4d():
70    inputA_np = np.random.randn(1, 3, 2, 2).astype(np.float32)
71    inputB_np = np.random.randn(1, 3, 2, 2).astype(np.float32)
72    grad_ = np.random.randn(1, 3, 2, 2).astype(np.float32)
73    output = gen_data(inputA_np, inputB_np, grad_)
74    print(output[0].asnumpy())
75    print(output[1].asnumpy())
76
77
78@pytest.mark.level0
79@pytest.mark.platform_x86_cpu
80@pytest.mark.env_onecard
81def test_min_tensor_grad_result():
82    inputA = np.array([[[[0.659578], [0.49113268], [0.75909054], [0.71681815], [0.30421826]]],
83                       [[[0.30322495], [0.02858258], [0.06398096], [0.09519596], [0.12498625]]],
84                       [[[0.7347768], [0.166469], [0.328553], [0.54908437], [0.23673844]]]]).astype(np.float32)
85    inputB = np.array([[[[0.9154968, 0.29014662, 0.6492294, 0.39918253, 0.1648203, 0.00861965]],
86                        [[0.996885, 0.24152198, 0.3601213, 0.51664376, 0.7933056, 0.84706444]],
87                        [[0.75606346, 0.974512, 0.3939527, 0.69697475, 0.83400667, 0.6348955]],
88                        [[0.68492866, 0.24609096, 0.4924665, 0.22500521, 0.38474053, 0.5586104]]]]).astype(np.float32)
89    grad_ = np.array([[[[0.42891738, 0.03434946, 0.06192983, 0.21216309, 0.37450036, 0.6619524],
90                        [0.8583447, 0.5765161, 0.1468952, 0.9975385, 0.6908136, 0.4903796],
91                        [0.68952006, 0.39336833, 0.9049695, 0.66886294, 0.2338471, 0.913618],
92                        [0.0428149, 0.6243054, 0.8519898, 0.12088962, 0.9735885, 0.45661286],
93                        [0.41563734, 0.41607043, 0.4754915, 0.32207987, 0.33823156, 0.47422352]],
94
95                       [[0.64478457, 0.22430937, 0.7682554, 0.46082005, 0.8938723, 0.20490853],
96                        [0.44393885, 0.08278944, 0.4734108, 0.5543551, 0.39428464, 0.44424313],
97                        [0.12612297, 0.76566416, 0.71133816, 0.81280327, 0.20583127, 0.54058075],
98                        [0.41341263, 0.48118508, 0.00401995, 0.37259838, 0.05435474, 0.5240658],
99                        [0.4081956, 0.48718935, 0.9132831, 0.67969185, 0.0119757, 0.8328054]],
100
101                       [[0.91695577, 0.95370644, 0.263782, 0.7477626, 0.6448147, 0.8080634],
102                        [0.15576603, 0.9104615, 0.3778708, 0.6912833, 0.2092224, 0.67462957],
103                        [0.7087075, 0.7888326, 0.4672294, 0.98221505, 0.25210258, 0.98920417],
104                        [0.7466197, 0.22702982, 0.01991269, 0.6846591, 0.7515228, 0.5890395],
105                        [0.04531088, 0.21740614, 0.8406235, 0.36480767, 0.37733936, 0.02914464]],
106
107                       [[0.33069974, 0.5497569, 0.9896345, 0.4167176, 0.78057563, 0.04659131],
108                        [0.7747768, 0.21427679, 0.29893255, 0.7706969, 0.9755185, 0.42388415],
109                        [0.3910244, 0.39381978, 0.37065396, 0.15558061, 0.05012341, 0.15870963],
110                        [0.17791101, 0.47219893, 0.13899496, 0.32323205, 0.3628809, 0.02580585],
111                        [0.30274773, 0.62890774, 0.11024303, 0.6980051, 0.35346958, 0.062852]]],
112
113                      [[[0.6925081, 0.74668753, 0.80145043, 0.06598313, 0.665123, 0.15073007],
114                        [0.11784806, 0.6385372, 0.5228278, 0.5349848, 0.84671104, 0.8096436],
115                        [0.09516156, 0.63298017, 0.52382874, 0.36734378, 0.66497755, 0.6019127],
116                        [0.46438488, 0.0194377, 0.9388292, 0.7286089, 0.29178405, 0.11872514],
117                        [0.22101837, 0.6164887, 0.6139798, 0.11711904, 0.6227745, 0.09701069]],
118
119                       [[0.80480653, 0.90034056, 0.8633447, 0.97415197, 0.08309154, 0.8446033],
120                        [0.9473769, 0.791024, 0.26339203, 0.01155075, 0.2673186, 0.7116369],
121                        [0.9687511, 0.24281934, 0.37777108, 0.09802654, 0.2421312, 0.87095344],
122                        [0.6311381, 0.23368953, 0.0998995, 0.4364419, 0.9187446, 0.5043872],
123                        [0.35226053, 0.09357589, 0.41317305, 0.85930043, 0.16249318, 0.5478765]],
124
125                       [[0.14338651, 0.24859418, 0.4246941, 0.73034066, 0.47172204, 0.8717199],
126                        [0.05415315, 0.78556925, 0.99214983, 0.7415298, 0.673708, 0.87817156],
127                        [0.616975, 0.42843062, 0.05179814, 0.1566958, 0.04536059, 0.70166487],
128                        [0.15493333, 0.776598, 0.4361967, 0.40253627, 0.89210516, 0.8144414],
129                        [0.04816005, 0.29696834, 0.4586605, 0.3419852, 0.5595613, 0.74093205]],
130
131                       [[0.1388035, 0.9168704, 0.64287645, 0.83864623, 0.48026922, 0.78323376],
132                        [0.12724937, 0.83034366, 0.42557436, 0.50578654, 0.25630295, 0.15349793],
133                        [0.27256685, 0.04547984, 0.5385756, 0.39270344, 0.7661698, 0.23722854],
134                        [0.24620503, 0.25431684, 0.71564585, 0.01161419, 0.846467, 0.7043044],
135                        [0.63272387, 0.11857849, 0.3772076, 0.16758402, 0.46743023, 0.05919575]]],
136
137                      [[[0.18827082, 0.8912264, 0.6841404, 0.74436826, 0.9582085, 0.1083683],
138                        [0.60695344, 0.09742349, 0.25074378, 0.87940735, 0.21116392, 0.39418384],
139                        [0.744686, 0.35679692, 0.01308284, 0.45166633, 0.68166, 0.8634658],
140                        [0.7331758, 0.21113694, 0.3935488, 0.87934476, 0.70728546, 0.09309767],
141                        [0.12128611, 0.93696386, 0.81177396, 0.85402405, 0.5827289, 0.9776509]],
142
143                       [[0.54069614, 0.66651285, 0.10646132, 0.17342485, 0.88795924, 0.03551182],
144                        [0.25531697, 0.87946486, 0.74267226, 0.89230734, 0.95171434, 0.94697934],
145                        [0.3708397, 0.507355, 0.97099817, 0.4918163, 0.17212386, 0.5008048],
146                        [0.62530744, 0.25210327, 0.73966664, 0.71555346, 0.82484317, 0.6094874],
147                        [0.4589691, 0.1386695, 0.27448782, 0.20373994, 0.27805242, 0.23292768]],
148
149                       [[0.7414099, 0.2270226, 0.90431255, 0.47035843, 0.9581062, 0.5359226],
150                        [0.79603523, 0.45549425, 0.80858237, 0.7705133, 0.017761, 0.98001194],
151                        [0.06013146, 0.99240226, 0.33515573, 0.04110833, 0.41470334, 0.7130743],
152                        [0.5687417, 0.5788611, 0.00722461, 0.6603336, 0.3420471, 0.75181854],
153                        [0.4699261, 0.51390815, 0.343182, 0.81498754, 0.8942413, 0.46532857]],
154
155                       [[0.4589523, 0.5534698, 0.2825786, 0.8205943, 0.78258514, 0.43154418],
156                        [0.27020997, 0.01667354, 0.60871965, 0.90670526, 0.3208025, 0.96995634],
157                        [0.85337156, 0.9711295, 0.1381724, 0.53670496, 0.7347996, 0.73380876],
158                        [0.6137464, 0.54751194, 0.9037335, 0.23134394, 0.61411524, 0.26583543],
159                        [0.70770144, 0.01813207, 0.24718016, 0.70329237, 0.7062925, 0.14399007]]]]).astype(np.float32)
160    output = gen_data(inputA, inputB, grad_)
161    expect0 = np.array([[[[5.7664223], [6.9810176], [2.6029902], [2.7598205], [6.763105]]],
162                        [[[10.065580], [12.077245], [9.3383940], [11.522709], [8.889048]]],
163                        [[[3.5789766], [13.424448], [8.7327460], [6.9677467], [9.635764]]]], np.float32)
164    expect1 = np.array([[[[0., 4.2504573, 2.5030296, 3.623167, 6.417151, 7.2115746]],
165                         [[0., 4.3674493, 2.8031523, 2.5352, 0., 0.]],
166                         [[0.7087075, 0., 2.040332, 2.1372325, 0., 2.9222295]],
167                         [[1.0278877, 5.247942, 2.6855955, 5.494814, 3.565799, 0.66265094]]]], np.float32)
168    error0 = np.ones(shape=expect0.shape) * 1.0e-5
169    error1 = np.ones(shape=expect1.shape) * 1.0e-5
170    assert np.all(np.abs(output[0].asnumpy() - expect0) < error0)
171    assert np.all(np.abs(output[1].asnumpy() - expect1) < error1)
172