1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore as ms 18from mindspore import Tensor 19from mindspore import context 20from mindspore.common.parameter import Parameter 21from mindspore.common import dtype as mstype 22from mindspore.ops import composite as C 23from mindspore.ops import operations as P 24from mindspore.ops import functional as F 25from mindspore.nn.optim.momentum import Momentum 26from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell 27import mindspore.nn as nn 28from mindspore.train import Model 29from mindspore.context import ParallelMode 30from tests.dataset_mock import MindData 31 32 33GRADIENT_CLIP_TYPE = 1 34GRADIENT_CLIP_VALUE = 1.0 35clip_grad = C.MultitypeFuncGraph("clip_grad") 36grad_scale = C.MultitypeFuncGraph("grad_scale") 37reciprocal = P.Reciprocal() 38 39 40@grad_scale.register("Tensor", "Tensor") 41def tensor_grad_scale(scale, grad): 42 return grad * reciprocal(scale) 43 44 45update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000) 46 47 48@clip_grad.register("Number", "Number", "Tensor") 49def _clip_grad(clip_type, clip_value, grad): 50 dt = F.dtype(grad) 51 if clip_type == 0: 52 new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), 53 F.cast(F.tuple_to_array((clip_value,)), dt)) 54 else: 55 new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) 56 return new_grad 57 58 59class TrainOneStepWithLossScaleCell(nn.Cell): 60 def __init__(self, network, optimizer, scale_update_cell=None): 61 super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) 62 self.network = network 63 self.weights = optimizer.parameters 64 self.optimizer = optimizer 65 self.grad = C.GradOperation(get_by_list=True, 66 sens_param=True) 67 self.reducer_flag = False 68 self.grad_reducer = F.identity 69 self.cast = P.Cast() 70 self.alloc_status = P.NPUAllocFloatStatus() 71 self.get_status = P.NPUGetFloatStatus() 72 self.clear_status = P.NPUClearFloatStatus() 73 self.reduce_sum = P.ReduceSum(keep_dims=False) 74 self.base = Tensor(1, mstype.float32) 75 self.less_equal = P.LessEqual() 76 self.hyper_map = C.HyperMap() 77 self.loss_scale = None 78 self.loss_scaling_manager = scale_update_cell 79 if scale_update_cell: 80 self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), 81 name="loss_scale") 82 83 def construct(self, x, sens=None): 84 """Defines the computation performed.""" 85 weights = self.weights 86 loss = self.network(x) 87 if sens is None: 88 scaling_sens = self.loss_scale 89 else: 90 scaling_sens = sens 91 # alloc status and clear should be right before gradoperation 92 init = self.alloc_status() 93 init = F.depend(init, loss) 94 clear_status = self.clear_status(init) 95 scaling_sens = F.depend(scaling_sens, clear_status) 96 grads = self.grad(self.network, weights)(x, self.cast(scaling_sens, mstype.float32)) 97 # apply grad reducer on grads 98 grads = self.grad_reducer(grads) 99 grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) 100 init = F.depend(init, grads) 101 get_status = self.get_status(init) 102 init = F.depend(init, get_status) 103 flag_sum = self.reduce_sum(init, (0,)) 104 cond = self.less_equal(self.base, flag_sum) 105 overflow = cond 106 if sens is None: 107 overflow = self.loss_scaling_manager(self.loss_scale, cond) 108 if not overflow: 109 self.optimizer(grads) 110 return (loss, cond, scaling_sens) 111 112 113class DatasetLenet(MindData): 114 def __init__(self, predict, label, length=3): 115 super(DatasetLenet, self).__init__(size=length) 116 self.predict = predict 117 self.label = label 118 self.index = 0 119 self.length = length 120 121 def __iter__(self): 122 return self 123 124 def __next__(self): 125 if self.index >= self.length: 126 raise StopIteration 127 self.index += 1 128 return self.predict, self.label 129 130 def reset(self): 131 self.index = 0 132 133 134class LoopLayer(nn.Cell): 135 def __init__(self): 136 super(LoopLayer, self).__init__() 137 self.matmul = P.MatMul() 138 self.relu = P.ReLU() 139 self.matmul_weight = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight") 140 141 def construct(self, x): 142 out = self.matmul(x, self.matmul_weight) 143 out = self.relu(out) 144 return out 145 146 147class Net(nn.Cell): 148 def __init__(self): 149 super(Net, self).__init__() 150 self.exp = P.Exp() 151 self.mean = P.ReduceMean() 152 layers = [] 153 for _ in range(3): 154 layer = LoopLayer() 155 layers.append(layer) 156 self.layers = nn.CellList(layers) 157 158 def construct(self, x): 159 out = self.exp(x) 160 for layer in self.layers: 161 layer_out = layer(out) 162 out = layer_out 163 out = self.mean(out, -1) 164 return out 165 166 167class Net2(nn.Cell): 168 def __init__(self): 169 super(Net2, self).__init__() 170 self.matmul = P.MatMul() 171 self.relu = P.ReLU() 172 self.matmul_weight = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight") 173 174 def construct(self, x, b): 175 out = self.matmul(x, self.matmul_weight) 176 out = self.relu(out) 177 return out 178 179 180def test_loss_scale(): 181 context.set_context(mode=context.GRAPH_MODE) 182 context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=8) 183 predict = Tensor(np.ones([64, 64]), dtype=ms.float32) 184 label = Tensor(np.ones([64,]), dtype=ms.int32) 185 dataset = DatasetLenet(predict, label) 186 net = Net() 187 opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) 188 net = TrainOneStepWithLossScaleCell(net, opt, update_cell) 189 model = Model(network=net) 190 model.train(2, dataset, dataset_sink_mode=False) 191 192 193def test_loss_scale2(): 194 context.set_context(mode=context.GRAPH_MODE) 195 context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=8) 196 predict = Tensor(np.ones([64, 64]), dtype=ms.float32) 197 label = Tensor(np.ones([64,]), dtype=ms.int32) 198 dataset = DatasetLenet(predict, label) 199 net = Net2() 200 opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) 201 net = nn.TrainOneStepWithLossScaleCell(net, opt, update_cell) 202 model = Model(network=net) 203 model.train(2, dataset, dataset_sink_mode=False) 204