1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "common/common_test.h"
17 #include "frontend/parallel/step_parallel.h"
18 #include "frontend/parallel/step_auto_parallel.h"
19 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
20 #include "frontend/parallel/ops_info/operator_info.h"
21 #include "frontend/operator/ops.h"
22 #include "pipeline/jit/static_analysis/static_analysis.h"
23
24 namespace mindspore {
25 namespace parallel {
26
27 class TestStepAutoParallel : public UT::Common {
28 public:
TestStepAutoParallel()29 TestStepAutoParallel() {}
30 void SetUp();
TearDown()31 void TearDown() {}
32 };
33
SetUp()34 void TestStepAutoParallel::SetUp() {
35 RankList dev_list;
36
37 for (int32_t i = 0; i < 20; i++) {
38 dev_list.push_back(i);
39 }
40
41 RankList stage_map;
42 stage_map.push_back(16);
43 stage_map.push_back(4);
44
45 int32_t local_dev = 0;
46
47 // create a new g_device_manager
48 g_device_manager = std::make_shared<DeviceManager>();
49 g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
50 }
51
Create_Node(Shape x,Shape y,Shape out)52 CNodePtr Create_Node(Shape x, Shape y, Shape out) {
53 FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
54 ParameterPtr param1 = func_graph->add_parameter();
55 ParameterPtr param2 = func_graph->add_parameter();
56 param1->set_name("x");
57 param2->set_name("y");
58 BaseShapePtr shape1 = std::make_shared<abstract::Shape>(x);
59 BaseShapePtr shape2 = std::make_shared<abstract::Shape>(y);
60 BaseShapePtr shape3 = std::make_shared<abstract::Shape>(out);
61 AbstractBasePtr abstract1 = abstract::FromValue(static_cast<int64_t>(1), false);
62 AbstractBasePtr abstract2 = abstract::FromValue(static_cast<int64_t>(1), false);
63 AbstractBasePtr abstract3 = abstract::FromValue(static_cast<int64_t>(1), false);
64 abstract1->set_shape(shape1);
65 abstract2->set_shape(shape2);
66 abstract3->set_shape(shape3);
67 param1->set_abstract(abstract1);
68 param2->set_abstract(abstract2);
69 std::vector<AnfNodePtr> inputs;
70 inputs.push_back(NewValueNode(prim::kPrimMatMul));
71 inputs.push_back(param1);
72 inputs.push_back(param2);
73 CNodePtr node = func_graph->NewCNode(inputs);
74 PrimitivePtr prim = node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
75 ValuePtr transpose_a = MakeValue(false);
76 ValuePtr transpose_b = MakeValue(false);
77 prim->set_attr("transpose_a", transpose_a);
78 prim->set_attr("transpose_b", transpose_b);
79
80 node->set_abstract(abstract3);
81 return node;
82 }
83
Create_two_nodes(Shape x,Shape y,Shape z,Shape w,Shape out)84 CNodePtr Create_two_nodes(Shape x, Shape y, Shape z, Shape w, Shape out) {
85 FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
86 ParameterPtr paramX = func_graph->add_parameter();
87 ParameterPtr paramY = func_graph->add_parameter();
88 ParameterPtr paramW = func_graph->add_parameter();
89 paramX->set_name("x");
90 paramY->set_name("y");
91 paramW->set_name("w");
92 BaseShapePtr shapeX = std::make_shared<abstract::Shape>(x);
93 BaseShapePtr shapeY = std::make_shared<abstract::Shape>(y);
94 BaseShapePtr shapeZ = std::make_shared<abstract::Shape>(z);
95 BaseShapePtr shapeW = std::make_shared<abstract::Shape>(w);
96 BaseShapePtr shapeOut = std::make_shared<abstract::Shape>(out);
97 AbstractBasePtr abstractX = abstract::FromValue(static_cast<int64_t>(1), false);
98 AbstractBasePtr abstractY = abstract::FromValue(static_cast<int64_t>(1), false);
99 AbstractBasePtr abstractZ = abstract::FromValue(static_cast<int64_t>(1), false);
100 AbstractBasePtr abstractW = abstract::FromValue(static_cast<int64_t>(1), false);
101 AbstractBasePtr abstractOut = abstract::FromValue(static_cast<int64_t>(1), false);
102 abstractX->set_shape(shapeX);
103 abstractY->set_shape(shapeY);
104 abstractZ->set_shape(shapeZ);
105 abstractW->set_shape(shapeW);
106 abstractOut->set_shape(shapeOut);
107 paramX->set_abstract(abstractX);
108 paramY->set_abstract(abstractY);
109 paramW->set_abstract(abstractW);
110
111 std::vector<AnfNodePtr> MatMul_1_inputs;
112 MatMul_1_inputs.push_back(NewValueNode(prim::kPrimMatMul));
113 MatMul_1_inputs.push_back(paramX);
114 MatMul_1_inputs.push_back(paramY);
115 CNodePtr MatMul_1_node = func_graph->NewCNode(MatMul_1_inputs);
116 PrimitivePtr prim = MatMul_1_node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
117 ValuePtr transpose_a = MakeValue(false);
118 ValuePtr transpose_b = MakeValue(false);
119 prim->set_attr("transpose_a", transpose_a);
120 prim->set_attr("transpose_b", transpose_b);
121 MatMul_1_node->set_abstract(abstractZ);
122
123 std::vector<AnfNodePtr> MatMul_2_inputs;
124 MatMul_2_inputs.push_back(NewValueNode(prim::kPrimMatMul));
125 MatMul_2_inputs.push_back(MatMul_1_node);
126 MatMul_2_inputs.push_back(paramW);
127 CNodePtr MatMul_2_node = func_graph->NewCNode(MatMul_2_inputs);
128 PrimitivePtr prim2 = MatMul_2_node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
129 ValuePtr transpose_a_2 = MakeValue(false);
130 ValuePtr transpose_b_2 = MakeValue(false);
131 prim2->set_attr("transpose_a", transpose_a);
132 prim2->set_attr("transpose_b", transpose_b);
133 MatMul_2_node->set_abstract(abstractOut);
134
135 return MatMul_2_node;
136 }
137
TEST_F(TestStepAutoParallel,test_create_op_instance)138 TEST_F(TestStepAutoParallel, test_create_op_instance) {
139 Shape inputs_x_dims = {64, 32};
140 Shape inputs_y_dims = {32, 64};
141 Shape outputs_dims = {64, 64};
142 CNodePtr node = Create_Node(inputs_x_dims, inputs_y_dims, outputs_dims);
143 bool result = node->input(0)->cast<ValueNodePtr>()->value()->isa<Primitive>();
144 ASSERT_EQ(result, true);
145 // creat prim and attrs
146 PrimitivePtr prim = node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
147 auto attrs = prim->attrs();
148
149 // creat shape
150 Shapes inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims};
151 Shapes outputs_shape = std::vector<Shape>{outputs_dims};
152 std::vector<Shapes> shape = {inputs_shape, outputs_shape};
153 StrategyPtr strategyPtr;
154
155 std::shared_ptr<OperatorInfo> matmul_info = NewOperatorInstance(prim, attrs, shape);
156 node->set_user_data<OperatorInfo>(matmul_info);
157 std::string name_expect = "MatMulInfo00";
158 std::string name_test = matmul_info->name();
159 ASSERT_EQ(name_expect, name_test);
160 }
161
TEST_F(TestStepAutoParallel,test_create_edge)162 TEST_F(TestStepAutoParallel, test_create_edge) {
163 Shape inputs_x_dims = {64, 32};
164 Shape inputs_y_dims = {32, 64};
165 Shape outputs_z_dims = {64, 64};
166 Shape inputs_w_dims = {64, 128};
167 Shape outputs_dim = {64, 128};
168 CNodePtr node = Create_two_nodes(inputs_x_dims, inputs_y_dims, outputs_z_dims, inputs_w_dims, outputs_dim);
169
170 // u-->v
171 PrimitivePtr v_prim = node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
172 auto v_attrs = v_prim->attrs();
173 PrimitivePtr u_prim = node->input(1)->cast<CNodePtr>()->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
174 auto u_attrs = u_prim->attrs();
175
176 // creat v node
177 Shapes v_inputs_shape = std::vector<Shape>{outputs_z_dims, inputs_w_dims};
178 Shapes v_outputs_shape = std::vector<Shape>{outputs_dim};
179 std::vector<Shapes> v_shape = {v_inputs_shape, v_outputs_shape};
180 StrategyPtr v_strategyPtr;
181 std::shared_ptr<OperatorInfo> v_matmul_info = NewOperatorInstance(v_prim, v_attrs, v_shape);
182
183 // create u node
184 Shapes u_inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims};
185 Shapes u_outputs_shape = std::vector<Shape>{outputs_z_dims};
186 std::vector<Shapes> u_shape = {u_inputs_shape, u_outputs_shape};
187 StrategyPtr u_strategyPtr;
188 std::shared_ptr<OperatorInfo> u_matmul_info = NewOperatorInstance(u_prim, u_attrs, u_shape);
189
190 std::string edge_name = u_prim->name() + "-" + v_prim->name();
191 std::shared_ptr<Edge> edge_ptr = std::make_shared<Edge>(edge_name, u_matmul_info, v_matmul_info, 0, 0, false);
192 std::string expected_name = "MatMul-MatMul";
193 ASSERT_EQ(edge_ptr->edge_name(), expected_name);
194 }
195
196 } // namespace parallel
197 } // namespace mindspore
198