# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ test_stop_gradient """ import numpy as np import pytest import mindspore as ms import mindspore.common.dtype as mstype import mindspore.nn as nn from mindspore import Parameter, ParameterTuple from mindspore import Tensor from mindspore import context from mindspore.common.api import ms_function from mindspore.ops import composite as C from mindspore.ops import operations as P from mindspore.ops.functional import stop_gradient from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer from tests.security_utils import security_off_wrap from ..ut_filter import non_graph_engine from ....mindspore_test_framework.utils.bprop_util import bprop grad_by_list = C.GradOperation(get_by_list=True) grad_all = C.GradOperation(get_all=True) def setup_module(module): context.set_context(mode=context.PYNATIVE_MODE) def stop_func(x, y): """ stop_func""" c = x * y c_s = x + y return c_s, c def stop_test1(x, y): """ stop_test1 """ c = x * y c_s = stop_gradient(c) return c_s def stop_test2(x, y): """ stop_test2 """ c = x * y c_s = stop_gradient(c) d = c_s + x * y return d * y def stop_test3(x, y): """ stop_test3 """ x = x * y z = stop_test1(x, y) k = z * y return k def stop_test5(x, y): """ stop_test3 """ x = x + y o1, o2 = stop_func(x, y) c = stop_gradient(o1) c = o2 + c return c def stop_test4(x, y): """ stop_test4 """ c = x + y c_s = stop_gradient(c) e = c + c_s return e @ms_function def grad_stop_test(x, y): """ grad_stop_test """ return grad_all(stop_test2)(x, y) @ms_function def grad_stop_test1(x, y): """ grad_stop_test1 """ return grad_all(stop_test3)(x, y) @ms_function def grad_stop_test5(x, y): """ grad_stop_test5 """ return grad_all(stop_test5)(x, y) def test_stop(): """ test_stop """ print("test_stop:", grad_stop_test(1, 1)) def test_stop1(): """ test_stop1 """ print("test_stop1:", grad_stop_test1(2, 3)) def test_stop5(): """ test_stop1 """ print("test_stop5:", grad_stop_test5(2, 3)) class GradWrap(nn.Cell): """ GradWrap definition """ def __init__(self, network): super(GradWrap, self).__init__() self.network = network self.weights = ParameterTuple(network.get_parameters()) @ms_function def construct(self, x, label): weights = self.weights return grad_by_list(self.network, weights)(x, label) @non_graph_engine def test_softmaxloss_grad(): """ test_softmaxloss_grad """ class NetWithLossClass(nn.Cell): """ NetWithLossClass definition """ def __init__(self, network): super(NetWithLossClass, self).__init__() self.loss = nn.SoftmaxCrossEntropyWithLogits() self.network = network @ms_function def construct(self, x, label): predict = self.network(x) return self.loss(predict, label) class Net(nn.Cell): """ Net definition """ def __init__(self): super(Net, self).__init__() self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight") self.bias = Parameter(Tensor(np.ones([10]).astype(np.float32)), name="bias") self.fc = P.MatMul() self.fc2 = nn.Dense(10, 10) self.biasAdd = P.BiasAdd() self.relu = nn.ReLU() self.cast = P.Cast() @ms_function def construct(self, x): x = self.fc(x, self.weight) x = self.cast(x, mstype.float32) x = self.relu(self.fc2(x)) x = self.fc2(x) x = stop_gradient(x) x = self.biasAdd(x, self.bias) return x net = GradWrap(NetWithLossClass(Net())) predict = Tensor(np.ones([1, 64]).astype(np.float32)) label = Tensor(np.zeros([1, 10]).astype(np.float32)) print("pynative run") out = net(predict, label) print("out:", out) def test_stop_gradient_1(): class Mul(nn.Cell): def __init__(self): super(Mul, self).__init__() @ms_function def construct(self, x, y): ret = x * y ret = stop_gradient(ret) return ret dx, dy = bprop(Mul(), Tensor(np.ones([2, 2]).astype(np.float32)), Tensor(np.ones([2, 2]).astype(np.float32)), wrt=['inputs']) expect = np.zeros([2, 2]) assert (dx.asnumpy() == expect).all() assert (dy.asnumpy() == expect).all() def test_stop_gradient_2(): class Mul(nn.Cell): def __init__(self): super(Mul, self).__init__() @ms_function def construct(self, x, y): c = x * y z = x * y return c, z class MulAdd(nn.Cell): def __init__(self): super(MulAdd, self).__init__() self.mul = Mul() @ms_function def construct(self, x, y): u = x + y v = x - y c, z = self.mul(u, v) c = stop_gradient(c) ret1 = c + x + y ret2 = z + y + y return ret1, ret2 dx = bprop(MulAdd(), Tensor(np.ones([2, 2]).astype(np.float32)), Tensor(np.ones([2, 2]).astype(np.float32))) expect = np.array([[3.0, 3.0], [3.0, 3.0]]) assert (dx.asnumpy() == expect).all() def test_stop_gradient_3(): class TupleGetItem(nn.Cell): def __init__(self): super(TupleGetItem, self).__init__() @ms_function def construct(self, x1, x2, x3, x4, x5): z1 = x1 + x1 z2 = x1 * x2 t = (z1, z2, x3, x4, x5) z2 = t[1] z2 = stop_gradient(z2) return z1, z2, x3, x4, x5 dx = bprop(TupleGetItem(), Tensor(np.ones([2]).astype(np.float32)), Tensor(np.ones([2]).astype(np.float32)), Tensor(np.ones([2]).astype(np.float32)), Tensor(np.ones([2]).astype(np.float32)), Tensor(np.ones([2]).astype(np.float32))) expect = np.array([[2.0, 2.0], [2.0, 2.0]]) assert (dx.asnumpy() == expect).all() def test_stop_gradient_4(): def stop_test(x): return stop_gradient(x) assert grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (0,) def test_stop_gradient_5(): def stop_test(x): y = x + x y = stop_gradient(y) ret = x + y return ret assert grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (1,) def test_stop_gradient_6(): def stop_test(x, y): ret = x * y ret = stop_gradient(ret) return ret assert grad_all(stop_test)(Tensor(1, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (0, 0) class PrimWithMultiOutputs(PrimitiveWithInfer): @prim_attr_register def __init__(self): """init""" def __call__(self, x, y): """Implement by vm mode.""" return x, y def infer_shape(self, x_shape, y_shape): return x_shape, y_shape def infer_dtype(self, x_type, y_type): return x_type, y_type def get_bprop(self): def bprop(x, y, out, dout): return (dout[0], dout[1]) return bprop def test_stop_gradient_7(): class PrimWithMultiOutputs_(nn.Cell): def __init__(self): super(PrimWithMultiOutputs_, self).__init__() self.prim_with_multi_outputs = PrimWithMultiOutputs() @ms_function def construct(self, x1, x2): x1, x2 = self.prim_with_multi_outputs(x1, x2) x1 = stop_gradient(x1) return x1, x2 dx, dy = bprop(PrimWithMultiOutputs_(), Tensor(np.ones([2]).astype(np.float32)), Tensor(np.ones([2]).astype(np.float32)), wrt=['inputs']) expect_dx = np.zeros([2]) expect_dy = np.ones([2]) assert (dx.asnumpy() == expect_dx).all() assert (dy.asnumpy() == expect_dy).all() def test_stop_gradient_8(): class PrimWithMultiOutputs_(nn.Cell): def __init__(self): super(PrimWithMultiOutputs_, self).__init__() self.prim_with_multi_output = PrimWithMultiOutputs() @ms_function def construct(self, x1, x2): x1, x2 = stop_gradient(self.prim_with_multi_output(x1, x2)) return x1, x2 dx, dy = bprop(PrimWithMultiOutputs_(), Tensor(np.ones([2]).astype(np.float32)), Tensor(np.ones([2]).astype(np.float32)), wrt=['inputs']) expect_dx = np.zeros([2]) expect_dy = np.zeros([2]) assert (dx.asnumpy() == expect_dx).all() assert (dy.asnumpy() == expect_dy).all() def test_stop_gradient_9(): class Mul(nn.Cell): def __init__(self): super(Mul, self).__init__() @ms_function def construct(self, x, y): c = x * y z = x * y return c, z class MulAdd(nn.Cell): def __init__(self): super(MulAdd, self).__init__() self.mul = Mul() @ms_function def construct(self, x, y): u = x + y v = x - y c, z = self.mul(u, v) c1 = stop_gradient(c) c2 = c ret1 = c1 + x + y + c2 ret2 = z + y + y return ret1, ret2 dx = bprop(MulAdd(), Tensor(np.ones([2, 2]).astype(np.float32)), Tensor(np.ones([2, 2]).astype(np.float32))) expect = np.array([[5.0, 5.0], [5.0, 5.0]]) assert (dx.asnumpy() == expect).all() class PrimWithNoBprop(PrimitiveWithInfer): @prim_attr_register def __init__(self): """init""" def __call__(self, x, y): """Implement by vm mode.""" return x, y def infer_shape(self, x_shape, y_shape): return x_shape, y_shape def infer_dtype(self, x_type, y_type): return x_type, y_type def test_stop_gradient_10(): class PrimWithNoBprop_(nn.Cell): def __init__(self): super(PrimWithNoBprop_, self).__init__() self.prim_with_no_bprop = PrimWithNoBprop() @ms_function def construct(self, x, y): x = x * y x, y = self.prim_with_no_bprop(x, y) x = stop_gradient(x) y = stop_gradient(y) return x, y dx = bprop(PrimWithNoBprop_(), Tensor(np.ones([2]).astype(np.float32)), Tensor(np.ones([2]).astype(np.float32))) expect_dx = np.zeros([2]) assert (dx.asnumpy() == expect_dx).all() def test_stop_gradient_11(): class PrimWithNoBprop_(nn.Cell): def __init__(self): super(PrimWithNoBprop_, self).__init__() self.prim_with_no_bprop = PrimWithNoBprop() @ms_function def construct(self, x, y): x, y = self.prim_with_no_bprop(x, y) x = stop_gradient(x) return x, y with pytest.raises(RuntimeError): bprop(PrimWithNoBprop_(), Tensor(np.ones([2]).astype(np.float32)), Tensor(np.ones([2]).astype(np.float32))) @security_off_wrap def test_stop_print(): class StopPrint(nn.Cell): def __init__(self): super(StopPrint, self).__init__() self.printm = P.Print() def construct(self, x, y): self.printm("StopPrint", x) self.printm(y) return x, y grad_all(StopPrint())(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.ones([2]).astype(np.float32)))