1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import os 17import numpy as np 18 19import mindspore.ops.operations as P 20from mindspore.nn import Cell 21from mindspore.nn import Adam 22from mindspore.nn import MultiFieldEmbeddingLookup as embedding 23from mindspore import Tensor 24from mindspore import context 25from mindspore.train import Model 26from mindspore.train.callback import CheckpointConfig 27from mindspore.train.callback import ModelCheckpoint 28from mindspore.train.serialization import load_checkpoint 29from mindspore.train.serialization import load_param_into_net 30from mindspore.communication.management import init 31from mindspore.communication.management import release 32from mindspore.communication.management import get_rank 33from mindspore.communication.management import get_group_size 34from mindspore.context import ParallelMode 35 36 37context.set_context(mode=context.GRAPH_MODE, device_target='GPU') 38 39 40def _count_unequal_element(data_expected, data_me, rtol, atol): 41 assert data_expected.shape == data_me.shape 42 total_count = len(data_expected.flatten()) 43 error = np.abs(data_expected - data_me) 44 greater = np.greater(error, atol + np.abs(data_me) * rtol) 45 loss_count = np.count_nonzero(greater) 46 assert (loss_count / total_count) < rtol, \ 47 "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ 48 format(data_expected[greater], data_me[greater], error[greater]) 49 50 51def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): 52 if np.any(np.isnan(data_expected)): 53 assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan) 54 elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan): 55 _count_unequal_element(data_expected, data_me, rtol, atol) 56 else: 57 assert True 58 59def clean_all_ckpt_files(folder_path): 60 if os.path.exists(folder_path): 61 for file_name in os.listdir(folder_path): 62 if file_name.endswith('.ckpt') or file_name.endswith('.meta'): 63 os.remove(os.path.join(folder_path, file_name)) 64 65 66def find_newest_ckpt_file(folder_path): 67 ckpt_files = map(lambda f: os.path.join(folder_path, f), 68 filter(lambda f: f.endswith('.ckpt'), 69 os.listdir(folder_path))) 70 return max(ckpt_files, key=os.path.getctime) 71 72 73class FakeDataInitMode: 74 RandomInit = 0 75 OnesInit = 1 76 UniqueInit = 2 77 ZerosInit = 3 78 79 80 81 82class FakeData: 83 def __init__(self, size=1024, batch_size=32, image_size=(3, 224, 224), 84 num_classes=10, random_offset=0, use_parallel=False, 85 fakedata_mode=FakeDataInitMode.RandomInit): 86 self.size = size 87 self.rank_batch_size = batch_size 88 self.total_batch_size = self.rank_batch_size 89 self.random_offset = random_offset 90 self.image_size = image_size 91 self.num_classes = num_classes 92 self.rank_size = 1 93 self.rank_id = 0 94 self.batch_index = 0 95 self.image_data_type = np.float32 96 self.label_data_type = np.float32 97 self.is_onehot = True 98 self.fakedata_mode = fakedata_mode 99 100 if use_parallel is True: 101 init(backend_name='nccl') 102 self.rank_size = get_group_size() 103 self.rank_id = get_rank() 104 105 self.total_batch_size = self.rank_batch_size * self.rank_size 106 107 assert (self.size % self.total_batch_size) == 0 108 109 self.total_batch_data_size = (self.rank_size, self.rank_batch_size) + image_size 110 111 def get_dataset_size(self): 112 return int(self.size / self.total_batch_size) 113 114 def get_repeat_count(self): 115 return 1 116 117 def set_image_data_type(self, data_type): 118 self.image_data_type = data_type 119 120 def set_label_data_type(self, data_type): 121 self.label_data_type = data_type 122 123 def set_label_onehot(self, is_onehot=True): 124 self.is_onehot = is_onehot 125 126 def create_tuple_iterator(self, num_epochs=-1, do_copy=True): 127 _ = num_epochs 128 return self 129 130 def __getitem__(self, batch_index): 131 if batch_index * self.total_batch_size >= len(self): 132 raise IndexError("{} index out of range".format(self.__class__.__name__)) 133 rng_state = np.random.get_state() 134 np.random.seed(batch_index + self.random_offset) 135 if self.fakedata_mode == FakeDataInitMode.OnesInit: 136 img = np.ones(self.total_batch_data_size) 137 elif self.fakedata_mode == FakeDataInitMode.ZerosInit: 138 img = np.zeros(self.total_batch_data_size) 139 elif self.fakedata_mode == FakeDataInitMode.UniqueInit: 140 total_size = 1 141 for i in self.total_batch_data_size: 142 total_size = total_size * i 143 img = np.reshape(np.arange(total_size) * 0.0001, self.total_batch_data_size) 144 else: 145 img = np.random.randn(*self.total_batch_data_size) 146 target = np.random.randint(0, self.num_classes, size=(self.rank_size, self.rank_batch_size)) 147 np.random.set_state(rng_state) 148 img = img[self.rank_id] 149 target = target[self.rank_id] 150 img_ret = img.astype(self.image_data_type) 151 target_ret = target.astype(self.label_data_type) 152 if self.is_onehot: 153 target_onehot = np.zeros(shape=(self.rank_batch_size, self.num_classes)) 154 target_onehot[np.arange(self.rank_batch_size), target] = 1 155 target_ret = target_onehot.astype(self.label_data_type) 156 return Tensor(img_ret), Tensor(target_ret) 157 158 def __len__(self): 159 return self.size 160 161 def __iter__(self): 162 self.batch_index = 0 163 return self 164 165 def reset(self): 166 self.batch_index = 0 167 168 def __next__(self): 169 if self.batch_index * self.total_batch_size < len(self): 170 data = self[self.batch_index] 171 self.batch_index += 1 172 return data 173 raise StopIteration 174 175 176 177class MultiHotNet(Cell): 178 def __init__(self, vocab_size, embedding_size, field_size, 179 param_init, target, slice_mode, sparse, operator, indices, field_ids): 180 super().__init__() 181 self.embedding = embedding(vocab_size=vocab_size, 182 embedding_size=embedding_size, field_size=field_size, 183 param_init=param_init, target=target, slice_mode=slice_mode, 184 sparse=sparse, operator=operator) 185 self.relu = P.ReLU() 186 self.indices = Tensor(indices) 187 self.field_ids = Tensor(field_ids) 188 if slice_mode == "table_column_slice": 189 self.relu.shard(((1, 1, 8),)) 190 elif slice_mode == "table_row_slice": 191 self.relu.shard(((8, 1, 1),)) 192 elif slice_mode == "batch_slice": 193 self.relu.shard(((8, 1, 1),)) 194 195 def construct(self, values, label): 196 x = self.embedding(self.indices, values, self.field_ids) 197 output = self.relu(x) 198 return output 199 200 201class ParallelMultiHotFactory: 202 def __init__(self, vocab_size, embedding_size, field_size, 203 param_init, target, slice_mode, sparse, operator, indices, field_ids): 204 self.vocab_size = vocab_size 205 self.embedding_size = embedding_size 206 self.field_size = field_size 207 self.param_init = param_init 208 self.target = target 209 self.slice_mode = slice_mode 210 self.sparse = sparse 211 self.operator = operator 212 self.indices = indices 213 self.field_ids = field_ids 214 self.global_rank_id = None 215 self.opt = None 216 self.model = None 217 self.standalone_ckpt = None 218 self.parallel_ckpt = None 219 self.loss_fn = None 220 self._init_parallel() 221 self._set_parallel_env() 222 223 def __enter__(self): 224 return self 225 226 def __exit__(self, exc_type, exc_val, exc_tb): 227 return 228 229 def __del__(self): 230 self._release_parallel() 231 232 def _set_parallel_env(self): 233 self.global_rank_id = get_rank() 234 235 def _init_parallel(self): 236 self._init_parallel_flag = False 237 init(backend_name='nccl') 238 self._init_parallel_flag = True 239 240 def _release_parallel(self): 241 release() 242 243 def _model_train_and_save_ckpt(self, net, dataset, epoch): 244 self.opt = Adam(params=net.get_parameters()) 245 if self.target == 'CPU': 246 self.opt.target = self.target 247 if self.sparse: 248 context.set_context(enable_sparse=True) 249 self.model = Model(network=net, 250 loss_fn=self.loss_fn, 251 optimizer=self.opt) 252 ckpt_config = CheckpointConfig(keep_checkpoint_max=1) 253 ckpt_path = './rank_{}_ckpt'.format(self.global_rank_id) 254 ckpt_callback = ModelCheckpoint(prefix='parallel', directory=ckpt_path, 255 config=ckpt_config) 256 clean_all_ckpt_files(ckpt_path) 257 self.model.train(epoch=epoch, 258 train_dataset=dataset, 259 callbacks=[ckpt_callback], 260 dataset_sink_mode=False) 261 newest_ckpt_file = find_newest_ckpt_file(ckpt_path) 262 return load_checkpoint(newest_ckpt_file) 263 264 def mindspore_auto_parallel_impl(self, dataset, epoch, device_num): 265 context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, 266 device_num=device_num) 267 parallel_mode_net = MultiHotNet(vocab_size=self.vocab_size, embedding_size=self.embedding_size, 268 field_size=self.field_size, param_init=self.param_init, target=self.target, 269 slice_mode=self.slice_mode, sparse=self.sparse, operator=self.operator, 270 indices=self.indices, field_ids=self.field_ids) 271 self.parallel_ckpt = self._model_train_and_save_ckpt(net=parallel_mode_net, epoch=epoch, dataset=dataset) 272 273 def mindspore_standalone_impl(self, epoch, dataset): 274 context.set_auto_parallel_context(parallel_mode=ParallelMode.STAND_ALONE) 275 stand_alone_net = MultiHotNet(vocab_size=self.vocab_size, embedding_size=self.embedding_size, 276 field_size=self.field_size, param_init=self.param_init, target=self.target, 277 slice_mode=self.slice_mode, sparse=self.sparse, operator=self.operator, 278 indices=self.indices, field_ids=self.field_ids) 279 self.standalone_ckpt = self._model_train_and_save_ckpt(net=stand_alone_net, 280 epoch=epoch, dataset=dataset) 281 282 def checkpoint_cmp(self, inputs_np, label): 283 standalone_net = MultiHotNet(vocab_size=self.vocab_size, embedding_size=self.embedding_size, 284 field_size=self.field_size, param_init=self.param_init, target=self.target, 285 slice_mode=self.slice_mode, sparse=self.sparse, operator=self.operator, 286 indices=self.indices, field_ids=self.field_ids) 287 parallel_net = MultiHotNet(vocab_size=self.vocab_size, embedding_size=self.embedding_size, 288 field_size=self.field_size, param_init=self.param_init, target=self.target, 289 slice_mode=self.slice_mode, sparse=self.sparse, operator=self.operator, 290 indices=self.indices, field_ids=self.field_ids) 291 load_param_into_net(standalone_net, self.standalone_ckpt) 292 load_param_into_net(parallel_net, self.parallel_ckpt) 293 standalone_out = standalone_net(Tensor(inputs_np), Tensor(label)) 294 parallel_out = parallel_net(Tensor(inputs_np), Tensor(label)) 295 allclose_nparray(standalone_out.asnumpy(), parallel_out.asnumpy(), 0.001, 0.001) 296 297def test_auto_parallel_multifieldembeddinglookup_device_table_column_slice_mean(): 298 inputs_np = 10 * np.random.randn(64, 64).astype(np.float32) 299 label = 10 * np.random.randn(64, 64).astype(np.float32) 300 indices = np.random.randint(0, 9, (64, 64), np.int32) 301 field_ids = np.random.randint(0, 20, (64, 64), np.int32) 302 fact = ParallelMultiHotFactory(vocab_size=32, embedding_size=64, field_size=64, param_init='one', target='DEVICE', 303 slice_mode='table_column_slice', sparse=False, operator='MEAN', 304 indices=indices, field_ids=field_ids) 305 306 #stand alone 307 standalone_dataset = FakeData(size=64, batch_size=64, image_size=(64,)) 308 fact.mindspore_standalone_impl(dataset=standalone_dataset, epoch=2) 309 310 #auto parallel 311 parallel_dataset = FakeData(size=64, batch_size=8, image_size=(64,), use_parallel=True) 312 fact.mindspore_auto_parallel_impl(dataset=parallel_dataset, epoch=2, device_num=8) 313 314 #compare 315 fact.checkpoint_cmp(inputs_np=inputs_np, label=label) 316