1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import os 16import numpy as np 17from mindspore.communication.management import get_rank 18from mindspore import Tensor 19from mindspore import Parameter 20from mindspore import context 21from mindspore.ops import operations as P 22import mindspore.nn as nn 23from mindspore.train import Model 24from mindspore.context import ParallelMode 25from mindspore.communication.management import init 26from mindspore.communication.management import get_group_size 27 28 29class FakeDataInitMode: 30 RandomInit = 0 31 OnesInit = 1 32 UniqueInit = 2 33 ZerosInit = 3 34 35class FakeData: 36 def __init__(self, size=1024, batch_size=32, image_size=(3, 224, 224), num_class=10, 37 random_offset=0, use_parallel=False, fakedata_mode=FakeDataInitMode.RandomInit): 38 39 self.size = size 40 self.rank_batch_size = batch_size 41 self.total_batch_size = self.rank_batch_size 42 self.random_offset = random_offset 43 self.image_size = image_size 44 self.num_class = num_class 45 self.rank_size = 1 46 self.rank_id = 0 47 self.batch_index = 0 48 self.image_data_type = np.float32 49 self.label_data_type = np.float32 50 self.is_onehot = True 51 self.fakedata_mode = fakedata_mode 52 53 if use_parallel: 54 if 'CONTEXT_DEVICE_TARGET' in os.environ and os.environ['CONTEXT_DEVICE_TARGET'] == 'GPU': 55 init(backend_name='nccl') 56 else: 57 init(backend_name='hccl') 58 self.rank_size = get_group_size() 59 self.rank_id = get_rank() 60 self.total_batch_size = self.rank_batch_size * self.rank_size 61 assert self.size % self.total_batch_size == 0 62 self.total_batch_data_size = (self.rank_size, self.rank_batch_size) + image_size 63 64 def get_dataset_size(self): 65 return int(self.size / self.total_batch_size) 66 67 def get_reeat_count(self): 68 return 1 69 70 def set_image_data_type(self, data_type): 71 self.image_data_type = data_type 72 73 def set_label_data_type(self, data_type): 74 self.label_data_type = data_type 75 76 def set_label_onehot(self, is_onehot=True): 77 self.is_onehot = is_onehot 78 79 def create_tuple_iterator(self, num_epochs=-1, do_copy=False): 80 return self 81 82 def __getitem__(self, batch_index): 83 if batch_index * self.total_batch_size >= len(self): 84 raise IndexError("{} index out of range".format(self.__class__.__name__)) 85 rng_state = np.random.get_state() 86 np.random.seed(batch_index + self.random_offset) 87 if self.fakedata_mode == FakeDataInitMode.OnesInit: 88 img = np.ones(self.total_batch_data_size) 89 elif self.fakedata_mode == FakeDataInitMode.ZerosInit: 90 img = np.zeros(self.total_batch_data_size) 91 elif self.fakedata_mode == FakeDataInitMode.UniqueInit: 92 total_size = 1 93 for i in self.total_batch_data_size: 94 total_size = total_size* i 95 img = np.reshape(np.arange(total_size)*0.0001, self.total_batch_data_size) 96 else: 97 img = np.random.randn(*self.total_batch_data_size) 98 target = np.random.randint(0, self.num_class, size=(self.rank_size, self.rank_batch_size)) 99 np.random.set_state(rng_state) 100 img = img[self.rank_id] 101 target = target[self.rank_id] 102 img_ret = img.astype(self.image_data_type) 103 target_ret = target.astype(self.label_data_type) 104 if self.is_onehot: 105 target_onehot = np.zeros(shape=(self.rank_batch_size, self.num_class)) 106 target_onehot[np.arange(self.rank_batch_size), target] = 1 107 target_ret = target_onehot.astype(self.label_data_type) 108 return Tensor(img_ret), Tensor(target_ret) 109 110 def __len__(self): 111 return self.size 112 113 def __iter__(self): 114 self.batch_index = 0 115 return self 116 117 def reset(self): 118 self.batch_index = 0 119 120 def __next__(self): 121 if self.batch_index * self.total_batch_size < len(self): 122 data = self[self.batch_index] 123 self.batch_index += 1 124 return data 125 raise StopIteration 126 127 128class NetWithSparseGatherV2(nn.Cell): 129 def __init__(self, strategy=None, sparse=True): 130 super(NetWithSparseGatherV2, self).__init__() 131 self.axis = 0 132 self.sparse = sparse 133 if sparse: 134 self.weight = Parameter(Tensor(np.ones([8, 8]).astype(np.float32)), name="weight") 135 self.gather = P.SparseGatherV2() 136 else: 137 self.weight = Parameter(Tensor(np.ones([8, 8]).astype(np.float32)), name="weight") 138 self.gather = P.Gather() 139 if strategy is not None: 140 self.gather.shard(strategy) 141 142 def construct(self, indices): 143 x = self.gather(self.weight, indices, self.axis) 144 return x 145 146 def train_mindspore_impl(self, indices, epoch, batch_size, use_parallel=True): 147 ds = FakeData(size=8, batch_size=batch_size, num_class=8, image_size=(), use_parallel=use_parallel) 148 ds.set_image_data_type(np.int32) 149 net = self 150 net.set_train() 151 loss = nn.SoftmaxCrossEntropyWithLogits() 152 optimizer = nn.Adam(net.trainable_params()) 153 optimizer.target = "CPU" 154 model = Model(net, loss, optimizer) 155 for _ in range(epoch): 156 model.train(1, ds, dataset_sink_mode=False) 157 output = net(indices) 158 return output 159 160 161def test_allreduce_sparsegatherv2_adam_auto_parallel(): 162 context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') 163 init(backend_name='hccl') 164 context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=8, gradients_mean=True) 165 indices = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7]).astype(np.int32)) 166 epoch = 3 167 batch_size = 1 168 context.set_context(enable_sparse=True) 169 net = NetWithSparseGatherV2(sparse=True) 170 output_sparse = net.train_mindspore_impl(indices, epoch, batch_size) 171 net = NetWithSparseGatherV2(sparse=False) 172 output = net.train_mindspore_impl(indices, epoch, batch_size) 173 assert np.allclose(output.asnumpy(), output_sparse.asnumpy(), 0.001, 0.001) 174