1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15""" test Communicate """ 16import numpy as np 17 18import mindspore.context as context 19import mindspore.nn as nn 20from mindspore import Tensor 21from mindspore.common.api import _cell_graph_executor 22from mindspore.communication._comm_helper import Backend 23from mindspore.communication.management import HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, GlobalComm, init 24from mindspore.nn import Dense 25from mindspore.nn import Momentum 26from mindspore.nn import ReLU 27from mindspore.nn import TrainOneStepCell, WithLossCell 28from mindspore.ops.operations.comm_ops import AllReduce, AllGather, AlltoAll, ReduceOp, ReduceScatter 29from mindspore.ops.operations.comm_ops import Broadcast, AllSwap 30from mindspore.ops.operations.array_ops import Gather 31import mindspore 32 33 34# pylint: disable=W0212 35# W0212: protected-access 36 37tag = 0 38 39context.set_context(device_target="Ascend") 40GlobalComm.CHECK_ENVS = False 41init("hccl") 42GlobalComm.CHECK_ENVS = True 43 44 45class AllReduceNet(nn.Cell): 46 """AllReduceNet definition""" 47 48 def __init__(self, input_channel, out_channel, op): 49 super(AllReduceNet, self).__init__() 50 self.dense = Dense(input_channel, out_channel) 51 self.reduce = AllReduce(op) 52 self.relu = ReLU() 53 54 def construct(self, x): 55 x = self.dense(x) 56 x = self.reduce(x) 57 return self.relu(x) 58 59 60class BroadCastNet(nn.Cell): 61 """BroadCastNet definition""" 62 63 def __init__(self, input_channel, out_channel): 64 super(BroadCastNet, self).__init__() 65 self.dense = Dense(input_channel, out_channel) 66 self.broadcast = Broadcast(0) 67 68 def construct(self, x): 69 x, = self.broadcast((x,)) 70 x = self.dense(x) 71 return x 72 73 74class AllGatherNet(nn.Cell): 75 """AllGatherNet definition""" 76 77 def __init__(self, input_channel, out_channel): 78 super(AllGatherNet, self).__init__() 79 self.dense = Dense(input_channel, out_channel) 80 if GlobalComm.BACKEND is Backend.HCCL: 81 self.allgather = AllGather(group=HCCL_WORLD_COMM_GROUP) 82 elif GlobalComm.BACKEND is Backend.NCCL: 83 self.allgather = AllGather(group=NCCL_WORLD_COMM_GROUP) 84 else: 85 self.allgather = AllGather() 86 87 self.relu = ReLU() 88 89 def construct(self, x): 90 x = self.dense(x) 91 x = self.allgather(x) 92 return self.relu(x) 93 94 95class ReduceScatterNet(nn.Cell): 96 """ReduceScatterNet definition""" 97 98 def __init__(self, input_channel, out_channel, op): 99 super(ReduceScatterNet, self).__init__() 100 self.dense = Dense(input_channel, out_channel) 101 self.reducescatter = ReduceScatter(op) 102 self.relu = ReLU() 103 104 def construct(self, x): 105 x = self.dense(x) 106 x = self.reducescatter(x) 107 return self.relu(x) 108 109 110class AlltoAllNet(nn.Cell): 111 """AlltoAllNet definition""" 112 113 def __init__(self, input_channel, out_channel): 114 super(AlltoAllNet, self).__init__() 115 self.dense = Dense(input_channel, out_channel) 116 self.alltoall = AlltoAll(1, 0, 1) 117 self.relu = ReLU() 118 119 def construct(self, x): 120 x = self.dense(x) 121 x = self.alltoall(x) 122 return self.relu(x) 123 124 125class AllSwapNet(nn.Cell): 126 """AlltoAllNet definition""" 127 128 def __init__(self, batch_size, input_channel, out_channel): 129 super(AllSwapNet, self).__init__() 130 self.dense = Dense(input_channel, out_channel) 131 self.allswap = AllSwap() 132 self.relu = ReLU() 133 part_slice = batch_size / 2 134 self.send_size = Tensor([0, part_slice*out_channel, part_slice*out_channel], mindspore.int64) 135 self.recv_size = Tensor([part_slice*out_channel, part_slice*out_channel, 0], mindspore.int64) 136 self.gatherv2 = Gather() 137 self.input = Tensor(np.ones([1]), mindspore.int32) 138 def construct(self, x): 139 x = self.allswap(x, self.send_size, self.recv_size) 140 x = self.relu(x) 141 x = self.gatherv2(x, self.input, 0) 142 return x 143 144 145def run_allreduce(op): 146 """run_allreduce""" 147 context.set_context(mode=context.GRAPH_MODE) 148 input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32)) 149 label_tensor = Tensor(np.array([[1.2], [2.2]], dtype=np.float32)) 150 network = AllReduceNet(2, 1, op) 151 loss_fn = nn.SoftmaxCrossEntropyWithLogits() 152 optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), 153 learning_rate=0.1, 154 momentum=0.9) 155 network = WithLossCell(network, loss_fn) 156 network = TrainOneStepCell(network, optimizer) 157 _cell_graph_executor.compile(network, input_tensor, label_tensor) 158 159 160def test_allreduce(): 161 """test_allreduce""" 162 context.set_context(mode=context.GRAPH_MODE) 163 run_allreduce(ReduceOp.SUM) 164 run_allreduce(ReduceOp.MAX) 165 run_allreduce(ReduceOp.MIN) 166 run_allreduce(ReduceOp.PROD) 167 168 169def test_allgather(): 170 """test_allgather""" 171 context.set_context(mode=context.GRAPH_MODE) 172 input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32)) 173 label_tensor = Tensor(np.array([[1.2], [2.2]], dtype=np.float32)) 174 network = AllGatherNet(2, 1) 175 loss_fn = nn.SoftmaxCrossEntropyWithLogits() 176 optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), 177 learning_rate=0.1, 178 momentum=0.9) 179 network = WithLossCell(network, loss_fn) 180 network = TrainOneStepCell(network, optimizer) 181 _cell_graph_executor.compile(network, input_tensor, label_tensor) 182 183def test_allswap(): 184 """run_allswap""" 185 context.set_context(mode=context.GRAPH_MODE) 186 input_tensor = Tensor(np.ones((100, 20)), dtype=mindspore.float32) 187 label_tensor = Tensor(np.ones((1, 20)), dtype=mindspore.float32) 188 network = AllSwapNet(100, 20, 20) 189 loss_fn = nn.SoftmaxCrossEntropyWithLogits() 190 optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), 191 learning_rate=0.1, 192 momentum=0.9) 193 network = WithLossCell(network, loss_fn) 194 network = TrainOneStepCell(network, optimizer) 195 _cell_graph_executor.compile(network, input_tensor, label_tensor) 196 197 198def run_reducescatter(op): 199 """run_reducescatter""" 200 context.set_context(mode=context.GRAPH_MODE) 201 input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32)) 202 label_tensor = Tensor(np.array([[1.2], [2.2]], dtype=np.float32)) 203 network = ReduceScatterNet(2, 1, op) 204 loss_fn = nn.SoftmaxCrossEntropyWithLogits() 205 optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), 206 learning_rate=0.1, 207 momentum=0.9) 208 network = WithLossCell(network, loss_fn) 209 network = TrainOneStepCell(network, optimizer) 210 _cell_graph_executor.compile(network, input_tensor, label_tensor) 211 212 213def test_reducescatter(): 214 """test_reducescatter""" 215 context.set_context(mode=context.GRAPH_MODE) 216 run_reducescatter(ReduceOp.SUM) 217 218 219def test_broadcast(): 220 """test_broadcast""" 221 context.set_context(mode=context.GRAPH_MODE) 222 input_tensor_1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32)) 223 label_tensor = Tensor(np.array([[1.2], [2.2]], dtype=np.float32)) 224 network = BroadCastNet(2, 1) 225 loss_fn = nn.SoftmaxCrossEntropyWithLogits() 226 optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), 227 learning_rate=0.1, 228 momentum=0.9) 229 network = WithLossCell(network, loss_fn) 230 network = TrainOneStepCell(network, optimizer) 231 _cell_graph_executor.compile(network, input_tensor_1, label_tensor) 232 233 234def test_alltoall(): 235 """test_alltoall""" 236 context.set_context(mode=context.GRAPH_MODE) 237 input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32)) 238 label_tensor = Tensor(np.array([[1.2], [2.2]], dtype=np.float32)) 239 network = AlltoAllNet(2, 1) 240 loss_fn = nn.SoftmaxCrossEntropyWithLogits() 241 optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), 242 learning_rate=0.1, 243 momentum=0.9) 244 network = WithLossCell(network, loss_fn) 245 network = TrainOneStepCell(network, optimizer) 246 _cell_graph_executor.compile(network, input_tensor, label_tensor) 247