1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_stop_gradient """ 16import numpy as np 17import pytest 18 19import mindspore as ms 20import mindspore.common.dtype as mstype 21import mindspore.nn as nn 22from mindspore import Parameter, ParameterTuple 23from mindspore import Tensor 24from mindspore import context 25from mindspore.common.api import ms_function 26from mindspore.ops import composite as C 27from mindspore.ops import operations as P 28from mindspore.ops.functional import stop_gradient 29from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer 30from tests.security_utils import security_off_wrap 31from ..ut_filter import non_graph_engine 32from ....mindspore_test_framework.utils.bprop_util import bprop 33 34 35grad_by_list = C.GradOperation(get_by_list=True) 36grad_all = C.GradOperation(get_all=True) 37 38 39def setup_module(module): 40 context.set_context(mode=context.PYNATIVE_MODE) 41 42 43def stop_func(x, y): 44 """ stop_func""" 45 c = x * y 46 c_s = x + y 47 return c_s, c 48 49 50def stop_test1(x, y): 51 """ stop_test1 """ 52 c = x * y 53 c_s = stop_gradient(c) 54 return c_s 55 56 57def stop_test2(x, y): 58 """ stop_test2 """ 59 c = x * y 60 c_s = stop_gradient(c) 61 d = c_s + x * y 62 return d * y 63 64 65def stop_test3(x, y): 66 """ stop_test3 """ 67 x = x * y 68 z = stop_test1(x, y) 69 k = z * y 70 return k 71 72 73def stop_test5(x, y): 74 """ stop_test3 """ 75 x = x + y 76 o1, o2 = stop_func(x, y) 77 c = stop_gradient(o1) 78 c = o2 + c 79 return c 80 81 82def stop_test4(x, y): 83 """ stop_test4 """ 84 c = x + y 85 c_s = stop_gradient(c) 86 e = c + c_s 87 return e 88 89 90@ms_function 91def grad_stop_test(x, y): 92 """ grad_stop_test """ 93 return grad_all(stop_test2)(x, y) 94 95 96@ms_function 97def grad_stop_test1(x, y): 98 """ grad_stop_test1 """ 99 return grad_all(stop_test3)(x, y) 100 101 102@ms_function 103def grad_stop_test5(x, y): 104 """ grad_stop_test5 """ 105 return grad_all(stop_test5)(x, y) 106 107 108def test_stop(): 109 """ test_stop """ 110 print("test_stop:", grad_stop_test(1, 1)) 111 112 113def test_stop1(): 114 """ test_stop1 """ 115 print("test_stop1:", grad_stop_test1(2, 3)) 116 117 118def test_stop5(): 119 """ test_stop1 """ 120 print("test_stop5:", grad_stop_test5(2, 3)) 121 122 123class GradWrap(nn.Cell): 124 """ GradWrap definition """ 125 126 def __init__(self, network): 127 super(GradWrap, self).__init__() 128 self.network = network 129 self.weights = ParameterTuple(network.get_parameters()) 130 131 @ms_function 132 def construct(self, x, label): 133 weights = self.weights 134 return grad_by_list(self.network, weights)(x, label) 135 136 137@non_graph_engine 138def test_softmaxloss_grad(): 139 """ test_softmaxloss_grad """ 140 141 class NetWithLossClass(nn.Cell): 142 """ NetWithLossClass definition """ 143 144 def __init__(self, network): 145 super(NetWithLossClass, self).__init__() 146 self.loss = nn.SoftmaxCrossEntropyWithLogits() 147 self.network = network 148 149 @ms_function 150 def construct(self, x, label): 151 predict = self.network(x) 152 return self.loss(predict, label) 153 154 class Net(nn.Cell): 155 """ Net definition """ 156 157 def __init__(self): 158 super(Net, self).__init__() 159 self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight") 160 self.bias = Parameter(Tensor(np.ones([10]).astype(np.float32)), name="bias") 161 self.fc = P.MatMul() 162 self.fc2 = nn.Dense(10, 10) 163 self.biasAdd = P.BiasAdd() 164 self.relu = nn.ReLU() 165 self.cast = P.Cast() 166 167 @ms_function 168 def construct(self, x): 169 x = self.fc(x, self.weight) 170 x = self.cast(x, mstype.float32) 171 x = self.relu(self.fc2(x)) 172 x = self.fc2(x) 173 x = stop_gradient(x) 174 x = self.biasAdd(x, self.bias) 175 return x 176 177 net = GradWrap(NetWithLossClass(Net())) 178 179 predict = Tensor(np.ones([1, 64]).astype(np.float32)) 180 label = Tensor(np.zeros([1, 10]).astype(np.float32)) 181 print("pynative run") 182 out = net(predict, label) 183 print("out:", out) 184 185 186def test_stop_gradient_1(): 187 class Mul(nn.Cell): 188 def __init__(self): 189 super(Mul, self).__init__() 190 191 @ms_function 192 def construct(self, x, y): 193 ret = x * y 194 ret = stop_gradient(ret) 195 return ret 196 197 dx, dy = bprop(Mul(), Tensor(np.ones([2, 2]).astype(np.float32)), 198 Tensor(np.ones([2, 2]).astype(np.float32)), wrt=['inputs']) 199 expect = np.zeros([2, 2]) 200 assert (dx.asnumpy() == expect).all() 201 assert (dy.asnumpy() == expect).all() 202 203 204def test_stop_gradient_2(): 205 class Mul(nn.Cell): 206 def __init__(self): 207 super(Mul, self).__init__() 208 209 @ms_function 210 def construct(self, x, y): 211 c = x * y 212 z = x * y 213 return c, z 214 215 class MulAdd(nn.Cell): 216 def __init__(self): 217 super(MulAdd, self).__init__() 218 self.mul = Mul() 219 220 @ms_function 221 def construct(self, x, y): 222 u = x + y 223 v = x - y 224 c, z = self.mul(u, v) 225 c = stop_gradient(c) 226 ret1 = c + x + y 227 ret2 = z + y + y 228 return ret1, ret2 229 230 dx = bprop(MulAdd(), Tensor(np.ones([2, 2]).astype(np.float32)), 231 Tensor(np.ones([2, 2]).astype(np.float32))) 232 expect = np.array([[3.0, 3.0], [3.0, 3.0]]) 233 assert (dx.asnumpy() == expect).all() 234 235 236def test_stop_gradient_3(): 237 class TupleGetItem(nn.Cell): 238 def __init__(self): 239 super(TupleGetItem, self).__init__() 240 241 @ms_function 242 def construct(self, x1, x2, x3, x4, x5): 243 z1 = x1 + x1 244 z2 = x1 * x2 245 t = (z1, z2, x3, x4, x5) 246 z2 = t[1] 247 z2 = stop_gradient(z2) 248 return z1, z2, x3, x4, x5 249 250 dx = bprop(TupleGetItem(), 251 Tensor(np.ones([2]).astype(np.float32)), 252 Tensor(np.ones([2]).astype(np.float32)), 253 Tensor(np.ones([2]).astype(np.float32)), 254 Tensor(np.ones([2]).astype(np.float32)), 255 Tensor(np.ones([2]).astype(np.float32))) 256 expect = np.array([[2.0, 2.0], [2.0, 2.0]]) 257 assert (dx.asnumpy() == expect).all() 258 259 260def test_stop_gradient_4(): 261 def stop_test(x): 262 return stop_gradient(x) 263 264 assert grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (0,) 265 266 267def test_stop_gradient_5(): 268 def stop_test(x): 269 y = x + x 270 y = stop_gradient(y) 271 ret = x + y 272 return ret 273 274 assert grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (1,) 275 276 277def test_stop_gradient_6(): 278 def stop_test(x, y): 279 ret = x * y 280 ret = stop_gradient(ret) 281 return ret 282 283 assert grad_all(stop_test)(Tensor(1, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (0, 0) 284 285 286class PrimWithMultiOutputs(PrimitiveWithInfer): 287 @prim_attr_register 288 def __init__(self): 289 """init""" 290 291 def __call__(self, x, y): 292 """Implement by vm mode.""" 293 return x, y 294 295 def infer_shape(self, x_shape, y_shape): 296 return x_shape, y_shape 297 298 def infer_dtype(self, x_type, y_type): 299 return x_type, y_type 300 301 def get_bprop(self): 302 def bprop(x, y, out, dout): 303 return (dout[0], dout[1]) 304 305 return bprop 306 307 308def test_stop_gradient_7(): 309 class PrimWithMultiOutputs_(nn.Cell): 310 def __init__(self): 311 super(PrimWithMultiOutputs_, self).__init__() 312 self.prim_with_multi_outputs = PrimWithMultiOutputs() 313 314 @ms_function 315 def construct(self, x1, x2): 316 x1, x2 = self.prim_with_multi_outputs(x1, x2) 317 x1 = stop_gradient(x1) 318 return x1, x2 319 320 dx, dy = bprop(PrimWithMultiOutputs_(), Tensor(np.ones([2]).astype(np.float32)), 321 Tensor(np.ones([2]).astype(np.float32)), wrt=['inputs']) 322 expect_dx = np.zeros([2]) 323 expect_dy = np.ones([2]) 324 assert (dx.asnumpy() == expect_dx).all() 325 assert (dy.asnumpy() == expect_dy).all() 326 327 328def test_stop_gradient_8(): 329 class PrimWithMultiOutputs_(nn.Cell): 330 def __init__(self): 331 super(PrimWithMultiOutputs_, self).__init__() 332 self.prim_with_multi_output = PrimWithMultiOutputs() 333 334 @ms_function 335 def construct(self, x1, x2): 336 x1, x2 = stop_gradient(self.prim_with_multi_output(x1, x2)) 337 return x1, x2 338 339 dx, dy = bprop(PrimWithMultiOutputs_(), Tensor(np.ones([2]).astype(np.float32)), 340 Tensor(np.ones([2]).astype(np.float32)), wrt=['inputs']) 341 expect_dx = np.zeros([2]) 342 expect_dy = np.zeros([2]) 343 assert (dx.asnumpy() == expect_dx).all() 344 assert (dy.asnumpy() == expect_dy).all() 345 346 347def test_stop_gradient_9(): 348 class Mul(nn.Cell): 349 def __init__(self): 350 super(Mul, self).__init__() 351 352 @ms_function 353 def construct(self, x, y): 354 c = x * y 355 z = x * y 356 return c, z 357 358 class MulAdd(nn.Cell): 359 def __init__(self): 360 super(MulAdd, self).__init__() 361 self.mul = Mul() 362 363 @ms_function 364 def construct(self, x, y): 365 u = x + y 366 v = x - y 367 c, z = self.mul(u, v) 368 c1 = stop_gradient(c) 369 c2 = c 370 ret1 = c1 + x + y + c2 371 ret2 = z + y + y 372 return ret1, ret2 373 374 dx = bprop(MulAdd(), Tensor(np.ones([2, 2]).astype(np.float32)), 375 Tensor(np.ones([2, 2]).astype(np.float32))) 376 expect = np.array([[5.0, 5.0], [5.0, 5.0]]) 377 assert (dx.asnumpy() == expect).all() 378 379 380class PrimWithNoBprop(PrimitiveWithInfer): 381 @prim_attr_register 382 def __init__(self): 383 """init""" 384 385 def __call__(self, x, y): 386 """Implement by vm mode.""" 387 return x, y 388 389 def infer_shape(self, x_shape, y_shape): 390 return x_shape, y_shape 391 392 def infer_dtype(self, x_type, y_type): 393 return x_type, y_type 394 395 396def test_stop_gradient_10(): 397 class PrimWithNoBprop_(nn.Cell): 398 def __init__(self): 399 super(PrimWithNoBprop_, self).__init__() 400 self.prim_with_no_bprop = PrimWithNoBprop() 401 402 @ms_function 403 def construct(self, x, y): 404 x = x * y 405 x, y = self.prim_with_no_bprop(x, y) 406 x = stop_gradient(x) 407 y = stop_gradient(y) 408 return x, y 409 410 dx = bprop(PrimWithNoBprop_(), Tensor(np.ones([2]).astype(np.float32)), 411 Tensor(np.ones([2]).astype(np.float32))) 412 expect_dx = np.zeros([2]) 413 assert (dx.asnumpy() == expect_dx).all() 414 415 416def test_stop_gradient_11(): 417 class PrimWithNoBprop_(nn.Cell): 418 def __init__(self): 419 super(PrimWithNoBprop_, self).__init__() 420 self.prim_with_no_bprop = PrimWithNoBprop() 421 422 @ms_function 423 def construct(self, x, y): 424 x, y = self.prim_with_no_bprop(x, y) 425 x = stop_gradient(x) 426 return x, y 427 428 with pytest.raises(RuntimeError): 429 bprop(PrimWithNoBprop_(), Tensor(np.ones([2]).astype(np.float32)), 430 Tensor(np.ones([2]).astype(np.float32))) 431 432 433@security_off_wrap 434def test_stop_print(): 435 class StopPrint(nn.Cell): 436 def __init__(self): 437 super(StopPrint, self).__init__() 438 self.printm = P.Print() 439 440 def construct(self, x, y): 441 self.printm("StopPrint", x) 442 self.printm(y) 443 return x, y 444 445 grad_all(StopPrint())(Tensor(np.ones([2]).astype(np.float32)), 446 Tensor(np.ones([2]).astype(np.float32))) 447