• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import pytest
16import numpy as np
17from mindspore import RowTensor
18from mindspore import context, nn, Tensor, ParameterTuple
19from mindspore.common import dtype as mstype
20from mindspore.common import ms_function
21from mindspore.ops import operations as P
22from mindspore.ops import composite as C
23
24
25def setup_module():
26    context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False)
27
28
29class _Grad(nn.Cell):
30    def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
31        super().__init__()
32        self.network = network
33        self.grad = grad
34        self.sens_param = self.grad.sens_param
35        self.wrt_params = wrt_params
36        self.real_inputs_count = real_inputs_count
37        if self.wrt_params:
38            self.params = ParameterTuple(self.network.trainable_params())
39
40    def construct(self, *inputs):
41        if self.wrt_params:
42            if self.real_inputs_count is None or self.sens_param is False:
43                return self.grad(self.network, self.params)(*inputs)
44            real_inputs = inputs[:self.real_inputs_count]
45            sense_param_inputs = inputs[self.real_inputs_count:]
46            return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
47
48        if self.real_inputs_count is None or self.sens_param is False:
49            return self.grad(self.network)(*inputs)
50        real_inputs = inputs[:self.real_inputs_count]
51        sense_param_inputs = inputs[self.real_inputs_count:]
52        return self.grad(self.network)(*real_inputs, sense_param_inputs)
53
54
55class GradOfFirstInput(_Grad):
56    """
57    get grad of first input
58    """
59
60    def __init__(self, network, sens_param=True, real_inputs_count=None):
61        super().__init__(grad=C.GradOperation(sens_param=sens_param),
62                         network=network, real_inputs_count=real_inputs_count)
63
64
65class GradOfAllInputs(_Grad):
66    """
67    get grad of first input
68    """
69
70    def __init__(self, network, sens_param=True, real_inputs_count=None):
71        super().__init__(grad=C.GradOperation(get_all=True, sens_param=sens_param),
72                         network=network, real_inputs_count=real_inputs_count)
73
74
75@pytest.mark.level1
76@pytest.mark.platform_arm_ascend_training
77@pytest.mark.platform_x86_ascend_training
78@pytest.mark.env_onecard
79def test_row_tensor_in_while():
80    class RowTensorValuesDouble(nn.Cell):
81
82        def construct(self, x):
83            indices = x.indices
84            values = x.values * 2
85            dense_shape = x.dense_shape
86            return RowTensor(indices, values, dense_shape)
87
88    class RowTensorValuesAdd2(nn.Cell):
89
90        def construct(self, x):
91            indices = x.indices
92            values = x.values + 2
93            dense_shape = x.dense_shape
94            return RowTensor(indices, values, dense_shape)
95
96    class RowTensorWithControlWhile(nn.Cell):
97        def __init__(self, dense_shape):
98            super().__init__()
99            self.op1 = RowTensorValuesDouble()
100            self.op2 = RowTensorValuesAdd2()
101            self.dense_shape = dense_shape
102
103        @ms_function
104        def construct(self, a, b, indices, values):
105            x = RowTensor(indices, values, self.dense_shape)
106            x = self.op2(x)
107            while a > b:
108                x = self.op1(x)
109                b = b + 1
110            return x.indices, x.values, x.dense_shape
111    a = Tensor(np.array(3).astype(np.int32))
112    b = Tensor(np.array(0).astype(np.int32))
113    indices = Tensor(np.array([0, 2]).astype(np.int32))
114    values = Tensor(np.ones([2, 2]).astype(np.float32))
115    dense_shape = (5, 2)
116    net = RowTensorWithControlWhile(dense_shape)
117    out = net(a, b, indices, values)
118    assert np.allclose(indices.asnumpy(), out[0].asnumpy(), .0, .0)
119    assert np.allclose(values.asnumpy()*24, out[1].asnumpy(), .0, .0)
120    assert dense_shape == out[2]
121
122
123@pytest.mark.level1
124@pytest.mark.platform_arm_ascend_training
125@pytest.mark.platform_x86_ascend_training
126@pytest.mark.env_onecard
127def test_parser_switch_layer_inputs_tuple():
128    class Add(nn.Cell):
129        def __init__(self):
130            super().__init__()
131            self.op = P.Add()
132
133        def construct(self, x):
134            y = self.op(x[0], x[1])
135            return self.op(x[0], y)
136
137    class Mul(nn.Cell):
138        def __init__(self):
139            super().__init__()
140            self.op = P.Mul()
141
142        def construct(self, x):
143            y = self.op(x[0], x[1])
144            return self.op(x[0], y)
145
146    class MulTwoInput(nn.Cell):
147        def __init__(self):
148            super().__init__()
149            self.op = P.Mul()
150
151        @ms_function
152        def construct(self, x, y):
153            y = self.op(x, y)
154            return self.op(x, y)
155
156    class TwoInputTupleFinalNet(nn.Cell):
157        def __init__(self, funcs):
158            super().__init__()
159            self.funcs = funcs
160
161        @ms_function
162        def construct(self, i, inputa, inputb):
163            inputs = (inputa, inputb)
164            x = self.funcs[i](inputs)
165            return x
166
167    func1 = Add()
168    func2 = Mul()
169
170    funcs = (func1, func2)
171    net = TwoInputTupleFinalNet(funcs)
172
173    input_data = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
174    input2 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
175    i = Tensor(1, mstype.int32)
176    netout = net(i, input_data, input2)
177    net_good = MulTwoInput()
178    goodout = net_good(input_data, input2)
179    assert np.allclose(goodout.asnumpy(), netout.asnumpy(), 0, 0)
180
181
182@pytest.mark.level1
183@pytest.mark.platform_arm_ascend_training
184@pytest.mark.platform_x86_ascend_training
185@pytest.mark.env_onecard
186def test_imagenet():
187    class ImageGradients(nn.Cell):
188        def __init__(self):
189            super().__init__()
190            self.imagegradients = nn.ImageGradients()
191
192        def construct(self, inputs):
193            return self.imagegradients(inputs)
194
195    net = ImageGradients()
196    net_me = GradOfFirstInput(net, real_inputs_count=1)
197    net_me.set_train()
198    input_data = Tensor(np.ones([32, 16, 8, 8]), dtype=mstype.float32)
199    output_grad = (Tensor(np.ones([32, 16, 8, 8]), dtype=mstype.float32),
200                   Tensor(np.ones([32, 16, 8, 8]), dtype=mstype.float32))
201    net_me(input_data, *output_grad)
202