• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16
17import mindspore as ms
18import mindspore.nn as nn
19from mindspore import Tensor, Parameter
20from mindspore import context
21from mindspore.common import dtype as mstype
22from mindspore.common.api import _cell_graph_executor
23from mindspore.nn.cell import Cell
24from mindspore.nn.optim.momentum import Momentum
25from mindspore.ops import composite as C
26from mindspore.ops import functional as F
27from mindspore.ops import operations as P
28from mindspore.train import Model
29from mindspore.context import ParallelMode
30from tests.dataset_mock import MindData
31from tests.ut.python.ops.test_math_ops import VirtualLoss
32
33
34grad_all = C.GradOperation(get_all=True)
35
36
37device_num = 16
38device_id = 2
39
40
41class StrategyModel():
42    onehot_strategy = ((1, device_num), (), ())
43    twod_strategy = ((1, device_num),)
44    twod_strategy_m = ((device_num, 1),)
45    scalar_twod_strategy = ((), (1, device_num))
46    twod_scalar_strategy = ((1, device_num), ())
47    scalar_strategy = ((),)
48    oned_strategy = ((1,),)
49    scalar_scalar_strategy = ((), ())
50    twod_twod_strategy = ((1, device_num), (1, device_num))
51    twod_twodbc_strategy = ((1, device_num), (1, 1))
52    twodbc_twod_strategy = ((1, 1), (device_num, 1))
53
54
55class StrategyBatch():
56    onehot_strategy = ((device_num, 1), (), ())
57    twod_strategy = ((1, device_num),)
58    twod_strategy_m = ((device_num, 1),)
59    scalar_twod_strategy = ((), (1, device_num))
60    twod_scalar_strategy = ((1, device_num), ())
61    scalar_strategy = ((),)
62    oned_strategy = ((1,),)
63    scalar_scalar_strategy = ((), ())
64    twod_twod_strategy = ((1, device_num), (1, device_num))
65    twod_twodbc_strategy = ((1, device_num), (1, 1))
66    twodbc_twod_strategy = ((1, 1), (device_num, 1))
67
68
69class Args():
70    a = 1
71    b = 2
72    c = 3
73    d = 4
74    e = 5
75    num_classes = 512
76    emb_size = 512
77
78
79class SemiAutoOneHotNet(Cell):
80    def __init__(self, args, strategy):
81        super(SemiAutoOneHotNet, self).__init__()
82        self.a = args.a
83        self.b = args.b
84        self.c = args.c
85        self.d = args.d
86        self.e = args.e
87        self.cast = P.Cast()
88        self.cast.shard(strategy=strategy.twod_strategy)
89        self.cast1 = P.Cast()
90        self.cast1.shard(strategy=strategy.twod_strategy)
91        self.cast2 = P.Cast()
92        self.cast2.shard(strategy=strategy.twod_strategy)
93        self.cast3 = P.Cast()
94        self.cast3.shard(strategy=strategy.scalar_strategy)
95        self.cast4 = P.Cast()
96        self.cast4.shard(strategy=strategy.scalar_strategy)
97        self.a_const = Tensor(self.a, dtype=mstype.float32)
98        self.b_const = Tensor(self.b, dtype=mstype.float32)
99        self.c_const = Tensor(self.c, dtype=mstype.float32)
100        self.d_const = Tensor(self.d, dtype=mstype.float32)
101        self.e_const = Tensor(self.e, dtype=mstype.float32)
102        self.m_const_zero = Tensor(0, dtype=mstype.float32)
103        self.a_const_one = Tensor(1, dtype=mstype.float32)
104        self.onehot = P.OneHot()
105        self.onehot.shard(strategy=strategy.onehot_strategy)
106        self.exp = P.Exp()
107        self.exp.shard(strategy=strategy.twod_strategy)
108        self.exp2 = P.Exp()
109        self.exp2.shard(strategy=strategy.twod_strategy)
110        self.exp3 = P.Exp()
111        self.exp3.shard(strategy=strategy.twod_strategy)
112        self.mul_const = P.Mul()
113        self.mul_const.shard(strategy=strategy.scalar_twod_strategy)
114        self.mul_const2 = P.Add()
115        self.mul_const2.shard(strategy=strategy.scalar_twod_strategy)
116        self.mul_const3 = P.Sub()
117        self.mul_const3.shard(strategy=strategy.twod_scalar_strategy)
118        self.mul_const4 = P.Sub()
119        self.mul_const4.shard(strategy=strategy.scalar_twod_strategy)
120        self.mul_const5 = P.Mul()
121        self.mul_const5.shard(strategy=strategy.twod_scalar_strategy)
122        self.mul = P.Mul()
123        self.mul.shard(strategy=strategy.twod_twod_strategy)
124        self.mul2 = P.Mul()
125        self.mul2.shard(strategy=strategy.twod_twod_strategy)
126        self.mul3 = P.Add()
127        self.mul3.shard(strategy=strategy.twod_twod_strategy)
128        self.mul4 = P.Sub()
129        self.mul4.shard(strategy=strategy.twod_twodbc_strategy)
130        self.mul5 = P.RealDiv()
131        self.mul5.shard(strategy=strategy.twod_twodbc_strategy)
132        self.mul6 = P.Mul()
133        self.mul6.shard(strategy=strategy.twod_twod_strategy)
134        self.mul7 = P.Mul()
135        self.mul7.shard(strategy=strategy.twod_scalar_strategy)
136        self.mul8 = P.RealDiv()
137        self.mul8.shard(strategy=strategy.scalar_scalar_strategy)
138        self.mul9 = P.Add()
139        self.mul9.shard(strategy=strategy.twod_scalar_strategy)
140
141        self.reduce_max = P.ReduceMax(keep_dims=True)
142        self.reduce_max.shard(strategy=strategy.twod_strategy)
143
144        self.reduce_sum = P.ReduceSum(keep_dims=False)
145        self.reduce_sum.shard(strategy=strategy.twod_strategy)
146        self.reduce_sum_2 = P.ReduceSum(keep_dims=False)
147        self.reduce_sum_2.shard(strategy=strategy.twod_strategy)
148        self.reduce_sum_3 = P.ReduceSum(keep_dims=False)
149        self.reduce_sum_3.shard(strategy=strategy.oned_strategy)
150
151        self.reshape = P.Reshape()
152        self.log = P.Log()
153        self.log.shard(strategy=strategy.twod_strategy)
154
155        self.on_value = Tensor(1.0, mstype.float32)
156        self.off_value = Tensor(0.0, mstype.float32)
157        self.normalize = P.L2Normalize(axis=1)
158        self.normalize.shard(strategy=strategy.twod_strategy_m)
159        self.normalize2 = P.L2Normalize(axis=1)
160        self.normalize2.shard(strategy=strategy.twod_strategy_m)
161        self.fc = P.MatMul(transpose_b=True)
162        self.fc.shard(strategy=strategy.twodbc_twod_strategy)
163        weight_shape = [args.num_classes, args.emb_size]
164        weight_np = np.zeros(weight_shape, np.float32)
165        self.weight = Parameter(Tensor(weight_np), name='model_parallel_weight')
166
167    def construct(self, input_, label):
168        input_n = self.normalize(input_)
169        w = self.normalize2(self.weight)
170        fc_o = self.fc(input_n, w)
171        fc_o_shape = F.shape(fc_o)
172        one_hot_float = self.onehot(label, fc_o_shape[1], self.on_value, self.off_value)
173        local_label = self.cast(one_hot_float, mstype.int32)
174
175        exp_o = self.exp(fc_o)
176        mul_const_o = self.mul_const(self.a_const, exp_o)
177        mul_const2_o = self.mul_const2(self.b_const, mul_const_o)
178        exp2_o = self.exp2(mul_const2_o)
179        mul_const3_o = self.mul_const3(exp2_o, self.c_const)
180        mul_const4_o = self.mul_const4(F.scalar_to_array(1), local_label)
181        mul6_o = self.mul6(self.mul(mul_const3_o, one_hot_float),
182                           self.mul2(fc_o, self.cast2(mul_const4_o, mstype.float32)))
183        mul_const5_o = self.mul_const5(mul6_o, self.d_const)
184
185        max_o = self.reduce_max(mul_const5_o, -1)
186        mul4_o = self.mul4(mul_const5_o, max_o)
187        exp3_o = self.exp3(mul4_o)
188        sum_o = self.reduce_sum(exp3_o, -1)
189        reshape_o = self.reshape(sum_o, (F.shape(sum_o)[0], 1))
190        mul5_o = self.mul5(exp3_o, reshape_o)
191        log_o = self.log(self.mul9(mul5_o, self.e_const))
192        mul3_o = self.mul3(log_o, one_hot_float)
193        mul7_o = self.mul7(mul3_o, self.cast3(F.scalar_to_array(-1), mstype.float32))
194        sum2_o = self.reduce_sum_2(mul7_o, -1)
195        loss = self.mul8(self.reduce_sum_3(sum2_o, -1),
196                         self.cast4(F.scalar_to_array(F.shape(mul_const5_o)[0]), mstype.float32))
197        return loss
198
199
200class Dataset(MindData):
201    def __init__(self, predict, label, length=3, input_num=2):
202        super(Dataset, self).__init__(size=length)
203        self.predict = predict
204        self.label = label
205        self.index = 0
206        self.length = length
207        self.input_num = input_num
208
209    def __iter__(self):
210        return self
211
212    def __next__(self):
213        if self.index >= self.length:
214            raise StopIteration
215        self.index += 1
216        if self.input_num == 2:
217            return (self.predict, self.label)
218        return (self.predict,)
219
220    def reset(self):
221        self.index = 0
222
223
224class NetWithLoss(nn.Cell):
225    def __init__(self, network):
226        super(NetWithLoss, self).__init__()
227        self.loss = VirtualLoss()
228        self.network = network
229
230    def construct(self, x, b):
231        predict = self.network(x, b)
232        return self.loss(predict)
233
234
235class GradWrap(nn.Cell):
236    def __init__(self, network):
237        super(GradWrap, self).__init__()
238        self.network = network
239
240    def construct(self, x, b):
241        return grad_all(self.network)(x, b)
242
243
244def bn_with_initialize(out_channels):
245    bn = nn.BatchNorm2d(out_channels, momentum=0.3, eps=1e-5).add_flags_recursive(fp32=True)
246    return bn
247
248
249def fc_with_initialize(input_channels, out_channels):
250    return nn.Dense(input_channels, out_channels)
251
252
253class BNReshapeDenseBNNet(nn.Cell):
254    def __init__(self):
255        super(BNReshapeDenseBNNet, self).__init__()
256        self.batch_norm = bn_with_initialize(2)
257        self.reshape = P.Reshape()
258        self.batch_norm2 = nn.BatchNorm1d(512, affine=False)
259        self.fc = fc_with_initialize(2 * 32 * 32, 512)
260        self.loss = SemiAutoOneHotNet(args=Args(), strategy=StrategyBatch())
261
262    def construct(self, x, label):
263        x = self.batch_norm(x)
264        x = self.reshape(x, (16, 2 * 32 * 32))
265        x = self.fc(x)
266        x = self.batch_norm2(x)
267        loss = self.loss(x, label)
268        return loss
269
270
271def test_bn_reshape_dense_bn_train_loss():
272    batch_size = 16
273    context.set_auto_parallel_context(device_num=device_num, global_rank=0)
274    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
275    input_ = Tensor(np.ones([batch_size, 2, 32, 32]).astype(np.float32) * 0.01)
276    label = Tensor(np.ones([batch_size]), dtype=ms.int32)
277
278    net = GradWrap(NetWithLoss(BNReshapeDenseBNNet()))
279    net.set_auto_parallel()
280
281    net.set_train()
282    _cell_graph_executor.compile(net, input_, label)
283
284
285def test_semi_one_hot_net_batch():
286    batch_size = 16
287    context.set_auto_parallel_context(device_num=device_num, global_rank=0)
288    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
289    input_ = Tensor(np.ones([batch_size * 1, 512]).astype(np.float32) * 0.01)
290    label = Tensor(np.ones([batch_size]), dtype=ms.int32)
291
292    net = SemiAutoOneHotNet(args=Args(), strategy=StrategyBatch())
293    net = GradWrap(NetWithLoss(net))
294    net.set_auto_parallel()
295
296    net.set_train()
297    _cell_graph_executor.compile(net, input_, label)
298
299
300def test_semi_one_hot_net_model():
301    batch_size = 16
302    learning_rate = 0.1
303    momentum = 0.9
304    epoch_size = 2
305
306    predict = Tensor(np.ones([batch_size, 512]), dtype=ms.float32)
307    label = Tensor(np.ones([batch_size]), dtype=ms.int32)
308    dataset = Dataset(predict, label, 2, input_num=2)
309
310    context.reset_auto_parallel_context()
311    context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=16)
312    context.set_context(mode=context.GRAPH_MODE)
313    net = SemiAutoOneHotNet(args=Args(), strategy=StrategyModel())
314    opt = Momentum(net.trainable_params(), learning_rate, momentum)
315    model = Model(net, optimizer=opt)
316    model.train(epoch_size, dataset, dataset_sink_mode=False)
317