1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16from mindspore import nn 17from mindspore.ops import operations as P 18from mindspore.ops import functional as F 19from mindspore.ops import composite as C 20from mindspore.common.tensor import Tensor 21from mindspore.common import dtype as mstype 22from mindspore.common.parameter import ParameterTuple 23 24 25class ClipByNorm(nn.Cell): 26 """ 27 Clips tensor values to a maximum :math:`L_2`-norm. 28 """ 29 30 def __init__(self): 31 super(ClipByNorm, self).__init__() 32 self.reduce_sum = P.ReduceSum(keep_dims=True) 33 self.select_ = P.Select() 34 self.greater_ = P.Greater() 35 self.cast = P.Cast() 36 self.sqrt = P.Sqrt() 37 self.max_op = P.Maximum() 38 self.shape = P.Shape() 39 self.reshape = P.Reshape() 40 self.fill = P.Fill() 41 self.expand_dims = P.ExpandDims() 42 self.dtype = P.DType() 43 44 def construct(self, x, clip_norm): 45 """add ms_function decorator for pynative mode""" 46 mul_x = F.square(x) 47 if mul_x.shape == (1,): 48 l2sum = self.cast(mul_x, mstype.float32) 49 else: 50 l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32) 51 cond = self.greater_(l2sum, 0) 52 ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0) 53 l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum))) 54 l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum) 55 56 intermediate = x * clip_norm 57 58 max_norm = self.max_op(l2norm, clip_norm) 59 values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1) 60 values_clip = self.reshape(values_clip, self.shape(x)) 61 values_clip = F.identity(values_clip) 62 return values_clip 63 64 65clip_grad = C.MultitypeFuncGraph("clip_grad") 66 67 68@clip_grad.register("Number", "Number", "Tensor") 69def _clip_grad(clip_type, clip_value, grad): 70 """ 71 Clip gradients. 72 73 Inputs: 74 clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. 75 clip_value (float): Specifies how much to clip. 76 grad (tuple[Tensor]): Gradients. 77 78 Outputs: 79 tuple[Tensor], clipped gradients. 80 """ 81 if clip_type not in (0, 1): 82 return grad 83 dt = F.dtype(grad) 84 if clip_type == 0: 85 new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), 86 F.cast(F.tuple_to_array((clip_value,)), dt)) 87 else: 88 new_grad = ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) 89 return new_grad 90 91 92grad_scale = C.MultitypeFuncGraph("grad_scale") 93reciprocal = P.Reciprocal() 94 95 96@grad_scale.register("Tensor", "Tensor") 97def tensor_grad_scale(scale, grad): 98 return grad * reciprocal(scale) 99 100 101class ClipGradients(nn.Cell): 102 """ 103 Clip gradients. 104 105 Inputs: 106 grads (list): List of gradient tuples. 107 clip_type (Tensor): The way to clip, 'value' or 'norm'. 108 clip_value (Tensor): Specifies how much to clip. 109 110 Returns: 111 List, a list of clipped_grad tuples. 112 """ 113 def __init__(self): 114 super(ClipGradients, self).__init__() 115 self.clip_by_norm = nn.ClipByNorm() 116 self.cast = P.Cast() 117 self.dtype = P.DType() 118 119 def construct(self, 120 grads, 121 clip_type, 122 clip_value): 123 """clip gradients""" 124 if clip_type not in (0, 1): 125 return grads 126 new_grads = () 127 for grad in grads: 128 dt = self.dtype(grad) 129 if clip_type == 0: 130 t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt), 131 self.cast(F.tuple_to_array((clip_value,)), dt)) 132 else: 133 t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt)) 134 new_grads = new_grads + (t,) 135 return new_grads 136 137 138class CrossEntropy(nn.Cell): 139 """ 140 Cross Entropy loss 141 """ 142 def __init__(self, num_labels): 143 super(CrossEntropy, self).__init__() 144 self.onehot = P.OneHot() 145 self.on_value = Tensor(1.0, mstype.float32) 146 self.off_value = Tensor(0.0, mstype.float32) 147 self.reduce_sum = P.ReduceSum() 148 self.reduce_mean = P.ReduceMean() 149 self.reshape = P.Reshape() 150 self.last_idx = (-1,) 151 self.neg = P.Neg() 152 self.cast = P.Cast() 153 self.num_labels = num_labels 154 155 def construct(self, logits, label_ids): 156 label_ids = self.reshape(label_ids, self.last_idx) 157 one_hot_labels = self.onehot(label_ids, self.num_labels, self.on_value, self.off_value) 158 per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx)) 159 loss = self.reduce_mean(per_example_loss, self.last_idx) 160 return_value = self.cast(loss, mstype.float32) 161 return return_value 162 163 164class NetworkWithCLSLoss(nn.Cell): 165 def __init__(self, network): 166 super(NetworkWithCLSLoss, self).__init__(auto_prefix=False) 167 self.cls_network = network 168 self.loss_fct = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 169 170 def construct(self, input_ids, input_mask, token_type_id, label_ids): 171 logits = self.cls_network(input_ids, input_mask, token_type_id) 172 cls_loss = self.loss_fct(logits, label_ids) 173 return cls_loss 174 175 176class NetworkWithMLMLoss(nn.Cell): 177 def __init__(self, network, vocab_size=21128): 178 super(NetworkWithMLMLoss, self).__init__(auto_prefix=False) 179 self.mlm_network = network 180 self.vocab_size = vocab_size 181 self.reshape = P.Reshape() 182 self.loss_fct = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 183 184 def construct(self, input_ids, input_mask, token_type_id, label_ids): 185 prediction_scores = self.mlm_network(input_ids, input_mask, token_type_id) 186 prediction_scores = self.reshape(prediction_scores, (-1, self.vocab_size)) 187 label_ids = self.reshape(label_ids, (-1,)) 188 mlm_loss = self.loss_fct(prediction_scores, label_ids) 189 return mlm_loss 190 191 192class NetworkTrainCell(nn.Cell): 193 def __init__(self, network, optimizer, sens=1.0): 194 super(NetworkTrainCell, self).__init__(auto_prefix=False) 195 self.network = network 196 self.network.set_grad() 197 self.weights = optimizer.parameters 198 self.optimizer = optimizer 199 self.sens = sens 200 self.grad = C.GradOperation(get_by_list=True, 201 sens_param=True) 202 self.clip_type = 1 203 self.clip_value = 1.0 204 self.cast = P.Cast() 205 self.hyper_map = C.HyperMap() 206 207 self.get_weights_by_key = P.PullWeight() 208 self.over_weights_by_key = P.PushWeight() 209 210 self.get_weights_by_key_input_1, \ 211 self.get_weights_by_key_input_2, \ 212 self.get_weights_by_key_input_3 = self._get_weights_by_key_inputs(self.network.parameters_and_names()) 213 214 self.over_weights_by_key_input_1, \ 215 self.over_weights_by_key_input_2, \ 216 self.over_weights_by_key_input_3 = self._over_weights_by_key_inputs(self.network.parameters_and_names()) 217 218 def _communication_with_server_1(self, weights): 219 result = self.hyper_map(F.partial(self.get_weights_by_key), weights, 220 self.get_weights_by_key_input_2, self.get_weights_by_key_input_3) 221 return result 222 223 def _communication_with_server_2(self, weights): 224 result = self.hyper_map(F.partial(self.over_weights_by_key), weights, 225 self.over_weights_by_key_input_2, 226 self.over_weights_by_key_input_3) 227 return result 228 229 def _get_weights_by_key_inputs(self, weights): 230 filtered_weights = [] 231 weight_names = [] 232 weight_indices = [] 233 index = 0 234 for weight in weights: 235 if weight[1].pull_weight_from_server: 236 filtered_weights.append(weight[1]) 237 weight_names.append(weight[1].name) 238 weight_indices.append(index) 239 index += 1 240 return ParameterTuple(filtered_weights), tuple(weight_names), tuple(weight_indices) 241 242 def _over_weights_by_key_inputs(self, weights): 243 filtered_weights = [] 244 weight_names = [] 245 weight_indices = [] 246 index = 0 247 for weight in weights: 248 if weight[1].push_weight_to_server: 249 filtered_weights.append(weight[1]) 250 weight_names.append(weight[1].name) 251 weight_indices.append(index) 252 index += 1 253 return ParameterTuple(filtered_weights), tuple(weight_names), tuple(weight_indices) 254 255 def construct(self, input_ids, input_mask, token_type_id, label_ids): 256 weights = self.weights 257 res = self._communication_with_server_1(self.get_weights_by_key_input_1) 258 input_ids = F.depend(input_ids, res) 259 loss = self.network(input_ids, input_mask, token_type_id, label_ids) 260 grads = self.grad(self.network, weights)(input_ids, 261 input_mask, 262 token_type_id, 263 label_ids, 264 self.cast(F.tuple_to_array((self.sens,)), 265 mstype.float32)) 266 grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads) 267 loss = F.depend(loss, self.optimizer(grads)) 268 weights1 = F.depend(self.over_weights_by_key_input_1, loss) 269 loss = F.depend(loss, self._communication_with_server_2(weights1)) 270 return loss 271 272 273class NetworkNoClientTrainCell(nn.Cell): 274 def __init__(self, network, optimizer, sens=1.0): 275 super(NetworkNoClientTrainCell, self).__init__(auto_prefix=False) 276 self.network = network 277 self.network.set_grad() 278 self.weights = optimizer.parameters 279 self.optimizer = optimizer 280 self.sens = sens 281 self.grad = C.GradOperation(get_by_list=True, 282 sens_param=True) 283 self.clip_type = 1 284 self.clip_value = 1.0 285 self.cast = P.Cast() 286 self.hyper_map = C.HyperMap() 287 288 def construct(self, input_ids, input_mask, token_type_id, label_ids): 289 weights = self.weights 290 loss = self.network(input_ids, input_mask, token_type_id, label_ids) 291 grads = self.grad(self.network, weights)(input_ids, 292 input_mask, 293 token_type_id, 294 label_ids, 295 self.cast(F.tuple_to_array((self.sens,)), 296 mstype.float32)) 297 grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads) 298 self.optimizer(grads) 299 return loss 300