1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <iostream>
17 #include <string>
18 #include "common/common_test.h"
19 #include "common/py_func_graph_fetcher.h"
20
21 #include "utils/log_adapter.h"
22 #include "pipeline/jit/parse/parse.h"
23 #include "debug/draw.h"
24 #include "frontend/optimizer/clean.h"
25
26 namespace mindspore {
27 namespace opt {
28 using mindspore::abstract::AbstractAttribute;
29 using mindspore::abstract::AbstractClass;
30 using mindspore::abstract::AbstractError;
31 using mindspore::abstract::AbstractList;
32 using mindspore::abstract::AbstractScalar;
33 using mindspore::abstract::AbstractTensor;
34 using mindspore::abstract::AbstractTuple;
35
36 class TestClean : public UT::Common {
37 public:
TestClean()38 TestClean() : getPyFun("gtest_input.optimizer.clean_test", true) {}
39 virtual void SetUp();
40 virtual void TearDown();
41
42 public:
43 UT::PyFuncGraphFetcher getPyFun;
44 FuncGraphPtr me_graph;
45 };
46
SetUp()47 void TestClean::SetUp() {
48 // build the func_graph.
49 me_graph = std::make_shared<FuncGraph>();
50 me_graph->debug_info()->set_name("next");
51
52 // build the nodes
53 AnfNodePtr valuenode_next = NewValueNode(std::string("ms_next"));
54 ParameterPtr parameter = std::make_shared<Parameter>(me_graph);
55 AbstractBasePtr para_scalar = std::make_shared<AbstractScalar>(static_cast<int64_t>(0));
56 AbstractBasePtr para_list = std::make_shared<AbstractList>(
57 AbstractBasePtrList({std::make_shared<AbstractScalar>(kFloat64), std::make_shared<AbstractScalar>(kFloat64)}));
58 AbstractBasePtrList para_elem{para_scalar, para_list};
59 AbstractBasePtr para_tuple = std::make_shared<AbstractTuple>(para_elem);
60 parameter->set_abstract(para_tuple);
61
62 AbstractBasePtr app_float = std::make_shared<AbstractScalar>(kFloat64);
63 AbstractBasePtr app_int = std::make_shared<AbstractScalar>(kFloat64);
64 AbstractBasePtr app_list = std::make_shared<AbstractList>(
65 AbstractBasePtrList({std::make_shared<AbstractScalar>(kFloat64), std::make_shared<AbstractScalar>(kFloat64)}));
66 AbstractBasePtr app_tuple_inner = std::make_shared<AbstractTuple>(AbstractBasePtrList{app_int, app_list});
67 AbstractBasePtr app_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList{app_float, app_tuple_inner});
68 AnfNodePtr cnode_57 = me_graph->NewCNode({valuenode_next, parameter});
69 cnode_57->set_abstract(app_tuple);
70
71 AnfNodePtr cnode_67 = me_graph->NewCNode({NewValueNode(prim::kPrimPartial), valuenode_next, parameter});
72 cnode_67->set_abstract(app_tuple);
73
74 AnfNodePtr cnode_66 = me_graph->NewCNode({NewValueNode(prim::kPrimScalarAdd), cnode_57, cnode_67});
75 cnode_66->set_abstract(app_float);
76
77 AnfNodePtr valuenode_return = NewValueNode(prim::kPrimReturn);
78 CNodePtr cnode_55 = me_graph->NewCNode({valuenode_return, cnode_66});
79 cnode_55->set_abstract(app_tuple);
80
81 me_graph->set_output(cnode_66);
82 me_graph->set_return(cnode_55);
83 me_graph->add_parameter(parameter);
84 }
85
TearDown()86 void TestClean::TearDown() {}
87
TEST_F(TestClean,TestEraseClassGetAttr)88 TEST_F(TestClean, TestEraseClassGetAttr) {
89 FuncGraphPtr func_graph;
90
91 func_graph = getPyFun("test_erase_class_fn");
92 ASSERT_TRUE(nullptr != func_graph);
93
94 // save the func_graph to manager
95 std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
96 int dataclass_count = 0;
97
98 for (auto node : manager->all_nodes()) {
99 if (IsValueNode<parse::ClassObject>(node)) {
100 dataclass_count++;
101 }
102 if (!node->isa<CNode>()) {
103 continue;
104 }
105 auto input0 = node->cast<CNodePtr>()->input(0);
106 if (IsValueNode<parse::ClassObject>(input0)) {
107 std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kFloat64)},
108 {"y", std::make_shared<AbstractScalar>(kFloat64)}};
109 std::unordered_map<std::string, ValuePtr> methods;
110 AbstractBasePtr abs_ptr = std::make_shared<AbstractClass>(Named("Point"), attr, methods);
111 node->set_abstract(abs_ptr);
112 }
113 }
114
115 ASSERT_EQ(dataclass_count, 1);
116
117 SimplifyDataStructures(func_graph, manager);
118
119 int tuple_getitem_count = 0;
120
121 for (auto node : manager->all_nodes()) {
122 if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
123 tuple_getitem_count++;
124 }
125 }
126
127 ASSERT_EQ(dataclass_count, 1);
128 ASSERT_EQ(tuple_getitem_count, 2);
129 }
130
TEST_F(TestClean,TestEraseClassMakeRecord)131 TEST_F(TestClean, TestEraseClassMakeRecord) {
132 // build the graph
133 auto func_graph = std::make_shared<FuncGraph>();
134 func_graph->debug_info()->set_name("test_make_record");
135
136 auto cons_make_record = NewValueNode(prim::kPrimMakeRecord);
137 auto para1 = std::make_shared<Parameter>(func_graph);
138 auto para2 = std::make_shared<Parameter>(func_graph);
139
140 para1->set_abstract(std::make_shared<AbstractScalar>(kAnyValue, kInt64));
141 para2->set_abstract(std::make_shared<AbstractScalar>(kAnyValue, kInt64));
142 std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kAnyValue, kInt64)},
143 {"y", std::make_shared<AbstractScalar>(kAnyValue, kInt64)}};
144 std::unordered_map<std::string, ValuePtr> methods;
145 AbstractBasePtr abs_ptr = std::make_shared<AbstractClass>(Named("Point"), attr, methods);
146 auto cons_class = NewValueNode(abs_ptr->BuildValue());
147 cons_class->set_abstract(abs_ptr);
148
149 std::vector<AnfNodePtr> inputs{cons_make_record, cons_class, para1, para2};
150 auto apply22 = func_graph->NewCNode(inputs);
151
152 auto cons_return = NewValueNode(prim::kPrimReturn);
153 auto apply11 = func_graph->NewCNode({cons_return, apply22});
154 apply11->set_abstract(abs_ptr);
155
156 func_graph->set_output(apply22);
157 func_graph->set_return(apply11);
158 func_graph->add_parameter(para1);
159 func_graph->add_parameter(para2);
160
161 auto manager = Manage(func_graph);
162
163 SimplifyDataStructures(func_graph, manager);
164 }
165
TEST_F(TestClean,TestEraseClassPartial)166 TEST_F(TestClean, TestEraseClassPartial) {
167 // build the graph
168 auto func_graph = std::make_shared<FuncGraph>();
169 func_graph->debug_info()->set_name("test_partial");
170
171 auto cons_partial = NewValueNode(prim::kPrimPartial);
172 auto para1 = std::make_shared<Parameter>(func_graph);
173 para1->set_abstract(std::make_shared<AbstractScalar>(kAnyValue, kInt64));
174
175 auto cons_make_record = NewValueNode(prim::kPrimMakeRecord);
176
177 std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kAnyValue, kInt64)},
178 {"y", std::make_shared<AbstractScalar>(kAnyValue, kInt64)}};
179 std::unordered_map<std::string, ValuePtr> methods;
180 AbstractBasePtr abs_ptr = std::make_shared<AbstractClass>(Named("Point"), attr, methods);
181 auto cons_class = NewValueNode(abs_ptr->BuildValue());
182 cons_class->set_abstract(abs_ptr);
183
184 std::vector<AnfNodePtr> inputs{cons_partial, cons_make_record, cons_class, para1};
185 auto apply22 = func_graph->NewCNode(inputs);
186 std::vector<AnfNodePtr> inputs_nopara{cons_partial, cons_make_record, cons_class};
187 auto apply33 = func_graph->NewCNode(inputs_nopara);
188
189 auto apply11 = func_graph->NewCNode({NewValueNode(prim::kPrimScalarAdd), apply22, apply33});
190
191 auto cons_return = NewValueNode(prim::kPrimReturn);
192 auto apply00 = func_graph->NewCNode({cons_return, apply11});
193 apply00->set_abstract(abs_ptr);
194
195 func_graph->set_output(apply22);
196 func_graph->set_return(apply11);
197 func_graph->add_parameter(para1);
198
199 auto manager = Manage(func_graph);
200 SimplifyDataStructures(func_graph, manager);
201 }
202
TEST_F(TestClean,TestEraseTuple)203 TEST_F(TestClean, TestEraseTuple) {
204 ASSERT_TRUE(nullptr != me_graph);
205 std::shared_ptr<FuncGraphManager> manager = Manage(me_graph);
206
207 int abstract_tuple_count = 0;
208
209 for (auto node : manager->all_nodes()) {
210 auto dt = node->abstract();
211 if (dyn_cast<AbstractTuple>(dt) != nullptr) {
212 abstract_tuple_count++;
213 }
214 }
215 ASSERT_EQ(abstract_tuple_count, 4);
216
217 // erase tuple in CNode57 and Parameter
218 EraseTuple(me_graph, manager);
219
220 abstract_tuple_count = 0;
221 for (auto node : manager->all_nodes()) {
222 auto dt = node->abstract();
223 if (dyn_cast<AbstractTuple>(dt) != nullptr) {
224 abstract_tuple_count++;
225 }
226 }
227
228 ASSERT_EQ(abstract_tuple_count, 3);
229 }
230
231 } // namespace opt
232 } // namespace mindspore
233