1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16from mindspore.nn import Cell 17from mindspore.ops.composite import GradOperation 18from mindspore.common import ParameterTuple 19import numpy as np 20 21 22class _Grad(Cell): 23 def __init__(self, grad, network, wrt_params=False, real_inputs_count=None): 24 super().__init__() 25 self.network = network 26 self.grad = grad 27 self.sens_param = self.grad.sens_param 28 self.wrt_params = wrt_params 29 self.real_inputs_count = real_inputs_count 30 if self.wrt_params: 31 self.params = ParameterTuple(self.network.trainable_params()) 32 33 def construct(self, *inputs): 34 if self.wrt_params: 35 if self.real_inputs_count is None or self.sens_param is False: 36 return self.grad(self.network, self.params)(*inputs) 37 real_inputs = inputs[:self.real_inputs_count] 38 sense_param_inputs = inputs[self.real_inputs_count:] 39 return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs) 40 if self.real_inputs_count is None or self.sens_param is False: 41 return self.grad(self.network)(*inputs) 42 real_inputs = inputs[:self.real_inputs_count] 43 sense_param_inputs = inputs[self.real_inputs_count:] 44 return self.grad(self.network)(*real_inputs, sense_param_inputs) 45 46 47class GradOfFirstInput(_Grad): 48 """ 49 get grad of first input 50 """ 51 def __init__(self, network, sens_param=True, real_inputs_count=None): 52 super().__init__(grad=GradOperation(sens_param=sens_param), 53 network=network, real_inputs_count=real_inputs_count) 54 55 56class GradOfAllInputs(_Grad): 57 """ 58 get grads of all inputs 59 """ 60 def __init__(self, network, sens_param=True, real_inputs_count=None): 61 super().__init__(grad=GradOperation(get_all=True, sens_param=sens_param), 62 network=network, real_inputs_count=real_inputs_count) 63 64 65class GradOfAllParams(_Grad): 66 """ 67 get grads of all params 68 """ 69 def __init__(self, network, sens_param=True, real_inputs_count=None): 70 super().__init__(grad=GradOperation(get_by_list=True, sens_param=sens_param), 71 network=network, wrt_params=True, real_inputs_count=real_inputs_count) 72 73 74class GradOfAllInputsAndParams(_Grad): 75 """ 76 get grads of all inputs and params 77 """ 78 def __init__(self, network, sens_param=True, real_inputs_count=None): 79 super().__init__(grad=GradOperation(get_all=True, get_by_list=True, 80 sens_param=sens_param), 81 network=network, wrt_params=True, real_inputs_count=real_inputs_count) 82 83 84class GradOfDefault(_Grad): 85 """ 86 get default grad 87 """ 88 89 def __init__(self, network, sens_param=False, real_inputs_count=None): 90 super().__init__(grad=GradOperation(get_all=False, get_by_list=False, 91 sens_param=sens_param), 92 network=network, wrt_params=False, real_inputs_count=real_inputs_count) 93 94 95class HighGrad(Cell): 96 """ 97 get any order of grad 98 """ 99 def __init__(self, network, grad_list, sens_param=False, real_inputs_count=None): 100 super().__init__() 101 self.grads = [network,] 102 for i in range(len(grad_list)-1): 103 _grad = grad_list[i](self.grads[i], sens_param=False) 104 self.grads.append(_grad) 105 self.final_grad = grad_list[-1](self.grads[-1], 106 sens_param=sens_param, real_inputs_count=real_inputs_count) 107 108 def construct(self, *inputs): 109 return self.final_grad(*inputs) 110 111 112def _count_unequal_element(data_expected, data_me, rtol, atol): 113 assert data_expected.shape == data_me.shape 114 total_count = len(data_expected.flatten()) 115 error = np.abs(data_expected - data_me) 116 greater = np.greater(error, atol + np.abs(data_me) * rtol) 117 loss_count = np.count_nonzero(greater) 118 assert (loss_count / total_count) < rtol, \ 119 "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ 120 format(data_expected[greater], data_me[greater], error[greater]) 121 122 123def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): 124 if np.any(np.isnan(data_expected)) or np.any(np.isnan(data_me)): 125 assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan) 126 elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan): 127 _count_unequal_element(data_expected, data_me, rtol, atol) 128 else: 129 assert np.array(data_expected).shape == np.array(data_me).shape 130