1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_fn_bprop """ 16import numpy as np 17import pytest 18 19import mindspore as ms 20import mindspore.common.dtype as mstype 21import mindspore.nn as nn 22from mindspore import Parameter 23from mindspore import context 24from mindspore.common.api import jit 25from mindspore.common.tensor import Tensor 26from mindspore.ops import composite as C 27from mindspore.ops import operations as P 28from mindspore.ops.functional import vjp 29from mindspore.ops.function.grad.grad_func import custom_vjp 30 31context.set_context(mode=context.GRAPH_MODE) 32 33grad_all = C.GradOperation(get_all=True) 34 35 36@pytest.mark.level0 37@pytest.mark.platform_x86_ascend_training 38@pytest.mark.env_onecard 39def test_custom_vjp_mul_add(): 40 """ 41 Features: Custom function bprop 42 Description: Get the custom vjp of mul_add function. 43 Expectation: No exception. 44 """ 45 46 @custom_vjp 47 def fn(x, y): 48 return 2 * x + y 49 50 def bprop_fn(x, y, out, dout): 51 return 2 * dout, 2 * y 52 53 fn.defbwd(bprop_fn) 54 55 x = Tensor(1, dtype=ms.int32) 56 y = Tensor(2, dtype=ms.int32) 57 v = Tensor(1, dtype=ms.int32) 58 _, grad_fn = vjp(fn, x, y) 59 grads = grad_fn(v) 60 assert grads[0] == Tensor(2, dtype=ms.int32) 61 assert grads[1] == Tensor(4, dtype=ms.int32) 62 63 64@pytest.mark.level1 65@pytest.mark.platform_x86_ascend_training 66@pytest.mark.env_onecard 67def test_custom_vjp_inline_mul_add(): 68 """ 69 Features: Custom function bprop 70 Description: Get the custom vjp when mul_add function is inline with other function. 71 Expectation: No exception. 72 """ 73 74 @custom_vjp 75 def mul_add(x, y): 76 return 2 * x + y 77 78 def bprop_mul_add(x, y, out, dout): 79 return 2 * dout, 2 * y 80 81 mul_add.defbwd(bprop_mul_add) 82 83 @jit 84 def inline_mul_add(x, y): 85 param = 2 86 return mul_add(x, y) + x + param * y 87 88 x = Tensor(1, dtype=ms.int32) 89 y = Tensor(2, dtype=ms.int32) 90 v = Tensor(1, dtype=ms.int32) 91 _, grad_fn = vjp(inline_mul_add, x, y) 92 grads = grad_fn(v) 93 assert grads[0] == Tensor(3, dtype=ms.int32) 94 assert grads[1] == Tensor(6, dtype=ms.int32) 95 96 97@pytest.mark.level1 98@pytest.mark.platform_x86_ascend_training 99@pytest.mark.env_onecard 100def test_custom_vjp_with_no_bprop(): 101 """ 102 Features: Custom function bprop 103 Description: Get the vjp with no bprop. 104 Expectation: No exception. 105 """ 106 107 def with_no_bprop(x, y): 108 return 2 * x + y 109 110 x = Tensor(1, dtype=ms.int32) 111 y = Tensor(2, dtype=ms.int32) 112 v = Tensor(1, dtype=ms.int32) 113 _, grad_fn = vjp(with_no_bprop, x, y) 114 grads = grad_fn(v) 115 assert grads[0] == Tensor(2, dtype=ms.int32) 116 assert grads[1] == Tensor(1, dtype=ms.int32) 117 118 119@pytest.mark.level0 120@pytest.mark.platform_x86_ascend_training 121@pytest.mark.env_onecard 122def test_custom_vjp_bprop_in_fn_2(): 123 """ 124 Features: Custom function bprop 125 Description: Get the custom vjp when bprop in fn_2. 126 Expectation: No exception. 127 """ 128 129 def fn_1(x, y): 130 relu = P.ReLU() 131 return relu(x) 132 133 @custom_vjp 134 def fn_2(x, y): 135 grads = grad_all(fn_1)(x, y) 136 return fn_1(x, y), grads[0], grads[1] 137 138 def bprop_fn_2(x, y, out, dout): 139 grads = grad_all(fn_1)(x, y) 140 return out[1], grads[1] 141 142 fn_2.defbwd(bprop_fn_2) 143 144 @jit 145 def fn_3(x, y): 146 return fn_2(x, y) 147 148 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 149 x = Tensor(np.ones([2, 2]).astype(np.float32)) 150 y = Tensor(np.ones([2, 2]).astype(np.float32)) 151 152 _, grad_fn = vjp(fn_3, x, y) 153 grads = grad_fn(v, v, v) 154 assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() 155 assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all() 156 157 158@pytest.mark.level1 159@pytest.mark.platform_x86_ascend_training 160@pytest.mark.env_onecard 161def test_custom_vjp_bprop_in_fn3(): 162 """ 163 Features: Custom function bprop 164 Description: Get the custom vjp when bprop in fn_3. 165 Expectation: No exception. 166 """ 167 168 def fn_1(x, y): 169 relu = P.ReLU() 170 return relu(x) 171 172 @custom_vjp 173 def fn_2(x, y): 174 grads = grad_all(fn_1)(x, y) 175 return fn_1(x, y), grads[0], grads[1] 176 177 def bprop_fn_2(x, y, out, dout): 178 grads = grad_all(fn_1)(x, y) 179 return out[1], grads[1] 180 181 fn_2.defbwd(bprop_fn_2) 182 183 @custom_vjp 184 def fn_3(x, y): 185 return fn_2(x, y) 186 187 def bprop_fn_3(x, y, out, dout): 188 return x + y + y + out[0], x + x + y + y + dout[0] 189 190 fn_3.defbwd(bprop_fn_3) 191 192 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 193 x = Tensor(np.ones([2, 2]).astype(np.float32)) 194 y = Tensor(np.ones([2, 2]).astype(np.float32)) 195 _, grad_fn = vjp(fn_3, x, y) 196 grads = grad_fn(v, v, v) 197 assert (grads[0].asnumpy() == np.array([[4, 4], [4, 4]]).astype(np.float32)).all() 198 assert (grads[1].asnumpy() == np.array([[5, 5], [5, 5]]).astype(np.float32)).all() 199 200 201@pytest.mark.level1 202@pytest.mark.platform_x86_ascend_training 203@pytest.mark.env_onecard 204def test_custom_vjp_one_input_bprop(): 205 """ 206 Features: Custom function bprop 207 Description: Get the custom vjp when the function has only one input. 208 Expectation: No exception. 209 """ 210 211 def bprop_fn(x, out, dout): 212 return (5 * x,) 213 214 @custom_vjp 215 def fn(x): 216 op = P.ReLU() 217 return op(x) 218 219 fn.defbwd(bprop_fn) 220 input1 = Tensor(np.ones([2, 2]).astype(np.float32)) 221 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 222 _, grad_fn = vjp(fn, input1) 223 grads = grad_fn(v) 224 assert (grads[0].asnumpy() == np.array([5, 5]).astype(np.float32)).all() 225 226 227@pytest.mark.level1 228@pytest.mark.platform_x86_ascend_training 229@pytest.mark.env_onecard 230def test_custom_vjp_inline_bprop_two_input(): 231 """ 232 Features: Custom function bprop 233 Description: Get the custom vjp when the function has two inputs. 234 Expectation: No exception. 235 """ 236 237 def fn_1(x, y): 238 return x * y 239 240 @custom_vjp 241 def fn_2(x, y): 242 grads = grad_all(fn_1)(x, y) 243 return fn_1(x, y), grads[0], grads[1] 244 245 def bprop_fn_2(x, y, out, dout): 246 grads = grad_all(fn_1)(x, y) 247 return grads[0] * 2, grads[1] * 2 248 249 fn_2.defbwd(bprop_fn_2) 250 251 input1 = Tensor(np.ones([2, 2]).astype(np.float32)) 252 input2 = Tensor(np.ones([2, 2]).astype(np.float32)) 253 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 254 _, grad_fn = vjp(fn_2, input1, input2) 255 grads = grad_fn(v, v, v) 256 assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() 257 assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all() 258 assert len(grads) == 2 259 260 261@pytest.mark.level1 262@pytest.mark.platform_x86_ascend_training 263@pytest.mark.env_onecard 264def test_custom_vjp_inline_bprop_multi_input(): 265 """ 266 Features: Custom function bprop 267 Description: Get the custom vjp of hybrid bprop function. 268 Expectation: No exception. 269 """ 270 271 def tensor_mul(x, y): 272 return x * y 273 274 @custom_vjp 275 def two_input(x, y): 276 op = P.Mul() 277 return op(x, y) 278 279 def two_input_bprop(x, y, out, dout): 280 return 5 * x, 8 * y 281 282 two_input.defbwd(two_input_bprop) 283 284 def two_input_1(x, y): 285 op = P.Mul() 286 x = 1 + x 287 return op(x, y) 288 289 @custom_vjp 290 def two_input_2(x, y): 291 op = P.Mul() 292 return op(x, y) 293 294 def two_input_2_bprop(x, y, out, dout): 295 return 5 * x, 8 * y 296 297 two_input_2.defbwd(two_input_2_bprop) 298 299 def inline_mutil_two_input(x, y): 300 output = ( 301 two_input(x, y) + tensor_mul(x, y) + two_input_1(x, y) + two_input_2(x, y) 302 ) 303 return output 304 305 input1 = Tensor(np.ones([2, 2]).astype(np.float32)) 306 input2 = Tensor(np.ones([2, 2]).astype(np.float32)) 307 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 308 _, grad_fn = vjp(inline_mutil_two_input, input1, input2) 309 grads = grad_fn(v) 310 assert ( 311 grads[0].asnumpy() == np.array([[12, 12], [12, 12]]).astype(np.float32) 312 ).all() 313 assert ( 314 grads[1].asnumpy() == np.array([[19, 19], [19, 19]]).astype(np.float32) 315 ).all() 316 assert len(grads) == 2 317 318 319@pytest.mark.level1 320@pytest.mark.platform_x86_cpu 321@pytest.mark.env_onecard 322def test_custom_vjp_fn_with_net(): 323 """ 324 Features: Custom function bprop 325 Description: Get the custom vjp when the function contains Cell. 326 Expectation: No exception. 327 """ 328 329 class Net(nn.Cell): 330 def __init__(self): 331 super(Net, self).__init__() 332 self.matmul = P.MatMul() 333 self.z = Parameter(Tensor(np.array([1.0], np.float32)), name="z") 334 335 def construct(self, x, y): 336 x = x * self.z 337 out = self.matmul(x, y) 338 return out 339 340 def fn_bprop(x, y, out, dout): 341 dx = x + x 342 dy = y + y 343 return dx, dy 344 345 @custom_vjp 346 def fn(x, y): 347 net = Net() 348 return net(x, y) 349 350 fn.defbwd(fn_bprop) 351 352 def grad_net(x, y): 353 grad_f = grad_all(fn) 354 return grad_f(x, y) 355 356 x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32) 357 y = Tensor( 358 [[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32 359 ) 360 out = grad_net(x, y) 361 expect_dx = np.array([[1.0, 1.2, 0.8], [2.4, 2.6, 2.2]]).astype(np.float32) 362 expect_dy = np.array([[0.02, 0.6, 2.2], [0.2, 0.4, 2.6], [4.2, 2.4, 6.6]]).astype( 363 np.float32 364 ) 365 assert np.allclose(out[0].asnumpy(), expect_dx) 366 assert np.allclose(out[1].asnumpy(), expect_dy) 367 368 369@pytest.mark.level1 370@pytest.mark.platform_x86_cpu 371@pytest.mark.env_onecard 372def test_custom_vjp_forward_net_call_fn(): 373 """ 374 Feature: Custom function bprop 375 Description: Get the custom vjp when the forward net call the function. 376 Expectation: No exception. 377 """ 378 379 class Net1(nn.Cell): 380 def __init__(self): 381 super(Net1, self).__init__() 382 self.matmul = P.MatMul() 383 self.z = Parameter(Tensor(np.array([1.0], np.float32)), name="z") 384 385 def construct(self, x, y): 386 x = x * self.z 387 out = self.matmul(x, y) 388 return out 389 390 @custom_vjp 391 def fn(x, y): 392 net = Net1() 393 return net(x, y) 394 395 def fn_bprop(x, y, out, dout): 396 dx = x + x 397 dy = y + y 398 return dx, dy 399 400 fn.defbwd(fn_bprop) 401 402 class Net(nn.Cell): 403 def construct(self, x, y): 404 return fn(x, y) 405 406 def grad_net(x, y): 407 grad_f = grad_all(Net()) 408 return grad_f(x, y) 409 410 x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32) 411 y = Tensor( 412 [[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32 413 ) 414 out = grad_net(x, y) 415 expect_dx = np.array([[1.0, 1.2, 0.8], [2.4, 2.6, 2.2]]).astype(np.float32) 416 expect_dy = np.array([[0.02, 0.6, 2.2], [0.2, 0.4, 2.6], [4.2, 2.4, 6.6]]).astype( 417 np.float32 418 ) 419 assert np.allclose(out[0].asnumpy(), expect_dx) 420 assert np.allclose(out[1].asnumpy(), expect_dy) 421