• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16
17import mindspore as ms
18from mindspore import Tensor
19from mindspore import context
20from mindspore.common.parameter import Parameter
21from mindspore.common import dtype as mstype
22from mindspore.ops import composite as C
23from mindspore.ops import operations as P
24from mindspore.ops import functional as F
25from mindspore.nn.optim.momentum import Momentum
26from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
27import mindspore.nn as nn
28from mindspore.train import Model
29from mindspore.context import ParallelMode
30from tests.dataset_mock import MindData
31
32
33GRADIENT_CLIP_TYPE = 1
34GRADIENT_CLIP_VALUE = 1.0
35clip_grad = C.MultitypeFuncGraph("clip_grad")
36grad_scale = C.MultitypeFuncGraph("grad_scale")
37reciprocal = P.Reciprocal()
38
39
40@grad_scale.register("Tensor", "Tensor")
41def tensor_grad_scale(scale, grad):
42    return grad * reciprocal(scale)
43
44
45update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000)
46
47
48@clip_grad.register("Number", "Number", "Tensor")
49def _clip_grad(clip_type, clip_value, grad):
50    dt = F.dtype(grad)
51    if clip_type == 0:
52        new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
53                                   F.cast(F.tuple_to_array((clip_value,)), dt))
54    else:
55        new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
56    return new_grad
57
58
59class TrainOneStepWithLossScaleCell(nn.Cell):
60    def __init__(self, network, optimizer, scale_update_cell=None):
61        super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
62        self.network = network
63        self.weights = optimizer.parameters
64        self.optimizer = optimizer
65        self.grad = C.GradOperation(get_by_list=True,
66                                    sens_param=True)
67        self.reducer_flag = False
68        self.grad_reducer = F.identity
69        self.cast = P.Cast()
70        self.alloc_status = P.NPUAllocFloatStatus()
71        self.get_status = P.NPUGetFloatStatus()
72        self.clear_status = P.NPUClearFloatStatus()
73        self.reduce_sum = P.ReduceSum(keep_dims=False)
74        self.base = Tensor(1, mstype.float32)
75        self.less_equal = P.LessEqual()
76        self.hyper_map = C.HyperMap()
77        self.loss_scale = None
78        self.loss_scaling_manager = scale_update_cell
79        if scale_update_cell:
80            self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
81                                        name="loss_scale")
82
83    def construct(self, x, sens=None):
84        """Defines the computation performed."""
85        weights = self.weights
86        loss = self.network(x)
87        if sens is None:
88            scaling_sens = self.loss_scale
89        else:
90            scaling_sens = sens
91        # alloc status and clear should be right before gradoperation
92        init = self.alloc_status()
93        init = F.depend(init, loss)
94        clear_status = self.clear_status(init)
95        scaling_sens = F.depend(scaling_sens, clear_status)
96        grads = self.grad(self.network, weights)(x, self.cast(scaling_sens, mstype.float32))
97        # apply grad reducer on grads
98        grads = self.grad_reducer(grads)
99        grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
100        init = F.depend(init, grads)
101        get_status = self.get_status(init)
102        init = F.depend(init, get_status)
103        flag_sum = self.reduce_sum(init, (0,))
104        cond = self.less_equal(self.base, flag_sum)
105        overflow = cond
106        if sens is None:
107            overflow = self.loss_scaling_manager(self.loss_scale, cond)
108        if not overflow:
109            self.optimizer(grads)
110        return (loss, cond, scaling_sens)
111
112
113class DatasetLenet(MindData):
114    def __init__(self, predict, label, length=3):
115        super(DatasetLenet, self).__init__(size=length)
116        self.predict = predict
117        self.label = label
118        self.index = 0
119        self.length = length
120
121    def __iter__(self):
122        return self
123
124    def __next__(self):
125        if self.index >= self.length:
126            raise StopIteration
127        self.index += 1
128        return self.predict, self.label
129
130    def reset(self):
131        self.index = 0
132
133
134class LoopLayer(nn.Cell):
135    def __init__(self):
136        super(LoopLayer, self).__init__()
137        self.matmul = P.MatMul()
138        self.relu = P.ReLU()
139        self.matmul_weight = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight")
140
141    def construct(self, x):
142        out = self.matmul(x, self.matmul_weight)
143        out = self.relu(out)
144        return out
145
146
147class Net(nn.Cell):
148    def __init__(self):
149        super(Net, self).__init__()
150        self.exp = P.Exp()
151        self.mean = P.ReduceMean()
152        layers = []
153        for _ in range(3):
154            layer = LoopLayer()
155            layers.append(layer)
156        self.layers = nn.CellList(layers)
157
158    def construct(self, x):
159        out = self.exp(x)
160        for layer in self.layers:
161            layer_out = layer(out)
162            out = layer_out
163        out = self.mean(out, -1)
164        return out
165
166
167class Net2(nn.Cell):
168    def __init__(self):
169        super(Net2, self).__init__()
170        self.matmul = P.MatMul()
171        self.relu = P.ReLU()
172        self.matmul_weight = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight")
173
174    def construct(self, x, b):
175        out = self.matmul(x, self.matmul_weight)
176        out = self.relu(out)
177        return out
178
179
180def test_loss_scale():
181    context.set_context(mode=context.GRAPH_MODE)
182    context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=8)
183    predict = Tensor(np.ones([64, 64]), dtype=ms.float32)
184    label = Tensor(np.ones([64,]), dtype=ms.int32)
185    dataset = DatasetLenet(predict, label)
186    net = Net()
187    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
188    net = TrainOneStepWithLossScaleCell(net, opt, update_cell)
189    model = Model(network=net)
190    model.train(2, dataset, dataset_sink_mode=False)
191
192
193def test_loss_scale2():
194    context.set_context(mode=context.GRAPH_MODE)
195    context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=8)
196    predict = Tensor(np.ones([64, 64]), dtype=ms.float32)
197    label = Tensor(np.ones([64,]), dtype=ms.int32)
198    dataset = DatasetLenet(predict, label)
199    net = Net2()
200    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
201    net = nn.TrainOneStepWithLossScaleCell(net, opt, update_cell)
202    model = Model(network=net)
203    model.train(2, dataset, dataset_sink_mode=False)
204