• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "common/common_test.h"
17 #include "frontend/parallel/step_parallel.h"
18 #include "frontend/parallel/step_auto_parallel.h"
19 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
20 #include "frontend/parallel/ops_info/operator_info.h"
21 #include "frontend/operator/ops.h"
22 #include "pipeline/jit/static_analysis/static_analysis.h"
23 
24 namespace mindspore {
25 namespace parallel {
26 
27 class TestStepAutoParallel : public UT::Common {
28  public:
TestStepAutoParallel()29   TestStepAutoParallel() {}
30   void SetUp();
TearDown()31   void TearDown() {}
32 };
33 
SetUp()34 void TestStepAutoParallel::SetUp() {
35   RankList dev_list;
36 
37   for (int32_t i = 0; i < 20; i++) {
38     dev_list.push_back(i);
39   }
40 
41   RankList stage_map;
42   stage_map.push_back(16);
43   stage_map.push_back(4);
44 
45   int32_t local_dev = 0;
46 
47   // create a new g_device_manager
48   g_device_manager = std::make_shared<DeviceManager>();
49   g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
50 }
51 
Create_Node(Shape x,Shape y,Shape out)52 CNodePtr Create_Node(Shape x, Shape y, Shape out) {
53   FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
54   ParameterPtr param1 = func_graph->add_parameter();
55   ParameterPtr param2 = func_graph->add_parameter();
56   param1->set_name("x");
57   param2->set_name("y");
58   BaseShapePtr shape1 = std::make_shared<abstract::Shape>(x);
59   BaseShapePtr shape2 = std::make_shared<abstract::Shape>(y);
60   BaseShapePtr shape3 = std::make_shared<abstract::Shape>(out);
61   AbstractBasePtr abstract1 = abstract::FromValue(static_cast<int64_t>(1), false);
62   AbstractBasePtr abstract2 = abstract::FromValue(static_cast<int64_t>(1), false);
63   AbstractBasePtr abstract3 = abstract::FromValue(static_cast<int64_t>(1), false);
64   abstract1->set_shape(shape1);
65   abstract2->set_shape(shape2);
66   abstract3->set_shape(shape3);
67   param1->set_abstract(abstract1);
68   param2->set_abstract(abstract2);
69   std::vector<AnfNodePtr> inputs;
70   inputs.push_back(NewValueNode(prim::kPrimMatMul));
71   inputs.push_back(param1);
72   inputs.push_back(param2);
73   CNodePtr node = func_graph->NewCNode(inputs);
74   PrimitivePtr prim = node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
75   ValuePtr transpose_a = MakeValue(false);
76   ValuePtr transpose_b = MakeValue(false);
77   prim->set_attr("transpose_a", transpose_a);
78   prim->set_attr("transpose_b", transpose_b);
79 
80   node->set_abstract(abstract3);
81   return node;
82 }
83 
Create_two_nodes(Shape x,Shape y,Shape z,Shape w,Shape out)84 CNodePtr Create_two_nodes(Shape x, Shape y, Shape z, Shape w, Shape out) {
85   FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
86   ParameterPtr paramX = func_graph->add_parameter();
87   ParameterPtr paramY = func_graph->add_parameter();
88   ParameterPtr paramW = func_graph->add_parameter();
89   paramX->set_name("x");
90   paramY->set_name("y");
91   paramW->set_name("w");
92   BaseShapePtr shapeX = std::make_shared<abstract::Shape>(x);
93   BaseShapePtr shapeY = std::make_shared<abstract::Shape>(y);
94   BaseShapePtr shapeZ = std::make_shared<abstract::Shape>(z);
95   BaseShapePtr shapeW = std::make_shared<abstract::Shape>(w);
96   BaseShapePtr shapeOut = std::make_shared<abstract::Shape>(out);
97   AbstractBasePtr abstractX = abstract::FromValue(static_cast<int64_t>(1), false);
98   AbstractBasePtr abstractY = abstract::FromValue(static_cast<int64_t>(1), false);
99   AbstractBasePtr abstractZ = abstract::FromValue(static_cast<int64_t>(1), false);
100   AbstractBasePtr abstractW = abstract::FromValue(static_cast<int64_t>(1), false);
101   AbstractBasePtr abstractOut = abstract::FromValue(static_cast<int64_t>(1), false);
102   abstractX->set_shape(shapeX);
103   abstractY->set_shape(shapeY);
104   abstractZ->set_shape(shapeZ);
105   abstractW->set_shape(shapeW);
106   abstractOut->set_shape(shapeOut);
107   paramX->set_abstract(abstractX);
108   paramY->set_abstract(abstractY);
109   paramW->set_abstract(abstractW);
110 
111   std::vector<AnfNodePtr> MatMul_1_inputs;
112   MatMul_1_inputs.push_back(NewValueNode(prim::kPrimMatMul));
113   MatMul_1_inputs.push_back(paramX);
114   MatMul_1_inputs.push_back(paramY);
115   CNodePtr MatMul_1_node = func_graph->NewCNode(MatMul_1_inputs);
116   PrimitivePtr prim = MatMul_1_node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
117   ValuePtr transpose_a = MakeValue(false);
118   ValuePtr transpose_b = MakeValue(false);
119   prim->set_attr("transpose_a", transpose_a);
120   prim->set_attr("transpose_b", transpose_b);
121   MatMul_1_node->set_abstract(abstractZ);
122 
123   std::vector<AnfNodePtr> MatMul_2_inputs;
124   MatMul_2_inputs.push_back(NewValueNode(prim::kPrimMatMul));
125   MatMul_2_inputs.push_back(MatMul_1_node);
126   MatMul_2_inputs.push_back(paramW);
127   CNodePtr MatMul_2_node = func_graph->NewCNode(MatMul_2_inputs);
128   PrimitivePtr prim2 = MatMul_2_node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
129   ValuePtr transpose_a_2 = MakeValue(false);
130   ValuePtr transpose_b_2 = MakeValue(false);
131   prim2->set_attr("transpose_a", transpose_a);
132   prim2->set_attr("transpose_b", transpose_b);
133   MatMul_2_node->set_abstract(abstractOut);
134 
135   return MatMul_2_node;
136 }
137 
TEST_F(TestStepAutoParallel,test_create_op_instance)138 TEST_F(TestStepAutoParallel, test_create_op_instance) {
139   Shape inputs_x_dims = {64, 32};
140   Shape inputs_y_dims = {32, 64};
141   Shape outputs_dims = {64, 64};
142   CNodePtr node = Create_Node(inputs_x_dims, inputs_y_dims, outputs_dims);
143   bool result = node->input(0)->cast<ValueNodePtr>()->value()->isa<Primitive>();
144   ASSERT_EQ(result, true);
145   // creat prim and attrs
146   PrimitivePtr prim = node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
147   auto attrs = prim->attrs();
148 
149   // creat shape
150   Shapes inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims};
151   Shapes outputs_shape = std::vector<Shape>{outputs_dims};
152   std::vector<Shapes> shape = {inputs_shape, outputs_shape};
153   StrategyPtr strategyPtr;
154 
155   std::shared_ptr<OperatorInfo> matmul_info = NewOperatorInstance(prim, attrs, shape);
156   node->set_user_data<OperatorInfo>(matmul_info);
157   std::string name_expect = "MatMulInfo00";
158   std::string name_test = matmul_info->name();
159   ASSERT_EQ(name_expect, name_test);
160 }
161 
TEST_F(TestStepAutoParallel,test_create_edge)162 TEST_F(TestStepAutoParallel, test_create_edge) {
163   Shape inputs_x_dims = {64, 32};
164   Shape inputs_y_dims = {32, 64};
165   Shape outputs_z_dims = {64, 64};
166   Shape inputs_w_dims = {64, 128};
167   Shape outputs_dim = {64, 128};
168   CNodePtr node = Create_two_nodes(inputs_x_dims, inputs_y_dims, outputs_z_dims, inputs_w_dims, outputs_dim);
169 
170   // u-->v
171   PrimitivePtr v_prim = node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
172   auto v_attrs = v_prim->attrs();
173   PrimitivePtr u_prim = node->input(1)->cast<CNodePtr>()->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
174   auto u_attrs = u_prim->attrs();
175 
176   // creat v node
177   Shapes v_inputs_shape = std::vector<Shape>{outputs_z_dims, inputs_w_dims};
178   Shapes v_outputs_shape = std::vector<Shape>{outputs_dim};
179   std::vector<Shapes> v_shape = {v_inputs_shape, v_outputs_shape};
180   StrategyPtr v_strategyPtr;
181   std::shared_ptr<OperatorInfo> v_matmul_info = NewOperatorInstance(v_prim, v_attrs, v_shape);
182 
183   // create u node
184   Shapes u_inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims};
185   Shapes u_outputs_shape = std::vector<Shape>{outputs_z_dims};
186   std::vector<Shapes> u_shape = {u_inputs_shape, u_outputs_shape};
187   StrategyPtr u_strategyPtr;
188   std::shared_ptr<OperatorInfo> u_matmul_info = NewOperatorInstance(u_prim, u_attrs, u_shape);
189 
190   std::string edge_name = u_prim->name() + "-" + v_prim->name();
191   std::shared_ptr<Edge> edge_ptr = std::make_shared<Edge>(edge_name, u_matmul_info, v_matmul_info, 0, 0, false);
192   std::string expected_name = "MatMul-MatMul";
193   ASSERT_EQ(edge_ptr->edge_name(), expected_name);
194 }
195 
196 }  // namespace parallel
197 }  // namespace mindspore
198