1# Copyright 2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import time 17import stat 18import os 19from mindspore.nn import Cell 20from mindspore.ops.composite import GradOperation 21from mindspore.common import ParameterTuple 22 23 24class _Grad(Cell): 25 def __init__(self, grad, network, wrt_params=False, real_inputs_count=None): 26 super().__init__() 27 self.network = network 28 self.grad = grad 29 self.sens_param = self.grad.sens_param 30 self.wrt_params = wrt_params 31 self.real_inputs_count = real_inputs_count 32 if self.wrt_params: 33 self.params = ParameterTuple(self.network.trainable_params()) 34 35 def __call__(self, *inputs): 36 if self.sens_param and self._dynamic_shape_inputs is not None: 37 # not support dynamic shape sens 38 if self.real_inputs_count is None: 39 dyn_inputs = self._dynamic_shape_inputs[:-1] 40 real_sens = inputs[-1:] 41 else: 42 idx = self.real_inputs_count 43 dyn_inputs = self._dynamic_shape_inputs[:idx] 44 real_sens = inputs[idx:] 45 static_sens = list(dyn_inputs) + list(real_sens) 46 super().set_inputs(*static_sens) 47 48 a = time.perf_counter() 49 out = super().__call__(*inputs) 50 b = time.perf_counter() 51 if os.environ.get("perf") == '1': 52 phase = os.environ.get("PHASE") 53 flags = os.O_WRONLY | os.O_CREAT 54 modes = stat.S_IWUSR | stat.S_IRUSR 55 with os.fdopen(os.open(phase, flags, modes), 'w') as f: 56 f.write(str(b - a)) 57 return out 58 59 def construct(self, *inputs): 60 if self.wrt_params: 61 if self.real_inputs_count is None or self.sens_param is False: 62 return self.grad(self.network, self.params)(*inputs) 63 real_inputs = inputs[:self.real_inputs_count] 64 sense_param_inputs = inputs[self.real_inputs_count:] 65 return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs) 66 if self.real_inputs_count is None or self.sens_param is False: 67 return self.grad(self.network)(*inputs) 68 real_inputs = inputs[:self.real_inputs_count] 69 sense_param_inputs = inputs[self.real_inputs_count:] 70 return self.grad(self.network)(*real_inputs, sense_param_inputs) 71 72 73class GradOfAllInputsAndParams(_Grad): 74 """ 75 get grads of all inputs and params 76 """ 77 def __init__(self, network, sens_param=True, real_inputs_count=None): 78 super().__init__(grad=GradOperation(get_all=True, get_by_list=True, 79 sens_param=sens_param), 80 network=network, wrt_params=True, real_inputs_count=real_inputs_count) 81 82class GradOfFirstInput(_Grad): 83 """ 84 get grad of first input 85 """ 86 87 def __init__(self, network, sens_param=True, real_inputs_count=None): 88 super().__init__(grad=GradOperation(sens_param=sens_param), 89 network=network, real_inputs_count=real_inputs_count) 90