1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include "common/common_test.h" 17 #include "frontend/parallel/step_parallel.h" 18 #include "frontend/parallel/graph_util/generate_graph.h" 19 #include "common/py_func_graph_fetcher.h" 20 #include "debug/draw.h" 21 #include "frontend/operator/ops.h" 22 #include "pipeline/jit/static_analysis/static_analysis.h" 23 #include "utils/convert_utils_py.h" 24 25 namespace mindspore { 26 namespace parallel { 27 extern size_t TOTAL_OPS; 28 class TestStepParallel : public UT::Common { 29 public: 30 TestStepParallel() {} 31 void SetUp(); 32 void TearDown() {} 33 }; 34 35 void Init_Device_Manager() { 36 RankList dev_list; 37 38 for (int32_t i = 0; i < 20; i++) { 39 dev_list.push_back(i); 40 } 41 42 RankList stage_map; 43 stage_map.push_back(16); 44 stage_map.push_back(4); 45 46 int32_t local_dev = 0; 47 48 // create a new g_device_manager 49 g_device_manager = std::make_shared<DeviceManager>(); 50 g_device_manager->Init(dev_list, local_dev, stage_map, "hccl"); 51 } 52 53 void TestStepParallel::SetUp() { 54 UT::InitPythonPath(); 55 Init_Device_Manager(); 56 } 57 58 CNodePtr Make_Node(Shape x, Shape y, Shape out, int64_t condition = 0) { 59 FuncGraphPtr func_graph = std::make_shared<FuncGraph>(); 60 ParameterPtr param1 = func_graph->add_parameter(); 61 ParameterPtr param2 = func_graph->add_parameter(); 62 param1->set_name("x"); 63 param2->set_name("y"); 64 BaseShapePtr shape1 = std::make_shared<abstract::Shape>(x); 65 BaseShapePtr shape2 = std::make_shared<abstract::Shape>(y); 66 BaseShapePtr shape3 = std::make_shared<abstract::Shape>(out); 67 std::shared_ptr<tensor::Tensor> inputs_x = std::make_shared<tensor::Tensor>(kNumberTypeInt32, x); 68 std::shared_ptr<tensor::Tensor> inputs_y = std::make_shared<tensor::Tensor>(kNumberTypeInt32, y); 69 std::shared_ptr<tensor::Tensor> inputs_out = std::make_shared<tensor::Tensor>(kNumberTypeInt32, out); 70 AbstractBasePtr abstract1 = abstract::FromValue(inputs_x, true); 71 AbstractBasePtr abstract2 = abstract::FromValue(inputs_y, true); 72 AbstractBasePtr abstract3 = abstract::FromValue(inputs_out, true); 73 switch (condition) { 74 case 0: { 75 abstract1->set_shape(shape1); 76 abstract2->set_shape(shape2); 77 abstract3->set_shape(shape3); 78 param1->set_abstract(abstract1); 79 param2->set_abstract(abstract2); 80 break; 81 } 82 case 1: { 83 // Don't set abstract of param1, expecting a exception raised. 84 param2->set_abstract(abstract2); 85 break; 86 } 87 case 2: { 88 abstract1->set_shape(shape1); 89 abstract2->set_shape(shape2); 90 param1->set_abstract(abstract1); 91 param2->set_abstract(abstract2); 92 abstract3 = abstract::FromValue(static_cast<int64_t>(1), false); 93 break; 94 } 95 case 3: { 96 std::vector<BaseShapePtr> shape_o = {std::make_shared<abstract::Shape>(x), std::make_shared<abstract::Shape>(y)}; 97 BaseShapePtr shape4 = std::make_shared<abstract::TupleShape>(shape_o); 98 abstract1->set_shape(shape1); 99 abstract2->set_shape(shape2); 100 abstract3->set_shape(shape4); 101 param1->set_abstract(abstract1); 102 param2->set_abstract(abstract2); 103 break; 104 } 105 default: 106 MS_LOG(INFO) << "Do Nothing!"; 107 } 108 std::vector<AnfNodePtr> inputs; 109 inputs.push_back(NewValueNode(prim::kPrimMatMul)); 110 inputs.push_back(param1); 111 inputs.push_back(param2); 112 CNodePtr node = func_graph->NewCNode(inputs); 113 node->set_abstract(abstract3); 114 return node; 115 } 116 117 FuncGraphManagerPtr Make_Manager(int64_t condition = 0) { 118 std::vector<int64_t> inputs_x = {64, 32}; 119 std::vector<int64_t> inputs_y = {32, 64}; 120 std::vector<int64_t> inputs_z = {64, 128}; 121 std::vector<int64_t> outputs_1 = {64, 64}; 122 std::vector<int64_t> outputs_2 = {64, 128}; 123 FuncGraphPtr func_graph = std::make_shared<FuncGraph>(); 124 ParameterPtr param1 = func_graph->add_parameter(); 125 ParameterPtr param2 = func_graph->add_parameter(); 126 ParameterPtr param3 = func_graph->add_parameter(); 127 std::shared_ptr<tensor::Tensor> inputs_x_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, inputs_x); 128 std::shared_ptr<tensor::Tensor> inputs_y_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, inputs_y); 129 std::shared_ptr<tensor::Tensor> inputs_z_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, inputs_z); 130 std::shared_ptr<tensor::Tensor> inputs_out1_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, outputs_1); 131 std::shared_ptr<tensor::Tensor> inputs_out2_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, outputs_2); 132 AbstractBasePtr abstract_x = abstract::FromValue(inputs_x_dim, true); 133 AbstractBasePtr abstract_y = abstract::FromValue(inputs_y_dim, true); 134 AbstractBasePtr abstract_z = abstract::FromValue(inputs_z_dim, true); 135 AbstractBasePtr abstract_out1 = abstract::FromValue(inputs_out1_dim, true); 136 AbstractBasePtr abstract_out2 = abstract::FromValue(inputs_out2_dim, true); 137 param1->set_abstract(abstract_x); 138 param2->set_abstract(abstract_y); 139 param3->set_abstract(abstract_z); 140 Dimensions v1 = {2, 2}; 141 Dimensions v2 = {2, 4}; 142 std::vector<ValuePtr> elements = {MakeValue(v1), MakeValue(v2)}; 143 ValueTuplePtr var = std::make_shared<ValueTuple>(elements); 144 std::vector<AnfNodePtr> inputs; 145 inputs.push_back(NewValueNode(prim::kPrimMatMul)); 146 inputs.push_back(param1); 147 inputs.push_back(param2); 148 CNodePtr node1 = func_graph->NewCNode(inputs); 149 node1->set_in_forward_flag(true); 150 node1->set_abstract(abstract_out1); 151 PrimitivePtr prim1 = node1->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>(); 152 ValuePtr transpose_a = MakeValue(false); 153 ValuePtr transpose_b = MakeValue(false); 154 prim1->AddAttr("transpose_a", transpose_a); 155 prim1->AddAttr("transpose_b", transpose_b); 156 prim1->AddAttr("instance_name", MakeValue("matmul1")); 157 prim1->AddAttr("strategy", var); 158 inputs.clear(); 159 Dimensions v3 = {2, 2}; 160 Dimensions v4 = {2, 4}; 161 std::vector<ValuePtr> elements2 = {MakeValue(v3), MakeValue(v4)}; 162 ValueTuplePtr var2 = std::make_shared<ValueTuple>(elements2); 163 inputs.push_back(NewValueNode(prim::kPrimMatMul)); 164 inputs.push_back(node1); 165 inputs.push_back(param3); 166 CNodePtr node2 = func_graph->NewCNode(inputs); 167 node2->set_in_forward_flag(true); 168 node2->set_abstract(abstract_out2); 169 inputs.clear(); 170 inputs.push_back(NewValueNode(prim::kPrimReturn)); 171 inputs.push_back(node2); 172 CNodePtr cnode_return = func_graph->NewCNode(inputs); 173 cnode_return->set_in_forward_flag(true); 174 func_graph->set_return(cnode_return); 175 PrimitivePtr prim2 = node2->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>(); 176 prim2->AddAttr("transpose_a", transpose_a); 177 prim2->AddAttr("transpose_b", transpose_b); 178 prim2->AddAttr("instance_name", MakeValue("matmul2")); 179 prim2->AddAttr("strategy", var2); 180 switch (condition) { 181 case 1: { 182 prim1->set_attr("strategy", MakeValue(static_cast<int64_t>(0))); 183 break; 184 } 185 case 2: { 186 std::vector<ValuePtr> elements_t = {MakeValue(static_cast<int64_t>(0))}; 187 ValueTuplePtr var_t = std::make_shared<ValueTuple>(elements_t); 188 prim1->set_attr("strategy", var_t); 189 break; 190 } 191 case 3: { 192 Dimensions vt1 = {2, 4}; 193 Dimensions vt2 = {2, 4}; 194 std::vector<ValuePtr> elements_t2 = {MakeValue(vt1), MakeValue(vt2)}; 195 ValueTuplePtr var_t2 = std::make_shared<ValueTuple>(elements_t2); 196 prim1->set_attr("strategy", var_t2); 197 break; 198 } 199 } 200 std::vector<FuncGraphPtr> func_graphs{func_graph}; 201 FuncGraphManagerPtr manager = std::make_shared<FuncGraphManager>(func_graphs, true); 202 manager->Init(); 203 return manager; 204 } 205 206 TEST_F(TestStepParallel, GetPythonPath1) { 207 OperatorName operator_name = "AllReduce"; 208 const std::string expect = "mindspore.ops.operations"; 209 auto temp = parallel::GetOpPythonPath(operator_name); 210 ASSERT_EQ(temp, expect); 211 } 212 213 TEST_F(TestStepParallel, GetPythonPath2) { 214 OperatorName operator_name = "Add"; 215 const std::string expect = "mindspore.ops.operations"; 216 auto temp = parallel::GetOpPythonPath(operator_name); 217 ASSERT_EQ(temp, expect); 218 } 219 220 TEST_F(TestStepParallel, ExtractStrategy) { 221 Dimensions v1 = {2, 2}; 222 Dimensions v2 = {4, 4}; 223 std::unordered_map<std::string, ValuePtr> attrs; 224 // stage 225 ValuePtr val1 = MakeValue(v1); 226 ValuePtr val2 = MakeValue(v2); 227 std::vector<ValuePtr> elements = {val1, val2}; 228 ValueTuplePtr strategy_tuple = std::make_shared<ValueTuple>(elements); 229 attrs["strategy"] = strategy_tuple; 230 Strategys strategy_expect = {v1, v2}; 231 StrategyPtr strategy = ExtractStrategy(attrs["strategy"]); 232 Strategys strategy_test = strategy->GetInputDim(); 233 234 ASSERT_EQ(strategy_expect, strategy_test); 235 } 236 237 TEST_F(TestStepParallel, ExtractShape) { 238 Shape inputs_x_dims = {64, 32}; 239 Shape inputs_y_dims = {32, 64}; 240 Shape outputs_dims = {64, 64}; 241 CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims, 4); 242 EXPECT_THROW({ ExtractShape(node); }, std::runtime_error); 243 } 244 245 TEST_F(TestStepParallel, ExtractShape1) { 246 Shape inputs_x_dims = {64, 32}; 247 Shape inputs_y_dims = {32, 64}; 248 Shape outputs_dims = {64, 64}; 249 CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims); 250 std::vector<Shapes> shape_test = ExtractShape(node); 251 Shapes inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims}; 252 Shapes outputs_shape = std::vector<Shape>{outputs_dims}; 253 std::vector<Shapes> shape_expect = {inputs_shape, outputs_shape}; 254 ASSERT_EQ(shape_test, shape_expect); 255 } 256 257 TEST_F(TestStepParallel, ExtractShape2) { 258 Shape inputs_x_dims = {64, 32}; 259 Shape inputs_y_dims = {32, 64}; 260 Shape outputs_dims = {64, 64}; 261 CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims, 1); 262 EXPECT_THROW({ ExtractShape(node); }, std::runtime_error); 263 } 264 265 TEST_F(TestStepParallel, ExtractShape3) { 266 Shape inputs_x_dims = {64, 32}; 267 Shape inputs_y_dims = {32, 64}; 268 Shape outputs_dims = {64, 64}; 269 CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims, 3); 270 Shapes inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims}; 271 std::vector<Shapes> shape_expect = {inputs_shape, inputs_shape}; 272 std::vector<Shapes> shape_test = ExtractShape(node); 273 ASSERT_EQ(shape_test, shape_expect); 274 } 275 276 TEST_F(TestStepParallel, CreatOpInstance) { 277 ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM); 278 ValuePtr attr1_value = MakeValue("0-1-2"); 279 Attr attr0 = std::make_pair("op", attr0_value); 280 Attr attr1 = std::make_pair("group", attr1_value); 281 OperatorAttrs attrs = {attr0, attr1}; 282 OperatorName op_name = "AllReduce"; 283 OperatorParams operator_param; 284 OperatorArgs args = std::make_pair(attrs, operator_param); 285 auto op_instance = CreatOpInstance(args.first, op_name, "test"); 286 ASSERT_TRUE(op_instance); 287 PrimitivePyPtr allreduce_ptr = dyn_cast<PrimitivePy>(op_instance); 288 ASSERT_TRUE(allreduce_ptr); 289 if (nullptr != allreduce_ptr) { 290 MS_LOG(INFO) << "Get PrimitivePyPtr: " << allreduce_ptr->name(); 291 292 std::vector<py::object> arglist; 293 (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arglist), 294 [](Attr attr) { return ValueToPyData(attr.second); }); 295 py::object allreduce_pyobj = parse::python_adapter::CallPyFn( 296 "mindspore.parallel._utils", "_get_python_op", "AllReduce", "mindspore.ops.operations", "test", arglist); 297 py::dict opAttr = py::getattr(allreduce_pyobj, "attrs"); 298 std::unordered_map<std::string, ValuePtr> attributes{}; 299 for (auto item : opAttr) { 300 if (!py::isinstance<py::str>(item.first)) { 301 MS_LOG(EXCEPTION) << "type error in py dict convert"; 302 } 303 std::string name = py::cast<std::string>(item.first); 304 MS_LOG(INFO) << "Attr name: " << name; 305 306 ValuePtr converted_ret; 307 if (name == "op") { 308 parse::ConvertData(py::cast<py::object>(item.second), &converted_ret); 309 ASSERT_EQ(converted_ret->ToString(), "sum"); 310 } else { 311 if (name == "group") { 312 parse::ConvertData(py::cast<py::object>(item.second), &converted_ret); 313 ASSERT_EQ(converted_ret->ToString(), "0-1-2"); 314 } else if (name == "fusion") { 315 parse::ConvertData(py::cast<py::object>(item.second), &converted_ret); 316 ASSERT_EQ(converted_ret->ToString(), "0"); 317 } else if (name == "instance_name") { 318 parse::ConvertData(py::cast<py::object>(item.second), &converted_ret); 319 ASSERT_EQ(converted_ret->ToString(), "test"); 320 } else if (name == "index") { 321 parse::ConvertData(py::cast<py::object>(item.second), &converted_ret); 322 ASSERT_EQ(converted_ret->ToString(), "0"); 323 } else if (name == "no_elimilate") { 324 parse::ConvertData(py::cast<py::object>(item.second), &converted_ret); 325 ASSERT_EQ(converted_ret->ToString(), "true"); 326 } else { 327 MS_LOG(EXCEPTION) << "Test failed"; 328 } 329 } 330 attributes.emplace(name, converted_ret); 331 } 332 } 333 } 334 335 TEST_F(TestStepParallel, CreatOpInstance1) { 336 OperatorAttrs attrs; 337 OperatorName op_name = "ABC"; 338 OperatorParams operator_param; 339 OperatorArgs args = std::make_pair(attrs, operator_param); 340 EXPECT_THROW({ CreatOpInstance(args.first, op_name, "test"); }, std::runtime_error); 341 } 342 343 TEST_F(TestStepParallel, OperatorInstance) { 344 // create attrs and prim 345 PrimitivePtr prim = NewValueNode(prim::kPrimMatMul)->value()->cast<PrimitivePtr>(); 346 ValuePtr transpose_a = MakeValue(false); 347 ValuePtr transpose_b = MakeValue(false); 348 prim->set_attr("transpose_a", transpose_a); 349 prim->set_attr("transpose_b", transpose_b); 350 auto attrs = prim->attrs(); 351 // create strategy 352 Strategys strategy = {{2, 2}, {2, 4}}; 353 StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy); 354 // create shape 355 Shapes inputs_shape = std::vector<Shape>{{64, 32}, {32, 64}}; 356 Shapes outputs_shape = std::vector<Shape>{{64, 64}}; 357 std::vector<Shapes> shape = {inputs_shape, outputs_shape}; 358 TOTAL_OPS = 0; 359 OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape); 360 matmul_info->Init(strategyPtr); 361 std::string name_expect = "MatMulInfo00"; 362 std::string name_test = matmul_info->name(); 363 ASSERT_EQ(name_expect, name_test); 364 } 365 366 TEST_F(TestStepParallel, ExtractInformation) { 367 FuncGraphManagerPtr manager = Make_Manager(); 368 FuncGraphSet graphs = manager->func_graphs(); 369 FuncGraphPtr graph = *graphs.begin(); 370 auto ret = graph->get_return(); 371 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret); 372 ExtractInformation(all_nodes); 373 } 374 375 TEST_F(TestStepParallel, ExtractInformation2) { 376 FuncGraphManagerPtr manager = Make_Manager(2); 377 FuncGraphSet graphs = manager->func_graphs(); 378 FuncGraphPtr graph = *graphs.begin(); 379 auto ret = graph->get_return(); 380 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret); 381 EXPECT_THROW({ ExtractInformation(all_nodes); }, std::runtime_error); 382 } 383 384 TEST_F(TestStepParallel, ExtractInformation3) { 385 FuncGraphManagerPtr manager = Make_Manager(3); 386 FuncGraphSet graphs = manager->func_graphs(); 387 FuncGraphPtr graph = *graphs.begin(); 388 auto ret = graph->get_return(); 389 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret); 390 EXPECT_THROW({ ExtractInformation(all_nodes); }, std::runtime_error); 391 } 392 393 TEST_F(TestStepParallel, ForwardCommunication1) { 394 ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM); 395 ValuePtr attr1_value = MakeValue("0-1-2"); 396 Attr attr0 = std::make_pair("op", attr0_value); 397 Attr attr1 = std::make_pair("group", attr1_value); 398 OperatorAttrs attrs = {attr0, attr1}; 399 OperatorName op_name = "AllReduce"; 400 OperatorParams operator_param; 401 OperatorArgs args = std::make_pair(attrs, operator_param); 402 Operator op = std::make_pair(op_name, args); 403 OperatorVector op_list = {op, op}; 404 FuncGraphManagerPtr manager = Make_Manager(); 405 FuncGraphSet graphs = manager->func_graphs(); 406 FuncGraphPtr graph = *graphs.begin(); 407 auto ret = graph->get_return(); 408 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret); 409 ExtractInformation(all_nodes); 410 for (auto &node : all_nodes) { 411 if (!node->isa<CNode>()) { 412 continue; 413 } 414 auto cnode = node->cast<CNodePtr>(); 415 FuncGraphPtr func_graph = node->func_graph(); 416 PrimitivePtr prim = cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>(); 417 if (prim->name() == "MatMul") { 418 ForwardCommunication(op_list, cnode); 419 } 420 } 421 AnfNodeSet after_nodes = manager->all_nodes(); 422 for (auto &node : after_nodes) { 423 if (!node->isa<CNode>()) { 424 continue; 425 } 426 auto &inputs = node->cast<CNodePtr>()->inputs(); 427 PrimitivePtr prim = inputs[0]->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>(); 428 if (prim->name() == "Return" || prim->name() == "MatMul") { 429 if (!inputs[1]->isa<Parameter>()) { 430 CNodePtr pre_node = inputs[1]->cast<CNodePtr>(); 431 PrimitivePtr pre_prim = pre_node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>(); 432 CNodePtr pre_node2 = pre_node->input(1)->cast<CNodePtr>(); 433 PrimitivePtr pre_prim2 = pre_node2->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>(); 434 ASSERT_EQ("AllReduce", pre_prim->name()); 435 ASSERT_EQ("AllReduce", pre_prim2->name()); 436 } 437 } 438 } 439 } 440 441 TEST_F(TestStepParallel, ForwardCommunication2) { 442 OperatorVector op_list; 443 FuncGraphManagerPtr manager = Make_Manager(); 444 FuncGraphSet graphs = manager->func_graphs(); 445 FuncGraphPtr graph = *graphs.begin(); 446 auto ret = graph->get_return(); 447 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret); 448 ExtractInformation(all_nodes); 449 for (auto &node : all_nodes) { 450 if (!node->isa<CNode>()) { 451 continue; 452 } 453 auto cnode = node->cast<CNodePtr>(); 454 FuncGraphPtr func_graph = node->func_graph(); 455 func_graph->set_manager(nullptr); 456 PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); 457 if (prim->name() == "MatMul") { 458 EXPECT_THROW({ ForwardCommunication(op_list, cnode); }, std::runtime_error); 459 break; 460 } 461 } 462 } 463 464 TEST_F(TestStepParallel, ForwardCommunication3) { 465 OperatorVector op_list; 466 FuncGraphManagerPtr manager = Make_Manager(); 467 FuncGraphSet graphs = manager->func_graphs(); 468 FuncGraphPtr graph = *graphs.begin(); 469 auto ret = graph->get_return(); 470 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret); 471 ExtractInformation(all_nodes); 472 for (auto &node : all_nodes) { 473 if (!node->isa<CNode>()) { 474 continue; 475 } 476 auto cnode = node->cast<CNodePtr>(); 477 FuncGraphPtr func_graph = node->func_graph(); 478 PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); 479 if (prim->name() == "MatMul") { 480 OperatorAttrs attrs; 481 OperatorParams operator_param; 482 OperatorArgs args = std::make_pair(attrs, operator_param); 483 Operator op = std::make_pair("ABC", args); 484 OperatorVector op_list = {op}; 485 EXPECT_THROW({ ForwardCommunication(op_list, cnode); }, std::runtime_error); 486 break; 487 } 488 } 489 } 490 491 TEST_F(TestStepParallel, GetTensorInLayout) { 492 // create attrs and prim 493 FuncGraphPtr func_graph = std::make_shared<FuncGraph>(); 494 Shape inputs_x_dims = {64, 32}; 495 Shape inputs_y_dims = {32, 64}; 496 Shape outputs_dims = {64, 64}; 497 CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims); 498 std::vector<AnfNodePtr> inputs(node->inputs()); 499 CNodePtr node1 = func_graph->NewCNode(inputs); 500 PrimitivePtr prim = node1->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>(); 501 ValuePtr transpose_a = MakeValue(false); 502 ValuePtr transpose_b = MakeValue(false); 503 prim->set_attr("transpose_a", transpose_a); 504 prim->set_attr("transpose_b", transpose_b); 505 auto attrs = prim->attrs(); 506 // create strategy 507 Strategys strategy = {{2, 2}, {2, 4}}; 508 StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy); 509 // create shape 510 Shapes inputs_shape = std::vector<Shape>{{64, 32}, {32, 64}}; 511 Shapes outputs_shape = std::vector<Shape>{{64, 64}}; 512 std::vector<Shapes> shape = {inputs_shape, outputs_shape}; 513 OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape); 514 matmul_info->Init(strategyPtr); 515 node->set_user_data<OperatorInfo>(matmul_info); 516 OperatorInfoPtr distribute_operator_pre = node->user_data<OperatorInfo>(); 517 TensorLayout tensorlayout_e; 518 Shape array = {64, 64}; 519 TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre); 520 Shape tensor_shape_test = tensorlayout.tensor_shape().array(); 521 ASSERT_EQ(array, tensor_shape_test); 522 } 523 524 } // namespace parallel 525 } // namespace mindspore 526