1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include "common/common_test.h" 17 #include "frontend/parallel/step_parallel.h" 18 #include "frontend/parallel/step_auto_parallel.h" 19 #include "frontend/parallel/auto_parallel/edge_costmodel.h" 20 #include "frontend/parallel/ops_info/operator_info.h" 21 #include "frontend/operator/ops.h" 22 #include "pipeline/jit/static_analysis/static_analysis.h" 23 24 namespace mindspore { 25 namespace parallel { 26 27 class TestStepAutoParallel : public UT::Common { 28 public: 29 TestStepAutoParallel() {} 30 void SetUp(); 31 void TearDown() {} 32 }; 33 34 void TestStepAutoParallel::SetUp() { 35 RankList dev_list; 36 37 for (int32_t i = 0; i < 20; i++) { 38 dev_list.push_back(i); 39 } 40 41 RankList stage_map; 42 stage_map.push_back(16); 43 stage_map.push_back(4); 44 45 int32_t local_dev = 0; 46 47 // create a new g_device_manager 48 g_device_manager = std::make_shared<DeviceManager>(); 49 g_device_manager->Init(dev_list, local_dev, stage_map, "hccl"); 50 } 51 52 CNodePtr Create_Node(Shape x, Shape y, Shape out) { 53 FuncGraphPtr func_graph = std::make_shared<FuncGraph>(); 54 ParameterPtr param1 = func_graph->add_parameter(); 55 ParameterPtr param2 = func_graph->add_parameter(); 56 param1->set_name("x"); 57 param2->set_name("y"); 58 BaseShapePtr shape1 = std::make_shared<abstract::Shape>(x); 59 BaseShapePtr shape2 = std::make_shared<abstract::Shape>(y); 60 BaseShapePtr shape3 = std::make_shared<abstract::Shape>(out); 61 AbstractBasePtr abstract1 = abstract::FromValue(static_cast<int64_t>(1), false); 62 AbstractBasePtr abstract2 = abstract::FromValue(static_cast<int64_t>(1), false); 63 AbstractBasePtr abstract3 = abstract::FromValue(static_cast<int64_t>(1), false); 64 abstract1->set_shape(shape1); 65 abstract2->set_shape(shape2); 66 abstract3->set_shape(shape3); 67 param1->set_abstract(abstract1); 68 param2->set_abstract(abstract2); 69 std::vector<AnfNodePtr> inputs; 70 inputs.push_back(NewValueNode(prim::kPrimMatMul)); 71 inputs.push_back(param1); 72 inputs.push_back(param2); 73 CNodePtr node = func_graph->NewCNode(inputs); 74 PrimitivePtr prim = node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>(); 75 ValuePtr transpose_a = MakeValue(false); 76 ValuePtr transpose_b = MakeValue(false); 77 prim->set_attr("transpose_a", transpose_a); 78 prim->set_attr("transpose_b", transpose_b); 79 80 node->set_abstract(abstract3); 81 return node; 82 } 83 84 CNodePtr Create_two_nodes(Shape x, Shape y, Shape z, Shape w, Shape out) { 85 FuncGraphPtr func_graph = std::make_shared<FuncGraph>(); 86 ParameterPtr paramX = func_graph->add_parameter(); 87 ParameterPtr paramY = func_graph->add_parameter(); 88 ParameterPtr paramW = func_graph->add_parameter(); 89 paramX->set_name("x"); 90 paramY->set_name("y"); 91 paramW->set_name("w"); 92 BaseShapePtr shapeX = std::make_shared<abstract::Shape>(x); 93 BaseShapePtr shapeY = std::make_shared<abstract::Shape>(y); 94 BaseShapePtr shapeZ = std::make_shared<abstract::Shape>(z); 95 BaseShapePtr shapeW = std::make_shared<abstract::Shape>(w); 96 BaseShapePtr shapeOut = std::make_shared<abstract::Shape>(out); 97 AbstractBasePtr abstractX = abstract::FromValue(static_cast<int64_t>(1), false); 98 AbstractBasePtr abstractY = abstract::FromValue(static_cast<int64_t>(1), false); 99 AbstractBasePtr abstractZ = abstract::FromValue(static_cast<int64_t>(1), false); 100 AbstractBasePtr abstractW = abstract::FromValue(static_cast<int64_t>(1), false); 101 AbstractBasePtr abstractOut = abstract::FromValue(static_cast<int64_t>(1), false); 102 abstractX->set_shape(shapeX); 103 abstractY->set_shape(shapeY); 104 abstractZ->set_shape(shapeZ); 105 abstractW->set_shape(shapeW); 106 abstractOut->set_shape(shapeOut); 107 paramX->set_abstract(abstractX); 108 paramY->set_abstract(abstractY); 109 paramW->set_abstract(abstractW); 110 111 std::vector<AnfNodePtr> MatMul_1_inputs; 112 MatMul_1_inputs.push_back(NewValueNode(prim::kPrimMatMul)); 113 MatMul_1_inputs.push_back(paramX); 114 MatMul_1_inputs.push_back(paramY); 115 CNodePtr MatMul_1_node = func_graph->NewCNode(MatMul_1_inputs); 116 PrimitivePtr prim = MatMul_1_node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>(); 117 ValuePtr transpose_a = MakeValue(false); 118 ValuePtr transpose_b = MakeValue(false); 119 prim->set_attr("transpose_a", transpose_a); 120 prim->set_attr("transpose_b", transpose_b); 121 MatMul_1_node->set_abstract(abstractZ); 122 123 std::vector<AnfNodePtr> MatMul_2_inputs; 124 MatMul_2_inputs.push_back(NewValueNode(prim::kPrimMatMul)); 125 MatMul_2_inputs.push_back(MatMul_1_node); 126 MatMul_2_inputs.push_back(paramW); 127 CNodePtr MatMul_2_node = func_graph->NewCNode(MatMul_2_inputs); 128 PrimitivePtr prim2 = MatMul_2_node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>(); 129 ValuePtr transpose_a_2 = MakeValue(false); 130 ValuePtr transpose_b_2 = MakeValue(false); 131 prim2->set_attr("transpose_a", transpose_a); 132 prim2->set_attr("transpose_b", transpose_b); 133 MatMul_2_node->set_abstract(abstractOut); 134 135 return MatMul_2_node; 136 } 137 138 TEST_F(TestStepAutoParallel, test_create_op_instance) { 139 Shape inputs_x_dims = {64, 32}; 140 Shape inputs_y_dims = {32, 64}; 141 Shape outputs_dims = {64, 64}; 142 CNodePtr node = Create_Node(inputs_x_dims, inputs_y_dims, outputs_dims); 143 bool result = node->input(0)->cast<ValueNodePtr>()->value()->isa<Primitive>(); 144 ASSERT_EQ(result, true); 145 // creat prim and attrs 146 PrimitivePtr prim = node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>(); 147 auto attrs = prim->attrs(); 148 149 // creat shape 150 Shapes inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims}; 151 Shapes outputs_shape = std::vector<Shape>{outputs_dims}; 152 std::vector<Shapes> shape = {inputs_shape, outputs_shape}; 153 StrategyPtr strategyPtr; 154 155 std::shared_ptr<OperatorInfo> matmul_info = NewOperatorInstance(prim, attrs, shape); 156 node->set_user_data<OperatorInfo>(matmul_info); 157 std::string name_expect = "MatMulInfo00"; 158 std::string name_test = matmul_info->name(); 159 ASSERT_EQ(name_expect, name_test); 160 } 161 162 TEST_F(TestStepAutoParallel, test_create_edge) { 163 Shape inputs_x_dims = {64, 32}; 164 Shape inputs_y_dims = {32, 64}; 165 Shape outputs_z_dims = {64, 64}; 166 Shape inputs_w_dims = {64, 128}; 167 Shape outputs_dim = {64, 128}; 168 CNodePtr node = Create_two_nodes(inputs_x_dims, inputs_y_dims, outputs_z_dims, inputs_w_dims, outputs_dim); 169 170 // u-->v 171 PrimitivePtr v_prim = node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>(); 172 auto v_attrs = v_prim->attrs(); 173 PrimitivePtr u_prim = node->input(1)->cast<CNodePtr>()->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>(); 174 auto u_attrs = u_prim->attrs(); 175 176 // creat v node 177 Shapes v_inputs_shape = std::vector<Shape>{outputs_z_dims, inputs_w_dims}; 178 Shapes v_outputs_shape = std::vector<Shape>{outputs_dim}; 179 std::vector<Shapes> v_shape = {v_inputs_shape, v_outputs_shape}; 180 StrategyPtr v_strategyPtr; 181 std::shared_ptr<OperatorInfo> v_matmul_info = NewOperatorInstance(v_prim, v_attrs, v_shape); 182 183 // create u node 184 Shapes u_inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims}; 185 Shapes u_outputs_shape = std::vector<Shape>{outputs_z_dims}; 186 std::vector<Shapes> u_shape = {u_inputs_shape, u_outputs_shape}; 187 StrategyPtr u_strategyPtr; 188 std::shared_ptr<OperatorInfo> u_matmul_info = NewOperatorInstance(u_prim, u_attrs, u_shape); 189 190 std::string edge_name = u_prim->name() + "-" + v_prim->name(); 191 std::shared_ptr<Edge> edge_ptr = std::make_shared<Edge>(edge_name, u_matmul_info, v_matmul_info, 0, 0, false); 192 std::string expected_name = "MatMul-MatMul"; 193 ASSERT_EQ(edge_ptr->edge_name(), expected_name); 194 } 195 196 } // namespace parallel 197 } // namespace mindspore 198