• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15""" test Communicate """
16import numpy as np
17
18import mindspore.context as context
19import mindspore.nn as nn
20from mindspore import Tensor
21from mindspore.common.api import _cell_graph_executor
22from mindspore.communication._comm_helper import Backend
23from mindspore.communication.management import HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, GlobalComm, init
24from mindspore.nn import Dense
25from mindspore.nn import Momentum
26from mindspore.nn import ReLU
27from mindspore.nn import TrainOneStepCell, WithLossCell
28from mindspore.ops.operations.comm_ops import AllReduce, AllGather, AlltoAll, ReduceOp, ReduceScatter
29from mindspore.ops.operations.comm_ops import Broadcast, AllSwap
30from mindspore.ops.operations.array_ops import Gather
31import mindspore
32
33
34# pylint: disable=W0212
35# W0212: protected-access
36
37tag = 0
38
39context.set_context(device_target="Ascend")
40GlobalComm.CHECK_ENVS = False
41init("hccl")
42GlobalComm.CHECK_ENVS = True
43
44
45class AllReduceNet(nn.Cell):
46    """AllReduceNet definition"""
47
48    def __init__(self, input_channel, out_channel, op):
49        super(AllReduceNet, self).__init__()
50        self.dense = Dense(input_channel, out_channel)
51        self.reduce = AllReduce(op)
52        self.relu = ReLU()
53
54    def construct(self, x):
55        x = self.dense(x)
56        x = self.reduce(x)
57        return self.relu(x)
58
59
60class BroadCastNet(nn.Cell):
61    """BroadCastNet definition"""
62
63    def __init__(self, input_channel, out_channel):
64        super(BroadCastNet, self).__init__()
65        self.dense = Dense(input_channel, out_channel)
66        self.broadcast = Broadcast(0)
67
68    def construct(self, x):
69        x, = self.broadcast((x,))
70        x = self.dense(x)
71        return x
72
73
74class AllGatherNet(nn.Cell):
75    """AllGatherNet definition"""
76
77    def __init__(self, input_channel, out_channel):
78        super(AllGatherNet, self).__init__()
79        self.dense = Dense(input_channel, out_channel)
80        if GlobalComm.BACKEND is Backend.HCCL:
81            self.allgather = AllGather(group=HCCL_WORLD_COMM_GROUP)
82        elif GlobalComm.BACKEND is Backend.NCCL:
83            self.allgather = AllGather(group=NCCL_WORLD_COMM_GROUP)
84        else:
85            self.allgather = AllGather()
86
87        self.relu = ReLU()
88
89    def construct(self, x):
90        x = self.dense(x)
91        x = self.allgather(x)
92        return self.relu(x)
93
94
95class ReduceScatterNet(nn.Cell):
96    """ReduceScatterNet definition"""
97
98    def __init__(self, input_channel, out_channel, op):
99        super(ReduceScatterNet, self).__init__()
100        self.dense = Dense(input_channel, out_channel)
101        self.reducescatter = ReduceScatter(op)
102        self.relu = ReLU()
103
104    def construct(self, x):
105        x = self.dense(x)
106        x = self.reducescatter(x)
107        return self.relu(x)
108
109
110class AlltoAllNet(nn.Cell):
111    """AlltoAllNet definition"""
112
113    def __init__(self, input_channel, out_channel):
114        super(AlltoAllNet, self).__init__()
115        self.dense = Dense(input_channel, out_channel)
116        self.alltoall = AlltoAll(1, 0, 1)
117        self.relu = ReLU()
118
119    def construct(self, x):
120        x = self.dense(x)
121        x = self.alltoall(x)
122        return self.relu(x)
123
124
125class AllSwapNet(nn.Cell):
126    """AlltoAllNet definition"""
127
128    def __init__(self, batch_size, input_channel, out_channel):
129        super(AllSwapNet, self).__init__()
130        self.dense = Dense(input_channel, out_channel)
131        self.allswap = AllSwap()
132        self.relu = ReLU()
133        part_slice = batch_size / 2
134        self.send_size = Tensor([0, part_slice*out_channel, part_slice*out_channel], mindspore.int64)
135        self.recv_size = Tensor([part_slice*out_channel, part_slice*out_channel, 0], mindspore.int64)
136        self.gatherv2 = Gather()
137        self.input = Tensor(np.ones([1]), mindspore.int32)
138    def construct(self, x):
139        x = self.allswap(x, self.send_size, self.recv_size)
140        x = self.relu(x)
141        x = self.gatherv2(x, self.input, 0)
142        return x
143
144
145def run_allreduce(op):
146    """run_allreduce"""
147    context.set_context(mode=context.GRAPH_MODE)
148    input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32))
149    label_tensor = Tensor(np.array([[1.2], [2.2]], dtype=np.float32))
150    network = AllReduceNet(2, 1, op)
151    loss_fn = nn.SoftmaxCrossEntropyWithLogits()
152    optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()),
153                         learning_rate=0.1,
154                         momentum=0.9)
155    network = WithLossCell(network, loss_fn)
156    network = TrainOneStepCell(network, optimizer)
157    _cell_graph_executor.compile(network, input_tensor, label_tensor)
158
159
160def test_allreduce():
161    """test_allreduce"""
162    context.set_context(mode=context.GRAPH_MODE)
163    run_allreduce(ReduceOp.SUM)
164    run_allreduce(ReduceOp.MAX)
165    run_allreduce(ReduceOp.MIN)
166    run_allreduce(ReduceOp.PROD)
167
168
169def test_allgather():
170    """test_allgather"""
171    context.set_context(mode=context.GRAPH_MODE)
172    input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32))
173    label_tensor = Tensor(np.array([[1.2], [2.2]], dtype=np.float32))
174    network = AllGatherNet(2, 1)
175    loss_fn = nn.SoftmaxCrossEntropyWithLogits()
176    optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()),
177                         learning_rate=0.1,
178                         momentum=0.9)
179    network = WithLossCell(network, loss_fn)
180    network = TrainOneStepCell(network, optimizer)
181    _cell_graph_executor.compile(network, input_tensor, label_tensor)
182
183def test_allswap():
184    """run_allswap"""
185    context.set_context(mode=context.GRAPH_MODE)
186    input_tensor = Tensor(np.ones((100, 20)), dtype=mindspore.float32)
187    label_tensor = Tensor(np.ones((1, 20)), dtype=mindspore.float32)
188    network = AllSwapNet(100, 20, 20)
189    loss_fn = nn.SoftmaxCrossEntropyWithLogits()
190    optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()),
191                         learning_rate=0.1,
192                         momentum=0.9)
193    network = WithLossCell(network, loss_fn)
194    network = TrainOneStepCell(network, optimizer)
195    _cell_graph_executor.compile(network, input_tensor, label_tensor)
196
197
198def run_reducescatter(op):
199    """run_reducescatter"""
200    context.set_context(mode=context.GRAPH_MODE)
201    input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32))
202    label_tensor = Tensor(np.array([[1.2], [2.2]], dtype=np.float32))
203    network = ReduceScatterNet(2, 1, op)
204    loss_fn = nn.SoftmaxCrossEntropyWithLogits()
205    optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()),
206                         learning_rate=0.1,
207                         momentum=0.9)
208    network = WithLossCell(network, loss_fn)
209    network = TrainOneStepCell(network, optimizer)
210    _cell_graph_executor.compile(network, input_tensor, label_tensor)
211
212
213def test_reducescatter():
214    """test_reducescatter"""
215    context.set_context(mode=context.GRAPH_MODE)
216    run_reducescatter(ReduceOp.SUM)
217
218
219def test_broadcast():
220    """test_broadcast"""
221    context.set_context(mode=context.GRAPH_MODE)
222    input_tensor_1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32))
223    label_tensor = Tensor(np.array([[1.2], [2.2]], dtype=np.float32))
224    network = BroadCastNet(2, 1)
225    loss_fn = nn.SoftmaxCrossEntropyWithLogits()
226    optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()),
227                         learning_rate=0.1,
228                         momentum=0.9)
229    network = WithLossCell(network, loss_fn)
230    network = TrainOneStepCell(network, optimizer)
231    _cell_graph_executor.compile(network, input_tensor_1, label_tensor)
232
233
234def test_alltoall():
235    """test_alltoall"""
236    context.set_context(mode=context.GRAPH_MODE)
237    input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32))
238    label_tensor = Tensor(np.array([[1.2], [2.2]], dtype=np.float32))
239    network = AlltoAllNet(2, 1)
240    loss_fn = nn.SoftmaxCrossEntropyWithLogits()
241    optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()),
242                         learning_rate=0.1,
243                         momentum=0.9)
244    network = WithLossCell(network, loss_fn)
245    network = TrainOneStepCell(network, optimizer)
246    _cell_graph_executor.compile(network, input_tensor, label_tensor)
247