• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test nn ops """
16import numpy as np
17from numpy.random import normal
18import pytest
19
20import mindspore.nn as nn
21import mindspore.context as context
22from mindspore.ops.composite.base import core
23from mindspore.common.api import jit
24
25from mindspore import Tensor
26from mindspore.ops import functional as F
27from mindspore.ops import prim_attr_register, PrimitiveWithInfer
28
29context.set_context(mode=context.GRAPH_MODE)
30
31
32class FakeOp(PrimitiveWithInfer):
33    @prim_attr_register
34    def __init__(self):
35        """"""
36
37    def infer_shape(self, x, y):
38        self.second_shape = y
39        self.add_prim_attr("second_shape", y)
40        return x
41
42    def infer_dtype(self, x, y):
43        return x
44
45
46# test the normal case that should generate independent primitive because of different
47# generated attributes after inference
48def test_conv2d_same_primitive():
49    class Conv2DSameNet(nn.Cell):
50        def __init__(self):
51            super(Conv2DSameNet, self).__init__()
52            self.conv1 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
53            self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
54
55        def construct(self, x, y):
56            r1 = self.conv1(x)
57            r2 = self.conv2(y)
58            return (r1, r2)
59
60    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
61    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
62    net = Conv2DSameNet()
63    net(t1, t2)
64
65
66# test free variable function list as parameter
67def test_remove_and_fv_2():
68    @core(loop_can_uroll=True)
69    def inner_loop(x, input_data, fv_func_list):
70        ret = ()
71        for fv_fn in fv_func_list:
72            ele = fv_fn(input_data)
73            ret += (ele,)
74        return ret
75
76    @jit
77    def out_loop(input1, input_data0, input_data1):
78        ret = ()
79
80        def fv_func1(y):
81            return input1 * y
82        def fv_func2(y):
83            return input1 - y
84        fv_func_list = [fv_func1, fv_func2]
85        ele0 = inner_loop(input1, input_data0, fv_func_list)
86        ele1 = inner_loop(input1, input_data1, fv_func_list)
87        ret = (ele0, ele1)
88        return ret
89
90    input_data0 = Tensor(normal(0, 0.1, (3, 3)))
91    input_data1 = Tensor(normal(0, 0.1, (3, 1)))
92    input1 = Tensor(normal(0, 0.1, (3, 3)))
93    out_loop(input1, input_data0, input_data1)
94
95
96# test cell as high order argument
97# The graph with free variables used as argument is not supported yet
98# because of the limit of inference specialize system
99def test_conv2d_op_with_argi_1():
100    class Conv2dNet(nn.Cell):
101        def construct(self, op, x):
102            return op(x)
103
104    class OpsNet(nn.Cell):
105        def __init__(self, net):
106            super(OpsNet, self).__init__()
107            self.opnet = net
108            self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
109
110        def construct(self, x, y):
111            conv_op = self.conv2
112            a = self.opnet(conv_op, x)
113            b = self.opnet(conv_op, y)
114            return (a, b)
115
116    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
117    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
118    net = OpsNet(Conv2dNet())
119    net(t1, t2)
120
121
122def test_conv2d_op_with_arg():
123    class FackOpNet(nn.Cell):
124        def __init__(self):
125            super(FackOpNet, self).__init__()
126            self.op = FakeOp()
127
128        def construct(self, x, y):
129            return self.op(x, y)
130
131    class OpNet(nn.Cell):
132        def construct(self, op, x, y):
133            return op(x, y)
134
135    class OpsNet(nn.Cell):
136        def __init__(self, net):
137            super(OpsNet, self).__init__()
138            self.opnet = net
139            self.op = FackOpNet()
140
141        def construct(self, x, y):
142            op = self.op
143            a = self.opnet(op, x, y)
144            b = self.opnet(op, y, x)
145            return (a, b)
146
147    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
148    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
149    net = OpsNet(OpNet())
150    net(t1, t2)
151
152
153def test_conv2d_op_with_arg_same_input():
154    class FackOpNet(nn.Cell):
155        def __init__(self):
156            super(FackOpNet, self).__init__()
157            self.op = FakeOp()
158
159        def construct(self, x, y):
160            return self.op(x, y)
161
162    class OpNet(nn.Cell):
163        def construct(self, op, x, y):
164            return op(x, y)
165
166    class OpsNet(nn.Cell):
167        def __init__(self, net):
168            super(OpsNet, self).__init__()
169            self.opnet = net
170            self.op = FackOpNet()
171
172        def construct(self, x, y):
173            op = self.op
174            a = self.opnet(op, x, x)
175            b = self.opnet(op, y, x)
176            return (a, b)
177
178    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
179    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
180    net = OpsNet(OpNet())
181    net(t1, t2)
182
183
184# test op with partial
185def test_op_as_partial():
186    class OpAsPartial(nn.Cell):
187        def __init__(self):
188            super(OpAsPartial, self).__init__()
189            self.op = FakeOp()
190
191        def construct(self, x, y, z):
192            partial_op = F.partial(self.op, x)
193            a = partial_op(y)
194            b = partial_op(z)
195            return a, b
196
197    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
198    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
199    t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
200    net = OpAsPartial()
201    net(t1, t2, t3)
202
203
204# test op with partial
205def test_op_as_partial_inside():
206    class OpAsPartial(nn.Cell):
207        def __init__(self):
208            super(OpAsPartial, self).__init__()
209            self.op = FakeOp()
210
211        def construct(self, x, y, z):
212            partial_op = F.partial(self.op, x)
213            a = partial_op(y)
214            b = partial_op(z)
215            return a, b
216
217    class OuterNet(nn.Cell):
218        def __init__(self):
219            super(OuterNet, self).__init__()
220            self.net = OpAsPartial()
221
222        def construct(self, x, y, z):
223            a, b = self.net(x, y, z)
224            return a, b
225
226    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
227    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
228    t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
229    net = OuterNet()
230    net(t1, t2, t3)
231
232
233# test op with partial case 2
234def test_op_as_partial_independent():
235    class OpAsPartial(nn.Cell):
236        def __init__(self):
237            super(OpAsPartial, self).__init__()
238            self.op = FakeOp()
239
240        def construct(self, x, y, z):
241            partial_op1 = F.partial(self.op, x)
242            a = partial_op1(y)
243            partial_op2 = F.partial(self.op, x)
244            b = partial_op2(z)
245            return a, b
246
247    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
248    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
249    t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
250    net = OpAsPartial()
251    net(t1, t2, t3)
252
253
254def test_nest_partial():
255    class NestPartial(nn.Cell):
256        def __init__(self):
257            super(NestPartial, self).__init__()
258            self.op = FakeOp()
259
260        def construct(self, x, y, z):
261            partial_op1 = F.partial(self.op)
262            partial_op2 = F.partial(partial_op1, x)
263            a = partial_op2(y)
264            partial_op3 = F.partial(self.op)
265            partial_op4 = F.partial(partial_op3, x)
266            b = partial_op4(z)
267            return a, b
268
269    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
270    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
271    t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
272    net = NestPartial()
273    net(t1, t2, t3)
274
275
276# high order argument
277# op and op args as network arguments
278def test_op_with_arg_as_input():
279    class WithOpArgNet(nn.Cell):
280        def construct(self, op, x, y):
281            return op(x, y)
282
283    class OpsNet(nn.Cell):
284        def __init__(self, net):
285            super(OpsNet, self).__init__()
286            self.opnet = net
287            self.op = FakeOp()
288
289        def construct(self, x, y, z):
290            op = self.op
291            a = self.opnet(op, x, z)
292            b = self.opnet(op, x, y)
293            return (a, b)
294
295    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
296    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
297    t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
298    net = OpsNet(WithOpArgNet())
299    net(t1, t2, t3)
300
301
302# The partial application used as argument is not supported yet
303# because of the limit of inference specialize system
304@pytest.mark.skip("poly in infer")
305def test_partial_as_arg():
306    class PartialArgNet(nn.Cell):
307        def construct(self, partial_op, y):
308            return partial_op(y)
309
310    class OpsNet(nn.Cell):
311        def __init__(self, net):
312            super(OpsNet, self).__init__()
313            self.partial_net = net
314            self.op = FakeOp()
315
316        def construct(self, x, y, z):
317            partial_op = F.partial(self.op, x)
318            a = self.partial_net(partial_op, z)
319            b = self.partial_net(partial_op, y)
320            return (a, b)
321
322    t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
323    t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
324    t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
325    net = OpsNet(PartialArgNet())
326    net(t1, t2, t3)
327