• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Utils for Cell related computation."""
17
18# pylint: disable=missing-docstring
19
20import numpy as np
21
22from mindspore import ParameterTuple
23from mindspore import nn, context
24from mindspore.common.api import _cell_graph_executor, ms_function
25from mindspore.common.tensor import Tensor
26from mindspore.ops import functional as F
27from mindspore.ops import operations as P
28from mindspore.ops.composite import GradOperation
29from . import keyword
30
31
32def get_uniform_with_shape(shape):
33    np.random.seed(1)
34    return np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32)
35
36
37def set_block_param_with_rand(net, rand_func=None):
38    if not isinstance(net, nn.Cell) or rand_func is None:
39        return
40    net.init_parameters_data()
41    for param in net.trainable_params():
42        param.set_data(Tensor(rand_func(param.data.asnumpy().shape)))
43
44
45def compile_block(net, *inputs, rand_func=None, training=True):
46    set_block_training(net, training)
47    set_block_param_with_rand(net, rand_func)
48    return _cell_graph_executor.compile(net, *inputs)
49
50
51def run_block(net, *inputs, rand_func=None, training=True):
52    set_block_training(net, training)
53    set_block_param_with_rand(net, rand_func)
54    if context.get_context("mode") == context.PYNATIVE_MODE:
55        def func_pynative(*inputs):
56            @ms_function
57            def _func_pynative(*inputs):
58                return net(*inputs)
59
60            return _func_pynative(*inputs)
61
62        return func_pynative(*inputs)
63    return net(*inputs)
64
65
66class IthOutputCell(nn.Cell):
67    def __init__(self, network, output_index):
68        if isinstance(network, nn.Cell):
69            super(IthOutputCell, self).__init__(auto_prefix=False)
70        else:
71            super(IthOutputCell, self).__init__()
72        self.network = network
73        self.output_index = output_index
74
75    def construct(self, *inputs):
76        predict = self.network(*inputs)[self.output_index]
77        return predict
78
79
80def get_output_cell(network, num_input, output_index, training=True):
81    _ = num_input
82    net = IthOutputCell(network, output_index)
83    set_block_training(net, training)
84    return net
85
86
87class OutputReduceSumCell(nn.Cell):
88    def __init__(self, network, output_num):
89        super(OutputReduceSumCell, self).__init__()
90        self.output_num = output_num
91        self.network = network
92        self.reduce_sum = P.ReduceSum()
93
94    def construct(self, *inputs):
95        if self.output_num == 1:
96            return self.reduce_sum(self.network(*inputs), None)
97        ret = F.make_tuple()
98        for index in range(self.output_num):
99            predict = self.network(*inputs)[index]
100            predict_reduce = self.reduce_sum(predict, None)
101            ret = ret + F.make_tuple(predict_reduce)
102        return ret
103
104
105def get_output_reduce_cell(network, output_num, training=True):
106    net = OutputReduceSumCell(network, output_num)
107    set_block_training(net, training)
108    return net
109
110
111class InputOpNet(nn.Cell):
112    def __init__(self, op, c1=None, c2=None, c3=None, c4=None):
113        super(InputOpNet, self).__init__()
114        self.op = op
115        self.c1 = c1
116        self.c2 = c2
117        self.c3 = c3
118        self.c4 = c4
119
120    def construct(self, *inputs):
121        raise NotImplementedError
122
123    def construct0_c0_fake(self, data):
124        x = self.op() + data
125        return x
126
127    def construct0_c1_fake(self, data):
128        x = self.op(self.c1) + data
129        return x
130
131    def construct0_c2_fake(self, data):
132        x = self.op(self.c1, self.c2) + data
133        return x
134
135    def construct0_c3_fake(self, data):
136        x = self.op(self.c1, self.c2, self.c3) + data
137        return x
138
139    def construct0_c0(self):
140        x = self.op()
141        return x
142
143    def construct0_c1(self):
144        x = self.op(self.c1)
145        return x
146
147    def construct0_c2(self):
148        x = self.op(self.c1, self.c2)
149        return x
150
151    def construct1_c0(self, x1):
152        x = self.op(x1)
153        return x
154
155    def construct1_c1(self, x1):
156        x = self.op(x1, self.c1)
157        return x
158
159    def construct1_c2(self, x1):
160        x = self.op(x1, self.c1, self.c2)
161        return x
162
163    def construct1_c3(self, x1):
164        x = self.op(x1, self.c1, self.c2, self.c3)
165        return x
166
167    def construct1_c4(self, x1):
168        x = self.op(x1, self.c1, self.c2, self.c3, self.c4)
169        return x
170
171    def constructc1_1(self, x1):
172        x = self.op(self.c1, x1)
173        return x
174
175    def construct2_c0(self, x1, x2):
176        x = self.op(x1, x2)
177        return x
178
179    def construct2_c1(self, x1, x2):
180        x = self.op(x1, x2, self.c1)
181        return x
182
183    def construct2_c3(self, x1, x2):
184        x = self.op(x1, x2, self.c1, self.c2, self.c3)
185        return x
186
187    def construct3_c0(self, x1, x2, x3):
188        x = self.op(x1, x2, x3)
189        return x
190
191    def construct3_c1(self, x1, x2, x3):
192        x = self.op(x1, x2, x3, self.c1)
193        return x
194
195    def construct4_c0(self, x1, x2, x3, x4):
196        x = self.op(x1, x2, x3, x4)
197        return x
198
199    def construct4_c1(self, x1, x2, x3, x4):
200        x = self.op(x1, x2, x3, x4, self.c1)
201        return x
202
203    def construct4_c2(self, x1, x2, x3, x4):
204        x = self.op(x1, x2, x3, x4, self.c1, self.c2)
205        return x
206
207    def construct4_c4(self, x1, x2, x3, x4):
208        x = self.op(x1, x2, x3, x4, self.c1, self.c2, self.c3, self.c4)
209        return x
210
211    def construct5_c0(self, x1, x2, x3, x4, x5):
212        x = self.op(x1, x2, x3, x4, x5)
213        return x
214
215    def construct6_c0(self, x1, x2, x3, x4, x5, x6):
216        x = self.op(x1, x2, x3, x4, x5, x6)
217        return x
218
219    def construct5_c1(self, x1, x2, x3, x4, x5):
220        x = self.op(x1, x2, x3, x4, x5, self.c1)
221        return x
222
223    def construct5_c4(self, x1, x2, x3, x4, x5):
224        x = self.op(x1, x2, x3, x4, x5, self.c1, self.c2, self.c3, self.c4)
225        return x
226
227
228def gen_net(op, input_num, training=True, desc_const=(), const_first=False, add_fake_input=False):
229    if isinstance(op, nn.Cell):
230        return op
231    net = InputOpNet(op, *desc_const)
232    if const_first:
233        fn_name = 'constructc%d_%d' % (len(desc_const), input_num)
234    else:
235        fn_name = 'construct%d_c%d' % (input_num, len(desc_const))
236    if add_fake_input:
237        fn_name += '_fake'
238    f = getattr(net, fn_name)
239    setattr(net, "construct", f)
240    set_block_training(net, training)
241    return net
242
243
244class OperationBackward(nn.Cell):
245    def __init__(self, network, grad_op, sens):
246        if isinstance(network, nn.Cell):
247            super(OperationBackward, self).__init__(auto_prefix=False)
248        else:
249            super(OperationBackward, self).__init__()
250        self.network = network
251        self.grad = grad_op
252        self.sens = sens
253
254    def construct(self, *inputs):
255        return self.grad(self.network)(*inputs, self.sens)
256
257
258class OperationBackwardWithNoSens(nn.Cell):
259    def __init__(self, network, grad_op):
260        if isinstance(network, nn.Cell):
261            super(OperationBackwardWithNoSens, self).__init__(auto_prefix=False)
262        else:
263            super(OperationBackwardWithNoSens, self).__init__()
264        self.network = network
265        self.grad = grad_op
266
267    def construct(self, *inputs):
268        return self.grad(self.network)(*inputs)
269
270
271class NNBackward(nn.Cell):
272    def __init__(self, network, grad_op, sens):
273        if isinstance(network, nn.Cell):
274            super(NNBackward, self).__init__(auto_prefix=False)
275        else:
276            super(NNBackward, self).__init__()
277        self.network = network
278        self.grad = grad_op
279        self.params = ParameterTuple(network.trainable_params())
280        self.sens = sens
281
282    def construct(self, *inputs):
283        return self.grad(self.network, self.params)(*inputs, self.sens)
284
285
286class NNBackwardWithNoSens(nn.Cell):
287    def __init__(self, network, grad_op):
288        if isinstance(network, nn.Cell):
289            super(NNBackwardWithNoSens, self).__init__(auto_prefix=False)
290        else:
291            super(NNBackwardWithNoSens, self).__init__()
292        self.network = network
293        self.grad = grad_op
294        self.params = ParameterTuple(network.trainable_params())
295
296    def construct(self, *inputs):
297        return self.grad(self.network, self.params)(*inputs)
298
299
300def gen_grad_net(net, grad_op, input_num, sens=None, training=True, desc_const=(),
301                 const_first=False, add_fake_input=False):
302    if not isinstance(net, nn.Cell):
303        net = gen_net(net, input_num, desc_const=desc_const, const_first=const_first, add_fake_input=add_fake_input)
304    if grad_op.get_by_list:
305        if grad_op.sens_param:
306            net = NNBackward(net, grad_op, sens)
307        else:
308            net = NNBackwardWithNoSens(net, grad_op)
309    else:
310        if grad_op.sens_param:
311            net = OperationBackward(net, grad_op, sens)
312        else:
313            net = OperationBackwardWithNoSens(net, grad_op)
314    set_block_training(net, training)
315    return net
316
317
318def set_block_training(net, training=True):
319    if isinstance(net, nn.Cell):
320        net.set_train(training)
321
322
323def set_block_phase(net, phase='train'):
324    if isinstance(net, nn.Cell):
325        net.phase = phase
326
327
328def create_funcs(verification_set, block_generator, block_runner, grad_op=None, default_rand_func=None):
329    def create_func(block, num_outputs, rand_func, desc_const, const_first, add_fake_input, split_outputs):
330        def function(*inputs):
331            # gradient
332            if grad_op:
333                if num_outputs == 0:
334                    grad_op_ = GradOperation(get_all=grad_op.get_all,
335                                             get_by_list=grad_op.get_by_list, sens_param=False)
336                    b = block_generator(block, grad_op_, len(inputs), desc_const=desc_const,
337                                        const_first=const_first, add_fake_input=add_fake_input)
338                    return block_runner(b, *inputs, rand_func=rand_func)
339                if num_outputs == 1:
340                    b = block_generator(block, grad_op, len(inputs) - 1, inputs[-1], desc_const=desc_const,
341                                        const_first=const_first, add_fake_input=add_fake_input)
342                    return block_runner(b, *(inputs[:-1]), rand_func=rand_func)
343                if split_outputs:
344                    block_inputs = inputs[0:len(inputs) - num_outputs]
345                    sens_inputs = inputs[len(inputs) - num_outputs:]
346                    ret = []
347                    for i in range(num_outputs):
348                        bi_inputs = list(block_inputs)
349                        bi = get_output_cell(block, len(block_inputs), i)
350                        bi = block_generator(bi, grad_op, len(bi_inputs), sens_inputs[i], desc_const=desc_const,
351                                             const_first=const_first, add_fake_input=add_fake_input)
352                        grads_i = block_runner(bi, *bi_inputs, rand_func=rand_func)
353                        if isinstance(grads_i, tuple):
354                            ret.extend(grads_i)
355                        else:
356                            ret.append(grads_i)
357                    return ret
358                block_inputs = inputs[0:len(inputs) - num_outputs]
359                sens_inputs = tuple(inputs[len(inputs) - num_outputs:])
360                b = block_generator(block, grad_op, len(block_inputs), sens_inputs, desc_const=desc_const,
361                                    const_first=const_first, add_fake_input=add_fake_input)
362                return block_runner(b, *block_inputs, rand_func=rand_func)
363            # forward
364            inputs_num = len(inputs)
365            if add_fake_input and inputs_num == 1:
366                # input is faked
367                inputs_num = 0
368            b = block_generator(block, inputs_num, desc_const=desc_const, const_first=const_first,
369                                add_fake_input=add_fake_input)
370            return block_runner(b, *inputs, rand_func=rand_func)
371
372        return function
373
374    bc_configs = verification_set[keyword.function]
375    for config in bc_configs:
376        block = config[keyword.block]
377        rand_func = config.get(keyword.init_param_with, default_rand_func)
378        num_outputs = config.get(keyword.num_outputs, 0)
379        desc_const = config.get(keyword.desc_const, [])
380        const_first = config.get(keyword.const_first, False)
381        add_fake_input = config.get(keyword.add_fake_input, False)
382        split_outputs = config.get(keyword.split_outputs, True)
383        config[keyword.block] = create_func(block, num_outputs, rand_func, desc_const,
384                                            const_first, add_fake_input, split_outputs)
385    return bc_configs
386