1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16import pytest 17 18import mindspore as ms 19import mindspore.nn as nn 20from mindspore import Tensor, context 21from mindspore.common.api import _cell_graph_executor 22from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits 23from mindspore.nn.optim.momentum import Momentum 24from mindspore.parallel import _cost_model_context as cost_model_context 25from mindspore.parallel._auto_parallel_context import auto_parallel_context 26from mindspore.train import Model 27from mindspore.context import ParallelMode 28from tests.dataset_mock import MindData 29 30 31class Dataset(MindData): 32 def __init__(self, predict, label, length=3): 33 super(Dataset, self).__init__(size=length) 34 self.predict = predict 35 self.label = label 36 self.index = 0 37 self.length = length 38 39 def __iter__(self): 40 return self 41 42 def __next__(self): 43 if self.index >= self.length: 44 raise StopIteration 45 self.index += 1 46 return self.predict, self.label 47 48 def reset(self): 49 self.index = 0 50 51 52class DenseNet1(nn.Cell): 53 def __init__(self, has_bias=True, activation='relu'): 54 super(DenseNet1, self).__init__() 55 self.fc1 = nn.Dense(128, 128, has_bias=has_bias, activation=activation) 56 self.fc2 = nn.Dense(128, 128, has_bias=has_bias, activation=activation) 57 self.fc3 = nn.Dense(128, 128, has_bias=has_bias, activation=activation) 58 self.fc4 = nn.Dense(128, 128, has_bias=has_bias, activation=activation) 59 60 def construct(self, x): 61 q = self.fc1(x) 62 k = self.fc2(q) 63 v = self.fc3(k) 64 s = self.fc4(v) 65 return s 66 67 68class DenseNet2(nn.Cell): 69 def __init__(self, has_bias=True, activation='relu'): 70 super(DenseNet2, self).__init__() 71 self.fc1 = nn.Dense(128, 128, has_bias=has_bias, activation=activation) 72 self.fc2 = nn.Dense(128, 128, has_bias=has_bias, activation=activation) 73 self.fc3 = nn.Dense(128, 128, has_bias=has_bias, activation=activation) 74 self.fc4 = nn.Dense(128, 128, has_bias=has_bias, activation=activation) 75 self.fc5 = nn.Dense(128, 128, has_bias=has_bias, activation=activation) 76 self.fc6 = nn.Dense(128, 128, has_bias=has_bias, activation=activation) 77 self.fc7 = nn.Dense(128, 128, has_bias=has_bias, activation=activation) 78 self.fc8 = nn.Dense(128, 128, has_bias=has_bias, activation=activation) 79 80 def construct(self, x): 81 q = self.fc1(x) 82 k = self.fc2(q) 83 v = self.fc3(k) 84 s = self.fc4(v) 85 t = self.fc5(s) 86 u = self.fc6(t) 87 w = self.fc7(u) 88 z = self.fc8(w) 89 return z 90 91 92class SimpleDMLNet(nn.Cell): 93 def __init__(self, net1, net2): 94 super(SimpleDMLNet, self).__init__() 95 self.backbone1 = net1 96 self.backbone2 = net2 97 98 def construct(self, x): 99 x1 = self.backbone1(x) 100 x2 = self.backbone2(x) 101 return x1 + x2 102 103 104def train_common(net): 105 batch_size = 32 106 learning_rate = 0.1 107 momentum = 0.9 108 epoch_size = 2 109 device_num = 4 110 auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True) 111 context.set_auto_parallel_context(device_num=device_num, parameter_broadcast=False) 112 context.set_context(mode=context.GRAPH_MODE) 113 114 predict = Tensor(np.ones([batch_size, 128]), dtype=ms.float32) 115 label = Tensor(np.ones([batch_size]), dtype=ms.int32) 116 dataset = Dataset(predict, label, 2) 117 118 loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 119 opt = Momentum(net.trainable_params(), learning_rate, momentum) 120 model = Model(net, loss, opt) 121 122 model.train(epoch_size, dataset, dataset_sink_mode=False) 123 allreduce_fusion_dict = _cell_graph_executor._get_allreduce_fusion(model._train_network) 124 125 print(allreduce_fusion_dict) 126 return allreduce_fusion_dict 127 128 129@pytest.mark.skip(reason="depreciated feature") 130def test_allreduce_fusion_parameters(): 131 cost_model_context.reset_cost_model_context() 132 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2) 133 algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm') 134 assert algorithm == 2 135 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) 136 algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm') 137 assert algorithm == 1 138 cost_model_context.reset_cost_model_context() 139 algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm') 140 assert algorithm == 0 141 142 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) 143 fusion_times = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_times') 144 assert fusion_times == 2 145 146 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.2) 147 tail_percent = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_percent') 148 assert tail_percent == 0.2 149 cost_model_context.reset_cost_model_context() 150 tail_percent = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_percent') 151 assert tail_percent == 0.1 152 153 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.2) 154 tail_time = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_time') 155 assert tail_time == 0.2 156 cost_model_context.reset_cost_model_context() 157 tail_time = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_time') 158 assert tail_time == 0.1 159 160 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.2) 161 allreduce_inherent_time = cost_model_context.get_cost_model_context( 162 'costmodel_allreduce_fusion_allreduce_inherent_time') 163 assert allreduce_inherent_time == 0.2 164 cost_model_context.reset_cost_model_context() 165 allreduce_inherent_time = cost_model_context.get_cost_model_context( 166 'costmodel_allreduce_fusion_allreduce_inherent_time') 167 assert allreduce_inherent_time == 0.1 168 169 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_bandwidth=0.2) 170 allreduce_bandwidth = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_allreduce_bandwidth') 171 assert allreduce_bandwidth == 0.2 172 cost_model_context.reset_cost_model_context() 173 allreduce_bandwidth = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_allreduce_bandwidth') 174 assert allreduce_bandwidth == 0.1 175 176 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_computation_time_parameter=0.2) 177 computation_time_parameter = cost_model_context.get_cost_model_context( 178 'costmodel_allreduce_fusion_computation_time_parameter') 179 assert computation_time_parameter == 0.2 180 cost_model_context.reset_cost_model_context() 181 computation_time_parameter = cost_model_context.get_cost_model_context( 182 'costmodel_allreduce_fusion_computation_time_parameter') 183 assert computation_time_parameter == 0.1 184 185 186@pytest.mark.skip(reason="depreciated feature") 187def test_allreduce_fusion1(): 188 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) 189 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) 190 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5) 191 context.reset_auto_parallel_context() 192 context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) 193 net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None)) 194 allreduce_fusion_dict = train_common(net) 195 expect_dict = {'backbone2.fc8.weight': 2, 196 'backbone2.fc7.weight': 2, 197 'backbone2.fc6.weight': 2, 198 'backbone1.fc4.weight': 2, 199 'backbone1.fc3.weight': 2, 200 'backbone1.fc2.weight': 2, 201 'backbone2.fc5.weight': 1, 202 'backbone2.fc4.weight': 1, 203 'backbone2.fc3.weight': 1, 204 'backbone2.fc2.weight': 1, 205 'backbone2.fc1.weight': 1, 206 'backbone1.fc1.weight': 1} 207 assert allreduce_fusion_dict == expect_dict 208 cost_model_context.reset_cost_model_context() 209 210 211@pytest.mark.skip(reason="depreciated feature") 212# reset_cost_model_context is called, the default value of costmodel_allreduce_fusion_times is 0, step_allreduce_fusion 213# is bypassed. 214def test_allreduce_fusion2(): 215 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) 216 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5) 217 cost_model_context.reset_cost_model_context() 218 context.reset_auto_parallel_context() 219 context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) 220 net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None)) 221 allreduce_fusion_dict = train_common(net) 222 expect_dict = {} 223 assert allreduce_fusion_dict == expect_dict 224 cost_model_context.reset_cost_model_context() 225 226 227@pytest.mark.skip(reason="depreciated feature") 228def test_allreduce_fusion3(): 229 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) 230 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=3) 231 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.3333333) 232 context.reset_auto_parallel_context() 233 context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) 234 net = SimpleDMLNet(DenseNet1(has_bias=True, activation='relu'), DenseNet2(has_bias=False, activation='relu')) 235 allreduce_fusion_dict = train_common(net) 236 expect_dict = {'backbone2.fc8.weight': 3, 237 'backbone2.fc7.weight': 3, 238 'backbone2.fc6.weight': 2, 239 'backbone2.fc5.weight': 2, 240 'backbone2.fc4.weight': 2, 241 'backbone2.fc3.weight': 1, 242 'backbone2.fc2.weight': 1, 243 'backbone2.fc1.weight': 1, 244 'backbone1.fc4.bias': 3, 245 'backbone1.fc4.weight': 3, 246 'backbone1.fc3.bias': 3, 247 'backbone1.fc3.weight': 2, 248 'backbone1.fc2.bias': 2, 249 'backbone1.fc2.weight': 2, 250 'backbone1.fc1.bias': 2, 251 'backbone1.fc1.weight': 2} 252 assert allreduce_fusion_dict == expect_dict 253 cost_model_context.reset_cost_model_context() 254 255 256@pytest.mark.skip(reason="depreciated feature") 257def test_allreduce_fusion4(): 258 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) 259 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) 260 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5) 261 context.reset_auto_parallel_context() 262 context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) 263 net = SimpleDMLNet(DenseNet2(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None)) 264 allreduce_fusion_dict = train_common(net) 265 expect_dict = {'backbone2.fc8.weight': 2, 266 'backbone2.fc7.weight': 2, 267 'backbone2.fc6.weight': 2, 268 'backbone1.fc8.weight': 2, 269 'backbone1.fc7.weight': 2, 270 'backbone1.fc6.weight': 2, 271 'backbone2.fc5.weight': 1, 272 'backbone2.fc4.weight': 1, 273 'backbone2.fc3.weight': 1, 274 'backbone2.fc2.weight': 1, 275 'backbone2.fc1.weight': 1, 276 'backbone1.fc5.weight': 1, 277 'backbone1.fc4.weight': 1, 278 'backbone1.fc3.weight': 1, 279 'backbone1.fc2.weight': 1, 280 'backbone1.fc1.weight': 1} 281 282 assert allreduce_fusion_dict == expect_dict 283 cost_model_context.reset_cost_model_context() 284 285 286@pytest.mark.skip(reason="depreciated feature") 287def test_allreduce_fusion5(): 288 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2) 289 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.1) 290 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.05) 291 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_bandwidth=0.000001) 292 cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_computation_time_parameter=0.0000015) 293 context.reset_auto_parallel_context() 294 context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) 295 net = SimpleDMLNet(DenseNet2(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None)) 296 allreduce_fusion_dict = train_common(net) 297 298 expect_dict = {'backbone2.fc8.weight': 3, 299 'backbone2.fc7.weight': 3, 300 'backbone2.fc6.weight': 3, 301 'backbone2.fc5.weight': 3, 302 'backbone2.fc4.weight': 2, 303 'backbone2.fc3.weight': 2, 304 'backbone2.fc2.weight': 1, 305 'backbone2.fc1.weight': 1, 306 'backbone1.fc8.weight': 3, 307 'backbone1.fc7.weight': 3, 308 'backbone1.fc6.weight': 3, 309 'backbone1.fc5.weight': 3, 310 'backbone1.fc4.weight': 2, 311 'backbone1.fc3.weight': 2, 312 'backbone1.fc2.weight': 1, 313 'backbone1.fc1.weight': 1,} 314 315 assert allreduce_fusion_dict == expect_dict 316 cost_model_context.reset_cost_model_context() 317