• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16from mindspore.nn import Cell
17from mindspore.ops.composite import GradOperation
18from mindspore.common import ParameterTuple
19import numpy as np
20
21
22class _Grad(Cell):
23    def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
24        super().__init__()
25        self.network = network
26        self.grad = grad
27        self.sens_param = self.grad.sens_param
28        self.wrt_params = wrt_params
29        self.real_inputs_count = real_inputs_count
30        if self.wrt_params:
31            self.params = ParameterTuple(self.network.trainable_params())
32
33    def construct(self, *inputs):
34        if self.wrt_params:
35            if self.real_inputs_count is None or self.sens_param is False:
36                return self.grad(self.network, self.params)(*inputs)
37            real_inputs = inputs[:self.real_inputs_count]
38            sense_param_inputs = inputs[self.real_inputs_count:]
39            return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
40        if self.real_inputs_count is None or self.sens_param is False:
41            return self.grad(self.network)(*inputs)
42        real_inputs = inputs[:self.real_inputs_count]
43        sense_param_inputs = inputs[self.real_inputs_count:]
44        return self.grad(self.network)(*real_inputs, sense_param_inputs)
45
46
47class GradOfFirstInput(_Grad):
48    """
49    get grad of first input
50    """
51    def __init__(self, network, sens_param=True, real_inputs_count=None):
52        super().__init__(grad=GradOperation(sens_param=sens_param),
53                         network=network, real_inputs_count=real_inputs_count)
54
55
56class GradOfAllInputs(_Grad):
57    """
58    get grads of all inputs
59    """
60    def __init__(self, network, sens_param=True, real_inputs_count=None):
61        super().__init__(grad=GradOperation(get_all=True, sens_param=sens_param),
62                         network=network, real_inputs_count=real_inputs_count)
63
64
65class GradOfAllParams(_Grad):
66    """
67    get grads of all params
68    """
69    def __init__(self, network, sens_param=True, real_inputs_count=None):
70        super().__init__(grad=GradOperation(get_by_list=True, sens_param=sens_param),
71                         network=network, wrt_params=True, real_inputs_count=real_inputs_count)
72
73
74class GradOfAllInputsAndParams(_Grad):
75    """
76    get grads of all inputs and params
77    """
78    def __init__(self, network, sens_param=True, real_inputs_count=None):
79        super().__init__(grad=GradOperation(get_all=True, get_by_list=True,
80                                            sens_param=sens_param),
81                         network=network, wrt_params=True, real_inputs_count=real_inputs_count)
82
83
84class GradOfDefault(_Grad):
85    """
86    get default grad
87    """
88
89    def __init__(self, network, sens_param=False, real_inputs_count=None):
90        super().__init__(grad=GradOperation(get_all=False, get_by_list=False,
91                                            sens_param=sens_param),
92                         network=network, wrt_params=False, real_inputs_count=real_inputs_count)
93
94
95class HighGrad(Cell):
96    """
97    get any order of grad
98    """
99    def __init__(self, network, grad_list, sens_param=False, real_inputs_count=None):
100        super().__init__()
101        self.grads = [network,]
102        for i in range(len(grad_list)-1):
103            _grad = grad_list[i](self.grads[i], sens_param=False)
104            self.grads.append(_grad)
105        self.final_grad = grad_list[-1](self.grads[-1],
106                                        sens_param=sens_param, real_inputs_count=real_inputs_count)
107
108    def construct(self, *inputs):
109        return self.final_grad(*inputs)
110
111
112def _count_unequal_element(data_expected, data_me, rtol, atol):
113    assert data_expected.shape == data_me.shape
114    total_count = len(data_expected.flatten())
115    error = np.abs(data_expected - data_me)
116    greater = np.greater(error, atol + np.abs(data_me) * rtol)
117    loss_count = np.count_nonzero(greater)
118    assert (loss_count / total_count) < rtol, \
119        "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
120            format(data_expected[greater], data_me[greater], error[greater])
121
122
123def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
124    if np.any(np.isnan(data_expected)) or np.any(np.isnan(data_me)):
125        assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
126    elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
127        _count_unequal_element(data_expected, data_me, rtol, atol)
128    else:
129        assert np.array(data_expected).shape == np.array(data_me).shape
130