• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import numpy as np
17import pytest
18
19import mindspore.context as context
20import mindspore.nn as nn
21from mindspore import Tensor
22from mindspore.common.initializer import initializer
23from mindspore.common.parameter import Parameter
24from mindspore.common.parameter import ParameterTuple
25from mindspore.ops import operations as P
26from mindspore.ops import composite as C
27
28context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
29
30
31class NetConv2d(nn.Cell):
32    def __init__(self):
33        super(NetConv2d, self).__init__()
34        out_channel = 2
35        kernel_size = 1
36        self.conv = P.Conv2D(out_channel,
37                             kernel_size,
38                             mode=1,
39                             pad_mode="valid",
40                             pad=0,
41                             stride=1,
42                             dilation=1,
43                             group=1)
44        self.w = Parameter(initializer(
45            Tensor(np.arange(2 * 3 * 1 * 1).reshape(2, 3, 1, 1).astype(np.float32)), [2, 3, 1, 1]), name='w')
46        self.x = Parameter(initializer(
47            Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32)), [1, 3, 3, 3]), name='x')
48
49    def construct(self):
50        return self.conv(self.x, self.w)
51
52
53@pytest.mark.level0
54@pytest.mark.platform_x86_cpu
55@pytest.mark.env_onecard
56def test_conv2d():
57    conv2d = NetConv2d()
58    output = conv2d()
59    print("================================")
60    #   expect output:
61    #   [[[[ 45.  48.  51.]
62    #      [ 54.  57.  60.]
63    #      [ 63.  66.  69.]]
64    #     [[126. 138. 150.]
65    #      [162. 174. 186.]
66    #      [198. 210. 222.]]]]
67    expect = np.array([[[[45, 48, 51],
68                         [54, 57, 60],
69                         [63, 66, 69]],
70                        [[126, 138, 150],
71                         [162, 174, 186],
72                         [198, 210, 222]]]]).astype(np.float32)
73    print(output)
74    assert (output.asnumpy() == expect).all()
75
76
77class NetConv(nn.Cell):
78    def __init__(self, weight, x):
79        super(NetConv, self).__init__()
80        self.conv = nn.Conv2d(in_channels=3,
81                              out_channels=3,
82                              kernel_size=(5, 3),
83                              stride=2,
84                              pad_mode='same',
85                              padding=(0, 0, 0, 0),
86                              dilation=(1, 1),
87                              group=1,
88                              has_bias=False,
89                              weight_init=Tensor(weight)
90                              )
91        self.x = Parameter(initializer(Tensor(x), [1, 3, 4, 2]), name="x")
92
93    def construct(self):
94        return self.conv(self.x)
95
96
97def test_conv():
98    weight = np.array([[[[0.38968208, 0.14398979, 0.7962463],
99                         [-2.1836321, -0.63823014, -0.50588065],
100                         [0.6660469, 0.64673275, -0.13160042],
101                         [1.3683757, 1.4005762, -0.37235805],
102                         [-0.22638111, 0.45427424, -0.10293389]],
103                        [[1.4985064, -0.29318333, -0.92694616],
104                         [1.539068, 0.8937254, -1.2598171],
105                         [0.9658142, -0.63945454, -0.23185322],
106                         [1.363089, -0.41694695, -2.2750475],
107                         [-0.4865508, -1.6938025, 0.609849]],
108                        [[1.1844803, 0.99874926, -1.9475793],
109                         [0.4987858, 0.5307887, -0.04226681],
110                         [0.4529779, -1.1960793, 0.9456575],
111                         [3.133675, 0.2309789, -0.29201075],
112                         [-0.59632736, -0.0789804, -0.69486314]]],
113                       [[[-0.5606142, 0.6420862, 0.2478745],
114                         [0.02717604, 1.5483379, -0.9373383],
115                         [-1.1017276, -0.259478, 1.0311872],
116                         [1.8387799, 0.16468556, 0.33392152],
117                         [-1.8781787, 1.0158662, 1.6527579]],
118
119                        [[0.45696944, -0.5652523, -1.5618048],
120                         [-0.30304828, 0.1331878, -0.36955845],
121                         [0.91655576, 0.66612357, 0.3068175],
122                         [-0.45732066, 0.8923335, 1.0542952],
123                         [-0.73519516, 1.0518405, -1.0273266]],
124
125                        [[-0.79712886, -0.26814285, 0.12779616],
126                         [1.0367643, -1.6180774, 0.42999932],
127                         [-0.81818223, -0.81502074, 0.882194],
128                         [0.53640485, 0.4178927, 1.6037121],
129                         [0.9256354, -1.1006796, 0.16614541]]],
130
131                       [[[-1.5216796, -1.2473261, 0.6549515],
132                         [0.63627815, 0.7221449, 0.02977821],
133                         [-0.61331123, -0.49451825, 0.33852202],
134                         [1.4510741, -1.3818305, -0.791747],
135                         [0.6989747, 0.49558765, 1.0813237]],
136
137                        [[-0.03969796, 0.71586496, 0.8326594],
138                         [-0.15443641, 1.0389746, -0.59301984],
139                         [0.7197836, 0.03257621, 1.8398637],
140                         [0.6111736, -0.16166899, -2.4869773],
141                         [1.3066711, -1.8003578, 0.17412892]],
142
143                        [[-0.31470737, -0.5938182, -1.1311078],
144                         [-0.99081016, 0.4005125, 0.44154453],
145                         [1.0876914, -2.5958562, -0.5914863],
146                         [1.3759689, -0.7741513, 0.19928917],
147                         [1.6792973, 2.2744863, -0.04308867]]]]).astype(np.float32)
148    x = np.array([[[[-1.4311737, 1.015344],
149                    [0.04431088, -2.2886624],
150                    [1.4832113, 1.240908],
151                    [0.67040104, 0.15266363]],
152
153                   [[0.44226435, 1.1461105],
154                    [1.194218, 1.5547837],
155                    [0.23152256, 1.5911953],
156                    [0.11206784, 0.17978816]],
157
158                   [[-0.57803905, 0.8039611],
159                    [0.0823025, -0.6134477],
160                    [-1.4171146, 1.6269946],
161                    [0.48878875, 0.9117505]]]]).astype(np.float32)
162    conv2d = NetConv(weight, x)
163    output = conv2d()
164    expected = np.array([[[[2.3498724],
165                           [-1.9199573]],
166                          [[5.376562],
167                           [-5.425745]],
168                          [[5.9105043],
169                           [7.469034]]]]).astype(np.float32)
170    loss = np.abs(expected - output.asnumpy())
171    error = 1e-4 * np.ones(loss.shape)
172    assert (loss < error).all()
173
174
175class NetConv3d(nn.Cell):
176    def __init__(self, mode, pad_mode, pad):
177        super(NetConv3d, self).__init__()
178        out_channel = 4
179        kernel_size = 2
180        self.conv = P.Conv3D(out_channel,
181                             kernel_size,
182                             mode=mode,
183                             pad_mode=pad_mode,
184                             pad=pad,
185                             stride=1,
186                             dilation=1,
187                             group=1)
188
189    def construct(self, x, w):
190        return self.conv(x, w)
191
192
193@pytest.mark.level0
194@pytest.mark.platform_x86_cpu
195@pytest.mark.env_onecard
196def test_conv3d():
197    x = Tensor(np.arange(1 * 3 * 3 * 3 * 3).reshape(1, 3, 3, 3, 3).astype(np.float32))
198    w = Tensor(np.arange(4 * 3 * 2 * 2 * 2).reshape(4, 3, 2, 2, 2).astype(np.float32))
199    expect = np.array([[[[[12960., 13236.],
200                          [13788., 14064.]],
201                         [[15444., 15720.],
202                          [16272., 16548.]]],
203                        [[[32256., 33108.],
204                          [34812., 35664.]],
205                         [[39924., 40776.],
206                          [42480., 43332.]]],
207                        [[[51552., 52980.],
208                          [55836., 57264.]],
209                         [[64404., 65832.],
210                          [68688., 70116.]]],
211                        [[[70848., 72852.],
212                          [76860., 78864.]],
213                         [[88884., 90888.],
214                          [94896., 96900.]]]]]).astype(np.float32)
215    mode = 1
216    pad_mode = "valid"
217    pad = 0
218    context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
219    net = NetConv3d(mode, pad_mode, pad)
220    output = net(x, w)
221    assert (output.asnumpy() == expect).all()
222    context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
223    net = NetConv3d(mode, pad_mode, pad)
224    output = net(x, w)
225    assert (output.asnumpy() == expect).all()
226
227
228@pytest.mark.level0
229@pytest.mark.platform_x86_cpu
230@pytest.mark.env_onecard
231def test_conv3d_2():
232    context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
233    x = Tensor(np.arange(1 * 3 * 3 * 3 * 3).reshape(1, 3, 3, 3, 3).astype(np.float32))
234    w = Tensor(np.arange(4 * 3 * 2 * 2 * 2).reshape(4, 3, 2, 2, 2).astype(np.float32))
235    expect = np.array([[[[[1647, 3258, 3345, 1650],
236                          [3267, 6447, 6609, 3252],
237                          [3519, 6933, 7095, 3486],
238                          [1719, 3378, 3453, 1692]],
239                         [[3375, 6639, 6789, 3330],
240                          [6606, 12960, 13236, 6474],
241                          [7038, 13788, 14064, 6870],
242                          [3393, 6627, 6753, 3288]],
243                         [[4077, 7989, 8139, 3978],
244                          [7902, 15444, 15720, 7662],
245                          [8334, 16272, 16548, 8058],
246                          [3987, 7761, 7887, 3828]],
247                         [[1917, 3732, 3795, 1842],
248                          [3663, 7107, 7221, 3492],
249                          [3843, 7449, 7563, 3654],
250                          [1809, 3492, 3543, 1704]]],
251                        [[[3591, 7218, 7449, 3738],
252                          [7371, 14799, 15249, 7644],
253                          [8055, 16149, 16599, 8310],
254                          [4095, 8202, 8421, 4212]],
255                         [[7911, 15855, 16293, 8154],
256                          [16110, 32256, 33108, 16554],
257                          [17406, 34812, 35664, 17814],
258                          [8793, 17571, 17985, 8976]],
259                         [[9909, 19797, 20235, 10098],
260                          [19998, 39924, 40776, 20334],
261                          [21294, 42480, 43332, 21594],
262                          [10683, 21297, 21711, 10812]],
263                         [[5157, 10284, 10491, 5226],
264                          [10359, 20643, 21045, 10476],
265                          [10971, 21849, 22251, 11070],
266                          [5481, 10908, 11103, 5520]]],
267                        [[[5535, 11178, 11553, 5826],
268                          [11475, 23151, 23889, 12036],
269                          [12591, 25365, 26103, 13134],
270                          [6471, 13026, 13389, 6732]],
271                         [[12447, 25071, 25797, 12978],
272                          [25614, 51552, 52980, 26634],
273                          [27774, 55836, 57264, 28758],
274                          [14193, 28515, 29217, 14664]],
275                         [[15741, 31605, 32331, 16218],
276                          [32094, 64404, 65832, 33006],
277                          [34254, 68688, 70116, 35130],
278                          [17379, 34833, 35535, 17796]],
279                         [[8397, 16836, 17187, 8610],
280                          [17055, 34179, 34869, 17460],
281                          [18099, 36249, 36939, 18486],
282                          [9153, 18324, 18663, 9336]]],
283                        [[[7479, 15138, 15657, 7914],
284                          [15579, 31503, 32529, 16428],
285                          [17127, 34581, 35607, 17958],
286                          [8847, 17850, 18357, 9252]],
287                         [[16983, 34287, 35301, 17802],
288                          [35118, 70848, 72852, 36714],
289                          [38142, 76860, 78864, 39702],
290                          [19593, 39459, 40449, 20352]],
291                         [[21573, 43413, 44427, 22338],
292                          [44190, 88884, 90888, 45678],
293                          [47214, 94896, 96900, 48666],
294                          [24075, 48369, 49359, 24780]],
295                         [[11637, 23388, 23883, 11994],
296                          [23751, 47715, 48693, 24444],
297                          [25227, 50649, 51627, 25902],
298                          [12825, 25740, 26223, 13152]]]]]).astype(np.float32)
299    mode = 1
300    pad_mode = "pad"
301    pad = (1, 1, 1, 1, 1, 1)
302    net = NetConv3d(mode, pad_mode, pad)
303    output = net(x, w)
304    assert (output.asnumpy() == expect).all()
305
306
307class MSConv3dNet(nn.Cell):
308    def __init__(self, in_channels, out_channels, kernel_size, pad_mode='pad', padding=0, stride=1, dilation=1,
309                 has_bias=False, weight_init='normal'):
310        super(MSConv3dNet, self).__init__()
311        self.cv1 = nn.Conv3d(in_channels=in_channels,
312                             out_channels=out_channels,
313                             kernel_size=kernel_size,
314                             pad_mode=pad_mode,
315                             padding=padding,
316                             stride=stride,
317                             dilation=dilation,
318                             group=1,
319                             has_bias=has_bias,
320                             weight_init=weight_init,
321                             data_format='NCDHW')
322
323    def construct(self, x):
324        x = self.cv1(x)
325        return x
326
327
328class MSGradNet(nn.Cell):
329    def __init__(self, network):
330        super(MSGradNet, self).__init__()
331        self.grad = C.GradOperation(get_all=True, sens_param=True, get_by_list=True)
332        self.network = network
333        self.params = ParameterTuple(network.trainable_params())
334
335    def construct(self, x, dy):
336        grad_op = self.grad(self.network, self.params)
337        output = grad_op(x, dy)
338        return output
339