• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <iostream>
17 #include <string>
18 #include "common/common_test.h"
19 #include "common/py_func_graph_fetcher.h"
20 
21 #include "utils/log_adapter.h"
22 #include "pipeline/jit/parse/parse.h"
23 #include "debug/draw.h"
24 #include "frontend/optimizer/clean.h"
25 
26 namespace mindspore {
27 namespace opt {
28 using mindspore::abstract::AbstractAttribute;
29 using mindspore::abstract::AbstractClass;
30 using mindspore::abstract::AbstractError;
31 using mindspore::abstract::AbstractList;
32 using mindspore::abstract::AbstractScalar;
33 using mindspore::abstract::AbstractTensor;
34 using mindspore::abstract::AbstractTuple;
35 
36 class TestClean : public UT::Common {
37  public:
38   TestClean() : getPyFun("gtest_input.optimizer.clean_test", true) {}
39   virtual void SetUp();
40   virtual void TearDown();
41 
42  public:
43   UT::PyFuncGraphFetcher getPyFun;
44   FuncGraphPtr me_graph;
45 };
46 
47 void TestClean::SetUp() {
48   // build the func_graph.
49   me_graph = std::make_shared<FuncGraph>();
50   me_graph->debug_info()->set_name("next");
51 
52   // build the nodes
53   AnfNodePtr valuenode_next = NewValueNode(std::string("ms_next"));
54   ParameterPtr parameter = std::make_shared<Parameter>(me_graph);
55   AbstractBasePtr para_scalar = std::make_shared<AbstractScalar>(static_cast<int64_t>(0));
56   AbstractBasePtr para_list = std::make_shared<AbstractList>(
57     AbstractBasePtrList({std::make_shared<AbstractScalar>(kFloat64), std::make_shared<AbstractScalar>(kFloat64)}));
58   AbstractBasePtrList para_elem{para_scalar, para_list};
59   AbstractBasePtr para_tuple = std::make_shared<AbstractTuple>(para_elem);
60   parameter->set_abstract(para_tuple);
61 
62   AbstractBasePtr app_float = std::make_shared<AbstractScalar>(kFloat64);
63   AbstractBasePtr app_int = std::make_shared<AbstractScalar>(kFloat64);
64   AbstractBasePtr app_list = std::make_shared<AbstractList>(
65     AbstractBasePtrList({std::make_shared<AbstractScalar>(kFloat64), std::make_shared<AbstractScalar>(kFloat64)}));
66   AbstractBasePtr app_tuple_inner = std::make_shared<AbstractTuple>(AbstractBasePtrList{app_int, app_list});
67   AbstractBasePtr app_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList{app_float, app_tuple_inner});
68   AnfNodePtr cnode_57 = me_graph->NewCNode({valuenode_next, parameter});
69   cnode_57->set_abstract(app_tuple);
70 
71   AnfNodePtr cnode_67 = me_graph->NewCNode({NewValueNode(prim::kPrimPartial), valuenode_next, parameter});
72   cnode_67->set_abstract(app_tuple);
73 
74   AnfNodePtr cnode_66 = me_graph->NewCNode({NewValueNode(prim::kPrimScalarAdd), cnode_57, cnode_67});
75   cnode_66->set_abstract(app_float);
76 
77   AnfNodePtr valuenode_return = NewValueNode(prim::kPrimReturn);
78   CNodePtr cnode_55 = me_graph->NewCNode({valuenode_return, cnode_66});
79   cnode_55->set_abstract(app_tuple);
80 
81   me_graph->set_output(cnode_66);
82   me_graph->set_return(cnode_55);
83   me_graph->add_parameter(parameter);
84 }
85 
86 void TestClean::TearDown() {}
87 
88 TEST_F(TestClean, TestEraseClassGetAttr) {
89   FuncGraphPtr func_graph;
90 
91   func_graph = getPyFun("test_erase_class_fn");
92   ASSERT_TRUE(nullptr != func_graph);
93 
94   // save the func_graph to manager
95   std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
96   int dataclass_count = 0;
97 
98   for (auto node : manager->all_nodes()) {
99     if (IsValueNode<parse::ClassObject>(node)) {
100       dataclass_count++;
101     }
102     if (!node->isa<CNode>()) {
103       continue;
104     }
105     auto input0 = node->cast<CNodePtr>()->input(0);
106     if (IsValueNode<parse::ClassObject>(input0)) {
107       std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kFloat64)},
108                                              {"y", std::make_shared<AbstractScalar>(kFloat64)}};
109       std::unordered_map<std::string, ValuePtr> methods;
110       AbstractBasePtr abs_ptr = std::make_shared<AbstractClass>(Named("Point"), attr, methods);
111       node->set_abstract(abs_ptr);
112     }
113   }
114 
115   ASSERT_EQ(dataclass_count, 1);
116 
117   SimplifyDataStructures(func_graph, manager);
118 
119   int tuple_getitem_count = 0;
120 
121   for (auto node : manager->all_nodes()) {
122     if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
123       tuple_getitem_count++;
124     }
125   }
126 
127   ASSERT_EQ(dataclass_count, 1);
128   ASSERT_EQ(tuple_getitem_count, 2);
129 }
130 
131 TEST_F(TestClean, TestEraseClassMakeRecord) {
132   // build the graph
133   auto func_graph = std::make_shared<FuncGraph>();
134   func_graph->debug_info()->set_name("test_make_record");
135 
136   auto cons_make_record = NewValueNode(prim::kPrimMakeRecord);
137   auto para1 = std::make_shared<Parameter>(func_graph);
138   auto para2 = std::make_shared<Parameter>(func_graph);
139 
140   para1->set_abstract(std::make_shared<AbstractScalar>(kAnyValue, kInt64));
141   para2->set_abstract(std::make_shared<AbstractScalar>(kAnyValue, kInt64));
142   std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kAnyValue, kInt64)},
143                                          {"y", std::make_shared<AbstractScalar>(kAnyValue, kInt64)}};
144   std::unordered_map<std::string, ValuePtr> methods;
145   AbstractBasePtr abs_ptr = std::make_shared<AbstractClass>(Named("Point"), attr, methods);
146   auto cons_class = NewValueNode(abs_ptr->BuildValue());
147   cons_class->set_abstract(abs_ptr);
148 
149   std::vector<AnfNodePtr> inputs{cons_make_record, cons_class, para1, para2};
150   auto apply22 = func_graph->NewCNode(inputs);
151 
152   auto cons_return = NewValueNode(prim::kPrimReturn);
153   auto apply11 = func_graph->NewCNode({cons_return, apply22});
154   apply11->set_abstract(abs_ptr);
155 
156   func_graph->set_output(apply22);
157   func_graph->set_return(apply11);
158   func_graph->add_parameter(para1);
159   func_graph->add_parameter(para2);
160 
161   auto manager = Manage(func_graph);
162 
163   SimplifyDataStructures(func_graph, manager);
164 }
165 
166 TEST_F(TestClean, TestEraseClassPartial) {
167   // build the graph
168   auto func_graph = std::make_shared<FuncGraph>();
169   func_graph->debug_info()->set_name("test_partial");
170 
171   auto cons_partial = NewValueNode(prim::kPrimPartial);
172   auto para1 = std::make_shared<Parameter>(func_graph);
173   para1->set_abstract(std::make_shared<AbstractScalar>(kAnyValue, kInt64));
174 
175   auto cons_make_record = NewValueNode(prim::kPrimMakeRecord);
176 
177   std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kAnyValue, kInt64)},
178                                          {"y", std::make_shared<AbstractScalar>(kAnyValue, kInt64)}};
179   std::unordered_map<std::string, ValuePtr> methods;
180   AbstractBasePtr abs_ptr = std::make_shared<AbstractClass>(Named("Point"), attr, methods);
181   auto cons_class = NewValueNode(abs_ptr->BuildValue());
182   cons_class->set_abstract(abs_ptr);
183 
184   std::vector<AnfNodePtr> inputs{cons_partial, cons_make_record, cons_class, para1};
185   auto apply22 = func_graph->NewCNode(inputs);
186   std::vector<AnfNodePtr> inputs_nopara{cons_partial, cons_make_record, cons_class};
187   auto apply33 = func_graph->NewCNode(inputs_nopara);
188 
189   auto apply11 = func_graph->NewCNode({NewValueNode(prim::kPrimScalarAdd), apply22, apply33});
190 
191   auto cons_return = NewValueNode(prim::kPrimReturn);
192   auto apply00 = func_graph->NewCNode({cons_return, apply11});
193   apply00->set_abstract(abs_ptr);
194 
195   func_graph->set_output(apply22);
196   func_graph->set_return(apply11);
197   func_graph->add_parameter(para1);
198 
199   auto manager = Manage(func_graph);
200   SimplifyDataStructures(func_graph, manager);
201 }
202 
203 TEST_F(TestClean, TestEraseTuple) {
204   ASSERT_TRUE(nullptr != me_graph);
205   std::shared_ptr<FuncGraphManager> manager = Manage(me_graph);
206 
207   int abstract_tuple_count = 0;
208 
209   for (auto node : manager->all_nodes()) {
210     auto dt = node->abstract();
211     if (dyn_cast<AbstractTuple>(dt) != nullptr) {
212       abstract_tuple_count++;
213     }
214   }
215   ASSERT_EQ(abstract_tuple_count, 4);
216 
217   // erase tuple in CNode57 and Parameter
218   EraseTuple(me_graph, manager);
219 
220   abstract_tuple_count = 0;
221   for (auto node : manager->all_nodes()) {
222     auto dt = node->abstract();
223     if (dyn_cast<AbstractTuple>(dt) != nullptr) {
224       abstract_tuple_count++;
225     }
226   }
227 
228   ASSERT_EQ(abstract_tuple_count, 3);
229 }
230 
231 }  // namespace opt
232 }  // namespace mindspore
233