• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16
17import mindspore as ms
18import mindspore.nn as nn
19from mindspore import Tensor
20from mindspore import context
21from mindspore.common import dtype as mstype
22from mindspore.common.parameter import ParameterTuple
23from mindspore.communication.management import init
24from mindspore.nn import Dense, Cell
25from mindspore.nn.loss.loss import LossBase
26from mindspore.nn.optim import Momentum
27from mindspore.ops import composite as C
28from mindspore.ops import operations as P
29from mindspore.train import Model
30from mindspore.context import ParallelMode
31from mindspore.communication._comm_helper import GlobalComm
32
33context.set_context(mode=context.GRAPH_MODE)
34device_number = 32
35batch_size_per_device = 128
36
37
38class Dataset():
39    def __init__(self, predict, length=3):
40        self.predict = predict
41        self.index = 0
42        self.length = length
43
44    def __iter__(self):
45        return self
46
47    def __next__(self):
48        if self.index >= self.length:
49            raise StopIteration
50        self.index += 1
51        return (self.predict,)
52
53    def reset(self):
54        self.index = 0
55
56    def get_dataset_size(self):
57        return 128
58
59    def get_repeat_count(self):
60        return 1
61
62    def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
63        return self
64
65
66class GatherV2(LossBase):
67    def __init__(self, index_dim, strategy, index_size=16):
68        super(GatherV2, self).__init__()
69        self.pow = P.Pow()
70        emb1_list = 21
71        emb2_list = 2
72        if index_dim == 1:
73            emb_list = list(range(index_size))
74            emb1_list = emb_list[0::2]
75            emb2_list = emb_list[1::2]
76        if index_dim == 2:
77            emb_list = np.arange(index_size * 16)
78            emb1_list = np.reshape(emb_list[0::2], (int(index_size / 2), 16))
79            emb2_list = np.reshape(emb_list[1::2], (int(index_size / 2), 16))
80        self.emb1_param = Tensor(emb1_list, dtype=mstype.int32)
81        self.emb2_param = Tensor(emb2_list, dtype=mstype.int32)
82        self.gatherv2 = P.Gather().shard(strategy).add_prim_attr("data_parallel", True)
83
84    def construct(self, nembeddings):
85        emb1 = self.gatherv2(nembeddings, self.emb1_param, 0)
86        emb2 = self.gatherv2(nembeddings, self.emb2_param, 0)
87        return self.pow((emb1 - emb2), 2.0)
88
89
90def fc_with_initialize(input_channels, out_channels):
91    return Dense(input_channels, out_channels)
92
93
94class BuildTrainNetwork(nn.Cell):
95    def __init__(self, network, criterion):
96        super(BuildTrainNetwork, self).__init__()
97        self.network = network
98        self.criterion = criterion
99
100    def construct(self, input_data):
101        embeddings = self.network(input_data)
102        loss = self.criterion(embeddings)
103        return loss
104
105
106class TrainOneStepCell(Cell):
107    def __init__(self, network, optimizer, sens=1.0):
108        super(TrainOneStepCell, self).__init__(auto_prefix=False)
109        self.network = network
110        self.network.add_flags(defer_inline=True)
111        self.weights = ParameterTuple(network.trainable_params())
112        self.optimizer = optimizer
113        self.grad = C.GradOperation(get_by_list=True,
114                                    sens_param=True)
115        self.sens = sens
116
117    def construct(self, data):
118        weights = self.weights
119        loss = self.network(data)
120        sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
121        grads = self.grad(self.network, weights)(data, sens)
122
123        self.optimizer(grads)
124        return loss
125
126
127def net_trains(criterion, rank):
128    GlobalComm.CHECK_ENVS = False
129    init()
130    GlobalComm.CHECK_ENVS = True
131    lr = 0.1
132    momentum = 0.9
133    max_epoch = 20
134    input_channels = 256
135    out_channels = 512
136    context.set_context(mode=context.GRAPH_MODE)
137    context.reset_auto_parallel_context()
138    context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_number,
139                                      global_rank=rank)
140    predict = Tensor(np.ones([batch_size_per_device, input_channels]), dtype=ms.float32)
141    dataset = Dataset(predict, 4)
142
143    network = fc_with_initialize(input_channels, out_channels)
144    network.set_train()
145
146    train_network = BuildTrainNetwork(network, criterion)
147    train_network.set_train()
148    opt = Momentum(train_network.trainable_params(), lr, momentum)
149    train_net = TrainOneStepCell(train_network, opt).set_train()
150
151    model = Model(train_net)
152    model.train(max_epoch, dataset, dataset_sink_mode=False)
153    context.reset_auto_parallel_context()
154
155
156def test_auto_batch_parallel():
157    gather_v2_strategy = None
158    criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
159    rank = 2
160    net_trains(criterion, rank)
161
162
163def test_2d_index_auto_batch_parallel():
164    gather_v2_strategy = None
165    criterion = GatherV2(2, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
166    rank = 2
167    net_trains(criterion, rank)
168
169
170def test_batch_parallel():
171    gather_v2_strategy = ((device_number, 1),)
172    criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
173    rank = 2
174    net_trains(criterion, rank)
175
176
177def test_strategy1():
178    gather_v2_strategy = ((16, 2),)
179    rank = 2
180    criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
181    net_trains(criterion, rank)
182
183
184def test_strategy2():
185    gather_v2_strategy = ((1, device_number),)
186    rank = 2
187    criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
188    net_trains(criterion, rank)
189
190
191def test_strategy3():
192    gather_v2_strategy = ((8, 1),)
193    rank = 2
194    criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
195    net_trains(criterion, rank)
196
197
198class GatherV2Axis1(LossBase):
199    def __init__(self, index_dim, strategy, index_size=16):
200        super(GatherV2Axis1, self).__init__()
201        self.pow = P.Pow()
202        emb1_list = 21
203        emb2_list = 2
204        if index_dim == 1:
205            emb_list = list(range(index_size))
206            emb1_list = emb_list[0::2]
207            emb2_list = emb_list[1::2]
208        if index_dim == 2:
209            emb_list = np.arange(index_size * index_size)
210            emb1_list = np.reshape(emb_list[0::2], (int(index_size / 2), index_size))
211            emb2_list = np.reshape(emb_list[1::2], (int(index_size / 2), index_size))
212        self.emb1_param = Tensor(emb1_list, dtype=mstype.int32)
213        self.emb2_param = Tensor(emb2_list, dtype=mstype.int32)
214        self.gatherv2 = P.Gather().shard(strategy)
215
216    def construct(self, nembeddings):
217        emb1 = self.gatherv2(nembeddings, self.emb1_param, 1)
218        emb2 = self.gatherv2(nembeddings, self.emb2_param, 1)
219        return self.pow((emb1 - emb2), 2.0)
220
221
222def test_axis1_auto_batch_parallel():
223    gather_v2_strategy = None
224    criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
225    rank = 2
226    net_trains(criterion, rank)
227
228
229def test_axis1_batch_parallel():
230    gather_v2_strategy = ((device_number, 1), (1,))
231    criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
232    rank = 2
233    net_trains(criterion, rank)
234
235
236def test_axis1_strategy1():
237    gather_v2_strategy = ((16, 2), (1,))
238    rank = 17
239    criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
240    net_trains(criterion, rank)
241