• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_stop_gradient """
16import numpy as np
17import pytest
18
19import mindspore as ms
20import mindspore.common.dtype as mstype
21import mindspore.nn as nn
22from mindspore import Parameter, ParameterTuple
23from mindspore import Tensor
24from mindspore import context
25from mindspore.common.api import ms_function
26from mindspore.ops import composite as C
27from mindspore.ops import operations as P
28from mindspore.ops.functional import stop_gradient
29from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
30from tests.security_utils import security_off_wrap
31from ..ut_filter import non_graph_engine
32from ....mindspore_test_framework.utils.bprop_util import bprop
33
34
35grad_by_list = C.GradOperation(get_by_list=True)
36grad_all = C.GradOperation(get_all=True)
37
38
39def setup_module(module):
40    context.set_context(mode=context.PYNATIVE_MODE)
41
42
43def stop_func(x, y):
44    """ stop_func"""
45    c = x * y
46    c_s = x + y
47    return c_s, c
48
49
50def stop_test1(x, y):
51    """ stop_test1 """
52    c = x * y
53    c_s = stop_gradient(c)
54    return c_s
55
56
57def stop_test2(x, y):
58    """ stop_test2 """
59    c = x * y
60    c_s = stop_gradient(c)
61    d = c_s + x * y
62    return d * y
63
64
65def stop_test3(x, y):
66    """ stop_test3 """
67    x = x * y
68    z = stop_test1(x, y)
69    k = z * y
70    return k
71
72
73def stop_test5(x, y):
74    """ stop_test3 """
75    x = x + y
76    o1, o2 = stop_func(x, y)
77    c = stop_gradient(o1)
78    c = o2 + c
79    return c
80
81
82def stop_test4(x, y):
83    """ stop_test4 """
84    c = x + y
85    c_s = stop_gradient(c)
86    e = c + c_s
87    return e
88
89
90@ms_function
91def grad_stop_test(x, y):
92    """ grad_stop_test """
93    return grad_all(stop_test2)(x, y)
94
95
96@ms_function
97def grad_stop_test1(x, y):
98    """ grad_stop_test1 """
99    return grad_all(stop_test3)(x, y)
100
101
102@ms_function
103def grad_stop_test5(x, y):
104    """ grad_stop_test5 """
105    return grad_all(stop_test5)(x, y)
106
107
108def test_stop():
109    """ test_stop """
110    print("test_stop:", grad_stop_test(1, 1))
111
112
113def test_stop1():
114    """ test_stop1 """
115    print("test_stop1:", grad_stop_test1(2, 3))
116
117
118def test_stop5():
119    """ test_stop1 """
120    print("test_stop5:", grad_stop_test5(2, 3))
121
122
123class GradWrap(nn.Cell):
124    """ GradWrap definition """
125
126    def __init__(self, network):
127        super(GradWrap, self).__init__()
128        self.network = network
129        self.weights = ParameterTuple(network.get_parameters())
130
131    @ms_function
132    def construct(self, x, label):
133        weights = self.weights
134        return grad_by_list(self.network, weights)(x, label)
135
136
137@non_graph_engine
138def test_softmaxloss_grad():
139    """ test_softmaxloss_grad """
140
141    class NetWithLossClass(nn.Cell):
142        """ NetWithLossClass definition """
143
144        def __init__(self, network):
145            super(NetWithLossClass, self).__init__()
146            self.loss = nn.SoftmaxCrossEntropyWithLogits()
147            self.network = network
148
149        @ms_function
150        def construct(self, x, label):
151            predict = self.network(x)
152            return self.loss(predict, label)
153
154    class Net(nn.Cell):
155        """ Net definition """
156
157        def __init__(self):
158            super(Net, self).__init__()
159            self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight")
160            self.bias = Parameter(Tensor(np.ones([10]).astype(np.float32)), name="bias")
161            self.fc = P.MatMul()
162            self.fc2 = nn.Dense(10, 10)
163            self.biasAdd = P.BiasAdd()
164            self.relu = nn.ReLU()
165            self.cast = P.Cast()
166
167        @ms_function
168        def construct(self, x):
169            x = self.fc(x, self.weight)
170            x = self.cast(x, mstype.float32)
171            x = self.relu(self.fc2(x))
172            x = self.fc2(x)
173            x = stop_gradient(x)
174            x = self.biasAdd(x, self.bias)
175            return x
176
177    net = GradWrap(NetWithLossClass(Net()))
178
179    predict = Tensor(np.ones([1, 64]).astype(np.float32))
180    label = Tensor(np.zeros([1, 10]).astype(np.float32))
181    print("pynative run")
182    out = net(predict, label)
183    print("out:", out)
184
185
186def test_stop_gradient_1():
187    class Mul(nn.Cell):
188        def __init__(self):
189            super(Mul, self).__init__()
190
191        @ms_function
192        def construct(self, x, y):
193            ret = x * y
194            ret = stop_gradient(ret)
195            return ret
196
197    dx, dy = bprop(Mul(), Tensor(np.ones([2, 2]).astype(np.float32)),
198                   Tensor(np.ones([2, 2]).astype(np.float32)), wrt=['inputs'])
199    expect = np.zeros([2, 2])
200    assert (dx.asnumpy() == expect).all()
201    assert (dy.asnumpy() == expect).all()
202
203
204def test_stop_gradient_2():
205    class Mul(nn.Cell):
206        def __init__(self):
207            super(Mul, self).__init__()
208
209        @ms_function
210        def construct(self, x, y):
211            c = x * y
212            z = x * y
213            return c, z
214
215    class MulAdd(nn.Cell):
216        def __init__(self):
217            super(MulAdd, self).__init__()
218            self.mul = Mul()
219
220        @ms_function
221        def construct(self, x, y):
222            u = x + y
223            v = x - y
224            c, z = self.mul(u, v)
225            c = stop_gradient(c)
226            ret1 = c + x + y
227            ret2 = z + y + y
228            return ret1, ret2
229
230    dx = bprop(MulAdd(), Tensor(np.ones([2, 2]).astype(np.float32)),
231               Tensor(np.ones([2, 2]).astype(np.float32)))
232    expect = np.array([[3.0, 3.0], [3.0, 3.0]])
233    assert (dx.asnumpy() == expect).all()
234
235
236def test_stop_gradient_3():
237    class TupleGetItem(nn.Cell):
238        def __init__(self):
239            super(TupleGetItem, self).__init__()
240
241        @ms_function
242        def construct(self, x1, x2, x3, x4, x5):
243            z1 = x1 + x1
244            z2 = x1 * x2
245            t = (z1, z2, x3, x4, x5)
246            z2 = t[1]
247            z2 = stop_gradient(z2)
248            return z1, z2, x3, x4, x5
249
250    dx = bprop(TupleGetItem(),
251               Tensor(np.ones([2]).astype(np.float32)),
252               Tensor(np.ones([2]).astype(np.float32)),
253               Tensor(np.ones([2]).astype(np.float32)),
254               Tensor(np.ones([2]).astype(np.float32)),
255               Tensor(np.ones([2]).astype(np.float32)))
256    expect = np.array([[2.0, 2.0], [2.0, 2.0]])
257    assert (dx.asnumpy() == expect).all()
258
259
260def test_stop_gradient_4():
261    def stop_test(x):
262        return stop_gradient(x)
263
264    assert grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (0,)
265
266
267def test_stop_gradient_5():
268    def stop_test(x):
269        y = x + x
270        y = stop_gradient(y)
271        ret = x + y
272        return ret
273
274    assert grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (1,)
275
276
277def test_stop_gradient_6():
278    def stop_test(x, y):
279        ret = x * y
280        ret = stop_gradient(ret)
281        return ret
282
283    assert grad_all(stop_test)(Tensor(1, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (0, 0)
284
285
286class PrimWithMultiOutputs(PrimitiveWithInfer):
287    @prim_attr_register
288    def __init__(self):
289        """init"""
290
291    def __call__(self, x, y):
292        """Implement by vm mode."""
293        return x, y
294
295    def infer_shape(self, x_shape, y_shape):
296        return x_shape, y_shape
297
298    def infer_dtype(self, x_type, y_type):
299        return x_type, y_type
300
301    def get_bprop(self):
302        def bprop(x, y, out, dout):
303            return (dout[0], dout[1])
304
305        return bprop
306
307
308def test_stop_gradient_7():
309    class PrimWithMultiOutputs_(nn.Cell):
310        def __init__(self):
311            super(PrimWithMultiOutputs_, self).__init__()
312            self.prim_with_multi_outputs = PrimWithMultiOutputs()
313
314        @ms_function
315        def construct(self, x1, x2):
316            x1, x2 = self.prim_with_multi_outputs(x1, x2)
317            x1 = stop_gradient(x1)
318            return x1, x2
319
320    dx, dy = bprop(PrimWithMultiOutputs_(), Tensor(np.ones([2]).astype(np.float32)),
321                   Tensor(np.ones([2]).astype(np.float32)), wrt=['inputs'])
322    expect_dx = np.zeros([2])
323    expect_dy = np.ones([2])
324    assert (dx.asnumpy() == expect_dx).all()
325    assert (dy.asnumpy() == expect_dy).all()
326
327
328def test_stop_gradient_8():
329    class PrimWithMultiOutputs_(nn.Cell):
330        def __init__(self):
331            super(PrimWithMultiOutputs_, self).__init__()
332            self.prim_with_multi_output = PrimWithMultiOutputs()
333
334        @ms_function
335        def construct(self, x1, x2):
336            x1, x2 = stop_gradient(self.prim_with_multi_output(x1, x2))
337            return x1, x2
338
339    dx, dy = bprop(PrimWithMultiOutputs_(), Tensor(np.ones([2]).astype(np.float32)),
340                   Tensor(np.ones([2]).astype(np.float32)), wrt=['inputs'])
341    expect_dx = np.zeros([2])
342    expect_dy = np.zeros([2])
343    assert (dx.asnumpy() == expect_dx).all()
344    assert (dy.asnumpy() == expect_dy).all()
345
346
347def test_stop_gradient_9():
348    class Mul(nn.Cell):
349        def __init__(self):
350            super(Mul, self).__init__()
351
352        @ms_function
353        def construct(self, x, y):
354            c = x * y
355            z = x * y
356            return c, z
357
358    class MulAdd(nn.Cell):
359        def __init__(self):
360            super(MulAdd, self).__init__()
361            self.mul = Mul()
362
363        @ms_function
364        def construct(self, x, y):
365            u = x + y
366            v = x - y
367            c, z = self.mul(u, v)
368            c1 = stop_gradient(c)
369            c2 = c
370            ret1 = c1 + x + y + c2
371            ret2 = z + y + y
372            return ret1, ret2
373
374    dx = bprop(MulAdd(), Tensor(np.ones([2, 2]).astype(np.float32)),
375               Tensor(np.ones([2, 2]).astype(np.float32)))
376    expect = np.array([[5.0, 5.0], [5.0, 5.0]])
377    assert (dx.asnumpy() == expect).all()
378
379
380class PrimWithNoBprop(PrimitiveWithInfer):
381    @prim_attr_register
382    def __init__(self):
383        """init"""
384
385    def __call__(self, x, y):
386        """Implement by vm mode."""
387        return x, y
388
389    def infer_shape(self, x_shape, y_shape):
390        return x_shape, y_shape
391
392    def infer_dtype(self, x_type, y_type):
393        return x_type, y_type
394
395
396def test_stop_gradient_10():
397    class PrimWithNoBprop_(nn.Cell):
398        def __init__(self):
399            super(PrimWithNoBprop_, self).__init__()
400            self.prim_with_no_bprop = PrimWithNoBprop()
401
402        @ms_function
403        def construct(self, x, y):
404            x = x * y
405            x, y = self.prim_with_no_bprop(x, y)
406            x = stop_gradient(x)
407            y = stop_gradient(y)
408            return x, y
409
410    dx = bprop(PrimWithNoBprop_(), Tensor(np.ones([2]).astype(np.float32)),
411               Tensor(np.ones([2]).astype(np.float32)))
412    expect_dx = np.zeros([2])
413    assert (dx.asnumpy() == expect_dx).all()
414
415
416def test_stop_gradient_11():
417    class PrimWithNoBprop_(nn.Cell):
418        def __init__(self):
419            super(PrimWithNoBprop_, self).__init__()
420            self.prim_with_no_bprop = PrimWithNoBprop()
421
422        @ms_function
423        def construct(self, x, y):
424            x, y = self.prim_with_no_bprop(x, y)
425            x = stop_gradient(x)
426            return x, y
427
428    with pytest.raises(RuntimeError):
429        bprop(PrimWithNoBprop_(), Tensor(np.ones([2]).astype(np.float32)),
430              Tensor(np.ones([2]).astype(np.float32)))
431
432
433@security_off_wrap
434def test_stop_print():
435    class StopPrint(nn.Cell):
436        def __init__(self):
437            super(StopPrint, self).__init__()
438            self.printm = P.Print()
439
440        def construct(self, x, y):
441            self.printm("StopPrint", x)
442            self.printm(y)
443            return x, y
444
445    grad_all(StopPrint())(Tensor(np.ones([2]).astype(np.float32)),
446                          Tensor(np.ones([2]).astype(np.float32)))
447