• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test ops """
16import numpy as np
17
18import mindspore.nn as nn
19import mindspore.ops.composite as C
20import mindspore.ops.functional as F
21import mindspore.ops.operations as P
22from mindspore import Tensor
23from mindspore.common.api import _cell_graph_executor
24
25
26grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
27
28
29class InputBackward(nn.Cell):
30    """ InputBackward definition """
31
32    def __init__(self, network, c1=None, c2=None):
33        super(InputBackward, self).__init__()
34        self.network = network
35        self.network.set_train()
36        self.grad = grad_all_with_sens
37        self.c1 = c1
38        self.c2 = c2
39
40    def construct(self, *inputs):
41        pass
42
43    def construct1(self, x1, sens):
44        return self.grad(self.network)(x1, sens)
45
46    def construct2(self, x1, x2, sens):
47        return self.grad(self.network)(x1, x2, sens)
48
49    def construct3(self, x1, x2, x3, sens):
50        return self.grad(self.network)(x1, x2, x3, sens)
51
52    def construct4(self, x1, x2, x3, x4, sens):
53        return self.grad(self.network)(x1, x2, x3, x4, sens)
54
55    def construct5(self, x1, x2, x3, x4, x5, sens):
56        return self.grad(self.network)(x1, x2, x3, x4, x5, sens)
57
58    def construct6(self, x1, x2, x3, x4, x5, x6, sens):
59        return self.grad(self.network)(x1, x2, x3, x4, x5, x6, sens)
60
61    def construct7(self, x1, x2, x3, x4, x5, x6, x7, sens):
62        return self.grad(self.network)(x1, x2, x3, x4, x5, x6, x7, sens)
63
64
65class InputOpNet(nn.Cell):
66    """ InputOpNet definition """
67
68    def __init__(self, op, get_first=False,
69                 c1=None, c2=None, c3=None, c4=None):
70        super(InputOpNet, self).__init__()
71        self.op = op
72        self.get_first = get_first
73        self.c1 = c1
74        self.c2 = c2
75        self.c3 = c3
76        self.c4 = c4
77
78    def construct(self, *inputs):
79        pass
80
81    def construct0_c0_fack(self, data):
82        x = self.op() + data
83        if self.get_first:
84            x = x[0]
85        return x
86
87    def construct0_c1_fack(self, data):
88        x = self.op(self.c1) + data
89        if self.get_first:
90            x = x[0]
91        return x
92
93    def construct0_c2_fack(self, data):
94        x = self.op(self.c1, self.c2) + data
95        if self.get_first:
96            x = x[0]
97        return x
98
99    def construct0_c0(self):
100        x = self.op()
101        if self.get_first:
102            x = x[0]
103        return x
104
105    def construct0_c1(self):
106        x = self.op(self.c1)
107        if self.get_first:
108            x = x[0]
109        return x
110
111    def construct0_c2(self):
112        x = self.op(self.c1, self.c2)
113        if self.get_first:
114            x = x[0]
115        return x
116
117    def construct1_c0(self, x1):
118        x = self.op(x1)
119        if self.get_first:
120            x = x[0]
121        return x
122
123    def construct1_c1(self, x1):
124        x = self.op(x1, self.c1)
125        if self.get_first:
126            x = x[0]
127        return x
128
129    def construct1_c2(self, x1):
130        x = self.op(x1, self.c1, self.c2)
131        if self.get_first:
132            x = x[0]
133        return x
134
135    def construct1_c3(self, x1):
136        x = self.op(x1, self.c1, self.c2, self.c3)
137        if self.get_first:
138            x = x[0]
139        return x
140
141    def construct1_c4(self, x1):
142        x = self.op(x1, self.c1, self.c2, self.c3, self.c4)
143        if self.get_first:
144            x = x[0]
145        return x
146
147    def constructc1_1(self, x1):
148        x = self.op(self.c1, x1)
149        if self.get_first:
150            x = x[0]
151        return x
152
153    def construct2_c0(self, x1, x2):
154        x = self.op(x1, x2)
155        if self.get_first:
156            x = x[0]
157        return x
158
159    def construct2_c1(self, x1, x2):
160        x = self.op(x1, x2, self.c1)
161        if self.get_first:
162            x = x[0]
163        return x
164
165    def construct2_c3(self, x1, x2):
166        x = self.op(x1, x2, self.c1, self.c2, self.c3)
167        if self.get_first:
168            x = x[0]
169        return x
170
171    def construct3_c0(self, x1, x2, x3):
172        x = self.op(x1, x2, x3)
173        if self.get_first:
174            x = x[0]
175        return x
176
177    def construct3_c1(self, x1, x2, x3):
178        x = self.op(x1, x2, x3, self.c1)
179        if self.get_first:
180            x = x[0]
181        return x
182
183    def construct4_c0(self, x1, x2, x3, x4):
184        x = self.op(x1, x2, x3, x4)
185        if self.get_first:
186            x = x[0]
187        return x
188
189    def construct4_c1(self, x1, x2, x3, x4):
190        x = self.op(x1, x2, x3, x4, self.c1)
191        if self.get_first:
192            x = x[0]
193        return x
194
195    def construct5_c0(self, x1, x2, x3, x4, x5):
196        x = self.op(x1, x2, x3, x4, x5)
197        if self.get_first:
198            x = x[0]
199        return x
200
201    def construct6_c0(self, x1, x2, x3, x4, x5, x6):
202        x = self.op(x1, x2, x3, x4, x5, x6)
203        if self.get_first:
204            x = x[0]
205        return x
206
207    def construct5_c1(self, x1, x2, x3, x4, x5):
208        x = self.op(x1, x2, x3, x4, x5, self.c1)
209        if self.get_first:
210            x = x[0]
211        return x
212
213
214class NetOutputAsLoss(nn.Cell):
215    """ NetOutputAsLoss definition """
216
217    def __init__(self, network, output_index):
218        super(NetOutputAsLoss, self).__init__()
219        self.network = network
220        self.output_index = output_index
221
222    def construct(self, *inputs):
223        pass
224
225    def construct1(self, x1):
226        predict = self.network(x1)[self.output_index]
227        return predict
228
229    def construct2(self, x1, x2):
230        predict = self.network(x1, x2)[self.output_index]
231        return predict
232
233    def construct3(self, x1, x2, x3):
234        predict = self.network(x1, x2, x3)[self.output_index]
235        return predict
236
237    def construct4(self, x1, x2, x3, x4):
238        predict = self.network(x1, x2, x3, x4)[self.output_index]
239        return predict
240
241    def construct5(self, x1, x2, x3, x4, x5):
242        predict = self.network(x1, x2, x3, x4, x5)[self.output_index]
243        return predict
244
245
246def get_loss_fun(construct_net, num_input, output_index):
247    net = NetOutputAsLoss(construct_net, output_index)
248    f = getattr(net, 'construct%d' % num_input)
249    setattr(net, "construct", f)
250    return net
251
252
253def build_construct_graph(net, *inputs, execute=True):
254    net.set_train()
255    _cell_graph_executor.compile(net, *inputs)
256    if execute:
257        _cell_graph_executor(net, inputs)
258
259
260def build_backward_graph(net, output_shapes, inputs, execute=True):
261    inputs = append_sens_to_inputs(output_shapes, inputs)
262    net = gen_backward_net(net, len(inputs) - 1)
263    net.set_train()
264    _cell_graph_executor.compile(net, inputs)
265    if execute:
266        _cell_graph_executor(net, inputs)
267
268
269def convert(shp, dtype=np.float32, scale=6):
270    if isinstance(shp, list):
271        if not shp:
272            return Tensor((np.random.rand() * scale).astype(dtype))
273        return Tensor((np.random.rand(*shp) * scale).astype(dtype))
274    return shp
275
276
277def gen_inputs(input_shapes, config):
278    add_fack_input = config.get('add_fack_input', False)
279    if not input_shapes and add_fack_input:
280        return [Tensor(np.array([1.0]).astype(config.get('fack_input_type', np.float32)))]
281    return [convert(shp) for shp in input_shapes]
282
283
284def gen_backward_inputs(input_shapes, output_shapes, config):
285    add_fack_input = config.get('add_fack_input', False)
286    if not input_shapes and add_fack_input:
287        inputs = [Tensor(np.array([1.0]))]
288    else:
289        inputs = [convert(shp) for shp in input_shapes]
290    sens_shape = output_shapes[0]
291    sens = convert(sens_shape)
292    return inputs + [sens]
293
294
295def append_sens_to_inputs(output_shapes, inputs):
296    inputs = inputs
297    sens = Tensor(np.random.normal(0, 1, output_shapes).astype(np.float32))
298    return inputs + [sens]
299
300
301def gen_net(shapes, config, get_first=False):
302    """
303    gen_net function
304    """
305    add_fack_input = config.get('add_fack_input', False)
306    op = config['op']
307    if 'const' not in config:
308        const_input = []
309    else:
310        const_input = config['const']
311    const_first = False
312    if 'const_first' in config:
313        const_first = config['const_first']
314
315    net = InputOpNet(op, get_first, *const_input)
316    if const_first:
317        fn_name = 'constructc%d_%d' % (len(const_input), len(shapes))
318    else:
319        fn_name = 'construct%d_c%d' % (len(shapes), len(const_input))
320    if add_fack_input:
321        fn_name += '_fack'
322    f = getattr(net, fn_name)
323    setattr(net, "construct", f)
324    return net
325
326
327def gen_backward_net(construct_net, input_num):
328    net = InputBackward(construct_net)
329    f = getattr(net, 'construct%d' % input_num)
330    setattr(net, "construct", f)
331    return net
332
333
334def batch_tuple_tensor(data, batch_size):
335    ret = [Tensor(np.tile(d.asnumpy(), (batch_size, 1))) for d in data]
336    return tuple(ret)
337
338
339class OutPutWrap(nn.Cell):
340    """
341    OutPutWrap definition
342    """
343
344    def __init__(self, network, num_output, output_is_tuple):
345        super(OutPutWrap, self).__init__()
346        self.network = network
347        self.num_output = num_output
348        self.one = Tensor(np.array([1]))
349        self.dtype = P.DType()
350        self.cast = P.Cast()
351        self.output_is_tuple = output_is_tuple
352
353    def construct(self, *inputs):
354        pass
355
356    def construct1(self, x1):
357        ret = F.make_tuple()
358        predict = self.network(x1)
359        if self.num_output == 1 and self.output_is_tuple == 0:
360            return predict * self.cast(self.one, self.dtype(predict))
361        for i in range(self.num_output):
362            ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i])))
363        return ret
364
365    def construct2(self, x1, x2):
366        ret = F.make_tuple()
367        predict = self.network(x1, x2)
368        if self.num_output == 1 and self.output_is_tuple == 0:
369            return predict * self.cast(self.one, self.dtype(predict))
370        for i in range(self.num_output):
371            ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i])))
372        return ret
373
374    def construct3(self, x1, x2, x3):
375        ret = F.make_tuple()
376        predict = self.network(x1, x2, x3)
377        if self.num_output == 1 and self.output_is_tuple == 0:
378            return predict * self.cast(self.one, self.dtype(predict))
379        for i in range(self.num_output):
380            ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i])))
381        return ret
382
383    def construct4(self, x1, x2, x3, x4):
384        ret = F.make_tuple()
385        predict = self.network(x1, x2, x3, x4)
386        if self.num_output == 1 and self.output_is_tuple == 0:
387            return predict * self.cast(self.one, self.dtype(predict))
388        for i in range(self.num_output):
389            ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i])))
390        return ret
391
392    def construct5(self, x1, x2, x3, x4, x5):
393        ret = F.make_tuple()
394        predict = self.network(x1, x2, x3, x4, x5)
395        if self.num_output == 1 and self.output_is_tuple == 0:
396            return predict * self.cast(self.one, self.dtype(predict))
397        for i in range(self.num_output):
398            ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i])))
399        return ret
400
401    def construct6(self, x1, x2, x3, x4, x5, x6):
402        ret = F.make_tuple()
403        predict = self.network(x1, x2, x3, x4, x5, x6)
404        if self.num_output == 1 and self.output_is_tuple == 0:
405            return predict * self.cast(self.one, self.dtype(predict))
406        for i in range(self.num_output):
407            ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i])))
408        return ret
409
410
411def get_output_wrap(network, num_input, num_output, output_is_tuple=0):
412    net = OutPutWrap(network, num_output, output_is_tuple)
413    f = getattr(net, 'construct%d' % num_input)
414    setattr(net, "construct", f)
415    return net
416