1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <algorithm>
17
18 #include "common/common_test.h"
19 #include "common/py_func_graph_fetcher.h"
20
21 #include "ir/manager.h"
22 #include "utils/log_adapter.h"
23 #include "ir/func_graph_cloner.h"
24 #include "pipeline/jit/parse/parse.h"
25 #include "ir/graph_utils.h"
26 #include "debug/draw.h"
27 #include "base/core_ops.h"
28
29 namespace mindspore {
30 class TestCloner : public UT::Common {
31 public:
TestCloner()32 TestCloner() : getPyFun("gtest_input.ir.clone_test", true) {
33 one = NewValueNode(static_cast<int64_t>(1));
34 two = NewValueNode(static_cast<int64_t>(2));
35 three = NewValueNode(static_cast<int64_t>(3));
36 }
37
GraphForInline()38 FuncGraphPtr GraphForInline() { return nullptr; }
39 void SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphPtr orig, const std::vector<AnfNodePtr>& params,
40 FuncGraphPtr target);
41
42 public:
43 UT::PyFuncGraphFetcher getPyFun;
44
45 ValueNodePtr one;
46 ValueNodePtr two;
47 ValueNodePtr three;
48 };
49
SuccessfulInlining(const std::shared_ptr<Cloner> cl,FuncGraphPtr orig,const std::vector<AnfNodePtr> & params,FuncGraphPtr target)50 void TestCloner::SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphPtr orig,
51 const std::vector<AnfNodePtr>& params, FuncGraphPtr target) {
52 auto g = (*cl)[orig];
53 ASSERT_TRUE(g != target);
54 ASSERT_TRUE(g == orig);
55
56 auto new_root = (*cl)[orig->output()];
57 ASSERT_TRUE(new_root != orig->output());
58
59 AnfNodeSet orig_nodes = AnfNodeSet(DeepLinkedGraphSearch(orig->output()));
60 AnfNodeSet new_nodes = AnfNodeSet(DeepLinkedGraphSearch(new_root));
61
62 for (auto& p : params) {
63 ASSERT_TRUE(new_nodes.contains(p));
64 }
65
66 for (auto& node : orig_nodes) {
67 if (node->func_graph() == orig) {
68 ASSERT_TRUE((*cl)[node]);
69 }
70 }
71 ASSERT_TRUE(target->output() == three);
72 }
73
TEST_F(TestCloner,test_clone_simple)74 TEST_F(TestCloner, test_clone_simple) {
75 std::string py_code = "test_clone_simple";
76
77 FuncGraphPtr g = getPyFun.CallAndParseRet(py_code);
78 ASSERT_TRUE(g != nullptr);
79
80 std::vector<FuncGraphPtr> gs = {g};
81 Cloner cl(gs, true);
82 auto g2 = cl[g];
83
84 AnfNodeSet d1 = AnfNodeSet(DeepScopedGraphSearch(g->get_return()));
85 AnfNodeSet d2 = AnfNodeSet(DeepScopedGraphSearch(g2->get_return()));
86
87 auto common = d1 & d2;
88 ASSERT_EQ((size_t)0, common.size());
89
90 Cloner cl2(gs);
91 auto g3 = cl2[g];
92
93 std::vector<Primitive> results = {Primitive(prim::kScalarAdd), Primitive(prim::kScalarMul), Primitive("Return")};
94 AnfNodeSet d3 = AnfNodeSet(DeepScopedGraphSearch(g3->get_return()));
95 common = d1 & d3;
96 for (auto& x : common) {
97 ASSERT_TRUE(x->isa<ValueNode>());
98 ASSERT_TRUE(find(results.begin(), results.end(), *x->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>()) !=
99 results.end());
100 }
101 }
102
TEST_F(TestCloner,test_clone_closure)103 TEST_F(TestCloner, test_clone_closure) {
104 std::string py_code = "test_clone_closure";
105
106 // parse ast to graph
107 FuncGraphPtr parsed_f = getPyFun(py_code);
108
109 FuncGraphIndex idx(parsed_f);
110 auto g = idx.GetFirstFuncGraph("j");
111
112 std::vector<FuncGraphPtr> gs = {g};
113 Cloner cl(gs, true);
114
115 auto g_clone = cl[g];
116 FuncGraphIndex idx2(g_clone, DeepLinkedGraphSearch);
117
118 std::string name_list = "xy";
119 for (auto name : name_list) {
120 ASSERT_EQ(idx.GetFirstNode(std::string(1, name)), idx2.GetFirstNode(std::string(1, name)));
121 }
122
123 ASSERT_FALSE(idx.GetFirstNode("z") == idx2.GetFirstNode("z"));
124 ASSERT_FALSE(idx.GetFirstFuncGraph("j") == idx2.GetFirstFuncGraph("j"));
125 }
126
TEST_F(TestCloner,test_clone_lifting)127 TEST_F(TestCloner, test_clone_lifting) {
128 std::string py_code = "test_clone_closure";
129
130 // parse ast to graph
131 FuncGraphPtr parsed_f = getPyFun(py_code);
132
133 auto g_lifting = LiftingClone(parsed_f);
134
135 FuncGraphIndex idx(g_lifting);
136 auto g = idx.GetFirstFuncGraph("j");
137
138 auto params = g_lifting->parameters();
139 auto child_params = g->parameters();
140 ASSERT_TRUE(params.size() + 1 == child_params.size());
141 }
142
TEST_F(TestCloner,test_clone_scoping)143 TEST_F(TestCloner, test_clone_scoping) {
144 std::string py_code = "test_clone_scoping";
145
146 // parse ast to graph
147 FuncGraphPtr g = getPyFun.CallAndParseRet(py_code);
148
149 std::vector<FuncGraphPtr> gs = {g};
150 Cloner cl(gs, true);
151
152 auto g2 = cl[g];
153
154 FuncGraphIndex idx1(g);
155 FuncGraphIndex idx2(g2);
156
157 std::string name_list = "fgi";
158 for (auto name : name_list) {
159 auto result1 = idx1.GetFirstFuncGraph(std::string(1, name));
160 auto result2 = idx2.GetFirstFuncGraph(std::string(1, name));
161 ASSERT_FALSE(result1 == result2);
162 }
163
164 name_list = "h";
165 for (auto name : name_list) {
166 ASSERT_TRUE(idx1.GetFirstFuncGraph(std::string(1, name)) == idx2.GetFirstFuncGraph(std::string(1, name)));
167 }
168 }
169
TEST_F(TestCloner,test_clone_total)170 TEST_F(TestCloner, test_clone_total) {
171 std::string py_code = "test_clone_total";
172
173 // parse ast to graph
174 getPyFun.SetDoResolve();
175 FuncGraphPtr g = getPyFun.CallAndParseRet(py_code);
176 if (g == nullptr) {
177 return;
178 }
179
180 FuncGraphIndex idx0(g);
181
182 std::vector<FuncGraphPtr> gs = {g};
183 Cloner cl1(gs, true, true, true);
184 auto g2 = cl1[g];
185 FuncGraphIndex idx1(g2);
186
187 ASSERT_FALSE(idx0.GetFirstFuncGraph("clone_total_sub") == idx1.GetFirstFuncGraph("clone_total_sub"));
188 ASSERT_FALSE(idx0.GetFirstFuncGraph("clone_total") == idx1.GetFirstFuncGraph("clone_total"));
189
190 Cloner cl2(gs, true);
191 FuncGraphIndex idx2(cl2[g]);
192
193 ASSERT_FALSE(idx0.GetFirstFuncGraph("clone_total") == idx2.GetFirstFuncGraph("clone_total"));
194 ASSERT_TRUE(idx0.GetFirstFuncGraph("clone_total_sub") == idx2.GetFirstFuncGraph("clone_total_sub"));
195 }
196
197 } // namespace mindspore
198