1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <iostream> 17 #include <string> 18 #include "common/common_test.h" 19 #include "common/py_func_graph_fetcher.h" 20 21 #include "utils/log_adapter.h" 22 #include "pipeline/jit/parse/parse.h" 23 #include "debug/draw.h" 24 #include "frontend/optimizer/clean.h" 25 26 namespace mindspore { 27 namespace opt { 28 using mindspore::abstract::AbstractAttribute; 29 using mindspore::abstract::AbstractClass; 30 using mindspore::abstract::AbstractError; 31 using mindspore::abstract::AbstractList; 32 using mindspore::abstract::AbstractScalar; 33 using mindspore::abstract::AbstractTensor; 34 using mindspore::abstract::AbstractTuple; 35 36 class TestClean : public UT::Common { 37 public: 38 TestClean() : getPyFun("gtest_input.optimizer.clean_test", true) {} 39 virtual void SetUp(); 40 virtual void TearDown(); 41 42 public: 43 UT::PyFuncGraphFetcher getPyFun; 44 FuncGraphPtr me_graph; 45 }; 46 47 void TestClean::SetUp() { 48 // build the func_graph. 49 me_graph = std::make_shared<FuncGraph>(); 50 me_graph->debug_info()->set_name("next"); 51 52 // build the nodes 53 AnfNodePtr valuenode_next = NewValueNode(std::string("ms_next")); 54 ParameterPtr parameter = std::make_shared<Parameter>(me_graph); 55 AbstractBasePtr para_scalar = std::make_shared<AbstractScalar>(static_cast<int64_t>(0)); 56 AbstractBasePtr para_list = std::make_shared<AbstractList>( 57 AbstractBasePtrList({std::make_shared<AbstractScalar>(kFloat64), std::make_shared<AbstractScalar>(kFloat64)})); 58 AbstractBasePtrList para_elem{para_scalar, para_list}; 59 AbstractBasePtr para_tuple = std::make_shared<AbstractTuple>(para_elem); 60 parameter->set_abstract(para_tuple); 61 62 AbstractBasePtr app_float = std::make_shared<AbstractScalar>(kFloat64); 63 AbstractBasePtr app_int = std::make_shared<AbstractScalar>(kFloat64); 64 AbstractBasePtr app_list = std::make_shared<AbstractList>( 65 AbstractBasePtrList({std::make_shared<AbstractScalar>(kFloat64), std::make_shared<AbstractScalar>(kFloat64)})); 66 AbstractBasePtr app_tuple_inner = std::make_shared<AbstractTuple>(AbstractBasePtrList{app_int, app_list}); 67 AbstractBasePtr app_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList{app_float, app_tuple_inner}); 68 AnfNodePtr cnode_57 = me_graph->NewCNode({valuenode_next, parameter}); 69 cnode_57->set_abstract(app_tuple); 70 71 AnfNodePtr cnode_67 = me_graph->NewCNode({NewValueNode(prim::kPrimPartial), valuenode_next, parameter}); 72 cnode_67->set_abstract(app_tuple); 73 74 AnfNodePtr cnode_66 = me_graph->NewCNode({NewValueNode(prim::kPrimScalarAdd), cnode_57, cnode_67}); 75 cnode_66->set_abstract(app_float); 76 77 AnfNodePtr valuenode_return = NewValueNode(prim::kPrimReturn); 78 CNodePtr cnode_55 = me_graph->NewCNode({valuenode_return, cnode_66}); 79 cnode_55->set_abstract(app_tuple); 80 81 me_graph->set_output(cnode_66); 82 me_graph->set_return(cnode_55); 83 me_graph->add_parameter(parameter); 84 } 85 86 void TestClean::TearDown() {} 87 88 TEST_F(TestClean, TestEraseClassGetAttr) { 89 FuncGraphPtr func_graph; 90 91 func_graph = getPyFun("test_erase_class_fn"); 92 ASSERT_TRUE(nullptr != func_graph); 93 94 // save the func_graph to manager 95 std::shared_ptr<FuncGraphManager> manager = Manage(func_graph); 96 int dataclass_count = 0; 97 98 for (auto node : manager->all_nodes()) { 99 if (IsValueNode<parse::ClassObject>(node)) { 100 dataclass_count++; 101 } 102 if (!node->isa<CNode>()) { 103 continue; 104 } 105 auto input0 = node->cast<CNodePtr>()->input(0); 106 if (IsValueNode<parse::ClassObject>(input0)) { 107 std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kFloat64)}, 108 {"y", std::make_shared<AbstractScalar>(kFloat64)}}; 109 std::unordered_map<std::string, ValuePtr> methods; 110 AbstractBasePtr abs_ptr = std::make_shared<AbstractClass>(Named("Point"), attr, methods); 111 node->set_abstract(abs_ptr); 112 } 113 } 114 115 ASSERT_EQ(dataclass_count, 1); 116 117 SimplifyDataStructures(func_graph, manager); 118 119 int tuple_getitem_count = 0; 120 121 for (auto node : manager->all_nodes()) { 122 if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { 123 tuple_getitem_count++; 124 } 125 } 126 127 ASSERT_EQ(dataclass_count, 1); 128 ASSERT_EQ(tuple_getitem_count, 2); 129 } 130 131 TEST_F(TestClean, TestEraseClassMakeRecord) { 132 // build the graph 133 auto func_graph = std::make_shared<FuncGraph>(); 134 func_graph->debug_info()->set_name("test_make_record"); 135 136 auto cons_make_record = NewValueNode(prim::kPrimMakeRecord); 137 auto para1 = std::make_shared<Parameter>(func_graph); 138 auto para2 = std::make_shared<Parameter>(func_graph); 139 140 para1->set_abstract(std::make_shared<AbstractScalar>(kAnyValue, kInt64)); 141 para2->set_abstract(std::make_shared<AbstractScalar>(kAnyValue, kInt64)); 142 std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kAnyValue, kInt64)}, 143 {"y", std::make_shared<AbstractScalar>(kAnyValue, kInt64)}}; 144 std::unordered_map<std::string, ValuePtr> methods; 145 AbstractBasePtr abs_ptr = std::make_shared<AbstractClass>(Named("Point"), attr, methods); 146 auto cons_class = NewValueNode(abs_ptr->BuildValue()); 147 cons_class->set_abstract(abs_ptr); 148 149 std::vector<AnfNodePtr> inputs{cons_make_record, cons_class, para1, para2}; 150 auto apply22 = func_graph->NewCNode(inputs); 151 152 auto cons_return = NewValueNode(prim::kPrimReturn); 153 auto apply11 = func_graph->NewCNode({cons_return, apply22}); 154 apply11->set_abstract(abs_ptr); 155 156 func_graph->set_output(apply22); 157 func_graph->set_return(apply11); 158 func_graph->add_parameter(para1); 159 func_graph->add_parameter(para2); 160 161 auto manager = Manage(func_graph); 162 163 SimplifyDataStructures(func_graph, manager); 164 } 165 166 TEST_F(TestClean, TestEraseClassPartial) { 167 // build the graph 168 auto func_graph = std::make_shared<FuncGraph>(); 169 func_graph->debug_info()->set_name("test_partial"); 170 171 auto cons_partial = NewValueNode(prim::kPrimPartial); 172 auto para1 = std::make_shared<Parameter>(func_graph); 173 para1->set_abstract(std::make_shared<AbstractScalar>(kAnyValue, kInt64)); 174 175 auto cons_make_record = NewValueNode(prim::kPrimMakeRecord); 176 177 std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kAnyValue, kInt64)}, 178 {"y", std::make_shared<AbstractScalar>(kAnyValue, kInt64)}}; 179 std::unordered_map<std::string, ValuePtr> methods; 180 AbstractBasePtr abs_ptr = std::make_shared<AbstractClass>(Named("Point"), attr, methods); 181 auto cons_class = NewValueNode(abs_ptr->BuildValue()); 182 cons_class->set_abstract(abs_ptr); 183 184 std::vector<AnfNodePtr> inputs{cons_partial, cons_make_record, cons_class, para1}; 185 auto apply22 = func_graph->NewCNode(inputs); 186 std::vector<AnfNodePtr> inputs_nopara{cons_partial, cons_make_record, cons_class}; 187 auto apply33 = func_graph->NewCNode(inputs_nopara); 188 189 auto apply11 = func_graph->NewCNode({NewValueNode(prim::kPrimScalarAdd), apply22, apply33}); 190 191 auto cons_return = NewValueNode(prim::kPrimReturn); 192 auto apply00 = func_graph->NewCNode({cons_return, apply11}); 193 apply00->set_abstract(abs_ptr); 194 195 func_graph->set_output(apply22); 196 func_graph->set_return(apply11); 197 func_graph->add_parameter(para1); 198 199 auto manager = Manage(func_graph); 200 SimplifyDataStructures(func_graph, manager); 201 } 202 203 TEST_F(TestClean, TestEraseTuple) { 204 ASSERT_TRUE(nullptr != me_graph); 205 std::shared_ptr<FuncGraphManager> manager = Manage(me_graph); 206 207 int abstract_tuple_count = 0; 208 209 for (auto node : manager->all_nodes()) { 210 auto dt = node->abstract(); 211 if (dyn_cast<AbstractTuple>(dt) != nullptr) { 212 abstract_tuple_count++; 213 } 214 } 215 ASSERT_EQ(abstract_tuple_count, 4); 216 217 // erase tuple in CNode57 and Parameter 218 EraseTuple(me_graph, manager); 219 220 abstract_tuple_count = 0; 221 for (auto node : manager->all_nodes()) { 222 auto dt = node->abstract(); 223 if (dyn_cast<AbstractTuple>(dt) != nullptr) { 224 abstract_tuple_count++; 225 } 226 } 227 228 ASSERT_EQ(abstract_tuple_count, 3); 229 } 230 231 } // namespace opt 232 } // namespace mindspore 233