• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16import pytest
17
18import mindspore as ms
19import mindspore.nn as nn
20from mindspore import Tensor, context
21from mindspore.common.api import _cell_graph_executor
22from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
23from mindspore.nn.optim.momentum import Momentum
24from mindspore.parallel import _cost_model_context as cost_model_context
25from mindspore.parallel._auto_parallel_context import auto_parallel_context
26from mindspore.train import Model
27from mindspore.context import ParallelMode
28from tests.dataset_mock import MindData
29
30
31class Dataset(MindData):
32    def __init__(self, predict, label, length=3):
33        super(Dataset, self).__init__(size=length)
34        self.predict = predict
35        self.label = label
36        self.index = 0
37        self.length = length
38
39    def __iter__(self):
40        return self
41
42    def __next__(self):
43        if self.index >= self.length:
44            raise StopIteration
45        self.index += 1
46        return self.predict, self.label
47
48    def reset(self):
49        self.index = 0
50
51
52class DenseNet1(nn.Cell):
53    def __init__(self, has_bias=True, activation='relu'):
54        super(DenseNet1, self).__init__()
55        self.fc1 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
56        self.fc2 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
57        self.fc3 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
58        self.fc4 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
59
60    def construct(self, x):
61        q = self.fc1(x)
62        k = self.fc2(q)
63        v = self.fc3(k)
64        s = self.fc4(v)
65        return s
66
67
68class DenseNet2(nn.Cell):
69    def __init__(self, has_bias=True, activation='relu'):
70        super(DenseNet2, self).__init__()
71        self.fc1 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
72        self.fc2 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
73        self.fc3 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
74        self.fc4 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
75        self.fc5 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
76        self.fc6 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
77        self.fc7 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
78        self.fc8 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
79
80    def construct(self, x):
81        q = self.fc1(x)
82        k = self.fc2(q)
83        v = self.fc3(k)
84        s = self.fc4(v)
85        t = self.fc5(s)
86        u = self.fc6(t)
87        w = self.fc7(u)
88        z = self.fc8(w)
89        return z
90
91
92class SimpleDMLNet(nn.Cell):
93    def __init__(self, net1, net2):
94        super(SimpleDMLNet, self).__init__()
95        self.backbone1 = net1
96        self.backbone2 = net2
97
98    def construct(self, x):
99        x1 = self.backbone1(x)
100        x2 = self.backbone2(x)
101        return x1 + x2
102
103
104def train_common(net):
105    batch_size = 32
106    learning_rate = 0.1
107    momentum = 0.9
108    epoch_size = 2
109    device_num = 4
110    auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True)
111    context.set_auto_parallel_context(device_num=device_num, parameter_broadcast=False)
112    context.set_context(mode=context.GRAPH_MODE)
113
114    predict = Tensor(np.ones([batch_size, 128]), dtype=ms.float32)
115    label = Tensor(np.ones([batch_size]), dtype=ms.int32)
116    dataset = Dataset(predict, label, 2)
117
118    loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
119    opt = Momentum(net.trainable_params(), learning_rate, momentum)
120    model = Model(net, loss, opt)
121
122    model.train(epoch_size, dataset, dataset_sink_mode=False)
123    allreduce_fusion_dict = _cell_graph_executor._get_allreduce_fusion(model._train_network)
124
125    print(allreduce_fusion_dict)
126    return allreduce_fusion_dict
127
128
129@pytest.mark.skip(reason="depreciated feature")
130def test_allreduce_fusion_parameters():
131    cost_model_context.reset_cost_model_context()
132    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2)
133    algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm')
134    assert algorithm == 2
135    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
136    algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm')
137    assert algorithm == 1
138    cost_model_context.reset_cost_model_context()
139    algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm')
140    assert algorithm == 0
141
142    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
143    fusion_times = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_times')
144    assert fusion_times == 2
145
146    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.2)
147    tail_percent = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_percent')
148    assert tail_percent == 0.2
149    cost_model_context.reset_cost_model_context()
150    tail_percent = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_percent')
151    assert tail_percent == 0.1
152
153    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.2)
154    tail_time = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_time')
155    assert tail_time == 0.2
156    cost_model_context.reset_cost_model_context()
157    tail_time = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_time')
158    assert tail_time == 0.1
159
160    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.2)
161    allreduce_inherent_time = cost_model_context.get_cost_model_context(
162        'costmodel_allreduce_fusion_allreduce_inherent_time')
163    assert allreduce_inherent_time == 0.2
164    cost_model_context.reset_cost_model_context()
165    allreduce_inherent_time = cost_model_context.get_cost_model_context(
166        'costmodel_allreduce_fusion_allreduce_inherent_time')
167    assert allreduce_inherent_time == 0.1
168
169    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_bandwidth=0.2)
170    allreduce_bandwidth = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_allreduce_bandwidth')
171    assert allreduce_bandwidth == 0.2
172    cost_model_context.reset_cost_model_context()
173    allreduce_bandwidth = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_allreduce_bandwidth')
174    assert allreduce_bandwidth == 0.1
175
176    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_computation_time_parameter=0.2)
177    computation_time_parameter = cost_model_context.get_cost_model_context(
178        'costmodel_allreduce_fusion_computation_time_parameter')
179    assert computation_time_parameter == 0.2
180    cost_model_context.reset_cost_model_context()
181    computation_time_parameter = cost_model_context.get_cost_model_context(
182        'costmodel_allreduce_fusion_computation_time_parameter')
183    assert computation_time_parameter == 0.1
184
185
186@pytest.mark.skip(reason="depreciated feature")
187def test_allreduce_fusion1():
188    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
189    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
190    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
191    context.reset_auto_parallel_context()
192    context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
193    net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
194    allreduce_fusion_dict = train_common(net)
195    expect_dict = {'backbone2.fc8.weight': 2,
196                   'backbone2.fc7.weight': 2,
197                   'backbone2.fc6.weight': 2,
198                   'backbone1.fc4.weight': 2,
199                   'backbone1.fc3.weight': 2,
200                   'backbone1.fc2.weight': 2,
201                   'backbone2.fc5.weight': 1,
202                   'backbone2.fc4.weight': 1,
203                   'backbone2.fc3.weight': 1,
204                   'backbone2.fc2.weight': 1,
205                   'backbone2.fc1.weight': 1,
206                   'backbone1.fc1.weight': 1}
207    assert allreduce_fusion_dict == expect_dict
208    cost_model_context.reset_cost_model_context()
209
210
211@pytest.mark.skip(reason="depreciated feature")
212# reset_cost_model_context is called, the default value of costmodel_allreduce_fusion_times is 0, step_allreduce_fusion
213# is bypassed.
214def test_allreduce_fusion2():
215    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
216    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
217    cost_model_context.reset_cost_model_context()
218    context.reset_auto_parallel_context()
219    context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
220    net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
221    allreduce_fusion_dict = train_common(net)
222    expect_dict = {}
223    assert allreduce_fusion_dict == expect_dict
224    cost_model_context.reset_cost_model_context()
225
226
227@pytest.mark.skip(reason="depreciated feature")
228def test_allreduce_fusion3():
229    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
230    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=3)
231    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.3333333)
232    context.reset_auto_parallel_context()
233    context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
234    net = SimpleDMLNet(DenseNet1(has_bias=True, activation='relu'), DenseNet2(has_bias=False, activation='relu'))
235    allreduce_fusion_dict = train_common(net)
236    expect_dict = {'backbone2.fc8.weight': 3,
237                   'backbone2.fc7.weight': 3,
238                   'backbone2.fc6.weight': 2,
239                   'backbone2.fc5.weight': 2,
240                   'backbone2.fc4.weight': 2,
241                   'backbone2.fc3.weight': 1,
242                   'backbone2.fc2.weight': 1,
243                   'backbone2.fc1.weight': 1,
244                   'backbone1.fc4.bias': 3,
245                   'backbone1.fc4.weight': 3,
246                   'backbone1.fc3.bias': 3,
247                   'backbone1.fc3.weight': 2,
248                   'backbone1.fc2.bias': 2,
249                   'backbone1.fc2.weight': 2,
250                   'backbone1.fc1.bias': 2,
251                   'backbone1.fc1.weight': 2}
252    assert allreduce_fusion_dict == expect_dict
253    cost_model_context.reset_cost_model_context()
254
255
256@pytest.mark.skip(reason="depreciated feature")
257def test_allreduce_fusion4():
258    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
259    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
260    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
261    context.reset_auto_parallel_context()
262    context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
263    net = SimpleDMLNet(DenseNet2(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
264    allreduce_fusion_dict = train_common(net)
265    expect_dict = {'backbone2.fc8.weight': 2,
266                   'backbone2.fc7.weight': 2,
267                   'backbone2.fc6.weight': 2,
268                   'backbone1.fc8.weight': 2,
269                   'backbone1.fc7.weight': 2,
270                   'backbone1.fc6.weight': 2,
271                   'backbone2.fc5.weight': 1,
272                   'backbone2.fc4.weight': 1,
273                   'backbone2.fc3.weight': 1,
274                   'backbone2.fc2.weight': 1,
275                   'backbone2.fc1.weight': 1,
276                   'backbone1.fc5.weight': 1,
277                   'backbone1.fc4.weight': 1,
278                   'backbone1.fc3.weight': 1,
279                   'backbone1.fc2.weight': 1,
280                   'backbone1.fc1.weight': 1}
281
282    assert allreduce_fusion_dict == expect_dict
283    cost_model_context.reset_cost_model_context()
284
285
286@pytest.mark.skip(reason="depreciated feature")
287def test_allreduce_fusion5():
288    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2)
289    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.1)
290    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.05)
291    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_bandwidth=0.000001)
292    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_computation_time_parameter=0.0000015)
293    context.reset_auto_parallel_context()
294    context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
295    net = SimpleDMLNet(DenseNet2(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
296    allreduce_fusion_dict = train_common(net)
297
298    expect_dict = {'backbone2.fc8.weight': 3,
299                   'backbone2.fc7.weight': 3,
300                   'backbone2.fc6.weight': 3,
301                   'backbone2.fc5.weight': 3,
302                   'backbone2.fc4.weight': 2,
303                   'backbone2.fc3.weight': 2,
304                   'backbone2.fc2.weight': 1,
305                   'backbone2.fc1.weight': 1,
306                   'backbone1.fc8.weight': 3,
307                   'backbone1.fc7.weight': 3,
308                   'backbone1.fc6.weight': 3,
309                   'backbone1.fc5.weight': 3,
310                   'backbone1.fc4.weight': 2,
311                   'backbone1.fc3.weight': 2,
312                   'backbone1.fc2.weight': 1,
313                   'backbone1.fc1.weight': 1,}
314
315    assert allreduce_fusion_dict == expect_dict
316    cost_model_context.reset_cost_model_context()
317