• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "common/common_test.h"
17 #include "frontend/parallel/step_parallel.h"
18 #include "frontend/parallel/graph_util/generate_graph.h"
19 #include "common/py_func_graph_fetcher.h"
20 #include "debug/draw.h"
21 #include "frontend/operator/ops.h"
22 #include "pipeline/jit/static_analysis/static_analysis.h"
23 #include "utils/convert_utils_py.h"
24 
25 namespace mindspore {
26 namespace parallel {
27 extern size_t TOTAL_OPS;
28 class TestStepParallel : public UT::Common {
29  public:
TestStepParallel()30   TestStepParallel() {}
31   void SetUp();
TearDown()32   void TearDown() {}
33 };
34 
Init_Device_Manager()35 void Init_Device_Manager() {
36   RankList dev_list;
37 
38   for (int32_t i = 0; i < 20; i++) {
39     dev_list.push_back(i);
40   }
41 
42   RankList stage_map;
43   stage_map.push_back(16);
44   stage_map.push_back(4);
45 
46   int32_t local_dev = 0;
47 
48   // create a new g_device_manager
49   g_device_manager = std::make_shared<DeviceManager>();
50   g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
51 }
52 
SetUp()53 void TestStepParallel::SetUp() {
54   UT::InitPythonPath();
55   Init_Device_Manager();
56 }
57 
Make_Node(Shape x,Shape y,Shape out,int64_t condition=0)58 CNodePtr Make_Node(Shape x, Shape y, Shape out, int64_t condition = 0) {
59   FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
60   ParameterPtr param1 = func_graph->add_parameter();
61   ParameterPtr param2 = func_graph->add_parameter();
62   param1->set_name("x");
63   param2->set_name("y");
64   BaseShapePtr shape1 = std::make_shared<abstract::Shape>(x);
65   BaseShapePtr shape2 = std::make_shared<abstract::Shape>(y);
66   BaseShapePtr shape3 = std::make_shared<abstract::Shape>(out);
67   std::shared_ptr<tensor::Tensor> inputs_x = std::make_shared<tensor::Tensor>(kNumberTypeInt32, x);
68   std::shared_ptr<tensor::Tensor> inputs_y = std::make_shared<tensor::Tensor>(kNumberTypeInt32, y);
69   std::shared_ptr<tensor::Tensor> inputs_out = std::make_shared<tensor::Tensor>(kNumberTypeInt32, out);
70   AbstractBasePtr abstract1 = abstract::FromValue(inputs_x, true);
71   AbstractBasePtr abstract2 = abstract::FromValue(inputs_y, true);
72   AbstractBasePtr abstract3 = abstract::FromValue(inputs_out, true);
73   switch (condition) {
74     case 0: {
75       abstract1->set_shape(shape1);
76       abstract2->set_shape(shape2);
77       abstract3->set_shape(shape3);
78       param1->set_abstract(abstract1);
79       param2->set_abstract(abstract2);
80       break;
81     }
82     case 1: {
83       // Don't set abstract of param1, expecting a exception raised.
84       param2->set_abstract(abstract2);
85       break;
86     }
87     case 2: {
88       abstract1->set_shape(shape1);
89       abstract2->set_shape(shape2);
90       param1->set_abstract(abstract1);
91       param2->set_abstract(abstract2);
92       abstract3 = abstract::FromValue(static_cast<int64_t>(1), false);
93       break;
94     }
95     case 3: {
96       std::vector<BaseShapePtr> shape_o = {std::make_shared<abstract::Shape>(x), std::make_shared<abstract::Shape>(y)};
97       BaseShapePtr shape4 = std::make_shared<abstract::TupleShape>(shape_o);
98       abstract1->set_shape(shape1);
99       abstract2->set_shape(shape2);
100       abstract3->set_shape(shape4);
101       param1->set_abstract(abstract1);
102       param2->set_abstract(abstract2);
103       break;
104     }
105     default:
106       MS_LOG(INFO) << "Do Nothing!";
107   }
108   std::vector<AnfNodePtr> inputs;
109   inputs.push_back(NewValueNode(prim::kPrimMatMul));
110   inputs.push_back(param1);
111   inputs.push_back(param2);
112   CNodePtr node = func_graph->NewCNode(inputs);
113   node->set_abstract(abstract3);
114   return node;
115 }
116 
Make_Manager(int64_t condition=0)117 FuncGraphManagerPtr Make_Manager(int64_t condition = 0) {
118   std::vector<int64_t> inputs_x = {64, 32};
119   std::vector<int64_t> inputs_y = {32, 64};
120   std::vector<int64_t> inputs_z = {64, 128};
121   std::vector<int64_t> outputs_1 = {64, 64};
122   std::vector<int64_t> outputs_2 = {64, 128};
123   FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
124   ParameterPtr param1 = func_graph->add_parameter();
125   ParameterPtr param2 = func_graph->add_parameter();
126   ParameterPtr param3 = func_graph->add_parameter();
127   std::shared_ptr<tensor::Tensor> inputs_x_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, inputs_x);
128   std::shared_ptr<tensor::Tensor> inputs_y_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, inputs_y);
129   std::shared_ptr<tensor::Tensor> inputs_z_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, inputs_z);
130   std::shared_ptr<tensor::Tensor> inputs_out1_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, outputs_1);
131   std::shared_ptr<tensor::Tensor> inputs_out2_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, outputs_2);
132   AbstractBasePtr abstract_x = abstract::FromValue(inputs_x_dim, true);
133   AbstractBasePtr abstract_y = abstract::FromValue(inputs_y_dim, true);
134   AbstractBasePtr abstract_z = abstract::FromValue(inputs_z_dim, true);
135   AbstractBasePtr abstract_out1 = abstract::FromValue(inputs_out1_dim, true);
136   AbstractBasePtr abstract_out2 = abstract::FromValue(inputs_out2_dim, true);
137   param1->set_abstract(abstract_x);
138   param2->set_abstract(abstract_y);
139   param3->set_abstract(abstract_z);
140   Dimensions v1 = {2, 2};
141   Dimensions v2 = {2, 4};
142   std::vector<ValuePtr> elements = {MakeValue(v1), MakeValue(v2)};
143   ValueTuplePtr var = std::make_shared<ValueTuple>(elements);
144   std::vector<AnfNodePtr> inputs;
145   inputs.push_back(NewValueNode(prim::kPrimMatMul));
146   inputs.push_back(param1);
147   inputs.push_back(param2);
148   CNodePtr node1 = func_graph->NewCNode(inputs);
149   node1->set_in_forward_flag(true);
150   node1->set_abstract(abstract_out1);
151   PrimitivePtr prim1 = node1->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
152   ValuePtr transpose_a = MakeValue(false);
153   ValuePtr transpose_b = MakeValue(false);
154   prim1->AddAttr("transpose_a", transpose_a);
155   prim1->AddAttr("transpose_b", transpose_b);
156   prim1->AddAttr("instance_name", MakeValue("matmul1"));
157   prim1->AddAttr("strategy", var);
158   inputs.clear();
159   Dimensions v3 = {2, 2};
160   Dimensions v4 = {2, 4};
161   std::vector<ValuePtr> elements2 = {MakeValue(v3), MakeValue(v4)};
162   ValueTuplePtr var2 = std::make_shared<ValueTuple>(elements2);
163   inputs.push_back(NewValueNode(prim::kPrimMatMul));
164   inputs.push_back(node1);
165   inputs.push_back(param3);
166   CNodePtr node2 = func_graph->NewCNode(inputs);
167   node2->set_in_forward_flag(true);
168   node2->set_abstract(abstract_out2);
169   inputs.clear();
170   inputs.push_back(NewValueNode(prim::kPrimReturn));
171   inputs.push_back(node2);
172   CNodePtr cnode_return = func_graph->NewCNode(inputs);
173   cnode_return->set_in_forward_flag(true);
174   func_graph->set_return(cnode_return);
175   PrimitivePtr prim2 = node2->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
176   prim2->AddAttr("transpose_a", transpose_a);
177   prim2->AddAttr("transpose_b", transpose_b);
178   prim2->AddAttr("instance_name", MakeValue("matmul2"));
179   prim2->AddAttr("strategy", var2);
180   switch (condition) {
181     case 1: {
182       prim1->set_attr("strategy", MakeValue(static_cast<int64_t>(0)));
183       break;
184     }
185     case 2: {
186       std::vector<ValuePtr> elements_t = {MakeValue(static_cast<int64_t>(0))};
187       ValueTuplePtr var_t = std::make_shared<ValueTuple>(elements_t);
188       prim1->set_attr("strategy", var_t);
189       break;
190     }
191     case 3: {
192       Dimensions vt1 = {2, 4};
193       Dimensions vt2 = {2, 4};
194       std::vector<ValuePtr> elements_t2 = {MakeValue(vt1), MakeValue(vt2)};
195       ValueTuplePtr var_t2 = std::make_shared<ValueTuple>(elements_t2);
196       prim1->set_attr("strategy", var_t2);
197       break;
198     }
199   }
200   std::vector<FuncGraphPtr> func_graphs{func_graph};
201   FuncGraphManagerPtr manager = std::make_shared<FuncGraphManager>(func_graphs, true);
202   manager->Init();
203   return manager;
204 }
205 
TEST_F(TestStepParallel,GetPythonPath1)206 TEST_F(TestStepParallel, GetPythonPath1) {
207   OperatorName operator_name = "AllReduce";
208   const std::string expect = "mindspore.ops.operations";
209   auto temp = parallel::GetOpPythonPath(operator_name);
210   ASSERT_EQ(temp, expect);
211 }
212 
TEST_F(TestStepParallel,GetPythonPath2)213 TEST_F(TestStepParallel, GetPythonPath2) {
214   OperatorName operator_name = "Add";
215   const std::string expect = "mindspore.ops.operations";
216   auto temp = parallel::GetOpPythonPath(operator_name);
217   ASSERT_EQ(temp, expect);
218 }
219 
TEST_F(TestStepParallel,ExtractStrategy)220 TEST_F(TestStepParallel, ExtractStrategy) {
221   Dimensions v1 = {2, 2};
222   Dimensions v2 = {4, 4};
223   std::unordered_map<std::string, ValuePtr> attrs;
224   // stage
225   ValuePtr val1 = MakeValue(v1);
226   ValuePtr val2 = MakeValue(v2);
227   std::vector<ValuePtr> elements = {val1, val2};
228   ValueTuplePtr strategy_tuple = std::make_shared<ValueTuple>(elements);
229   attrs["strategy"] = strategy_tuple;
230   Strategys strategy_expect = {v1, v2};
231   StrategyPtr strategy = ExtractStrategy(attrs["strategy"]);
232   Strategys strategy_test = strategy->GetInputDim();
233 
234   ASSERT_EQ(strategy_expect, strategy_test);
235 }
236 
TEST_F(TestStepParallel,ExtractShape)237 TEST_F(TestStepParallel, ExtractShape) {
238   Shape inputs_x_dims = {64, 32};
239   Shape inputs_y_dims = {32, 64};
240   Shape outputs_dims = {64, 64};
241   CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims, 4);
242   EXPECT_THROW({ ExtractShape(node); }, std::runtime_error);
243 }
244 
TEST_F(TestStepParallel,ExtractShape1)245 TEST_F(TestStepParallel, ExtractShape1) {
246   Shape inputs_x_dims = {64, 32};
247   Shape inputs_y_dims = {32, 64};
248   Shape outputs_dims = {64, 64};
249   CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims);
250   std::vector<Shapes> shape_test = ExtractShape(node);
251   Shapes inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims};
252   Shapes outputs_shape = std::vector<Shape>{outputs_dims};
253   std::vector<Shapes> shape_expect = {inputs_shape, outputs_shape};
254   ASSERT_EQ(shape_test, shape_expect);
255 }
256 
TEST_F(TestStepParallel,ExtractShape2)257 TEST_F(TestStepParallel, ExtractShape2) {
258   Shape inputs_x_dims = {64, 32};
259   Shape inputs_y_dims = {32, 64};
260   Shape outputs_dims = {64, 64};
261   CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims, 1);
262   EXPECT_THROW({ ExtractShape(node); }, std::runtime_error);
263 }
264 
TEST_F(TestStepParallel,ExtractShape3)265 TEST_F(TestStepParallel, ExtractShape3) {
266   Shape inputs_x_dims = {64, 32};
267   Shape inputs_y_dims = {32, 64};
268   Shape outputs_dims = {64, 64};
269   CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims, 3);
270   Shapes inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims};
271   std::vector<Shapes> shape_expect = {inputs_shape, inputs_shape};
272   std::vector<Shapes> shape_test = ExtractShape(node);
273   ASSERT_EQ(shape_test, shape_expect);
274 }
275 
TEST_F(TestStepParallel,CreatOpInstance)276 TEST_F(TestStepParallel, CreatOpInstance) {
277   ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM);
278   ValuePtr attr1_value = MakeValue("0-1-2");
279   Attr attr0 = std::make_pair("op", attr0_value);
280   Attr attr1 = std::make_pair("group", attr1_value);
281   OperatorAttrs attrs = {attr0, attr1};
282   OperatorName op_name = "AllReduce";
283   OperatorParams operator_param;
284   OperatorArgs args = std::make_pair(attrs, operator_param);
285   auto op_instance = CreatOpInstance(args.first, op_name, "test");
286   ASSERT_TRUE(op_instance);
287   PrimitivePyPtr allreduce_ptr = dyn_cast<PrimitivePy>(op_instance);
288   ASSERT_TRUE(allreduce_ptr);
289   if (nullptr != allreduce_ptr) {
290     MS_LOG(INFO) << "Get PrimitivePyPtr: " << allreduce_ptr->name();
291 
292     std::vector<py::object> arglist;
293     (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arglist),
294                          [](Attr attr) { return ValueToPyData(attr.second); });
295     py::object allreduce_pyobj = parse::python_adapter::CallPyFn(
296       "mindspore.parallel._utils", "_get_python_op", "AllReduce", "mindspore.ops.operations", "test", arglist);
297     py::dict opAttr = py::getattr(allreduce_pyobj, "attrs");
298     std::unordered_map<std::string, ValuePtr> attributes{};
299     for (auto item : opAttr) {
300       if (!py::isinstance<py::str>(item.first)) {
301         MS_LOG(EXCEPTION) << "type error in py dict convert";
302       }
303       std::string name = py::cast<std::string>(item.first);
304       MS_LOG(INFO) << "Attr name: " << name;
305 
306       ValuePtr converted_ret;
307       if (name == "op") {
308         parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
309         ASSERT_EQ(converted_ret->ToString(), "sum");
310       } else {
311         if (name == "group") {
312           parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
313           ASSERT_EQ(converted_ret->ToString(), "0-1-2");
314         } else if (name == "fusion") {
315           parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
316           ASSERT_EQ(converted_ret->ToString(), "0");
317         } else if (name == "instance_name") {
318           parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
319           ASSERT_EQ(converted_ret->ToString(), "test");
320         } else if (name == "index") {
321           parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
322           ASSERT_EQ(converted_ret->ToString(), "0");
323         } else if (name == "no_elimilate") {
324           parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
325           ASSERT_EQ(converted_ret->ToString(), "true");
326         } else {
327           MS_LOG(EXCEPTION) << "Test failed";
328         }
329       }
330       attributes.emplace(name, converted_ret);
331     }
332   }
333 }
334 
TEST_F(TestStepParallel,CreatOpInstance1)335 TEST_F(TestStepParallel, CreatOpInstance1) {
336   OperatorAttrs attrs;
337   OperatorName op_name = "ABC";
338   OperatorParams operator_param;
339   OperatorArgs args = std::make_pair(attrs, operator_param);
340   EXPECT_THROW({ CreatOpInstance(args.first, op_name, "test"); }, std::runtime_error);
341 }
342 
TEST_F(TestStepParallel,OperatorInstance)343 TEST_F(TestStepParallel, OperatorInstance) {
344   // create  attrs and prim
345   PrimitivePtr prim = NewValueNode(prim::kPrimMatMul)->value()->cast<PrimitivePtr>();
346   ValuePtr transpose_a = MakeValue(false);
347   ValuePtr transpose_b = MakeValue(false);
348   prim->set_attr("transpose_a", transpose_a);
349   prim->set_attr("transpose_b", transpose_b);
350   auto attrs = prim->attrs();
351   // create  strategy
352   Strategys strategy = {{2, 2}, {2, 4}};
353   StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy);
354   // create  shape
355   Shapes inputs_shape = std::vector<Shape>{{64, 32}, {32, 64}};
356   Shapes outputs_shape = std::vector<Shape>{{64, 64}};
357   std::vector<Shapes> shape = {inputs_shape, outputs_shape};
358   TOTAL_OPS = 0;
359   OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape);
360   matmul_info->Init(strategyPtr);
361   std::string name_expect = "MatMulInfo00";
362   std::string name_test = matmul_info->name();
363   ASSERT_EQ(name_expect, name_test);
364 }
365 
TEST_F(TestStepParallel,ExtractInformation)366 TEST_F(TestStepParallel, ExtractInformation) {
367   FuncGraphManagerPtr manager = Make_Manager();
368   FuncGraphSet graphs = manager->func_graphs();
369   FuncGraphPtr graph = *graphs.begin();
370   auto ret = graph->get_return();
371   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
372   ExtractInformation(all_nodes);
373 }
374 
TEST_F(TestStepParallel,ExtractInformation2)375 TEST_F(TestStepParallel, ExtractInformation2) {
376   FuncGraphManagerPtr manager = Make_Manager(2);
377   FuncGraphSet graphs = manager->func_graphs();
378   FuncGraphPtr graph = *graphs.begin();
379   auto ret = graph->get_return();
380   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
381   EXPECT_THROW({ ExtractInformation(all_nodes); }, std::runtime_error);
382 }
383 
TEST_F(TestStepParallel,ExtractInformation3)384 TEST_F(TestStepParallel, ExtractInformation3) {
385   FuncGraphManagerPtr manager = Make_Manager(3);
386   FuncGraphSet graphs = manager->func_graphs();
387   FuncGraphPtr graph = *graphs.begin();
388   auto ret = graph->get_return();
389   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
390   EXPECT_THROW({ ExtractInformation(all_nodes); }, std::runtime_error);
391 }
392 
TEST_F(TestStepParallel,ForwardCommunication1)393 TEST_F(TestStepParallel, ForwardCommunication1) {
394   ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM);
395   ValuePtr attr1_value = MakeValue("0-1-2");
396   Attr attr0 = std::make_pair("op", attr0_value);
397   Attr attr1 = std::make_pair("group", attr1_value);
398   OperatorAttrs attrs = {attr0, attr1};
399   OperatorName op_name = "AllReduce";
400   OperatorParams operator_param;
401   OperatorArgs args = std::make_pair(attrs, operator_param);
402   Operator op = std::make_pair(op_name, args);
403   OperatorVector op_list = {op, op};
404   FuncGraphManagerPtr manager = Make_Manager();
405   FuncGraphSet graphs = manager->func_graphs();
406   FuncGraphPtr graph = *graphs.begin();
407   auto ret = graph->get_return();
408   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
409   ExtractInformation(all_nodes);
410   for (auto &node : all_nodes) {
411     if (!node->isa<CNode>()) {
412       continue;
413     }
414     auto cnode = node->cast<CNodePtr>();
415     FuncGraphPtr func_graph = node->func_graph();
416     PrimitivePtr prim = cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
417     if (prim->name() == "MatMul") {
418       ForwardCommunication(op_list, cnode);
419     }
420   }
421   AnfNodeSet after_nodes = manager->all_nodes();
422   for (auto &node : after_nodes) {
423     if (!node->isa<CNode>()) {
424       continue;
425     }
426     auto &inputs = node->cast<CNodePtr>()->inputs();
427     PrimitivePtr prim = inputs[0]->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
428     if (prim->name() == "Return" || prim->name() == "MatMul") {
429       if (!inputs[1]->isa<Parameter>()) {
430         CNodePtr pre_node = inputs[1]->cast<CNodePtr>();
431         PrimitivePtr pre_prim = pre_node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
432         CNodePtr pre_node2 = pre_node->input(1)->cast<CNodePtr>();
433         PrimitivePtr pre_prim2 = pre_node2->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
434         ASSERT_EQ("AllReduce", pre_prim->name());
435         ASSERT_EQ("AllReduce", pre_prim2->name());
436       }
437     }
438   }
439 }
440 
TEST_F(TestStepParallel,ForwardCommunication2)441 TEST_F(TestStepParallel, ForwardCommunication2) {
442   OperatorVector op_list;
443   FuncGraphManagerPtr manager = Make_Manager();
444   FuncGraphSet graphs = manager->func_graphs();
445   FuncGraphPtr graph = *graphs.begin();
446   auto ret = graph->get_return();
447   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
448   ExtractInformation(all_nodes);
449   for (auto &node : all_nodes) {
450     if (!node->isa<CNode>()) {
451       continue;
452     }
453     auto cnode = node->cast<CNodePtr>();
454     FuncGraphPtr func_graph = node->func_graph();
455     func_graph->set_manager(nullptr);
456     PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
457     if (prim->name() == "MatMul") {
458       EXPECT_THROW({ ForwardCommunication(op_list, cnode); }, std::runtime_error);
459       break;
460     }
461   }
462 }
463 
TEST_F(TestStepParallel,ForwardCommunication3)464 TEST_F(TestStepParallel, ForwardCommunication3) {
465   OperatorVector op_list;
466   FuncGraphManagerPtr manager = Make_Manager();
467   FuncGraphSet graphs = manager->func_graphs();
468   FuncGraphPtr graph = *graphs.begin();
469   auto ret = graph->get_return();
470   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
471   ExtractInformation(all_nodes);
472   for (auto &node : all_nodes) {
473     if (!node->isa<CNode>()) {
474       continue;
475     }
476     auto cnode = node->cast<CNodePtr>();
477     FuncGraphPtr func_graph = node->func_graph();
478     PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
479     if (prim->name() == "MatMul") {
480       OperatorAttrs attrs;
481       OperatorParams operator_param;
482       OperatorArgs args = std::make_pair(attrs, operator_param);
483       Operator op = std::make_pair("ABC", args);
484       OperatorVector op_list = {op};
485       EXPECT_THROW({ ForwardCommunication(op_list, cnode); }, std::runtime_error);
486       break;
487     }
488   }
489 }
490 
TEST_F(TestStepParallel,GetTensorInLayout)491 TEST_F(TestStepParallel, GetTensorInLayout) {
492   // create  attrs and prim
493   FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
494   Shape inputs_x_dims = {64, 32};
495   Shape inputs_y_dims = {32, 64};
496   Shape outputs_dims = {64, 64};
497   CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims);
498   std::vector<AnfNodePtr> inputs(node->inputs());
499   CNodePtr node1 = func_graph->NewCNode(inputs);
500   PrimitivePtr prim = node1->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
501   ValuePtr transpose_a = MakeValue(false);
502   ValuePtr transpose_b = MakeValue(false);
503   prim->set_attr("transpose_a", transpose_a);
504   prim->set_attr("transpose_b", transpose_b);
505   auto attrs = prim->attrs();
506   // create  strategy
507   Strategys strategy = {{2, 2}, {2, 4}};
508   StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy);
509   // create  shape
510   Shapes inputs_shape = std::vector<Shape>{{64, 32}, {32, 64}};
511   Shapes outputs_shape = std::vector<Shape>{{64, 64}};
512   std::vector<Shapes> shape = {inputs_shape, outputs_shape};
513   OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape);
514   matmul_info->Init(strategyPtr);
515   node->set_user_data<OperatorInfo>(matmul_info);
516   OperatorInfoPtr distribute_operator_pre = node->user_data<OperatorInfo>();
517   TensorLayout tensorlayout_e;
518   Shape array = {64, 64};
519   TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre);
520   Shape tensor_shape_test = tensorlayout.tensor_shape().array();
521   ASSERT_EQ(array, tensor_shape_test);
522 }
523 
524 }  // namespace parallel
525 }  // namespace mindspore
526