1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import pytest 17 18import mindspore as ms 19import mindspore.ops.composite as C 20from mindspore import context 21import mindspore.nn as nn 22from mindspore.ops import operations as P 23from mindspore.ops import functional as F 24from mindspore import Tensor 25from mindspore.common.parameter import Parameter, ParameterTuple 26from tests.security_utils import security_off_wrap 27 28grad_all_list = C.GradOperation(get_all=True, get_by_list=True) 29grad_by_list = C.GradOperation(get_by_list=True) 30 31context.set_context(mode=context.GRAPH_MODE) 32 33 34def test_load_grad(): 35 class LoadNet(nn.Cell): 36 def __init__(self): 37 super().__init__() 38 self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') 39 40 def construct(self, x, y): 41 x = x * y * self.z 42 return x 43 44 x = Tensor(np.array([2.0], np.float32)) 45 y = Tensor(np.array([3.0], np.float32)) 46 load_net = LoadNet() 47 grad_net = grad_all_list( 48 load_net, ParameterTuple(load_net.trainable_params())) 49 print(grad_net(x, y)) 50 51 52def test_assign_only_grad(): 53 class AssignOnlyNet(nn.Cell): 54 def __init__(self): 55 super().__init__() 56 self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') 57 58 def construct(self, x, y): 59 self.z = x 60 x = x * y 61 return x 62 63 class GradNet(nn.Cell): 64 def __init__(self, net): 65 super(GradNet, self).__init__() 66 self.net = net 67 self.parameter_tuple = ParameterTuple(self.trainable_params()) 68 69 def construct(self, x, y): 70 return grad_all_list(self.net, self.parameter_tuple)(x, y) 71 72 assign_net = AssignOnlyNet() 73 net = GradNet(assign_net) 74 x = Tensor(np.array([2.0], np.float32)) 75 y = Tensor(np.array([3.0], np.float32)) 76 print(net(x, y)) 77 78 79def test_load_assign_grad(): 80 class AssignNet(nn.Cell): 81 def __init__(self): 82 super().__init__() 83 self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') 84 self.assign = P.Assign() 85 86 def construct(self, x, y): 87 x = x * self.z 88 self.assign(self.z, x) 89 out = y * self.z 90 return out 91 92 class GradNet(nn.Cell): 93 def __init__(self, net): 94 super(GradNet, self).__init__() 95 self.net = net 96 self.parameter_tuple = ParameterTuple(net.trainable_params()) 97 98 def construct(self, x, y): 99 return grad_all_list(self.net, self.parameter_tuple)(x, y) 100 101 assign_net = AssignNet() 102 net = GradNet(assign_net) 103 x = Tensor(np.array([2.0], np.float32)) 104 y = Tensor(np.array([3.0], np.float32)) 105 print(net(x, y)) 106 107 108def test_insert_gradient_of(): 109 class InsertGradientNet(nn.Cell): 110 def __init__(self): 111 super(InsertGradientNet, self).__init__() 112 self.gather = P.GatherV2() 113 self.damping = Tensor(np.array([0.03, 0.03], np.float32)) 114 self.cov_step = Parameter(0, name="cov_step", requires_grad=False) 115 self.freq = Tensor(278, ms.int32) 116 self.getG = P.InsertGradientOf(self.save_gradient) 117 118 def save_gradient(self, dout): 119 self.cov_step = self.cov_step + self.freq 120 return dout 121 122 def construct(self, x): 123 self.gather(self.damping, self.cov_step, 0) 124 out = P.ReLU()(x) 125 out = self.getG(out) 126 out = self.getG(out) 127 return out 128 129 net = InsertGradientNet() 130 input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype(np.float32) 131 grad_net = grad_all_list(net, ParameterTuple(net.trainable_params())) 132 print(grad_net(Tensor(input_data))) 133 134 135@security_off_wrap 136def test_user_defined_bprop(): 137 class UserDefinedNet(nn.Cell): 138 def __init__(self): 139 super().__init__() 140 self.print = P.Print() 141 142 def construct(self, x, y): 143 out = x * y 144 return out 145 146 def bprop(self, x, y, out, dout): 147 self.print(out) 148 out = x * y 149 self.print(out) 150 self.print(dout) 151 return y, x 152 153 class GradNet(nn.Cell): 154 def __init__(self, net): 155 super(GradNet, self).__init__() 156 self.net = net 157 self.parameter_tuple = ParameterTuple(net.trainable_params()) 158 159 def construct(self, x, y): 160 return grad_all_list(self.net, self.parameter_tuple)(x, y) 161 162 user_defined_net = UserDefinedNet() 163 net = GradNet(user_defined_net) 164 x = Tensor(np.array([2.0], np.float32)) 165 y = Tensor(np.array([3.0], np.float32)) 166 print(net(x, y)) 167 168 169# user defined bprop don't have the same size of parameters with primal's 170@security_off_wrap 171def test_user_defined_bad_bprop(): 172 class UserDefinedNet(nn.Cell): 173 def __init__(self): 174 super().__init__() 175 self.print = P.Print() 176 177 def construct(self, x, y): 178 out = x * y 179 return out 180 181 def bprop(self, x, out, dout): 182 self.print(out) 183 out = x 184 self.print(out) 185 self.print(dout) 186 return x, x 187 188 class GradNet(nn.Cell): 189 def __init__(self, net): 190 super(GradNet, self).__init__() 191 self.net = net 192 self.parameter_tuple = ParameterTuple(net.trainable_params()) 193 194 def construct(self, x, y): 195 return grad_all_list(self.net, self.parameter_tuple)(x, y) 196 197 user_defined_net = UserDefinedNet() 198 net = GradNet(user_defined_net) 199 x = Tensor(np.array([2.0], np.float32)) 200 y = Tensor(np.array([3.0], np.float32)) 201 with pytest.raises(TypeError): 202 net(x, y) 203 204 205# shoul compile success and Print in presented in the final function graph. 206@security_off_wrap 207@pytest.mark.skip(reason="isolated nodes exception") 208def test_unused_var(): 209 class UnusedVar(nn.Cell): 210 def __init__(self): 211 super().__init__() 212 self.print = P.Print() 213 214 def construct(self, x, y): 215 shape1 = self.get_shape(x) 216 out = x 217 for _ in range(shape1): 218 out = out + y 219 return out 220 221 def get_shape(self, x): 222 self.print(x) 223 _, c, _, _ = F.shape(x) 224 return c 225 226 net = UnusedVar() 227 x = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32) 228 y = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32) 229 print(net(x, y)) 230 231 232# shoul compile success and Print in presented in the final function graph. 233@security_off_wrap 234@pytest.mark.skip(reason="isolated nodes exception") 235def test_hof_unused_var(): 236 class UnusedVar(nn.Cell): 237 def __init__(self): 238 super().__init__() 239 self.print = P.Print() 240 241 def construct(self, x, y): 242 shape1 = self.hof_get_shape(self.get_shape, x) 243 out = x 244 for _ in range(shape1): 245 out = out + y 246 return out 247 248 def hof_get_shape(self, hof, x): 249 return hof(x) 250 251 def get_shape(self, x): 252 self.print(x) 253 _, c, _, _ = F.shape(x) 254 return c 255 256 net = UnusedVar() 257 x = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32) 258 y = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32) 259 print(net(x, y)) 260 261 262# shoul compile success and Print in presented in the final function graph. 263@security_off_wrap 264@pytest.mark.skip(reason="isolated nodes exception") 265def test_partial_hof_unused_var(): 266 class UnusedVar(nn.Cell): 267 def __init__(self): 268 super().__init__() 269 self.print = P.Print() 270 271 def construct(self, x, y): 272 shape1 = self.hof_get_shape(x)() 273 out = x 274 for _ in range(shape1): 275 out = out + y 276 return out 277 278 def hof_get_shape(self, x): 279 return F.partial(self.get_shape, x) 280 281 def get_shape(self, x): 282 self.print(x) 283 _, c, _, _ = F.shape(x) 284 return c 285 286 net = UnusedVar() 287 x = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32) 288 y = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32) 289 print(net(x, y)) 290 291 292# should compile success without endless loop. 293def test_while_if(): 294 class WhileIfNet(nn.Cell): 295 def __init__(self): 296 super().__init__() 297 self.zero = Tensor(np.zeros([1]).astype(np.float32)) 298 self.param = Parameter(Tensor(np.zeros([1]).astype(np.float32))) 299 300 def construct(self, idx, end, x): 301 out = self.zero 302 while idx < end: 303 if x < end: 304 out = out + self.param * 2 305 else: 306 out = out + self.param 307 idx = idx + 1 308 return out 309 310 idx = Tensor(np.array(0), dtype=ms.int32) 311 end = Tensor(np.array(5), dtype=ms.int32) 312 x = Tensor(np.zeros([1]).astype(np.float32)) 313 m = WhileIfNet() 314 m(idx, end, x) 315 316 317# should compile success without zeros_like_tensor args mismatch, the generated graph files 318# should not contain env_getitem or env_setitem. 319# InsertGradientOf primitive will make func_graph bprop_construct had BackPropAutoMonad flag set, 320# so all graph it used will be checked if any side effect it has, so the hyper_map_zeros_like 321# will have U as parameter, but the call site zeros_like(fv) don't have U argument. 322def test_grad_fv_and_insert_gradient_of(): 323 class FvAndInsertGradientNet(nn.Cell): 324 def __init__(self): 325 super(FvAndInsertGradientNet, self).__init__() 326 self.gather = P.GatherV2() 327 self.damping = Tensor(np.array([0.03, 0.03], np.float32)) 328 self.cov_step = Parameter(0, name="cov_step", requires_grad=False) 329 self.freq = Tensor(278, ms.int32) 330 self.getG = P.InsertGradientOf(self.save_gradient) 331 332 self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') 333 334 def save_gradient(self, dout): 335 self.cov_step = self.cov_step + self.freq 336 return dout 337 338 def construct(self, *inputs): 339 # fv self.z from construct_wrapper 340 x, = inputs 341 self.z = x 342 343 # insert_gradient_of 344 self.gather(self.damping, self.cov_step, 0) 345 out = self.getG(x) 346 return out 347 348 net = FvAndInsertGradientNet() 349 input_data = Tensor(np.array([1.0], np.float32)) 350 # if use grad_all_list, the generated graph will have env_setitem 351 # as gradient for inputs is constant zero, so it will depend on result of grad. 352 grad_net = grad_by_list(net, ParameterTuple(net.trainable_params())) 353 print(grad_net(input_data)) 354 355 356# should compile success as cnode with Partial primitive will not bind an additional U monad. 357def test_partial_parameter(): 358 z = Parameter(Tensor(np.array([True], np.bool_)), name='z') 359 360 class PartialNet(nn.Cell): 361 def __init__(self, input_z): 362 super().__init__() 363 self.input = input_z 364 365 def construct(self): 366 # getattr of all will be convert to Partial 367 out = self.input.all(axis=()) 368 return out 369 370 net = PartialNet(z) 371 print(net()) 372