• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test nn ops """
16import numpy as np
17from numpy.random import normal
18import pytest
19
20import mindspore.nn as nn
21import mindspore.context as context
22from mindspore.ops.composite import core
23from mindspore.common.api import ms_function
24
25from mindspore import Tensor
26from mindspore.ops import functional as F
27from mindspore.ops import prim_attr_register, PrimitiveWithInfer
28
29context.set_context(mode=context.GRAPH_MODE)
30
31
32class FakeOp(PrimitiveWithInfer):
33    @prim_attr_register
34    def __init__(self):
35        """"""
36
37    def infer_shape(self, x, y):
38        self.second_shape = y
39        self.add_prim_attr("second_shape", y)
40        return x
41
42    def infer_dtype(self, x, y):
43        return x
44
45
46# test the normal case that should generate independent primitive because of different
47# generated attributes after inference
48def test_conv2d_same_primitive():
49    class Conv2DSameNet(nn.Cell):
50        def __init__(self):
51            super(Conv2DSameNet, self).__init__()
52            self.conv1 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
53            self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
54
55        def construct(self, x, y):
56            r1 = self.conv1(x)
57            r2 = self.conv2(y)
58            return (r1, r2)
59
60    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
61    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
62    net = Conv2DSameNet()
63    net(t1, t2)
64
65
66# test free variable function list as parameter
67def test_remove_and_fv_2():
68    @core(loop_can_uroll=True)
69    def inner_loop(x, input_data, fv_func_list):
70        ret = ()
71        for fv_fn in fv_func_list:
72            ele = fv_fn(input_data)
73            ret += (ele,)
74        return ret
75
76    @ms_function
77    def out_loop(input1, input_data0, input_data1):
78        ret = ()
79
80        def fv_func1(y):
81            return input1 * y
82        def fv_func2(y):
83            return input1 - y
84        fv_func_list = [fv_func1, fv_func2]
85        ele0 = inner_loop(input1, input_data0, fv_func_list)
86        ele1 = inner_loop(input1, input_data1, fv_func_list)
87        ret = (ele0, ele1)
88        return ret
89
90    input_data0 = Tensor(normal(0, 0.1, (3, 3)))
91    input_data1 = Tensor(normal(0, 0.1, (3, 1)))
92    input1 = Tensor(normal(0, 0.1, (3, 3)))
93    out_loop(input1, input_data0, input_data1)
94
95
96# test cell as high order argument
97# The graph with free variables used as argument is not supported yet
98# because of the limit of inference specialize system
99def test_conv2d_op_with_argi_1():
100    class Conv2dNet(nn.Cell):
101        def __init__(self):
102            super(Conv2dNet, self).__init__()
103
104        def construct(self, op, x):
105            return op(x)
106
107    class OpsNet(nn.Cell):
108        def __init__(self, net):
109            super(OpsNet, self).__init__()
110            self.opnet = net
111            self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
112
113        def construct(self, x, y):
114            conv_op = self.conv2
115            a = self.opnet(conv_op, x)
116            b = self.opnet(conv_op, y)
117            return (a, b)
118
119    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
120    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
121    net = OpsNet(Conv2dNet())
122    net(t1, t2)
123
124
125def test_conv2d_op_with_arg():
126    class FackOpNet(nn.Cell):
127        def __init__(self):
128            super(FackOpNet, self).__init__()
129            self.op = FakeOp()
130
131        def construct(self, x, y):
132            return self.op(x, y)
133
134    class OpNet(nn.Cell):
135        def __init__(self):
136            super(OpNet, self).__init__()
137
138        def construct(self, op, x, y):
139            return op(x, y)
140
141    class OpsNet(nn.Cell):
142        def __init__(self, net):
143            super(OpsNet, self).__init__()
144            self.opnet = net
145            self.op = FackOpNet()
146
147        def construct(self, x, y):
148            op = self.op
149            a = self.opnet(op, x, y)
150            b = self.opnet(op, y, x)
151            return (a, b)
152
153    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
154    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
155    net = OpsNet(OpNet())
156    net(t1, t2)
157
158
159def test_conv2d_op_with_arg_same_input():
160    class FackOpNet(nn.Cell):
161        def __init__(self):
162            super(FackOpNet, self).__init__()
163            self.op = FakeOp()
164
165        def construct(self, x, y):
166            return self.op(x, y)
167
168    class OpNet(nn.Cell):
169        def __init__(self):
170            super(OpNet, self).__init__()
171
172        def construct(self, op, x, y):
173            return op(x, y)
174
175    class OpsNet(nn.Cell):
176        def __init__(self, net):
177            super(OpsNet, self).__init__()
178            self.opnet = net
179            self.op = FackOpNet()
180
181        def construct(self, x, y):
182            op = self.op
183            a = self.opnet(op, x, x)
184            b = self.opnet(op, y, x)
185            return (a, b)
186
187    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
188    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
189    net = OpsNet(OpNet())
190    net(t1, t2)
191
192
193# test op with partial
194def test_op_as_partial():
195    class OpAsPartial(nn.Cell):
196        def __init__(self):
197            super(OpAsPartial, self).__init__()
198            self.op = FakeOp()
199
200        def construct(self, x, y, z):
201            partial_op = F.partial(self.op, x)
202            a = partial_op(y)
203            b = partial_op(z)
204            return a, b
205
206    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
207    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
208    t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
209    net = OpAsPartial()
210    net(t1, t2, t3)
211
212
213# test op with partial
214def test_op_as_partial_inside():
215    class OpAsPartial(nn.Cell):
216        def __init__(self):
217            super(OpAsPartial, self).__init__()
218            self.op = FakeOp()
219
220        def construct(self, x, y, z):
221            partial_op = F.partial(self.op, x)
222            a = partial_op(y)
223            b = partial_op(z)
224            return a, b
225
226    class OuterNet(nn.Cell):
227        def __init__(self):
228            super(OuterNet, self).__init__()
229            self.net = OpAsPartial()
230
231        def construct(self, x, y, z):
232            a, b = self.net(x, y, z)
233            return a, b
234
235    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
236    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
237    t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
238    net = OuterNet()
239    net(t1, t2, t3)
240
241
242# test op with partial case 2
243def test_op_as_partial_independent():
244    class OpAsPartial(nn.Cell):
245        def __init__(self):
246            super(OpAsPartial, self).__init__()
247            self.op = FakeOp()
248
249        def construct(self, x, y, z):
250            partial_op1 = F.partial(self.op, x)
251            a = partial_op1(y)
252            partial_op2 = F.partial(self.op, x)
253            b = partial_op2(z)
254            return a, b
255
256    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
257    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
258    t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
259    net = OpAsPartial()
260    net(t1, t2, t3)
261
262
263def test_nest_partial():
264    class NestPartial(nn.Cell):
265        def __init__(self):
266            super(NestPartial, self).__init__()
267            self.op = FakeOp()
268
269        def construct(self, x, y, z):
270            partial_op1 = F.partial(self.op)
271            partial_op2 = F.partial(partial_op1, x)
272            a = partial_op2(y)
273            partial_op3 = F.partial(self.op)
274            partial_op4 = F.partial(partial_op3, x)
275            b = partial_op4(z)
276            return a, b
277
278    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
279    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
280    t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
281    net = NestPartial()
282    net(t1, t2, t3)
283
284
285# high order argument
286# op and op args as network arguments
287def test_op_with_arg_as_input():
288    class WithOpArgNet(nn.Cell):
289        def __init__(self):
290            super(WithOpArgNet, self).__init__()
291
292        def construct(self, op, x, y):
293            return op(x, y)
294
295    class OpsNet(nn.Cell):
296        def __init__(self, net):
297            super(OpsNet, self).__init__()
298            self.opnet = net
299            self.op = FakeOp()
300
301        def construct(self, x, y, z):
302            op = self.op
303            a = self.opnet(op, x, z)
304            b = self.opnet(op, x, y)
305            return (a, b)
306
307    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
308    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
309    t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
310    net = OpsNet(WithOpArgNet())
311    net(t1, t2, t3)
312
313
314# The partial application used as argument is not supported yet
315# because of the limit of inference specialize system
316@pytest.mark.skip("poly in infer")
317def test_partial_as_arg():
318    class PartialArgNet(nn.Cell):
319        def __init__(self):
320            super(PartialArgNet, self).__init__()
321
322        def construct(self, partial_op, y):
323            return partial_op(y)
324
325    class OpsNet(nn.Cell):
326        def __init__(self, net):
327            super(OpsNet, self).__init__()
328            self.partial_net = net
329            self.op = FakeOp()
330
331        def construct(self, x, y, z):
332            partial_op = F.partial(self.op, x)
333            a = self.partial_net(partial_op, z)
334            b = self.partial_net(partial_op, y)
335            return (a, b)
336
337    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
338    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
339    t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
340    net = OpsNet(PartialArgNet())
341    net(t1, t2, t3)
342