1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import pytest 16import numpy as np 17import mindspore as ms 18import mindspore.context as context 19from mindspore import Tensor, Parameter 20import mindspore.nn as nn 21from mindspore.common.api import _cell_graph_executor 22from mindspore.nn import TrainOneStepCell, Momentum 23from mindspore.ops import operations as P 24from mindspore.ops.operations.comm_ops import NeighborExchange 25 26_w1 = Tensor(np.ones([32, 32]), dtype=ms.float32) 27_x1 = Tensor(np.ones([32, 16]), dtype=ms.float32) 28_x2 = Tensor(np.ones([16, 32]), dtype=ms.float32) 29 30 31def compile_net(net): 32 context.set_context(mode=context.GRAPH_MODE) 33 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 34 train_net = TrainOneStepCell(net, optimizer) 35 train_net.set_train() 36 _cell_graph_executor.compile(train_net, _x1, _x2) 37 38 39def test_NeighborExchange_two_inputs_success(): 40 """ 41 Feature: NeighborExchange 42 Description: two inputs and two outputs, with valid arguments 43 Expectation: success 44 """ 45 context.set_auto_parallel_context(device_num=8, global_rank=0) 46 47 class MatMulNet(nn.Cell): 48 def __init__(self, weight1): 49 super(MatMulNet, self).__init__() 50 self.matmul = P.MatMul() 51 self.mul = P.Mul() 52 self.alltoallv = NeighborExchange(send_rank_ids=[0, 1], recv_rank_ids=[1, 2], 53 recv_shapes=([32, 32], [32, 64]), 54 send_shapes=([32, 32], [32, 16]), recv_type=ms.float32) 55 self.weight1 = Parameter(weight1, "w1") 56 57 def construct(self, x1, x2): 58 out = self.matmul(x1, x2) 59 out = self.mul(out, self.weight1) 60 out = self.alltoallv((out, x1)) 61 return out[0] 62 63 net = MatMulNet(_w1) 64 compile_net(net) 65 66 67def test_NeighborExchange_single_input_success(): 68 """ 69 Feature: NeighborExchange 70 Description: one inputs and two outputs, with valid arguments 71 Expectation: success 72 """ 73 context.set_auto_parallel_context(device_num=8, global_rank=0) 74 75 class MatMulNet2(nn.Cell): 76 def __init__(self, weight1): 77 super(MatMulNet2, self).__init__() 78 self.matmul = P.MatMul() 79 self.mul = P.Mul() 80 self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]), 81 send_shapes=([32, 32],), recv_type=ms.float32) 82 self.weight1 = Parameter(weight1, "w1") 83 84 def construct(self, x1, x2): 85 out = self.matmul(x1, x2) 86 out = self.mul(out, self.weight1) 87 out = self.alltoallv((out,)) 88 return out[0] 89 90 net = MatMulNet2(_w1) 91 compile_net(net) 92 93 94def test_NeighborExchange_empty_send_success(): 95 """ 96 Feature: NeighborExchange 97 Description: empty inputs, with valid arguments 98 Expectation: success 99 """ 100 context.set_auto_parallel_context(device_num=8, global_rank=0) 101 102 class Net(nn.Cell): 103 def __init__(self): 104 super(Net, self).__init__() 105 self.alltoallv = NeighborExchange(send_rank_ids=[], recv_rank_ids=[1], recv_shapes=([1],), 106 send_shapes=(), recv_type=ms.float32) 107 108 def construct(self, x1): 109 self.alltoallv() 110 return x1 111 112 net = Net() 113 _cell_graph_executor.compile(net, _x1) 114 115 116def test_NeighborExchange_empty_recv_success(): 117 """ 118 Feature: NeighborExchange 119 Description: empty outputs, with valid arguments 120 Expectation: success 121 """ 122 context.set_auto_parallel_context(device_num=8, global_rank=0) 123 124 class Net(nn.Cell): 125 def __init__(self): 126 super(Net, self).__init__() 127 self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[], recv_shapes=(), 128 send_shapes=([32, 16],), recv_type=ms.float32) 129 130 def construct(self, x1): 131 self.alltoallv((x1,)) 132 return x1 133 134 net = Net() 135 _cell_graph_executor.compile(net, _x1) 136 137 138def test_NeighborExchange_empty_send_empty_recv_success(): 139 """ 140 Feature: NeighborExchange 141 Description: empty inputs and empty outputs, with valid arguments 142 Expectation: success 143 """ 144 context.set_auto_parallel_context(device_num=8, global_rank=0) 145 146 class Net(nn.Cell): 147 def __init__(self): 148 super(Net, self).__init__() 149 self.alltoallv = NeighborExchange(send_rank_ids=[], recv_rank_ids=[], recv_shapes=(), 150 send_shapes=(), recv_type=ms.float32) 151 152 def construct(self, x1): 153 self.alltoallv() 154 return x1 155 156 net = Net() 157 _cell_graph_executor.compile(net, _x1) 158 159 160def test_NeighborExchange_recv_shape_num_diff_with_recv_rank_size_failed(): 161 """ 162 Feature: NeighborExchange 163 Description: send_rank_ids and send_shapes are set as 1 input, but gives 2 164 Expectation: throw ValueError 165 """ 166 context.set_auto_parallel_context(device_num=8, global_rank=0) 167 168 class Net(nn.Cell): 169 def __init__(self, weight1): 170 super(Net, self).__init__() 171 self.matmul = P.MatMul() 172 self.mul = P.Mul() 173 self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32],), 174 send_shapes=([32, 32],), recv_type=ms.float32) 175 self.weight1 = Parameter(weight1, "w1") 176 177 def construct(self, x1, x2): 178 out = self.matmul(x1, x2) 179 out = self.mul(out, self.weight1) 180 out = self.alltoallv((out,)) 181 return out[0] 182 183 net = Net(_w1) 184 with pytest.raises(ValueError): 185 compile_net(net) 186 187 188def test_NeighborExchange_send_shape_num_diff_with_send_rank_size_failed(): 189 """ 190 Feature: NeighborExchange 191 Description: send_rank_ids is set as 2 inputs, but send_shapes are set as 1 input 192 Expectation: throw ValueError 193 """ 194 context.set_auto_parallel_context(device_num=8, global_rank=0) 195 196 class Net(nn.Cell): 197 def __init__(self, weight1): 198 super(Net, self).__init__() 199 self.matmul = P.MatMul() 200 self.mul = P.Mul() 201 self.alltoallv = NeighborExchange(send_rank_ids=[0, 1], recv_rank_ids=[1, 2], 202 recv_shapes=([32, 32], [32, 32]), 203 send_shapes=([32, 32],), recv_type=ms.float32) 204 self.weight1 = Parameter(weight1, "w1") 205 206 def construct(self, x1, x2): 207 out = self.matmul(x1, x2) 208 out = self.mul(out, self.weight1) 209 out = self.alltoallv((out,)) 210 return out[0] 211 212 net = Net(_w1) 213 with pytest.raises(ValueError): 214 compile_net(net) 215 216 217def test_NeighborExchange_send_shape_num_diff_with_input_num_failed(): 218 """ 219 Feature: NeighborExchange 220 Description: send_rank_ids and send_shapes are set as 2 inputs, but has only 1 input 221 Expectation: throw Exception 222 """ 223 context.set_auto_parallel_context(device_num=8, global_rank=0) 224 225 class Net(nn.Cell): 226 def __init__(self, weight1): 227 super(Net, self).__init__() 228 self.matmul = P.MatMul() 229 self.mul = P.Mul() 230 self.alltoallv = NeighborExchange(send_rank_ids=[0, 1], recv_rank_ids=[1, 2], 231 recv_shapes=([32, 32], [32, 32]), 232 send_shapes=([32, 32], [32, 32]), recv_type=ms.float32) 233 self.weight1 = Parameter(weight1, "w1") 234 235 def construct(self, x1, x2): 236 out = self.matmul(x1, x2) 237 out = self.mul(out, self.weight1) 238 out = self.alltoallv((out,)) 239 return out[0] 240 241 net = Net(_w1) 242 with pytest.raises(Exception): 243 compile_net(net) 244 245 246def test_NeighborExchange_send_shape_diff_with_input_shape_failed(): 247 """ 248 Feature: NeighborExchange 249 Description: send_shapes is set as [16, 16], but input is [32, 32] 250 Expectation: throw Exception 251 """ 252 context.set_auto_parallel_context(device_num=8, global_rank=0) 253 254 class Net(nn.Cell): 255 def __init__(self, weight1): 256 super(Net, self).__init__() 257 self.matmul = P.MatMul() 258 self.mul = P.Mul() 259 self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]), 260 send_shapes=([16, 16],), recv_type=ms.float32) 261 self.weight1 = Parameter(weight1, "w1") 262 263 def construct(self, x1, x2): 264 out = self.matmul(x1, x2) 265 out = self.mul(out, self.weight1) 266 out = self.alltoallv((out,)) 267 return out[0] 268 269 net = Net(_w1) 270 with pytest.raises(Exception): 271 compile_net(net) 272 273 274def test_NeighborExchange_attr_check_send_rank_ids_is_tuple_failed(): 275 """ 276 Feature: NeighborExchange 277 Description: send_rank_ids should be list, but a tuple is given 278 Expectation: throw TypeError 279 """ 280 context.set_auto_parallel_context(device_num=8, global_rank=0) 281 282 class Net(nn.Cell): 283 def __init__(self): 284 super(Net, self).__init__() 285 self.alltoallv = NeighborExchange(send_rank_ids=(0), recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]), 286 send_shapes=([32, 16],), recv_type=ms.float32) 287 288 def construct(self, x1): 289 out = self.alltoallv((x1,)) 290 return out[0] 291 292 net = Net() 293 with pytest.raises(TypeError): 294 _cell_graph_executor.compile(net, _x1) 295 296 297def test_NeighborExchange_attr_check_send_rank_ids_is_tuple_2_failed(): 298 """ 299 Feature: NeighborExchange 300 Description: send_rank_ids should be list, but a tuple is given 301 Expectation: throw TypeError 302 """ 303 context.set_auto_parallel_context(device_num=8, global_rank=0) 304 305 class Net(nn.Cell): 306 def __init__(self): 307 super(Net, self).__init__() 308 self.alltoallv = NeighborExchange(send_rank_ids=(0,), recv_rank_ids=[1, 2], 309 recv_shapes=([32, 32], [32, 64]), 310 send_shapes=([32, 16],), recv_type=ms.float32) 311 312 def construct(self, x1): 313 out = self.alltoallv((x1,)) 314 return out[0] 315 316 net = Net() 317 with pytest.raises(TypeError): 318 _cell_graph_executor.compile(net, _x1) 319 320 321def test_NeighborExchange_attr_check_send_rank_ids_is_float_failed(): 322 """ 323 Feature: NeighborExchange 324 Description: send_rank_ids should be int, but a float is given 325 Expectation: throw TypeError 326 """ 327 context.set_auto_parallel_context(device_num=8, global_rank=0) 328 329 class Net(nn.Cell): 330 def __init__(self): 331 super(Net, self).__init__() 332 self.alltoallv = NeighborExchange(send_rank_ids=[1.0], recv_rank_ids=[1, 2], 333 recv_shapes=([32, 32], [32, 64]), 334 send_shapes=([32, 16],), recv_type=ms.float32) 335 336 def construct(self, x1): 337 out = self.alltoallv((x1,)) 338 return out[0] 339 340 net = Net() 341 with pytest.raises(TypeError): 342 _cell_graph_executor.compile(net, _x1) 343 344 345def test_NeighborExchange_attr_check_recv_rank_ids_is_tuple_failed(): 346 """ 347 Feature: NeighborExchange 348 Description: recv_rank_ids should be list, but a tuple is given 349 Expectation: throw TypeError 350 """ 351 context.set_auto_parallel_context(device_num=8, global_rank=0) 352 353 class Net(nn.Cell): 354 def __init__(self): 355 super(Net, self).__init__() 356 self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=([1, 2],), 357 recv_shapes=([32, 32], [32, 64]), 358 send_shapes=([32, 16],), recv_type=ms.float32) 359 360 def construct(self, x1): 361 out = self.alltoallv((x1,)) 362 return out[0] 363 364 net = Net() 365 with pytest.raises(TypeError): 366 _cell_graph_executor.compile(net, _x1) 367 368 369def test_NeighborExchange_attr_check_recv_rank_ids_is_tuple_2_failed(): 370 """ 371 Feature: NeighborExchange 372 Description: recv_rank_ids should be list, but a tuple is given 373 Expectation: throw TypeError 374 """ 375 context.set_auto_parallel_context(device_num=8, global_rank=0) 376 377 class Net(nn.Cell): 378 def __init__(self): 379 super(Net, self).__init__() 380 self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=(1, 2,), 381 recv_shapes=([32, 32], [32, 64]), 382 send_shapes=([32, 16],), recv_type=ms.float32) 383 384 def construct(self, x1): 385 out = self.alltoallv((x1,)) 386 return out[0] 387 388 net = Net() 389 with pytest.raises(TypeError): 390 _cell_graph_executor.compile(net, _x1) 391 392 393def test_NeighborExchange_attr_check_recv_rank_ids_is_float_failed(): 394 """ 395 Feature: NeighborExchange 396 Description: recv_rank_ids should be int, but a float is given 397 Expectation: throw TypeError 398 """ 399 context.set_auto_parallel_context(device_num=8, global_rank=0) 400 401 class Net(nn.Cell): 402 def __init__(self): 403 super(Net, self).__init__() 404 self.alltoallv = NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1, 2.0], 405 recv_shapes=([32, 32], [32, 64]), 406 send_shapes=([32, 16],), recv_type=ms.float32) 407 408 def construct(self, x1): 409 out = self.alltoallv((x1,)) 410 return out[0] 411 412 net = Net() 413 with pytest.raises(TypeError): 414 _cell_graph_executor.compile(net, _x1) 415 416 417def test_NeighborExchange_attr_check_send_shape_not_tuple_failed(): 418 """ 419 Feature: NeighborExchange 420 Description: send_shapes should be tuple(list), but a list is given 421 Expectation: throw TypeError 422 """ 423 context.set_auto_parallel_context(device_num=8, global_rank=0) 424 425 class Net(nn.Cell): 426 def __init__(self): 427 super(Net, self).__init__() 428 self.alltoallv = NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1, 2], 429 recv_shapes=([32, 32], [32, 64]), 430 send_shapes=([32, 16]), recv_type=ms.float32) 431 432 def construct(self, x1): 433 out = self.alltoallv((x1,)) 434 return out[0] 435 436 net = Net() 437 with pytest.raises(TypeError): 438 _cell_graph_executor.compile(net, _x1) 439 440 441def test_NeighborExchange_attr_check_send_shape_list_failed(): 442 """ 443 Feature: NeighborExchange 444 Description: send_shapes should be tuple(list), but a list(list) is given 445 Expectation: throw TypeError 446 """ 447 context.set_auto_parallel_context(device_num=8, global_rank=0) 448 449 class Net(nn.Cell): 450 def __init__(self): 451 super(Net, self).__init__() 452 self.alltoallv = NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1, 2], 453 recv_shapes=([32, 32], [32, 64]), 454 send_shapes=[[32, 16]], recv_type=ms.float32) 455 456 def construct(self, x1): 457 out = self.alltoallv((x1,)) 458 return out[0] 459 460 net = Net() 461 with pytest.raises(TypeError): 462 _cell_graph_executor.compile(net, _x1) 463 464 465def test_NeighborExchange_attr_check_recv_type_numpy_failed(): 466 """ 467 Feature: NeighborExchange 468 Description: recv_type should be mindspore type, but a numpy type is given 469 Expectation: throw TypeError 470 """ 471 context.set_auto_parallel_context(device_num=8, global_rank=0) 472 473 class Net(nn.Cell): 474 def __init__(self): 475 super(Net, self).__init__() 476 self.alltoallv = NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1, 2], 477 recv_shapes=([32, 32], [32, 64]), 478 send_shapes=([32, 16],), recv_type=np.float32) 479 480 def construct(self, x1): 481 out = self.alltoallv((x1,)) 482 return out[0] 483 484 net = Net() 485 with pytest.raises(TypeError): 486 _cell_graph_executor.compile(net, _x1) 487 488 489def test_NeighborExchange_attr_invalid_grpup_failed(): 490 """ 491 Feature: NeighborExchange 492 Description: group should be str, but a tuple is given 493 Expectation: throw TypeError 494 """ 495 context.set_auto_parallel_context(device_num=8, global_rank=0) 496 497 class Net(nn.Cell): 498 def __init__(self): 499 super(Net, self).__init__() 500 self.alltoallv = NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1, 2], 501 recv_shapes=([32, 32], [32, 64]), 502 send_shapes=([32, 16],), recv_type=ms.float32, group=("str",)) 503 504 def construct(self, x1): 505 out = self.alltoallv((x1,)) 506 return out[0] 507 508 net = Net() 509 with pytest.raises(TypeError): 510 _cell_graph_executor.compile(net, _x1) 511