1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "common/common_test.h"
17 #include "frontend/parallel/step_parallel.h"
18 #include "frontend/parallel/graph_util/generate_graph.h"
19 #include "common/py_func_graph_fetcher.h"
20 #include "debug/draw.h"
21 #include "frontend/operator/ops.h"
22 #include "pipeline/jit/static_analysis/static_analysis.h"
23 #include "utils/convert_utils_py.h"
24
25 namespace mindspore {
26 namespace parallel {
27 extern size_t TOTAL_OPS;
28 class TestStepParallel : public UT::Common {
29 public:
TestStepParallel()30 TestStepParallel() {}
31 void SetUp();
TearDown()32 void TearDown() {}
33 };
34
Init_Device_Manager()35 void Init_Device_Manager() {
36 RankList dev_list;
37
38 for (int32_t i = 0; i < 20; i++) {
39 dev_list.push_back(i);
40 }
41
42 RankList stage_map;
43 stage_map.push_back(16);
44 stage_map.push_back(4);
45
46 int32_t local_dev = 0;
47
48 // create a new g_device_manager
49 g_device_manager = std::make_shared<DeviceManager>();
50 g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
51 }
52
SetUp()53 void TestStepParallel::SetUp() {
54 UT::InitPythonPath();
55 Init_Device_Manager();
56 }
57
Make_Node(Shape x,Shape y,Shape out,int64_t condition=0)58 CNodePtr Make_Node(Shape x, Shape y, Shape out, int64_t condition = 0) {
59 FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
60 ParameterPtr param1 = func_graph->add_parameter();
61 ParameterPtr param2 = func_graph->add_parameter();
62 param1->set_name("x");
63 param2->set_name("y");
64 BaseShapePtr shape1 = std::make_shared<abstract::Shape>(x);
65 BaseShapePtr shape2 = std::make_shared<abstract::Shape>(y);
66 BaseShapePtr shape3 = std::make_shared<abstract::Shape>(out);
67 std::shared_ptr<tensor::Tensor> inputs_x = std::make_shared<tensor::Tensor>(kNumberTypeInt32, x);
68 std::shared_ptr<tensor::Tensor> inputs_y = std::make_shared<tensor::Tensor>(kNumberTypeInt32, y);
69 std::shared_ptr<tensor::Tensor> inputs_out = std::make_shared<tensor::Tensor>(kNumberTypeInt32, out);
70 AbstractBasePtr abstract1 = abstract::FromValue(inputs_x, true);
71 AbstractBasePtr abstract2 = abstract::FromValue(inputs_y, true);
72 AbstractBasePtr abstract3 = abstract::FromValue(inputs_out, true);
73 switch (condition) {
74 case 0: {
75 abstract1->set_shape(shape1);
76 abstract2->set_shape(shape2);
77 abstract3->set_shape(shape3);
78 param1->set_abstract(abstract1);
79 param2->set_abstract(abstract2);
80 break;
81 }
82 case 1: {
83 // Don't set abstract of param1, expecting a exception raised.
84 param2->set_abstract(abstract2);
85 break;
86 }
87 case 2: {
88 abstract1->set_shape(shape1);
89 abstract2->set_shape(shape2);
90 param1->set_abstract(abstract1);
91 param2->set_abstract(abstract2);
92 abstract3 = abstract::FromValue(static_cast<int64_t>(1), false);
93 break;
94 }
95 case 3: {
96 std::vector<BaseShapePtr> shape_o = {std::make_shared<abstract::Shape>(x), std::make_shared<abstract::Shape>(y)};
97 BaseShapePtr shape4 = std::make_shared<abstract::TupleShape>(shape_o);
98 abstract1->set_shape(shape1);
99 abstract2->set_shape(shape2);
100 abstract3->set_shape(shape4);
101 param1->set_abstract(abstract1);
102 param2->set_abstract(abstract2);
103 break;
104 }
105 default:
106 MS_LOG(INFO) << "Do Nothing!";
107 }
108 std::vector<AnfNodePtr> inputs;
109 inputs.push_back(NewValueNode(prim::kPrimMatMul));
110 inputs.push_back(param1);
111 inputs.push_back(param2);
112 CNodePtr node = func_graph->NewCNode(inputs);
113 node->set_abstract(abstract3);
114 return node;
115 }
116
Make_Manager(int64_t condition=0)117 FuncGraphManagerPtr Make_Manager(int64_t condition = 0) {
118 std::vector<int64_t> inputs_x = {64, 32};
119 std::vector<int64_t> inputs_y = {32, 64};
120 std::vector<int64_t> inputs_z = {64, 128};
121 std::vector<int64_t> outputs_1 = {64, 64};
122 std::vector<int64_t> outputs_2 = {64, 128};
123 FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
124 ParameterPtr param1 = func_graph->add_parameter();
125 ParameterPtr param2 = func_graph->add_parameter();
126 ParameterPtr param3 = func_graph->add_parameter();
127 std::shared_ptr<tensor::Tensor> inputs_x_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, inputs_x);
128 std::shared_ptr<tensor::Tensor> inputs_y_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, inputs_y);
129 std::shared_ptr<tensor::Tensor> inputs_z_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, inputs_z);
130 std::shared_ptr<tensor::Tensor> inputs_out1_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, outputs_1);
131 std::shared_ptr<tensor::Tensor> inputs_out2_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, outputs_2);
132 AbstractBasePtr abstract_x = abstract::FromValue(inputs_x_dim, true);
133 AbstractBasePtr abstract_y = abstract::FromValue(inputs_y_dim, true);
134 AbstractBasePtr abstract_z = abstract::FromValue(inputs_z_dim, true);
135 AbstractBasePtr abstract_out1 = abstract::FromValue(inputs_out1_dim, true);
136 AbstractBasePtr abstract_out2 = abstract::FromValue(inputs_out2_dim, true);
137 param1->set_abstract(abstract_x);
138 param2->set_abstract(abstract_y);
139 param3->set_abstract(abstract_z);
140 Dimensions v1 = {2, 2};
141 Dimensions v2 = {2, 4};
142 std::vector<ValuePtr> elements = {MakeValue(v1), MakeValue(v2)};
143 ValueTuplePtr var = std::make_shared<ValueTuple>(elements);
144 std::vector<AnfNodePtr> inputs;
145 inputs.push_back(NewValueNode(prim::kPrimMatMul));
146 inputs.push_back(param1);
147 inputs.push_back(param2);
148 CNodePtr node1 = func_graph->NewCNode(inputs);
149 node1->set_in_forward_flag(true);
150 node1->set_abstract(abstract_out1);
151 PrimitivePtr prim1 = node1->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
152 ValuePtr transpose_a = MakeValue(false);
153 ValuePtr transpose_b = MakeValue(false);
154 prim1->AddAttr("transpose_a", transpose_a);
155 prim1->AddAttr("transpose_b", transpose_b);
156 prim1->AddAttr("instance_name", MakeValue("matmul1"));
157 prim1->AddAttr("strategy", var);
158 inputs.clear();
159 Dimensions v3 = {2, 2};
160 Dimensions v4 = {2, 4};
161 std::vector<ValuePtr> elements2 = {MakeValue(v3), MakeValue(v4)};
162 ValueTuplePtr var2 = std::make_shared<ValueTuple>(elements2);
163 inputs.push_back(NewValueNode(prim::kPrimMatMul));
164 inputs.push_back(node1);
165 inputs.push_back(param3);
166 CNodePtr node2 = func_graph->NewCNode(inputs);
167 node2->set_in_forward_flag(true);
168 node2->set_abstract(abstract_out2);
169 inputs.clear();
170 inputs.push_back(NewValueNode(prim::kPrimReturn));
171 inputs.push_back(node2);
172 CNodePtr cnode_return = func_graph->NewCNode(inputs);
173 cnode_return->set_in_forward_flag(true);
174 func_graph->set_return(cnode_return);
175 PrimitivePtr prim2 = node2->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
176 prim2->AddAttr("transpose_a", transpose_a);
177 prim2->AddAttr("transpose_b", transpose_b);
178 prim2->AddAttr("instance_name", MakeValue("matmul2"));
179 prim2->AddAttr("strategy", var2);
180 switch (condition) {
181 case 1: {
182 prim1->set_attr("strategy", MakeValue(static_cast<int64_t>(0)));
183 break;
184 }
185 case 2: {
186 std::vector<ValuePtr> elements_t = {MakeValue(static_cast<int64_t>(0))};
187 ValueTuplePtr var_t = std::make_shared<ValueTuple>(elements_t);
188 prim1->set_attr("strategy", var_t);
189 break;
190 }
191 case 3: {
192 Dimensions vt1 = {2, 4};
193 Dimensions vt2 = {2, 4};
194 std::vector<ValuePtr> elements_t2 = {MakeValue(vt1), MakeValue(vt2)};
195 ValueTuplePtr var_t2 = std::make_shared<ValueTuple>(elements_t2);
196 prim1->set_attr("strategy", var_t2);
197 break;
198 }
199 }
200 std::vector<FuncGraphPtr> func_graphs{func_graph};
201 FuncGraphManagerPtr manager = std::make_shared<FuncGraphManager>(func_graphs, true);
202 manager->Init();
203 return manager;
204 }
205
TEST_F(TestStepParallel,GetPythonPath1)206 TEST_F(TestStepParallel, GetPythonPath1) {
207 OperatorName operator_name = "AllReduce";
208 const std::string expect = "mindspore.ops.operations";
209 auto temp = parallel::GetOpPythonPath(operator_name);
210 ASSERT_EQ(temp, expect);
211 }
212
TEST_F(TestStepParallel,GetPythonPath2)213 TEST_F(TestStepParallel, GetPythonPath2) {
214 OperatorName operator_name = "Add";
215 const std::string expect = "mindspore.ops.operations";
216 auto temp = parallel::GetOpPythonPath(operator_name);
217 ASSERT_EQ(temp, expect);
218 }
219
TEST_F(TestStepParallel,ExtractStrategy)220 TEST_F(TestStepParallel, ExtractStrategy) {
221 Dimensions v1 = {2, 2};
222 Dimensions v2 = {4, 4};
223 std::unordered_map<std::string, ValuePtr> attrs;
224 // stage
225 ValuePtr val1 = MakeValue(v1);
226 ValuePtr val2 = MakeValue(v2);
227 std::vector<ValuePtr> elements = {val1, val2};
228 ValueTuplePtr strategy_tuple = std::make_shared<ValueTuple>(elements);
229 attrs["strategy"] = strategy_tuple;
230 Strategys strategy_expect = {v1, v2};
231 StrategyPtr strategy = ExtractStrategy(attrs["strategy"]);
232 Strategys strategy_test = strategy->GetInputDim();
233
234 ASSERT_EQ(strategy_expect, strategy_test);
235 }
236
TEST_F(TestStepParallel,ExtractShape)237 TEST_F(TestStepParallel, ExtractShape) {
238 Shape inputs_x_dims = {64, 32};
239 Shape inputs_y_dims = {32, 64};
240 Shape outputs_dims = {64, 64};
241 CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims, 4);
242 EXPECT_THROW({ ExtractShape(node); }, std::runtime_error);
243 }
244
TEST_F(TestStepParallel,ExtractShape1)245 TEST_F(TestStepParallel, ExtractShape1) {
246 Shape inputs_x_dims = {64, 32};
247 Shape inputs_y_dims = {32, 64};
248 Shape outputs_dims = {64, 64};
249 CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims);
250 std::vector<Shapes> shape_test = ExtractShape(node);
251 Shapes inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims};
252 Shapes outputs_shape = std::vector<Shape>{outputs_dims};
253 std::vector<Shapes> shape_expect = {inputs_shape, outputs_shape};
254 ASSERT_EQ(shape_test, shape_expect);
255 }
256
TEST_F(TestStepParallel,ExtractShape2)257 TEST_F(TestStepParallel, ExtractShape2) {
258 Shape inputs_x_dims = {64, 32};
259 Shape inputs_y_dims = {32, 64};
260 Shape outputs_dims = {64, 64};
261 CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims, 1);
262 EXPECT_THROW({ ExtractShape(node); }, std::runtime_error);
263 }
264
TEST_F(TestStepParallel,ExtractShape3)265 TEST_F(TestStepParallel, ExtractShape3) {
266 Shape inputs_x_dims = {64, 32};
267 Shape inputs_y_dims = {32, 64};
268 Shape outputs_dims = {64, 64};
269 CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims, 3);
270 Shapes inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims};
271 std::vector<Shapes> shape_expect = {inputs_shape, inputs_shape};
272 std::vector<Shapes> shape_test = ExtractShape(node);
273 ASSERT_EQ(shape_test, shape_expect);
274 }
275
TEST_F(TestStepParallel,CreatOpInstance)276 TEST_F(TestStepParallel, CreatOpInstance) {
277 ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM);
278 ValuePtr attr1_value = MakeValue("0-1-2");
279 Attr attr0 = std::make_pair("op", attr0_value);
280 Attr attr1 = std::make_pair("group", attr1_value);
281 OperatorAttrs attrs = {attr0, attr1};
282 OperatorName op_name = "AllReduce";
283 OperatorParams operator_param;
284 OperatorArgs args = std::make_pair(attrs, operator_param);
285 auto op_instance = CreatOpInstance(args.first, op_name, "test");
286 ASSERT_TRUE(op_instance);
287 PrimitivePyPtr allreduce_ptr = dyn_cast<PrimitivePy>(op_instance);
288 ASSERT_TRUE(allreduce_ptr);
289 if (nullptr != allreduce_ptr) {
290 MS_LOG(INFO) << "Get PrimitivePyPtr: " << allreduce_ptr->name();
291
292 std::vector<py::object> arglist;
293 (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arglist),
294 [](Attr attr) { return ValueToPyData(attr.second); });
295 py::object allreduce_pyobj = parse::python_adapter::CallPyFn(
296 "mindspore.parallel._utils", "_get_python_op", "AllReduce", "mindspore.ops.operations", "test", arglist);
297 py::dict opAttr = py::getattr(allreduce_pyobj, "attrs");
298 std::unordered_map<std::string, ValuePtr> attributes{};
299 for (auto item : opAttr) {
300 if (!py::isinstance<py::str>(item.first)) {
301 MS_LOG(EXCEPTION) << "type error in py dict convert";
302 }
303 std::string name = py::cast<std::string>(item.first);
304 MS_LOG(INFO) << "Attr name: " << name;
305
306 ValuePtr converted_ret;
307 if (name == "op") {
308 parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
309 ASSERT_EQ(converted_ret->ToString(), "sum");
310 } else {
311 if (name == "group") {
312 parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
313 ASSERT_EQ(converted_ret->ToString(), "0-1-2");
314 } else if (name == "fusion") {
315 parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
316 ASSERT_EQ(converted_ret->ToString(), "0");
317 } else if (name == "instance_name") {
318 parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
319 ASSERT_EQ(converted_ret->ToString(), "test");
320 } else if (name == "index") {
321 parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
322 ASSERT_EQ(converted_ret->ToString(), "0");
323 } else if (name == "no_elimilate") {
324 parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
325 ASSERT_EQ(converted_ret->ToString(), "true");
326 } else {
327 MS_LOG(EXCEPTION) << "Test failed";
328 }
329 }
330 attributes.emplace(name, converted_ret);
331 }
332 }
333 }
334
TEST_F(TestStepParallel,CreatOpInstance1)335 TEST_F(TestStepParallel, CreatOpInstance1) {
336 OperatorAttrs attrs;
337 OperatorName op_name = "ABC";
338 OperatorParams operator_param;
339 OperatorArgs args = std::make_pair(attrs, operator_param);
340 EXPECT_THROW({ CreatOpInstance(args.first, op_name, "test"); }, std::runtime_error);
341 }
342
TEST_F(TestStepParallel,OperatorInstance)343 TEST_F(TestStepParallel, OperatorInstance) {
344 // create attrs and prim
345 PrimitivePtr prim = NewValueNode(prim::kPrimMatMul)->value()->cast<PrimitivePtr>();
346 ValuePtr transpose_a = MakeValue(false);
347 ValuePtr transpose_b = MakeValue(false);
348 prim->set_attr("transpose_a", transpose_a);
349 prim->set_attr("transpose_b", transpose_b);
350 auto attrs = prim->attrs();
351 // create strategy
352 Strategys strategy = {{2, 2}, {2, 4}};
353 StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy);
354 // create shape
355 Shapes inputs_shape = std::vector<Shape>{{64, 32}, {32, 64}};
356 Shapes outputs_shape = std::vector<Shape>{{64, 64}};
357 std::vector<Shapes> shape = {inputs_shape, outputs_shape};
358 TOTAL_OPS = 0;
359 OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape);
360 matmul_info->Init(strategyPtr);
361 std::string name_expect = "MatMulInfo00";
362 std::string name_test = matmul_info->name();
363 ASSERT_EQ(name_expect, name_test);
364 }
365
TEST_F(TestStepParallel,ExtractInformation)366 TEST_F(TestStepParallel, ExtractInformation) {
367 FuncGraphManagerPtr manager = Make_Manager();
368 FuncGraphSet graphs = manager->func_graphs();
369 FuncGraphPtr graph = *graphs.begin();
370 auto ret = graph->get_return();
371 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
372 ExtractInformation(all_nodes);
373 }
374
TEST_F(TestStepParallel,ExtractInformation2)375 TEST_F(TestStepParallel, ExtractInformation2) {
376 FuncGraphManagerPtr manager = Make_Manager(2);
377 FuncGraphSet graphs = manager->func_graphs();
378 FuncGraphPtr graph = *graphs.begin();
379 auto ret = graph->get_return();
380 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
381 EXPECT_THROW({ ExtractInformation(all_nodes); }, std::runtime_error);
382 }
383
TEST_F(TestStepParallel,ExtractInformation3)384 TEST_F(TestStepParallel, ExtractInformation3) {
385 FuncGraphManagerPtr manager = Make_Manager(3);
386 FuncGraphSet graphs = manager->func_graphs();
387 FuncGraphPtr graph = *graphs.begin();
388 auto ret = graph->get_return();
389 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
390 EXPECT_THROW({ ExtractInformation(all_nodes); }, std::runtime_error);
391 }
392
TEST_F(TestStepParallel,ForwardCommunication1)393 TEST_F(TestStepParallel, ForwardCommunication1) {
394 ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM);
395 ValuePtr attr1_value = MakeValue("0-1-2");
396 Attr attr0 = std::make_pair("op", attr0_value);
397 Attr attr1 = std::make_pair("group", attr1_value);
398 OperatorAttrs attrs = {attr0, attr1};
399 OperatorName op_name = "AllReduce";
400 OperatorParams operator_param;
401 OperatorArgs args = std::make_pair(attrs, operator_param);
402 Operator op = std::make_pair(op_name, args);
403 OperatorVector op_list = {op, op};
404 FuncGraphManagerPtr manager = Make_Manager();
405 FuncGraphSet graphs = manager->func_graphs();
406 FuncGraphPtr graph = *graphs.begin();
407 auto ret = graph->get_return();
408 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
409 ExtractInformation(all_nodes);
410 for (auto &node : all_nodes) {
411 if (!node->isa<CNode>()) {
412 continue;
413 }
414 auto cnode = node->cast<CNodePtr>();
415 FuncGraphPtr func_graph = node->func_graph();
416 PrimitivePtr prim = cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
417 if (prim->name() == "MatMul") {
418 ForwardCommunication(op_list, cnode);
419 }
420 }
421 AnfNodeSet after_nodes = manager->all_nodes();
422 for (auto &node : after_nodes) {
423 if (!node->isa<CNode>()) {
424 continue;
425 }
426 auto &inputs = node->cast<CNodePtr>()->inputs();
427 PrimitivePtr prim = inputs[0]->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
428 if (prim->name() == "Return" || prim->name() == "MatMul") {
429 if (!inputs[1]->isa<Parameter>()) {
430 CNodePtr pre_node = inputs[1]->cast<CNodePtr>();
431 PrimitivePtr pre_prim = pre_node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
432 CNodePtr pre_node2 = pre_node->input(1)->cast<CNodePtr>();
433 PrimitivePtr pre_prim2 = pre_node2->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
434 ASSERT_EQ("AllReduce", pre_prim->name());
435 ASSERT_EQ("AllReduce", pre_prim2->name());
436 }
437 }
438 }
439 }
440
TEST_F(TestStepParallel,ForwardCommunication2)441 TEST_F(TestStepParallel, ForwardCommunication2) {
442 OperatorVector op_list;
443 FuncGraphManagerPtr manager = Make_Manager();
444 FuncGraphSet graphs = manager->func_graphs();
445 FuncGraphPtr graph = *graphs.begin();
446 auto ret = graph->get_return();
447 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
448 ExtractInformation(all_nodes);
449 for (auto &node : all_nodes) {
450 if (!node->isa<CNode>()) {
451 continue;
452 }
453 auto cnode = node->cast<CNodePtr>();
454 FuncGraphPtr func_graph = node->func_graph();
455 func_graph->set_manager(nullptr);
456 PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
457 if (prim->name() == "MatMul") {
458 EXPECT_THROW({ ForwardCommunication(op_list, cnode); }, std::runtime_error);
459 break;
460 }
461 }
462 }
463
TEST_F(TestStepParallel,ForwardCommunication3)464 TEST_F(TestStepParallel, ForwardCommunication3) {
465 OperatorVector op_list;
466 FuncGraphManagerPtr manager = Make_Manager();
467 FuncGraphSet graphs = manager->func_graphs();
468 FuncGraphPtr graph = *graphs.begin();
469 auto ret = graph->get_return();
470 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
471 ExtractInformation(all_nodes);
472 for (auto &node : all_nodes) {
473 if (!node->isa<CNode>()) {
474 continue;
475 }
476 auto cnode = node->cast<CNodePtr>();
477 FuncGraphPtr func_graph = node->func_graph();
478 PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
479 if (prim->name() == "MatMul") {
480 OperatorAttrs attrs;
481 OperatorParams operator_param;
482 OperatorArgs args = std::make_pair(attrs, operator_param);
483 Operator op = std::make_pair("ABC", args);
484 OperatorVector op_list = {op};
485 EXPECT_THROW({ ForwardCommunication(op_list, cnode); }, std::runtime_error);
486 break;
487 }
488 }
489 }
490
TEST_F(TestStepParallel,GetTensorInLayout)491 TEST_F(TestStepParallel, GetTensorInLayout) {
492 // create attrs and prim
493 FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
494 Shape inputs_x_dims = {64, 32};
495 Shape inputs_y_dims = {32, 64};
496 Shape outputs_dims = {64, 64};
497 CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims);
498 std::vector<AnfNodePtr> inputs(node->inputs());
499 CNodePtr node1 = func_graph->NewCNode(inputs);
500 PrimitivePtr prim = node1->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
501 ValuePtr transpose_a = MakeValue(false);
502 ValuePtr transpose_b = MakeValue(false);
503 prim->set_attr("transpose_a", transpose_a);
504 prim->set_attr("transpose_b", transpose_b);
505 auto attrs = prim->attrs();
506 // create strategy
507 Strategys strategy = {{2, 2}, {2, 4}};
508 StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy);
509 // create shape
510 Shapes inputs_shape = std::vector<Shape>{{64, 32}, {32, 64}};
511 Shapes outputs_shape = std::vector<Shape>{{64, 64}};
512 std::vector<Shapes> shape = {inputs_shape, outputs_shape};
513 OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape);
514 matmul_info->Init(strategyPtr);
515 node->set_user_data<OperatorInfo>(matmul_info);
516 OperatorInfoPtr distribute_operator_pre = node->user_data<OperatorInfo>();
517 TensorLayout tensorlayout_e;
518 Shape array = {64, 64};
519 TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre);
520 Shape tensor_shape_test = tensorlayout.tensor_shape().array();
521 ASSERT_EQ(array, tensor_shape_test);
522 }
523
524 } // namespace parallel
525 } // namespace mindspore
526