• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <algorithm>
17 
18 #include "common/common_test.h"
19 #include "common/py_func_graph_fetcher.h"
20 
21 #include "ir/manager.h"
22 #include "utils/log_adapter.h"
23 #include "ir/func_graph_cloner.h"
24 #include "pipeline/jit/parse/parse.h"
25 #include "ir/graph_utils.h"
26 #include "debug/draw.h"
27 #include "base/core_ops.h"
28 
29 namespace mindspore {
30 class TestCloner : public UT::Common {
31  public:
TestCloner()32   TestCloner() : getPyFun("gtest_input.ir.clone_test", true) {
33     one = NewValueNode(static_cast<int64_t>(1));
34     two = NewValueNode(static_cast<int64_t>(2));
35     three = NewValueNode(static_cast<int64_t>(3));
36   }
37 
GraphForInline()38   FuncGraphPtr GraphForInline() { return nullptr; }
39   void SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphPtr orig, const std::vector<AnfNodePtr>& params,
40                           FuncGraphPtr target);
41 
42  public:
43   UT::PyFuncGraphFetcher getPyFun;
44 
45   ValueNodePtr one;
46   ValueNodePtr two;
47   ValueNodePtr three;
48 };
49 
SuccessfulInlining(const std::shared_ptr<Cloner> cl,FuncGraphPtr orig,const std::vector<AnfNodePtr> & params,FuncGraphPtr target)50 void TestCloner::SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphPtr orig,
51                                     const std::vector<AnfNodePtr>& params, FuncGraphPtr target) {
52   auto g = (*cl)[orig];
53   ASSERT_TRUE(g != target);
54   ASSERT_TRUE(g == orig);
55 
56   auto new_root = (*cl)[orig->output()];
57   ASSERT_TRUE(new_root != orig->output());
58 
59   AnfNodeSet orig_nodes = AnfNodeSet(DeepLinkedGraphSearch(orig->output()));
60   AnfNodeSet new_nodes = AnfNodeSet(DeepLinkedGraphSearch(new_root));
61 
62   for (auto& p : params) {
63     ASSERT_TRUE(new_nodes.contains(p));
64   }
65 
66   for (auto& node : orig_nodes) {
67     if (node->func_graph() == orig) {
68       ASSERT_TRUE((*cl)[node]);
69     }
70   }
71   ASSERT_TRUE(target->output() == three);
72 }
73 
TEST_F(TestCloner,test_clone_simple)74 TEST_F(TestCloner, test_clone_simple) {
75   std::string py_code = "test_clone_simple";
76 
77   FuncGraphPtr g = getPyFun.CallAndParseRet(py_code);
78   ASSERT_TRUE(g != nullptr);
79 
80   std::vector<FuncGraphPtr> gs = {g};
81   Cloner cl(gs, true);
82   auto g2 = cl[g];
83 
84   AnfNodeSet d1 = AnfNodeSet(DeepScopedGraphSearch(g->get_return()));
85   AnfNodeSet d2 = AnfNodeSet(DeepScopedGraphSearch(g2->get_return()));
86 
87   auto common = d1 & d2;
88   ASSERT_EQ((size_t)0, common.size());
89 
90   Cloner cl2(gs);
91   auto g3 = cl2[g];
92 
93   std::vector<Primitive> results = {Primitive(prim::kScalarAdd), Primitive(prim::kScalarMul), Primitive("Return")};
94   AnfNodeSet d3 = AnfNodeSet(DeepScopedGraphSearch(g3->get_return()));
95   common = d1 & d3;
96   for (auto& x : common) {
97     ASSERT_TRUE(x->isa<ValueNode>());
98     ASSERT_TRUE(find(results.begin(), results.end(), *x->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>()) !=
99                 results.end());
100   }
101 }
102 
TEST_F(TestCloner,test_clone_closure)103 TEST_F(TestCloner, test_clone_closure) {
104   std::string py_code = "test_clone_closure";
105 
106   // parse ast to graph
107   FuncGraphPtr parsed_f = getPyFun(py_code);
108 
109   FuncGraphIndex idx(parsed_f);
110   auto g = idx.GetFirstFuncGraph("j");
111 
112   std::vector<FuncGraphPtr> gs = {g};
113   Cloner cl(gs, true);
114 
115   auto g_clone = cl[g];
116   FuncGraphIndex idx2(g_clone, DeepLinkedGraphSearch);
117 
118   std::string name_list = "xy";
119   for (auto name : name_list) {
120     ASSERT_EQ(idx.GetFirstNode(std::string(1, name)), idx2.GetFirstNode(std::string(1, name)));
121   }
122 
123   ASSERT_FALSE(idx.GetFirstNode("z") == idx2.GetFirstNode("z"));
124   ASSERT_FALSE(idx.GetFirstFuncGraph("j") == idx2.GetFirstFuncGraph("j"));
125 }
126 
TEST_F(TestCloner,test_clone_lifting)127 TEST_F(TestCloner, test_clone_lifting) {
128   std::string py_code = "test_clone_closure";
129 
130   // parse ast to graph
131   FuncGraphPtr parsed_f = getPyFun(py_code);
132 
133   auto g_lifting = LiftingClone(parsed_f);
134 
135   FuncGraphIndex idx(g_lifting);
136   auto g = idx.GetFirstFuncGraph("j");
137 
138   auto params = g_lifting->parameters();
139   auto child_params = g->parameters();
140   ASSERT_TRUE(params.size() + 1 == child_params.size());
141 }
142 
TEST_F(TestCloner,test_clone_scoping)143 TEST_F(TestCloner, test_clone_scoping) {
144   std::string py_code = "test_clone_scoping";
145 
146   // parse ast to graph
147   FuncGraphPtr g = getPyFun.CallAndParseRet(py_code);
148 
149   std::vector<FuncGraphPtr> gs = {g};
150   Cloner cl(gs, true);
151 
152   auto g2 = cl[g];
153 
154   FuncGraphIndex idx1(g);
155   FuncGraphIndex idx2(g2);
156 
157   std::string name_list = "fgi";
158   for (auto name : name_list) {
159     auto result1 = idx1.GetFirstFuncGraph(std::string(1, name));
160     auto result2 = idx2.GetFirstFuncGraph(std::string(1, name));
161     ASSERT_FALSE(result1 == result2);
162   }
163 
164   name_list = "h";
165   for (auto name : name_list) {
166     ASSERT_TRUE(idx1.GetFirstFuncGraph(std::string(1, name)) == idx2.GetFirstFuncGraph(std::string(1, name)));
167   }
168 }
169 
TEST_F(TestCloner,test_clone_total)170 TEST_F(TestCloner, test_clone_total) {
171   std::string py_code = "test_clone_total";
172 
173   // parse ast to graph
174   getPyFun.SetDoResolve();
175   FuncGraphPtr g = getPyFun.CallAndParseRet(py_code);
176   if (g == nullptr) {
177     return;
178   }
179 
180   FuncGraphIndex idx0(g);
181 
182   std::vector<FuncGraphPtr> gs = {g};
183   Cloner cl1(gs, true, true, true);
184   auto g2 = cl1[g];
185   FuncGraphIndex idx1(g2);
186 
187   ASSERT_FALSE(idx0.GetFirstFuncGraph("clone_total_sub") == idx1.GetFirstFuncGraph("clone_total_sub"));
188   ASSERT_FALSE(idx0.GetFirstFuncGraph("clone_total") == idx1.GetFirstFuncGraph("clone_total"));
189 
190   Cloner cl2(gs, true);
191   FuncGraphIndex idx2(cl2[g]);
192 
193   ASSERT_FALSE(idx0.GetFirstFuncGraph("clone_total") == idx2.GetFirstFuncGraph("clone_total"));
194   ASSERT_TRUE(idx0.GetFirstFuncGraph("clone_total_sub") == idx2.GetFirstFuncGraph("clone_total_sub"));
195 }
196 
197 }  // namespace mindspore
198