• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16
17import mindspore.context as context
18import mindspore.nn as nn
19from mindspore import Tensor
20from mindspore.common.initializer import initializer
21from mindspore.common.parameter import Parameter
22from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size
23from mindspore.ops import operations as P
24from mindspore.ops.operations import _inner_ops as inner
25
26context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
27
28init()
29rank = get_rank()
30size = get_group_size()
31x = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1)
32y = np.ones([3, 4, 6, 3]).astype(np.float32) * 0.01 * (rank + 1)
33
34class Net(nn.Cell):
35    def __init__(self):
36        super(Net, self).__init__()
37        self.x1 = Parameter(initializer(Tensor(x), x.shape), name='x1')
38        self.x2 = Parameter(initializer(Tensor(x), x.shape), name='x2')
39        self.x3 = Parameter(initializer(Tensor(x), x.shape), name='x3')
40
41        self.op0 = "sum"
42        self.op1 = "sum"
43        self.op2 = "sum"
44
45        self.all_reduce1 = P.AllReduce(self.op0, group=NCCL_WORLD_COMM_GROUP)
46        self.all_reduce2 = P.AllReduce(self.op1, group=NCCL_WORLD_COMM_GROUP)
47        self.all_reduce3 = P.AllReduce(self.op2, group=NCCL_WORLD_COMM_GROUP)
48
49    def construct(self):
50        return (self.all_reduce1(self.x1),
51                self.all_reduce2(self.x2),
52                self.all_reduce3(self.x3))
53
54
55def test_AllReduce():
56    all_reduce = Net()
57    output = all_reduce()
58
59    expect0 = np.ones([3, 1, 3, 3]).astype(np.float32) * 0
60    for i in range(size):
61        part = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (i + 1)
62        expect0 += part
63    diff0 = output[0].asnumpy() - expect0
64    error0 = np.ones(shape=expect0.shape) * 1.0e-5
65    assert np.all(diff0 < error0)
66    assert output[0].shape == expect0.shape
67
68    expect1 = expect0
69    diff1 = output[1].asnumpy() - expect1
70    error1 = np.ones(shape=expect1.shape) * 1.0e-5
71    assert np.all(diff1 < error1)
72    assert output[1].shape == expect1.shape
73
74    expect2 = expect1
75    diff2 = output[2].asnumpy() - expect2
76    error2 = np.ones(shape=expect2.shape) * 1.0e-5
77    assert np.all(diff2 < error2)
78    assert output[2].shape == expect2.shape
79
80
81class Net2(nn.Cell):
82    def __init__(self):
83        super(Net2, self).__init__()
84        self.x1 = Parameter(initializer(Tensor(x), x.shape), name='x1')
85
86        self.op0 = "sum"
87        self.op1 = "sum"
88        self.op2 = "sum"
89
90        self.all_reduce1 = P.AllReduce(self.op0, group=NCCL_WORLD_COMM_GROUP)
91        self.all_reduce2 = P.AllReduce(self.op1, group=NCCL_WORLD_COMM_GROUP)
92        self.all_reduce3 = P.AllReduce(self.op2, group=NCCL_WORLD_COMM_GROUP)
93
94    def construct(self):
95        x_ = self.all_reduce1(self.x1)
96        y_ = self.all_reduce2(x_)
97        z_ = self.all_reduce3(y_)
98        return (x_, y_, z_)
99
100
101def test_AllReduce2():
102    all_reduce = Net2()
103    output = all_reduce()
104
105    expect0 = np.ones([3, 1, 3, 3]).astype(np.float32) * 0
106    for i in range(size):
107        part = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (i + 1)
108        expect0 += part
109    diff0 = abs(output[0].asnumpy() - expect0)
110    error0 = np.ones(shape=expect0.shape) * 1.0e-5
111    assert np.all(diff0 < error0)
112    assert output[0].shape == expect0.shape
113
114    expect1 = expect0 * size
115    diff1 = abs(output[1].asnumpy() - expect1)
116    error1 = np.ones(shape=expect1.shape) * 1.0e-5
117    assert np.all(diff1 < error1)
118    assert output[1].shape == expect1.shape
119
120    expect2 = expect1 * size
121    diff2 = abs(output[2].asnumpy() - expect2)
122    error2 = np.ones(shape=expect2.shape) * 1.0e-5
123    assert np.all(diff2 < error2)
124    assert output[2].shape == expect2.shape
125
126
127class DynamicAllReduceNet(nn.Cell):
128    def __init__(self):
129        super(DynamicAllReduceNet, self).__init__()
130        self.op = "sum"
131        self.all_reduce = P.AllReduce(self.op, group=NCCL_WORLD_COMM_GROUP)
132        self.d = inner.GpuConvertToDynamicShape()
133
134    def construct(self, input_x):
135        out = self.d(input_x)
136        out = self.all_reduce(out)
137        return out
138
139
140def test_all_reduce_dynamic():
141    context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
142    input1 = Tensor(x)
143    input2 = Tensor(y)
144    net = DynamicAllReduceNet()
145
146    output1 = net(input1)
147    expect1 = np.ones([3, 1, 3, 3]).astype(np.float32) * 0
148    for i in range(size):
149        part = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (i + 1)
150        expect1 += part
151    diff1 = abs(output1.asnumpy() - expect1)
152    error1 = np.ones(shape=expect1.shape) * 1.0e-5
153    assert np.all(diff1 < error1)
154    assert output1.shape == expect1.shape
155
156    output2 = net(input2)
157    expect2 = np.ones([3, 4, 6, 3]).astype(np.float32) * 0
158    for i in range(size):
159        part = np.ones([3, 4, 6, 3]).astype(np.float32) * 0.01 * (i + 1)
160        expect2 += part
161    diff2 = abs(output2.asnumpy() - expect2)
162    error2 = np.ones(shape=expect2.shape) * 1.0e-5
163    assert np.all(diff2 < error2)
164    assert output2.shape == expect2.shape
165