• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import os
16import re
17import subprocess
18import pytest
19import numpy as np
20import mindspore as ms
21import mindspore.ops.operations as P
22from mindspore.nn import Cell
23from mindspore.nn import ReLU, BatchNorm2d, Conv2d, ParameterUpdate
24from mindspore.nn import Momentum
25from mindspore.nn import SoftmaxCrossEntropyWithLogits
26from mindspore import amp
27from mindspore import context, Tensor
28from mindspore.common import ParameterTuple
29from mindspore.common.parameter import Parameter
30from mindspore.ops.composite import GradOperation
31from tests.security_utils import security_off_wrap
32
33context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
34
35
36class _Grad(Cell):
37    def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
38        super().__init__()
39        self.network = network
40        self.grad = grad
41        self.sens_param = self.grad.sens_param
42        self.wrt_params = wrt_params
43        self.real_inputs_count = real_inputs_count
44        if self.wrt_params:
45            self.params = ParameterTuple(self.network.trainable_params())
46
47    def construct(self, *inputs):
48        if self.real_inputs_count is None or self.sens_param is False:
49            if self.wrt_params:
50                return self.grad(self.network, self.params)(*inputs)
51            return self.grad(self.network)(*inputs)
52
53        real_inputs = inputs[:self.real_inputs_count]
54        sense_param_inputs = inputs[self.real_inputs_count:]
55        if self.wrt_params:
56            return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
57        return self.grad(self.network)(*real_inputs, sense_param_inputs)
58
59
60class GradOfAllInputs(_Grad):
61    '''
62    get grads of all inputs
63    '''
64
65    def __init__(self, network, sens_param=True, real_inputs_count=None):
66        super().__init__(grad=GradOperation(get_all=True, sens_param=sens_param),
67                         network=network, real_inputs_count=real_inputs_count)
68
69
70class GradOfAllInputsAndParams(_Grad):
71    '''
72    get grads of all inputs and params
73    '''
74
75    def __init__(self, network, sens_param=True, real_inputs_count=None):
76        super().__init__(grad=GradOperation(get_all=True, get_by_list=True, sens_param=sens_param),
77                         network=network, wrt_params=True, real_inputs_count=real_inputs_count)
78
79
80def _count_unequal_element(data_expected, data_me, rtol, atol):
81    assert data_expected.shape == data_me.shape
82    total_count = len(data_expected.flatten())
83    error = np.abs(data_expected - data_me)
84    greater = np.greater(error, atol + np.abs(data_me)*rtol)
85    loss_count = np.count_nonzero(greater)
86    assert (loss_count/total_count) < rtol, \
87        "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".\
88        format(data_expected[greater], data_me[greater], error[greater])
89
90
91def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
92    if np.any(np.isnan(data_expected)):
93        assert np.allclose(data_expected, data_me, rtol,
94                           atol, equal_nan=equal_nan)
95    elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
96        _count_unequal_element(data_expected, data_me, rtol, atol)
97    else:
98        assert True
99
100
101def clear_files():
102    os.system("rm verbose_ir_files/*")
103
104
105def find_files(file, para):
106    output = subprocess.check_output(
107        ["grep '%s' verbose_ir_files/%s | wc -l" % (para, file)],
108        shell=True)
109    out = str(output, 'utf-8').strip()
110    return out
111
112
113class SideEffectCastAll(Cell):
114    def __init__(self):
115        super().__init__()
116        self.cast = P.Cast()
117        self.dtype = ms.float16
118        np.random.seed(5)
119        inputs1 = np.random.randn(5, 5)
120        inputs2 = np.random.randn(5, 5)
121        self.parameter_a = Parameter(Tensor(inputs1, ms.float32), name="a")
122        self.parameter_b = Parameter(Tensor(inputs2, ms.float32), name="b")
123        self.assign = P.Assign()
124
125    def construct(self, x, y):
126        self.assign(self.parameter_a, x)
127        self.assign(self.parameter_b, y)
128        out_a = self.cast(self.parameter_a, self.dtype)
129        out_b = self.cast(self.parameter_b, self.dtype)
130        return out_a, out_b
131
132@security_off_wrap
133def test_side_effect_castall():
134    clear_files()
135    context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
136    net = SideEffectCastAll()
137    inputs1 = np.random.randn(5, 5)
138    inputs2 = np.random.randn(5, 5)
139    net(Tensor(inputs1, ms.float32), Tensor(inputs2, ms.float32))
140    result = find_files('./hwopt*cast_all*.ir', 'CastAll')
141    assert result == '2'
142
143
144class SideEffectControlFlowAssignDependWhileNet(Cell):
145    def __init__(self):
146        super().__init__()
147        self.parameter1 = Parameter(
148            Tensor([199.0], ms.float32), name="parameter1")
149        self.assign = P.Assign()
150        self.assignadd = P.AssignAdd()
151        self.addn = P.AddN()
152
153    def construct(self, x, y, z):
154        self.assign(self.parameter1, x)
155        while self.parameter1 < y:
156            x = self.addn((x, x))
157            self.assignadd(self.parameter1, z)
158        return x
159
160    def grad_mindspore_impl(self, params1, params2, params3, grad_ys):
161        grad_net = GradOfAllInputsAndParams(self)
162        grad_net.set_train()
163        grad_out = grad_net(params1, params2, params3, grad_ys)
164        return grad_out
165
166
167@pytest.mark.level0
168@pytest.mark.platform_x86_gpu_training
169@pytest.mark.env_onecard
170def test_side_effect_control_flow_assign_depend_while_net():
171    net = SideEffectControlFlowAssignDependWhileNet()
172    context.set_context(mode=context.GRAPH_MODE)
173    out1 = net(Tensor([9.0], ms.float32), Tensor(
174        [99.0], ms.float32), Tensor([1.0], ms.float32))
175    net = SideEffectControlFlowAssignDependWhileNet()
176    context.set_context(mode=context.PYNATIVE_MODE)
177    out2 = net(Tensor([9.0], ms.float32), Tensor(
178        [99.0], ms.float32), Tensor([1.0], ms.float32))
179    allclose_nparray(out1.asnumpy(), out2.asnumpy(), 0.001, 0.001)
180
181
182class Addn(Cell):
183    def __init__(self):
184        super().__init__()
185        self.parameter3 = Parameter(Tensor([1.0], ms.float32),
186                                    name="parameter3")
187        self.parameter4 = Parameter(Tensor([3.0], ms.float32),
188                                    name="parameter4")
189        self.addn = P.AddN()
190
191    def construct(self, inputs):
192        out = self.addn((inputs, self.parameter3, self.parameter4))
193        return out
194
195
196class Relu(Cell):
197    def __init__(self):
198        super().__init__()
199        self.relu = P.ReLU()
200
201    def construct(self, inputs):
202        out = self.relu(inputs)
203        return out
204
205
206class SideEffectTwoAssignTwoAddnDependencyNet(Cell):
207    def __init__(self):
208        super().__init__()
209        self.parameter1 = Parameter(Tensor([1.0], ms.float32),
210                                    name="parameter1")
211        self.parameter2 = Parameter(Tensor([3.0], ms.float32),
212                                    name="parameter2")
213        self.assign = P.Assign()
214        self.addN = P.AddN()
215
216    def construct(self, inputs):
217        self.assign(self.parameter1, inputs)
218        out = self.addN((inputs, self.parameter1, self.parameter2))
219        self.assign(self.parameter2, inputs)
220        out = self.addN((out, self.parameter1, self.parameter2))
221        return out
222
223    def grad_mindspore_impl(self, params, grad_ys):
224        grad_net = GradOfAllInputsAndParams(self)
225        grad_net.set_train()
226        grad_out = grad_net(params, grad_ys)
227        return grad_out
228
229
230@pytest.mark.level1
231@pytest.mark.platform_x86_gpu_training
232@pytest.mark.env_onecard
233def test_ctrl_while_by_while_and_if_in_first_while():
234    class Net(Cell):
235        def __init__(self):
236            super().__init__()
237            self.relu = P.ReLU()
238            self.sigmoid = P.Sigmoid()
239            self.tanh = P.Tanh()
240            self.add = P.Add()
241            a = np.full((1,), 5, dtype=np.float32)
242            self.a = Parameter(Tensor(a), name="a")
243            b = np.full((1,), 4, dtype=np.float32)
244            self.b = Parameter(Tensor(b), name="b")
245            c = np.full((1,), 7, dtype=np.float32)
246            self.c = Parameter(Tensor(c), name="c")
247
248        def construct(self, x):
249            out = x
250            while self.a < 7:
251                if self.a < self.c:
252                    out = self.relu(x)
253                self.a += 1
254            while self.c > 5:
255                out = self.add(out, out)
256                self.c -= 1
257            return out
258
259    context.set_context(mode=context.GRAPH_MODE)
260    input_np_a = np.random.randn(2, 3, 4, 5).astype(np.float32)
261    input_me_a = Tensor(input_np_a)
262    net = Net()
263    net(input_me_a)
264
265
266@pytest.mark.level1
267@pytest.mark.platform_x86_gpu_training
268@pytest.mark.env_onecard
269def test_ctrl_while_by_while_and_while_in_first_while():
270    class Net(Cell):
271        def __init__(self):
272            super().__init__()
273            self.relu = P.ReLU()
274            self.sigmoid = P.Sigmoid()
275            self.tanh = P.Tanh()
276            self.add = P.Add()
277            a = np.full((1,), 5, dtype=np.float32)
278            self.a = Parameter(Tensor(a), name="a")
279            b = np.full((1,), 4, dtype=np.float32)
280            self.b = Parameter(Tensor(b), name="b")
281            c = np.full((1,), 7, dtype=np.float32)
282            self.c = Parameter(Tensor(c), name="c")
283
284        def construct(self, x):
285            out = x
286            while self.a < self.c:
287                out = self.relu(x)
288                while self.b > 1:
289                    self.b -= 1
290                self.a += 1
291
292            while self.c > 5:
293                out = self.add(out, out)
294                self.c -= 1
295            return out
296
297    context.set_context(mode=context.GRAPH_MODE)
298    input_np_a = np.random.randn(2, 3, 4, 5).astype(np.float32)
299    input_me_a = Tensor(input_np_a)
300    net = Net()
301    net(input_me_a)
302
303
304class InplaceNet(Cell):
305    def __init__(self):
306        super().__init__()
307        self.bn1 = BatchNorm2d(num_features=4, eps=1e-4,
308                               momentum=0.9, gamma_init=1, beta_init=0,
309                               moving_mean_init=0, moving_var_init=1, data_format="NHWC")
310        self.bn2 = BatchNorm2d(num_features=4, eps=1e-4,
311                               momentum=0.9, gamma_init=1, beta_init=0,
312                               moving_mean_init=0, moving_var_init=1, data_format="NHWC")
313        self.add = P.Add()
314        self.relu = ReLU()
315        self.conv2d1 = Conv2d(in_channels=4, out_channels=4,
316                              kernel_size=2, data_format="NHWC")
317        self.conv2d2 = Conv2d(in_channels=4, out_channels=4,
318                              kernel_size=2, data_format="NHWC")
319        self.conv2d3 = Conv2d(in_channels=4, out_channels=4,
320                              kernel_size=2, data_format="NHWC")
321        self.conv2d4 = Conv2d(in_channels=4, out_channels=4,
322                              kernel_size=2, data_format="NHWC")
323
324    def construct(self, input_x):
325        tmp_c1 = self.conv2d1(input_x)
326        tmp_c2 = self.conv2d2(input_x)
327        tmp_x = self.bn1(tmp_c1)
328        tmp_y = self.bn2(tmp_c2)
329        tmp_w = self.add(tmp_x, tmp_y)
330        tmp_w = self.relu(tmp_w)
331
332        tmp_c1 = self.conv2d3(tmp_w)
333        tmp_c2 = self.conv2d4(tmp_w)
334        output = self.add(tmp_c1, tmp_c2)
335        return output
336
337@security_off_wrap
338def test_ir_fusion_inplace_bn_conv_conv():
339    clear_files()
340    context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
341    input_np = np.random.uniform(0.0, 255.0,
342                                 size=[4, 4, 4, 4]).astype(np.float32)
343    label = np.ones([4, 4, 4, 4]).astype(np.float32)
344    net = InplaceNet()
345    loss = SoftmaxCrossEntropyWithLogits(sparse=False)
346    opt = Momentum(learning_rate=0.01, momentum=0.9,
347                   params=filter(lambda x: x.requires_grad, net.get_parameters()))
348    net = amp.build_train_network(net, opt, loss, level="O2",
349                                  keep_batchnorm_fp32=False)
350    net.set_train()
351    net(Tensor(input_np), Tensor(label))
352    find_accum = find_files("./hwopt*cudnn_inplace*ir",
353                            "inplace_algo: accumulation")
354    find_cover = find_files("./hwopt*cudnn_inplace*ir",
355                            "inplace_algo: cover")
356    assert find_accum == '1'
357    assert find_cover == '1'
358
359
360def clean_all_ir_files(folder_path):
361    if os.path.exists(folder_path):
362        for file_name in os.listdir(folder_path):
363            if file_name.endswith('.ir') or file_name.endswith('.dot') or \
364                    file_name.endswith('.dat'):
365                os.remove(os.path.join(folder_path, file_name))
366
367
368def find_newest_validateir_file(folder_path):
369    ckpt_files = map(lambda f: os.path.join(folder_path, f),
370                     filter(lambda f: re.match(r'\d+_validate_\d+.ir', f),
371                            os.listdir(folder_path)))
372    return max(ckpt_files, key=os.path.getctime)
373
374
375def read_file():
376    filename = find_newest_validateir_file('./')
377    with open((os.path.join(filename)), 'r') as f:
378        content = f.read()
379    clean_all_ir_files('./')
380    return content
381
382
383class Add(Cell):
384    def __init__(self):
385        super().__init__()
386        self.add = P.Add()
387
388    def construct(self, x, y):
389        return self.add(x, y)
390
391
392class MixControlNet(Cell):
393    def __init__(self, in_channel, x):
394        super().__init__()
395        #self._save_graphs(save_graph_flag=True, save_graph_path=".")
396        self.biasadd = P.BiasAdd()
397        self.equal = P.Equal()
398        self.addn = P.AddN()
399        self.conv = Conv2d(in_channels=in_channel, out_channels=in_channel,
400                           kernel_size=1, stride=1, has_bias=False,
401                           weight_init='ones', pad_mode='same')
402        self.bn = BatchNorm2d(num_features=in_channel)
403        self.assignadd = P.AssignAdd()
404        self.assign = P.Assign()
405        self.relu = ReLU()
406        self.mean = P.ReduceMean(keep_dims=False)
407        self.bias = Parameter(
408            Tensor(np.random.randint(2, size=(3,)).astype((np.float32))),
409            name="bias")
410        self.bias2 = Parameter(Tensor(np.ones([3]).astype(np.float32)),
411                               name="bias2")
412        self.parameterupdate = ParameterUpdate(self.bias)
413        self.value = Tensor(np.random.randn(*(3,)), ms.float32)
414        self.x = x
415
416    def construct(self, input_x):
417        x = self.x
418        z = self.x
419        out = self.biasadd(input_x, self.bias)
420        while x < 20:
421            update = self.parameterupdate(self.bias2)
422            out = self.biasadd(out, update)
423            if x < 10:
424                out = self.addn((input_x, out))
425                while z < 20:
426                    out = self.conv(out)
427                    z = z + 1
428            if x < 20:
429                out = self.biasadd(out, self.bias)
430                if x % 2 == 0:
431                    self.assignadd(self.bias, self.value)
432                    out = self.biasadd(out, self.bias)
433                    out = self.bn(out)
434                else:
435                    out = self.conv(out)
436            x = x + 1
437        out = self.addn((out, out))
438        out = self.mean(out, (2, 3))
439        return out
440
441
442def use_build_train_network_controlflow_check_cast_num(network, level, input_x,
443                                                       label, cast_num,
444                                                       sparse=False,
445                                                       loss_flag=True,
446                                                       **kwargs):
447    opt = Momentum(learning_rate=0.0001, momentum=0.009,
448                   params=network.trainable_params())
449    loss = None
450    if loss_flag:
451        loss = SoftmaxCrossEntropyWithLogits(sparse=sparse, reduction='mean')
452
453    train_network = ms.amp.build_train_network(network, opt, loss, level=level,
454                                               **kwargs)
455    out_me = train_network(input_x, label)
456    if context.get_context("mode") == 0:
457        content = read_file()
458        castnum = re.findall('Cast', content)
459        assert len(castnum) == cast_num
460    return out_me
461
462@security_off_wrap
463def test_auto_mixed_precision_controlflow_auto():
464    context.set_context(mode=context.PYNATIVE_MODE, save_graphs=True)
465    net = MixControlNet(3, 5)
466    input_x = Tensor(
467        np.random.randint(2, size=(1, 3, 2, 2)).astype((np.float32)))
468    label = Tensor(np.zeros([1, 3]).astype(np.float32))
469    if ms.context.get_context("device_target") == "Ascend":
470        cast_num = 77
471    if ms.context.get_context("device_target") == "GPU":
472        cast_num = 73
473    use_build_train_network_controlflow_check_cast_num(net, "auto", input_x,
474                                                       label, cast_num)
475
476@security_off_wrap
477def test_updatestate_between_assigns():
478    class UpdateState_Assigns(Cell):
479        def __init__(self):
480            super().__init__()
481            self.para1 = Parameter(Tensor(1, dtype=ms.int32), name='para1')
482            self.para2 = Parameter(Tensor(3, dtype=ms.int32), name='para2')
483
484        def construct(self, value1, value2):
485            self.para1 = value1
486            self.para2 = value2
487            return self.para2
488
489    context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
490    input_x = Tensor(10, dtype=ms.int32)
491    input_y = Tensor(30, dtype=ms.int32)
492    expect = Tensor(30, dtype=ms.int32)
493    net = UpdateState_Assigns()
494    out = net(input_x, input_y)
495    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
496    if ms.context.get_context('mode') == 0:
497        content = read_file()
498        updatestate_num = re.findall('UpdateState', content)
499        assert len(updatestate_num) == 1
500
501@security_off_wrap
502def test_updatestate_between_maketuple_assign():
503    class UpdateState_MakeTuple_Assign(Cell):
504        def __init__(self):
505            super().__init__()
506            self.para1 = Parameter(Tensor(1, dtype=ms.int32), name='para1')
507            self.para2 = Parameter(Tensor(3, dtype=ms.int32), name='para2')
508            self.para3 = Parameter(Tensor(5, dtype=ms.int32), name='para3')
509
510        def construct(self, value1, value2, value3):
511            (self.para1, self.para2) = (value1, value2)
512            self.para3 = value3
513            return self.para3
514
515    context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
516    input_x = Tensor(10, dtype=ms.int32)
517    input_y = Tensor(30, dtype=ms.int32)
518    input_z = Tensor(50, dtype=ms.int32)
519    expect = Tensor(50, dtype=ms.int32)
520    net = UpdateState_MakeTuple_Assign()
521    out = net(input_x, input_y, input_z)
522    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
523    if ms.context.get_context('mode') == 0:
524        content = read_file()
525        updatestate_num = re.findall('UpdateState', content)
526        assert len(updatestate_num) == 1
527
528@security_off_wrap
529def test_updatestate_between_assign_maketuple():
530    class UpdateState_Assign_MakeTuple(Cell):
531        def __init__(self):
532            super().__init__()
533            self.para1 = Parameter(Tensor(1, dtype=ms.int32), name='para1')
534            self.para2 = Parameter(Tensor(3, dtype=ms.int32), name='para2')
535            self.para3 = Parameter(Tensor(5, dtype=ms.int32), name='para3')
536
537        def construct(self, value1, value2, value3):
538            self.para1 = value1
539            (self.para2, self.para3) = (value2, value3)
540            return self.para3
541
542    context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
543    input_x = Tensor(10, dtype=ms.int32)
544    input_y = Tensor(30, dtype=ms.int32)
545    input_z = Tensor(50, dtype=ms.int32)
546    expect = Tensor(50, dtype=ms.int32)
547    net = UpdateState_Assign_MakeTuple()
548    out = net(input_x, input_y, input_z)
549    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
550    if ms.context.get_context('mode') == 0:
551        content = read_file()
552        updatestate_num = re.findall('UpdateState', content)
553        assert len(updatestate_num) == 1
554