1# Copyright 2020-2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_bprop """ 16import numpy as np 17import pytest 18import mindspore as ms 19from mindspore import grad 20import mindspore.nn as nn 21from mindspore import context 22from mindspore.common import Tensor 23from mindspore.common.api import jit 24from mindspore.common.parameter import Parameter, ParameterTuple 25from mindspore.ops import operations as P 26from mindspore.ops import GradOperation 27from tests.mindspore_test_framework.utils.bprop_util import bprop 28from tests.st.pynative.utils import GradOfFirstInput, GradOfAllInputs, GradOfAllInputsAndParams 29 30 31def setup_module(): 32 context.set_context(mode=context.PYNATIVE_MODE) 33 34 35class Net(nn.Cell): 36 """ Net definition """ 37 38 def __init__(self): 39 super(Net, self).__init__() 40 self.matmul = P.MatMul() 41 self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') 42 43 @jit 44 def construct(self, x, y): 45 x = x * self.z 46 out = self.matmul(x, y) 47 return x, out 48 49 50def test_bprop_no_sens(): 51 grads = bprop(Net(), Tensor(np.ones([2, 3]).astype(np.float32)), 52 Tensor(np.ones([3, 2]).astype(np.float32)), wrt=['inputs']) 53 print(grads) 54 55 56def test_bprop_sens(): 57 grads = bprop(Net(), Tensor(np.ones([2, 3]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32)), 58 grads_wrt_outputs=(Tensor(np.ones([2, 3]).astype(np.float32)), 59 Tensor(np.ones([2, 2]).astype(np.float32))), wrt=['inputs']) 60 print(grads) 61 62 63def test_bprop_first_only(): 64 grads = bprop(Net(), Tensor(np.ones([2, 3]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32)), 65 grads_wrt_outputs=(Tensor(np.ones([2, 3]).astype(np.float32)), 66 Tensor(np.ones([2, 2]).astype(np.float32)))) 67 print(grads) 68 69 70def test_bprop_wrt_params(): 71 net = Net() 72 grads = bprop(net, Tensor(np.ones([2, 3]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32)), 73 grads_wrt_outputs=(Tensor(np.ones([2, 3]).astype(np.float32)), 74 Tensor(np.ones([2, 2]).astype(np.float32))), 75 wrt=['params'], 76 params=net.trainable_params()) 77 print(grads) 78 79 80def test_bprop_wrt_params_no_sens(): 81 net = Net() 82 grads = bprop(net, Tensor(np.ones([2, 3]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32)), 83 wrt=['params'], 84 params=net.trainable_params()) 85 print(grads) 86 87 88def test_bprop_wrt_inputs_and_params(): 89 net = Net() 90 grads = bprop(net, Tensor(np.ones([2, 3]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32)), 91 grads_wrt_outputs=(Tensor(np.ones([2, 3]).astype(np.float32)), 92 Tensor(np.ones([2, 2]).astype(np.float32))), 93 wrt=['inputs', 'params'], 94 params=net.trainable_params()) 95 print(grads) 96 97 98@pytest.mark.level1 99@pytest.mark.platform_x86_cpu 100@pytest.mark.env_onecard 101def test_network_with_dict_output(): 102 """ 103 Feature: Test sens dict 104 Description: Net out is dict 105 Expectation: Success 106 """ 107 108 class DicNet(nn.Cell): 109 def __init__(self): 110 super().__init__() 111 self.relu = P.ReLU() 112 113 def construct(self, x): 114 y = self.relu(x) 115 out = {Tensor(True): y} 116 return out 117 118 x = np.array([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]]) 119 ms_net = DicNet() 120 # No sens 121 ms_grad = GradOfFirstInput(ms_net, False) 122 grad_out = ms_grad(Tensor(x)) 123 assert np.allclose(np.ones_like(x), grad_out.asnumpy()) 124 125 # Have sens 126 out = ms_net(Tensor(x)) 127 ms_grad = GradOfFirstInput(ms_net, True) 128 grad_out = ms_grad(Tensor(x), out) 129 assert np.allclose(x, grad_out.asnumpy()) 130 131 132@pytest.mark.level0 133@pytest.mark.platform_x86_gpu_training 134@pytest.mark.env_onecard 135def test_jit_network_with_dict_output(): 136 """ 137 Feature: Test sens dict in jit 138 Description: Net out is dict in jit 139 Expectation: Success 140 """ 141 142 class DicNet(nn.Cell): 143 def __init__(self): 144 super().__init__() 145 self.relu = P.ReLU() 146 147 @jit 148 def construct(self, x): 149 y = self.relu(x) 150 out = {'a': y} 151 return out 152 153 x = np.array([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]]) 154 ms_net = DicNet() 155 # No sens 156 ms_grad = GradOfFirstInput(ms_net, False) 157 grad_out = ms_grad(Tensor(x)) 158 assert np.allclose(np.ones_like(x), grad_out.asnumpy()) 159 160 # Have sens 161 ms_net = DicNet() 162 out = ms_net(Tensor(x)) 163 ms_grad = GradOfFirstInput(ms_net, True) 164 grad_out = ms_grad(Tensor(x), out) 165 assert np.allclose(x, grad_out.asnumpy()) 166 167 168@pytest.mark.level0 169@pytest.mark.platform_x86_cpu 170@pytest.mark.env_onecard 171def test_pynative_synchronize(): 172 """ 173 Feature: Test pynative synchronize 174 Description: Test the code for the synchronous branch. 175 Expectation: success 176 """ 177 try: 178 context.set_context(pynative_synchronize=True) 179 180 # Cell object to be differentiated 181 class MulNet(nn.Cell): 182 def construct(self, x, y, z): 183 return x * y * z 184 185 x = Tensor([1, 2], ms.float32) 186 y = Tensor([-2, 3], ms.float32) 187 z = Tensor([0, 3], ms.float32) 188 net = MulNet() 189 net.set_inputs(Tensor(shape=[None], dtype=ms.float32), y, z) 190 output = grad(net, grad_position=(1, 2))(x, y, z) 191 assert (output[0].asnumpy() == np.array([0, 6], dtype=np.float32)).all() 192 assert (output[1].asnumpy() == np.array([-2, 6], dtype=np.float32)).all() 193 finally: 194 context.set_context(pynative_synchronize=False) 195 196 197@pytest.mark.level0 198@pytest.mark.platform_x86_cpu 199@pytest.mark.env_onecard 200def test_pynative_multi_grad(): 201 """ 202 Feature: Test pynative multi grad 203 Description: Test the code for PyNative multi grad. 204 Expectation: success 205 """ 206 207 class ForwardNetMul(nn.Cell): 208 def construct(self, x, y): 209 a = x * x 210 b = y * y 211 return a * b 212 213 class ForwardNetAdd(nn.Cell): 214 def construct(self, x, y): 215 a = x + x + x 216 b = y + y 217 return a * b 218 219 mulnet = ForwardNetMul() 220 addnet = ForwardNetAdd() 221 x = Tensor(np.ones([32]), dtype=ms.float32) 222 y = Tensor(np.ones([32]) * 2, dtype=ms.float32) 223 sens = Tensor(np.ones([32]), dtype=ms.float32) 224 mulnet.set_grad() 225 addnet.set_grad() 226 mulnet(x, y) 227 addnet(x, y) 228 grad_mul = GradOfAllInputs(mulnet) 229 grad_add = GradOfAllInputs(addnet) 230 grad_mul(x, y, sens) 231 grad_add(x, y, sens) 232 233 234class GradFactory: 235 def __init__(self, net_me, get_all, get_by_list, sens_param, net_params=None, 236 defalut_para=False): 237 self.net_me = net_me 238 self.get_all = get_all 239 self.get_by_list = get_by_list 240 self.sens_param = sens_param 241 self.net_params = net_params 242 self.default_para = defalut_para 243 244 def get_grad(self, ms_input): 245 output_grad_me = [] 246 out = self.net_me(*ms_input) 247 if isinstance(out, tuple): 248 for it in out: 249 if self.sens_param: 250 grad_np = np.random.randn(*it.shape).astype(np.float32) 251 else: 252 grad_np = np.ones(it.shape).astype(np.float32) 253 output_grad_me.append(Tensor(grad_np)) 254 output_grad_me = tuple(output_grad_me) 255 else: 256 if self.sens_param: 257 grad_np = np.random.randn(*out.shape).astype(np.float32) 258 else: 259 grad_np = np.ones(out.shape).astype(np.float32) 260 output_grad_me = Tensor(grad_np) 261 return output_grad_me 262 263 def one_backnet_call_twice(self, first_ms_input, second_ms_input, loss=0.001): 264 grad_input = self.get_grad(first_ms_input) 265 if self.default_para: 266 back_net = nn.ForwardValueAndGrad(self.net_me) 267 back_net(*first_ms_input) 268 else: 269 if self.get_by_list: 270 weight = self.net_params 271 else: 272 weight = None 273 back_net = nn.ForwardValueAndGrad(self.net_me, 274 weights=weight, get_all=self.get_all, 275 get_by_list=self.get_by_list, 276 sens_param=self.sens_param) 277 if self.sens_param: 278 back_net(*first_ms_input, grad_input[0]) 279 else: 280 back_net(*first_ms_input) 281 282 # second call 283 grad_input = self.get_grad(second_ms_input) 284 if self.default_para: 285 back_net(*second_ms_input) 286 else: 287 if self.sens_param: 288 back_net(*second_ms_input, grad_input[0]) 289 else: 290 back_net(*second_ms_input) 291 292 def two_backnet_call_twice(self, first_ms_input, second_ms_input, loss=0.001): 293 grad_input = self.get_grad(first_ms_input) 294 if self.default_para: 295 back_net = nn.ForwardValueAndGrad(self.net_me) 296 back_net(*first_ms_input) 297 else: 298 if self.get_by_list: 299 weight = self.net_params 300 else: 301 weight = None 302 back_net = nn.ForwardValueAndGrad(self.net_me, 303 weights=weight, get_all=self.get_all, 304 get_by_list=self.get_by_list, 305 sens_param=self.sens_param) 306 if self.sens_param: 307 back_net(*first_ms_input, grad_input[0]) 308 else: 309 back_net(*first_ms_input) 310 311 # second call 312 grad_input = self.get_grad(second_ms_input) 313 if self.default_para: 314 back_net2 = nn.ForwardValueAndGrad(self.net_me) 315 back_net2(*second_ms_input) 316 else: 317 back_net2 = nn.ForwardValueAndGrad(self.net_me, 318 weights=weight, get_all=self.get_all, 319 get_by_list=self.get_by_list, 320 sens_param=self.sens_param) 321 if self.sens_param: 322 back_net2(*second_ms_input, grad_input[0]) 323 else: 324 back_net2(*second_ms_input) 325 326 def first_forward_second_backnet(self, first_ms_input, second_ms_input, loss=0.001): 327 # second call 328 grad_input = self.get_grad(second_ms_input) 329 if self.default_para: 330 back_net2 = nn.ForwardValueAndGrad(self.net_me) 331 back_net2(*second_ms_input) 332 else: 333 if self.get_by_list: 334 weight = self.net_params 335 else: 336 weight = None 337 back_net2 = nn.ForwardValueAndGrad(self.net_me, 338 weights=weight, get_all=self.get_all, 339 get_by_list=self.get_by_list, 340 sens_param=self.sens_param) 341 if self.sens_param: 342 back_net2(*second_ms_input, grad_input[0]) 343 else: 344 back_net2(*second_ms_input) 345 346 347@pytest.mark.level0 348@pytest.mark.platform_x86_cpu 349@pytest.mark.env_onecard 350def test_forward_value_and_grad_0(): 351 """ 352 Feature: Test pynative value and grad 353 Description: Test the code for pynative value and grad. 354 Expectation: success 355 """ 356 357 class Net0(nn.Cell): 358 def __init__(self): 359 super().__init__() 360 self.para = Parameter(Tensor([2, 3, 4], ms.float32), name="para") 361 362 def construct(self): 363 x = self.para * self.para 364 return x 365 366 net_me = Net0() 367 fact = GradFactory(net_me=net_me, 368 get_all=True, 369 get_by_list=True, 370 sens_param=False, 371 net_params=ParameterTuple(net_me.trainable_params())) 372 373 first_input = () 374 second_input = () 375 fact.one_backnet_call_twice(first_input, second_input) 376 fact.two_backnet_call_twice(first_input, second_input) 377 fact.first_forward_second_backnet(first_input, second_input) 378 379 380@pytest.mark.level0 381@pytest.mark.platform_x86_cpu 382@pytest.mark.env_onecard 383def test_forward_value_and_grad_1(): 384 """ 385 Feature: Test pynative value and grad 386 Description: Test the code for pynative value and grad. 387 Expectation: success 388 """ 389 390 class Net1(nn.Cell): 391 def __init__(self): 392 super().__init__() 393 self.para = Parameter(Tensor([1], ms.float32), name="para") 394 395 def construct(self, x): 396 y = x + self.para 397 return y 398 399 net_me = Net1() 400 fact = GradFactory(net_me=net_me, 401 get_all=False, 402 get_by_list=False, 403 sens_param=False, 404 defalut_para=True) 405 406 input_1 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) 407 first_input = (input_1,) 408 409 input_1 = Tensor(np.random.randn(1, 2, 3, 4).astype(np.float32)) 410 second_input = (input_1,) 411 fact.one_backnet_call_twice(first_input, second_input) 412 fact.two_backnet_call_twice(first_input, second_input) 413 fact.first_forward_second_backnet(first_input, second_input) 414 415 416class CustomNet(nn.Cell): 417 def __init__(self): 418 super().__init__() 419 self.p1 = Parameter(Tensor(np.array([1.0], np.float32)), name='p1') 420 self.p2 = Parameter(Tensor(np.array([1.0], np.float32)), name='p2') 421 self.p3 = Parameter(Tensor(np.array([1.0], np.float32)), name='p2') 422 self.p1.requires_grad = False 423 self.p2.requires_grad = False 424 self.p3.requires_grad = True 425 426 def construct(self, x): 427 out = self.p1 * x 428 out = out * self.p2 429 out = out + self.p3 430 return out 431 432 433@pytest.mark.level0 434@pytest.mark.platform_x86_cpu 435@pytest.mark.env_onecard 436def test_pynative_requires_grad(): 437 """ 438 Feature: Test pynative requires grad 439 Description: Test the code for requires grad 440 Expectation: success 441 """ 442 x = Tensor([1], ms.float32) 443 net = CustomNet() 444 output = GradOfAllInputsAndParams(net, sens_param=False)(x) 445 assert (output[1][0].asnumpy() == np.array([1.0], dtype=np.float32)).all() 446 447 448@pytest.mark.level0 449@pytest.mark.platform_x86_cpu 450@pytest.mark.env_onecard 451def test_pynative_requires_grad_use_grad_operation(): 452 """ 453 Feature: Test pynative requires grad use grad operation 454 Description: Test the code for requires grad 455 Expectation: success 456 """ 457 458 # Cell object to be differentiated 459 x = Tensor([1], ms.float32) 460 net = CustomNet() 461 output = GradOperation(get_all=True, get_by_list=True)(net, [net.p1, net.p2, net.p3])(x) 462 assert (output[1][0].asnumpy() == np.array([0.0], dtype=np.float32)).all() 463 assert (output[1][1].asnumpy() == np.array([0.0], dtype=np.float32)).all() 464 assert (output[1][2].asnumpy() == np.array([1.0], dtype=np.float32)).all() 465 466 467@pytest.mark.level0 468@pytest.mark.platform_x86_cpu 469@pytest.mark.env_onecard 470def test_pynative_requires_grad_without_params(): 471 """ 472 Feature: Test pynative requires grad without params 473 Description: Test the code for requires grad 474 Expectation: success 475 """ 476 477 # Cell object to be differentiated 478 x = Tensor([1], ms.float32) 479 net = CustomNet() 480 output = GradOperation(get_all=True, get_by_list=True)(net)(x) 481 assert (output[1][0].asnumpy() == np.array([0.0], dtype=np.float32)).all() 482 assert (output[1][1].asnumpy() == np.array([0.0], dtype=np.float32)).all() 483 assert (output[1][2].asnumpy() == np.array([1.0], dtype=np.float32)).all() 484 485 486@pytest.mark.level0 487@pytest.mark.platform_x86_cpu 488@pytest.mark.env_onecard 489def test_pynative_requires_grad_case2(): 490 """ 491 Feature: Test pynative requires grad case2 492 Description: Test the code for requires grad 493 Expectation: success 494 """ 495 496 # Cell object to be differentiated 497 x = Tensor([1], ms.float32) 498 net = CustomNet() 499 output = GradOperation(get_all=True, get_by_list=True)(net, [net.p1])(x) 500 assert (output[1][0].asnumpy() == np.array([0.0], dtype=np.float32)).all() 501 assert len(output[1]) == 1 502