• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import os
17import numpy as np
18
19import mindspore.ops.operations as P
20from mindspore.nn import Cell
21from mindspore.nn import Adam
22from mindspore.nn import MultiFieldEmbeddingLookup as embedding
23from mindspore import Tensor
24from mindspore import context
25from mindspore.train import Model
26from mindspore.train.callback import CheckpointConfig
27from mindspore.train.callback import ModelCheckpoint
28from mindspore.train.serialization import load_checkpoint
29from mindspore.train.serialization import load_param_into_net
30from mindspore.communication.management import init
31from mindspore.communication.management import release
32from mindspore.communication.management import get_rank
33from mindspore.communication.management import get_group_size
34from mindspore.context import ParallelMode
35
36
37context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
38
39
40def _count_unequal_element(data_expected, data_me, rtol, atol):
41    assert data_expected.shape == data_me.shape
42    total_count = len(data_expected.flatten())
43    error = np.abs(data_expected - data_me)
44    greater = np.greater(error, atol + np.abs(data_me) * rtol)
45    loss_count = np.count_nonzero(greater)
46    assert (loss_count / total_count) < rtol, \
47        "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
48            format(data_expected[greater], data_me[greater], error[greater])
49
50
51def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
52    if np.any(np.isnan(data_expected)):
53        assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
54    elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
55        _count_unequal_element(data_expected, data_me, rtol, atol)
56    else:
57        assert True
58
59def clean_all_ckpt_files(folder_path):
60    if os.path.exists(folder_path):
61        for file_name in os.listdir(folder_path):
62            if file_name.endswith('.ckpt') or file_name.endswith('.meta'):
63                os.remove(os.path.join(folder_path, file_name))
64
65
66def find_newest_ckpt_file(folder_path):
67    ckpt_files = map(lambda f: os.path.join(folder_path, f),
68                     filter(lambda f: f.endswith('.ckpt'),
69                            os.listdir(folder_path)))
70    return max(ckpt_files, key=os.path.getctime)
71
72
73class FakeDataInitMode:
74    RandomInit = 0
75    OnesInit = 1
76    UniqueInit = 2
77    ZerosInit = 3
78
79
80
81
82class FakeData:
83    def __init__(self, size=1024, batch_size=32, image_size=(3, 224, 224),
84                 num_classes=10, random_offset=0, use_parallel=False,
85                 fakedata_mode=FakeDataInitMode.RandomInit):
86        self.size = size
87        self.rank_batch_size = batch_size
88        self.total_batch_size = self.rank_batch_size
89        self.random_offset = random_offset
90        self.image_size = image_size
91        self.num_classes = num_classes
92        self.rank_size = 1
93        self.rank_id = 0
94        self.batch_index = 0
95        self.image_data_type = np.float32
96        self.label_data_type = np.float32
97        self.is_onehot = True
98        self.fakedata_mode = fakedata_mode
99
100        if use_parallel is True:
101            init(backend_name='nccl')
102            self.rank_size = get_group_size()
103            self.rank_id = get_rank()
104
105        self.total_batch_size = self.rank_batch_size * self.rank_size
106
107        assert (self.size % self.total_batch_size) == 0
108
109        self.total_batch_data_size = (self.rank_size, self.rank_batch_size) + image_size
110
111    def get_dataset_size(self):
112        return int(self.size / self.total_batch_size)
113
114    def get_repeat_count(self):
115        return 1
116
117    def set_image_data_type(self, data_type):
118        self.image_data_type = data_type
119
120    def set_label_data_type(self, data_type):
121        self.label_data_type = data_type
122
123    def set_label_onehot(self, is_onehot=True):
124        self.is_onehot = is_onehot
125
126    def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
127        _ = num_epochs
128        return self
129
130    def __getitem__(self, batch_index):
131        if batch_index * self.total_batch_size >= len(self):
132            raise IndexError("{} index out of range".format(self.__class__.__name__))
133        rng_state = np.random.get_state()
134        np.random.seed(batch_index + self.random_offset)
135        if self.fakedata_mode == FakeDataInitMode.OnesInit:
136            img = np.ones(self.total_batch_data_size)
137        elif self.fakedata_mode == FakeDataInitMode.ZerosInit:
138            img = np.zeros(self.total_batch_data_size)
139        elif self.fakedata_mode == FakeDataInitMode.UniqueInit:
140            total_size = 1
141            for i in self.total_batch_data_size:
142                total_size = total_size * i
143            img = np.reshape(np.arange(total_size) * 0.0001, self.total_batch_data_size)
144        else:
145            img = np.random.randn(*self.total_batch_data_size)
146        target = np.random.randint(0, self.num_classes, size=(self.rank_size, self.rank_batch_size))
147        np.random.set_state(rng_state)
148        img = img[self.rank_id]
149        target = target[self.rank_id]
150        img_ret = img.astype(self.image_data_type)
151        target_ret = target.astype(self.label_data_type)
152        if self.is_onehot:
153            target_onehot = np.zeros(shape=(self.rank_batch_size, self.num_classes))
154            target_onehot[np.arange(self.rank_batch_size), target] = 1
155            target_ret = target_onehot.astype(self.label_data_type)
156        return Tensor(img_ret), Tensor(target_ret)
157
158    def __len__(self):
159        return self.size
160
161    def __iter__(self):
162        self.batch_index = 0
163        return self
164
165    def reset(self):
166        self.batch_index = 0
167
168    def __next__(self):
169        if self.batch_index * self.total_batch_size < len(self):
170            data = self[self.batch_index]
171            self.batch_index += 1
172            return data
173        raise StopIteration
174
175
176
177class MultiHotNet(Cell):
178    def __init__(self, vocab_size, embedding_size, field_size,
179                 param_init, target, slice_mode, sparse, operator, indices, field_ids):
180        super().__init__()
181        self.embedding = embedding(vocab_size=vocab_size,
182                                   embedding_size=embedding_size, field_size=field_size,
183                                   param_init=param_init, target=target, slice_mode=slice_mode,
184                                   sparse=sparse, operator=operator)
185        self.relu = P.ReLU()
186        self.indices = Tensor(indices)
187        self.field_ids = Tensor(field_ids)
188        if slice_mode == "table_column_slice":
189            self.relu.shard(((1, 1, 8),))
190        elif slice_mode == "table_row_slice":
191            self.relu.shard(((8, 1, 1),))
192        elif slice_mode == "batch_slice":
193            self.relu.shard(((8, 1, 1),))
194
195    def construct(self, values, label):
196        x = self.embedding(self.indices, values, self.field_ids)
197        output = self.relu(x)
198        return output
199
200
201class ParallelMultiHotFactory:
202    def __init__(self, vocab_size, embedding_size, field_size,
203                 param_init, target, slice_mode, sparse, operator, indices, field_ids):
204        self.vocab_size = vocab_size
205        self.embedding_size = embedding_size
206        self.field_size = field_size
207        self.param_init = param_init
208        self.target = target
209        self.slice_mode = slice_mode
210        self.sparse = sparse
211        self.operator = operator
212        self.indices = indices
213        self.field_ids = field_ids
214        self.global_rank_id = None
215        self.opt = None
216        self.model = None
217        self.standalone_ckpt = None
218        self.parallel_ckpt = None
219        self.loss_fn = None
220        self._init_parallel()
221        self._set_parallel_env()
222
223    def __enter__(self):
224        return self
225
226    def __exit__(self, exc_type, exc_val, exc_tb):
227        return
228
229    def __del__(self):
230        self._release_parallel()
231
232    def _set_parallel_env(self):
233        self.global_rank_id = get_rank()
234
235    def _init_parallel(self):
236        self._init_parallel_flag = False
237        init(backend_name='nccl')
238        self._init_parallel_flag = True
239
240    def _release_parallel(self):
241        release()
242
243    def _model_train_and_save_ckpt(self, net, dataset, epoch):
244        self.opt = Adam(params=net.get_parameters())
245        if self.target == 'CPU':
246            self.opt.target = self.target
247        if self.sparse:
248            context.set_context(enable_sparse=True)
249        self.model = Model(network=net,
250                           loss_fn=self.loss_fn,
251                           optimizer=self.opt)
252        ckpt_config = CheckpointConfig(keep_checkpoint_max=1)
253        ckpt_path = './rank_{}_ckpt'.format(self.global_rank_id)
254        ckpt_callback = ModelCheckpoint(prefix='parallel', directory=ckpt_path,
255                                        config=ckpt_config)
256        clean_all_ckpt_files(ckpt_path)
257        self.model.train(epoch=epoch,
258                         train_dataset=dataset,
259                         callbacks=[ckpt_callback],
260                         dataset_sink_mode=False)
261        newest_ckpt_file = find_newest_ckpt_file(ckpt_path)
262        return load_checkpoint(newest_ckpt_file)
263
264    def mindspore_auto_parallel_impl(self, dataset, epoch, device_num):
265        context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL,
266                                          device_num=device_num)
267        parallel_mode_net = MultiHotNet(vocab_size=self.vocab_size, embedding_size=self.embedding_size,
268                                        field_size=self.field_size, param_init=self.param_init, target=self.target,
269                                        slice_mode=self.slice_mode, sparse=self.sparse, operator=self.operator,
270                                        indices=self.indices, field_ids=self.field_ids)
271        self.parallel_ckpt = self._model_train_and_save_ckpt(net=parallel_mode_net, epoch=epoch, dataset=dataset)
272
273    def mindspore_standalone_impl(self, epoch, dataset):
274        context.set_auto_parallel_context(parallel_mode=ParallelMode.STAND_ALONE)
275        stand_alone_net = MultiHotNet(vocab_size=self.vocab_size, embedding_size=self.embedding_size,
276                                      field_size=self.field_size, param_init=self.param_init, target=self.target,
277                                      slice_mode=self.slice_mode, sparse=self.sparse, operator=self.operator,
278                                      indices=self.indices, field_ids=self.field_ids)
279        self.standalone_ckpt = self._model_train_and_save_ckpt(net=stand_alone_net,
280                                                               epoch=epoch, dataset=dataset)
281
282    def checkpoint_cmp(self, inputs_np, label):
283        standalone_net = MultiHotNet(vocab_size=self.vocab_size, embedding_size=self.embedding_size,
284                                     field_size=self.field_size, param_init=self.param_init, target=self.target,
285                                     slice_mode=self.slice_mode, sparse=self.sparse, operator=self.operator,
286                                     indices=self.indices, field_ids=self.field_ids)
287        parallel_net = MultiHotNet(vocab_size=self.vocab_size, embedding_size=self.embedding_size,
288                                   field_size=self.field_size, param_init=self.param_init, target=self.target,
289                                   slice_mode=self.slice_mode, sparse=self.sparse, operator=self.operator,
290                                   indices=self.indices, field_ids=self.field_ids)
291        load_param_into_net(standalone_net, self.standalone_ckpt)
292        load_param_into_net(parallel_net, self.parallel_ckpt)
293        standalone_out = standalone_net(Tensor(inputs_np), Tensor(label))
294        parallel_out = parallel_net(Tensor(inputs_np), Tensor(label))
295        allclose_nparray(standalone_out.asnumpy(), parallel_out.asnumpy(), 0.001, 0.001)
296
297def test_auto_parallel_multifieldembeddinglookup_device_table_column_slice_mean():
298    inputs_np = 10 * np.random.randn(64, 64).astype(np.float32)
299    label = 10 * np.random.randn(64, 64).astype(np.float32)
300    indices = np.random.randint(0, 9, (64, 64), np.int32)
301    field_ids = np.random.randint(0, 20, (64, 64), np.int32)
302    fact = ParallelMultiHotFactory(vocab_size=32, embedding_size=64, field_size=64, param_init='one', target='DEVICE',
303                                   slice_mode='table_column_slice', sparse=False, operator='MEAN',
304                                   indices=indices, field_ids=field_ids)
305
306    #stand alone
307    standalone_dataset = FakeData(size=64, batch_size=64, image_size=(64,))
308    fact.mindspore_standalone_impl(dataset=standalone_dataset, epoch=2)
309
310    #auto parallel
311    parallel_dataset = FakeData(size=64, batch_size=8, image_size=(64,), use_parallel=True)
312    fact.mindspore_auto_parallel_impl(dataset=parallel_dataset, epoch=2, device_num=8)
313
314    #compare
315    fact.checkpoint_cmp(inputs_np=inputs_np, label=label)
316