• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16from mindspore import nn
17from mindspore.ops import operations as P
18from mindspore.ops import functional as F
19from mindspore.ops import composite as C
20from mindspore.common.tensor import Tensor
21from mindspore.common import dtype as mstype
22from mindspore.common.parameter import ParameterTuple
23
24
25class ClipByNorm(nn.Cell):
26    """
27    Clips tensor values to a maximum :math:`L_2`-norm.
28    """
29
30    def __init__(self):
31        super(ClipByNorm, self).__init__()
32        self.reduce_sum = P.ReduceSum(keep_dims=True)
33        self.select_ = P.Select()
34        self.greater_ = P.Greater()
35        self.cast = P.Cast()
36        self.sqrt = P.Sqrt()
37        self.max_op = P.Maximum()
38        self.shape = P.Shape()
39        self.reshape = P.Reshape()
40        self.fill = P.Fill()
41        self.expand_dims = P.ExpandDims()
42        self.dtype = P.DType()
43
44    def construct(self, x, clip_norm):
45        """add ms_function decorator for pynative mode"""
46        mul_x = F.square(x)
47        if mul_x.shape == (1,):
48            l2sum = self.cast(mul_x, mstype.float32)
49        else:
50            l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32)
51        cond = self.greater_(l2sum, 0)
52        ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0)
53        l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
54        l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum)
55
56        intermediate = x * clip_norm
57
58        max_norm = self.max_op(l2norm, clip_norm)
59        values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1)
60        values_clip = self.reshape(values_clip, self.shape(x))
61        values_clip = F.identity(values_clip)
62        return values_clip
63
64
65clip_grad = C.MultitypeFuncGraph("clip_grad")
66
67
68@clip_grad.register("Number", "Number", "Tensor")
69def _clip_grad(clip_type, clip_value, grad):
70    """
71    Clip gradients.
72
73    Inputs:
74        clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
75        clip_value (float): Specifies how much to clip.
76        grad (tuple[Tensor]): Gradients.
77
78    Outputs:
79        tuple[Tensor], clipped gradients.
80    """
81    if clip_type not in (0, 1):
82        return grad
83    dt = F.dtype(grad)
84    if clip_type == 0:
85        new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
86                                   F.cast(F.tuple_to_array((clip_value,)), dt))
87    else:
88        new_grad = ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
89    return new_grad
90
91
92grad_scale = C.MultitypeFuncGraph("grad_scale")
93reciprocal = P.Reciprocal()
94
95
96@grad_scale.register("Tensor", "Tensor")
97def tensor_grad_scale(scale, grad):
98    return grad * reciprocal(scale)
99
100
101class ClipGradients(nn.Cell):
102    """
103    Clip gradients.
104
105    Inputs:
106        grads (list): List of gradient tuples.
107        clip_type (Tensor): The way to clip, 'value' or 'norm'.
108        clip_value (Tensor): Specifies how much to clip.
109
110    Returns:
111        List, a list of clipped_grad tuples.
112    """
113    def __init__(self):
114        super(ClipGradients, self).__init__()
115        self.clip_by_norm = nn.ClipByNorm()
116        self.cast = P.Cast()
117        self.dtype = P.DType()
118
119    def construct(self,
120                  grads,
121                  clip_type,
122                  clip_value):
123        """clip gradients"""
124        if clip_type not in (0, 1):
125            return grads
126        new_grads = ()
127        for grad in grads:
128            dt = self.dtype(grad)
129            if clip_type == 0:
130                t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
131                                    self.cast(F.tuple_to_array((clip_value,)), dt))
132            else:
133                t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
134            new_grads = new_grads + (t,)
135        return new_grads
136
137
138class CrossEntropy(nn.Cell):
139    """
140    Cross Entropy loss
141    """
142    def __init__(self, num_labels):
143        super(CrossEntropy, self).__init__()
144        self.onehot = P.OneHot()
145        self.on_value = Tensor(1.0, mstype.float32)
146        self.off_value = Tensor(0.0, mstype.float32)
147        self.reduce_sum = P.ReduceSum()
148        self.reduce_mean = P.ReduceMean()
149        self.reshape = P.Reshape()
150        self.last_idx = (-1,)
151        self.neg = P.Neg()
152        self.cast = P.Cast()
153        self.num_labels = num_labels
154
155    def construct(self, logits, label_ids):
156        label_ids = self.reshape(label_ids, self.last_idx)
157        one_hot_labels = self.onehot(label_ids, self.num_labels, self.on_value, self.off_value)
158        per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx))
159        loss = self.reduce_mean(per_example_loss, self.last_idx)
160        return_value = self.cast(loss, mstype.float32)
161        return return_value
162
163
164class NetworkWithCLSLoss(nn.Cell):
165    def __init__(self, network):
166        super(NetworkWithCLSLoss, self).__init__(auto_prefix=False)
167        self.cls_network = network
168        self.loss_fct = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
169
170    def construct(self, input_ids, input_mask, token_type_id, label_ids):
171        logits = self.cls_network(input_ids, input_mask, token_type_id)
172        cls_loss = self.loss_fct(logits, label_ids)
173        return cls_loss
174
175
176class NetworkWithMLMLoss(nn.Cell):
177    def __init__(self, network, vocab_size=21128):
178        super(NetworkWithMLMLoss, self).__init__(auto_prefix=False)
179        self.mlm_network = network
180        self.vocab_size = vocab_size
181        self.reshape = P.Reshape()
182        self.loss_fct = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
183
184    def construct(self, input_ids, input_mask, token_type_id, label_ids):
185        prediction_scores = self.mlm_network(input_ids, input_mask, token_type_id)
186        prediction_scores = self.reshape(prediction_scores, (-1, self.vocab_size))
187        label_ids = self.reshape(label_ids, (-1,))
188        mlm_loss = self.loss_fct(prediction_scores, label_ids)
189        return mlm_loss
190
191
192class NetworkTrainCell(nn.Cell):
193    def __init__(self, network, optimizer, sens=1.0):
194        super(NetworkTrainCell, self).__init__(auto_prefix=False)
195        self.network = network
196        self.network.set_grad()
197        self.weights = optimizer.parameters
198        self.optimizer = optimizer
199        self.sens = sens
200        self.grad = C.GradOperation(get_by_list=True,
201                                    sens_param=True)
202        self.clip_type = 1
203        self.clip_value = 1.0
204        self.cast = P.Cast()
205        self.hyper_map = C.HyperMap()
206
207        self.get_weights_by_key = P.PullWeight()
208        self.over_weights_by_key = P.PushWeight()
209
210        self.get_weights_by_key_input_1, \
211        self.get_weights_by_key_input_2, \
212        self.get_weights_by_key_input_3 = self._get_weights_by_key_inputs(self.network.parameters_and_names())
213
214        self.over_weights_by_key_input_1, \
215        self.over_weights_by_key_input_2, \
216        self.over_weights_by_key_input_3 = self._over_weights_by_key_inputs(self.network.parameters_and_names())
217
218    def _communication_with_server_1(self, weights):
219        result = self.hyper_map(F.partial(self.get_weights_by_key), weights,
220                                self.get_weights_by_key_input_2, self.get_weights_by_key_input_3)
221        return result
222
223    def _communication_with_server_2(self, weights):
224        result = self.hyper_map(F.partial(self.over_weights_by_key), weights,
225                                self.over_weights_by_key_input_2,
226                                self.over_weights_by_key_input_3)
227        return result
228
229    def _get_weights_by_key_inputs(self, weights):
230        filtered_weights = []
231        weight_names = []
232        weight_indices = []
233        index = 0
234        for weight in weights:
235            if weight[1].pull_weight_from_server:
236                filtered_weights.append(weight[1])
237                weight_names.append(weight[1].name)
238                weight_indices.append(index)
239            index += 1
240        return ParameterTuple(filtered_weights), tuple(weight_names), tuple(weight_indices)
241
242    def _over_weights_by_key_inputs(self, weights):
243        filtered_weights = []
244        weight_names = []
245        weight_indices = []
246        index = 0
247        for weight in weights:
248            if weight[1].push_weight_to_server:
249                filtered_weights.append(weight[1])
250                weight_names.append(weight[1].name)
251                weight_indices.append(index)
252            index += 1
253        return ParameterTuple(filtered_weights), tuple(weight_names), tuple(weight_indices)
254
255    def construct(self, input_ids, input_mask, token_type_id, label_ids):
256        weights = self.weights
257        res = self._communication_with_server_1(self.get_weights_by_key_input_1)
258        input_ids = F.depend(input_ids, res)
259        loss = self.network(input_ids, input_mask, token_type_id, label_ids)
260        grads = self.grad(self.network, weights)(input_ids,
261                                                 input_mask,
262                                                 token_type_id,
263                                                 label_ids,
264                                                 self.cast(F.tuple_to_array((self.sens,)),
265                                                           mstype.float32))
266        grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads)
267        loss = F.depend(loss, self.optimizer(grads))
268        weights1 = F.depend(self.over_weights_by_key_input_1, loss)
269        loss = F.depend(loss, self._communication_with_server_2(weights1))
270        return loss
271
272
273class NetworkNoClientTrainCell(nn.Cell):
274    def __init__(self, network, optimizer, sens=1.0):
275        super(NetworkNoClientTrainCell, self).__init__(auto_prefix=False)
276        self.network = network
277        self.network.set_grad()
278        self.weights = optimizer.parameters
279        self.optimizer = optimizer
280        self.sens = sens
281        self.grad = C.GradOperation(get_by_list=True,
282                                    sens_param=True)
283        self.clip_type = 1
284        self.clip_value = 1.0
285        self.cast = P.Cast()
286        self.hyper_map = C.HyperMap()
287
288    def construct(self, input_ids, input_mask, token_type_id, label_ids):
289        weights = self.weights
290        loss = self.network(input_ids, input_mask, token_type_id, label_ids)
291        grads = self.grad(self.network, weights)(input_ids,
292                                                 input_mask,
293                                                 token_type_id,
294                                                 label_ids,
295                                                 self.cast(F.tuple_to_array((self.sens,)),
296                                                           mstype.float32))
297        grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads)
298        self.optimizer(grads)
299        return loss
300