1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16 17import mindspore.ops.composite as C 18from mindspore import Tensor, Parameter 19from mindspore import context 20from mindspore.common import dtype as mstype 21from mindspore.common.parameter import ParameterTuple 22from mindspore.nn import Cell 23from mindspore.ops import operations as P 24 25context.set_context(mode=context.GRAPH_MODE) 26 27 28grad_by_list = C.GradOperation(get_by_list=True) 29grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True) 30grad_by_list_with_sens = C.GradOperation(get_by_list=True, sens_param=True) 31grad_all = C.GradOperation(get_all=True) 32grad_with_sens = C.GradOperation(sens_param=True) 33 34 35def test_net_vargs_expand(): 36 class AddNet(Cell): 37 def __init__(self): 38 super(AddNet, self).__init__() 39 self.w = Parameter( 40 Tensor(np.ones((3, 4, 5), np.float32)), "w2", requires_grad=True) 41 42 def construct(self, x, y): 43 return x + y 44 45 x = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32)) 46 y = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32)) 47 sens = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32)) 48 net = AddNet() 49 _ = grad_all_with_sens(net, net.trainable_params())(x, y, sens) 50 51 52class VarNet(Cell): 53 def __init__(self, net): 54 super(VarNet, self).__init__() 55 self.b = Parameter( 56 Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b", requires_grad=True) 57 self.w = Parameter( 58 Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "w", requires_grad=True) 59 self.net = net 60 61 def construct(self, *args): 62 return self.net(*args) * self.w + self.b 63 64 65class SecondNet(Cell): 66 def __init__(self): 67 super(SecondNet, self).__init__() 68 self.b2 = Parameter( 69 Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b2", requires_grad=True) 70 71 def construct(self, *args): 72 res = args[0] + args[1] 73 return res + self.b2 74 75 76class Bprop(Cell): 77 def __init__(self, func, wrt_params, params, grad_op, sens=None): 78 super(Bprop, self).__init__(auto_prefix=False) 79 self.func = func 80 self.wrt_params = wrt_params 81 self.params = None 82 if self.wrt_params and params: 83 self.params = ParameterTuple(params) 84 self.grad = grad_op 85 self.with_sens = False 86 self.sens = sens 87 if not sens is None: 88 self.sens = sens if isinstance(sens, Tensor) else Tensor(sens, dtype=mstype.float32) 89 self.with_sens = True 90 91 def construct(self, *inputs): 92 # pylint: disable=no-else-return 93 if self.wrt_params: 94 if self.with_sens: 95 return self.grad(self.func, self.params)(*inputs, self.sens) 96 else: 97 return self.grad(self.func, self.params)(*inputs) 98 elif self.with_sens: 99 return self.grad(self.func)(*inputs, self.sens) 100 else: 101 return self.grad(self.func)(*inputs) 102 103 104def test_all_var_args_grad_with_sens(): 105 """"test grad_by_list_with_sens with all var args input""" 106 107 class GradNet(Cell): 108 def __init__(self, net): 109 super(GradNet, self).__init__() 110 self.weights = ParameterTuple(net.trainable_params()) 111 self.net = net 112 113 def construct(self, *inputs): 114 return grad_by_list_with_sens(self.net, self.weights)(*inputs) 115 116 x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 117 y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 118 sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 119 net = VarNet(SecondNet()) 120 grad_net = GradNet(net) 121 _ = grad_net(x, y, sens) 122 123 124def test_grad_list_var_args(): 125 class GradNet(Cell): 126 def __init__(self, net): 127 super(GradNet, self).__init__() 128 self.weights = ParameterTuple(net.trainable_params()) 129 self.net = net 130 131 def construct(self, *inputs): 132 return grad_by_list(self.net, self.weights)(*inputs) 133 134 x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 135 y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 136 net = VarNet(SecondNet()) 137 grad_net = GradNet(net) 138 _ = grad_net(x, y) 139 140 141def test_grad_all_var_args(): 142 class GradNet(Cell): 143 def __init__(self, net): 144 super(GradNet, self).__init__() 145 self.weights = ParameterTuple(net.trainable_params()) 146 self.net = net 147 148 def construct(self, *inputs): 149 return grad_all(self.net)(*inputs) 150 151 x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 152 y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 153 net = VarNet(SecondNet()) 154 grad_net = GradNet(net) 155 _ = grad_net(x, y) 156 157 158def test_grad_all_var_args_with_sens(): 159 class GradNet(Cell): 160 def __init__(self, net): 161 super(GradNet, self).__init__() 162 self.weights = ParameterTuple(net.trainable_params()) 163 self.net = net 164 165 def construct(self, *inputs): 166 return grad_all_with_sens(self.net)(*inputs) 167 168 x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 169 y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 170 sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 171 net = VarNet(SecondNet()) 172 grad_net = GradNet(net) 173 _ = grad_net(x, y, sens) 174 175 176def test_grad_var_args_with_sens(): 177 class GradNet(Cell): 178 def __init__(self, net): 179 super(GradNet, self).__init__() 180 self.weights = ParameterTuple(net.trainable_params()) 181 self.net = net 182 183 def construct(self, *inputs): 184 return grad_with_sens(self.net)(*inputs) 185 186 x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 187 y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 188 sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 189 net = VarNet(SecondNet()) 190 grad_net = GradNet(net) 191 _ = grad_net(x, y, sens) 192 193 194def test_grad_with_param_sens(): 195 """"test grad_with_sens parameter""" 196 197 class GradNet(Cell): 198 def __init__(self, net): 199 super(GradNet, self).__init__() 200 self.weights = ParameterTuple(net.trainable_params()) 201 self.net = net 202 self.sens = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), name='sens', requires_grad=False) 203 self.grad = C.GradOperation(get_by_list=True, sens_param=True) 204 205 def construct(self, x, y): 206 return self.grad(self.net, self.weights)(x, y, self.sens) 207 208 x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 209 y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 210 net = SecondNet() 211 grad_net = GradNet(net) 212 _ = grad_net(x, y) 213 214 215def test_var_args_grad(): 216 class VarNet(Cell): 217 def __init__(self, net): 218 super(VarNet, self).__init__() 219 self.b = Parameter( 220 Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b", requires_grad=True) 221 self.net = net 222 223 def construct(self, *args): 224 return self.net(*args) + self.b 225 226 class SecondNet(Cell): 227 def __init__(self): 228 super(SecondNet, self).__init__() 229 self.b2 = Parameter( 230 Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b2", requires_grad=True) 231 232 def construct(self, *args): 233 res = args[0] + args[1] 234 return res + self.b2 235 236 class GradNet(Cell): 237 def __init__(self, net): 238 super(GradNet, self).__init__() 239 self.net = net 240 self.weights = ParameterTuple(net.trainable_params()) 241 242 def construct(self, x, y, sens): 243 return grad_by_list_with_sens(self.net, self.weights)(x, y, sens) 244 245 x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 246 y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 247 sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 248 net = VarNet(SecondNet()) 249 grad_net = GradNet(net) 250 _ = grad_net(x, y, sens) 251 252 253def test_var_args_positional(): 254 """"test grad_all with var args in inner graph""" 255 256 class VarNet(Cell): 257 def __init__(self, net): 258 super(VarNet, self).__init__() 259 self.net = net 260 261 def construct(self, x, y): 262 return self.net(x, y) * x 263 264 class SecondNet(Cell): 265 def __init__(self): 266 super(SecondNet, self).__init__() 267 268 def construct(self, *args): 269 return args[0] + args[1] 270 271 class GradNet(Cell): 272 def __init__(self, net): 273 super(GradNet, self).__init__() 274 self.net = net 275 self.weights = ParameterTuple(net.trainable_params()) 276 277 def construct(self, x, y): 278 return grad_all(self.net)(x, y) 279 280 x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 281 y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 282 net = VarNet(SecondNet()) 283 grad_net = GradNet(net) 284 _ = grad_net(x, y) 285 286 287def test_grad_within_if_else(): 288 class GradNet(Cell): 289 def __init__(self, net): 290 super(GradNet, self).__init__() 291 self.weights = ParameterTuple(net.trainable_params()) 292 self.net = net 293 grad_op = C.GradOperation(get_all=False, get_by_list=True, sens_param=True) 294 sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 295 self.grad = Bprop(self.net, True, self.weights, grad_op, sens) 296 297 def construct(self, *inputs): 298 return self.grad(*inputs) 299 300 x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 301 y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) 302 net = VarNet(SecondNet()) 303 grad_net = GradNet(net) 304 out = grad_net(x, y) 305 print("test_grad_var_args_with_sens out=", out) 306 307 308def test_grad_for_concat(): 309 class GradNet(Cell): 310 def __init__(self, net): 311 super(GradNet, self).__init__() 312 self.weights = ParameterTuple(net.trainable_params()) 313 self.net = net 314 grad_op = C.GradOperation(get_all=True, get_by_list=False, sens_param=True) 315 self.grad = Bprop(self.net, False, self.weights, grad_op) 316 317 def construct(self, *inputs): 318 return self.grad(*inputs) 319 320 class Concat(Cell): 321 def __init__(self, axis): 322 super().__init__() 323 self.concat = P.Concat(axis=axis) 324 325 def construct(self, *input1): 326 return self.concat(input1) 327 328 class ConcatFactory: 329 def __init__(self, input_shape, axis, dtype=np.float32): 330 super(ConcatFactory, self).__init__() 331 self.inputs_np = [] 332 for s in input_shape: 333 self.inputs_np.append(np.random.randn(*s).astype(dtype)) 334 self.axis = axis 335 self.out_numpy = np.concatenate(self.inputs_np, axis=self.axis) 336 self.out_grad_np = self.out_numpy 337 338 def grad_mindspore_impl(self): 339 inputs = [] 340 for i in self.inputs_np: 341 inputs.append(Tensor(i)) 342 net = Concat(axis=self.axis) 343 grad_net = GradNet(net) 344 grad_net.set_train() 345 _ = grad_net(*inputs, Tensor(self.out_grad_np)) 346 347 def grad_cmp(self): 348 self.grad_mindspore_impl() 349 350 fact = ConcatFactory(input_shape=( 351 (2, 184320, 1), (2, 46080, 1), (2, 11520, 1), (2, 2880, 1), (2, 720, 1)), axis=1) 352 fact.grad_cmp() 353