1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <iostream>
17 #include <memory>
18
19 #include "common/common_test.h"
20 #include "common/py_func_graph_fetcher.h"
21
22 #include "ir/manager.h"
23 #include "pipeline/jit/static_analysis/prim.h"
24 #include "pipeline/jit/static_analysis/program_specialize.h"
25 #include "pipeline/static_analysis/helper.h"
26 #include "utils/log_adapter.h"
27 #include "ir/graph_utils.h"
28 #include "utils/misc.h"
29 #include "debug/draw.h"
30 #include "base/core_ops.h"
31
32 namespace mindspore {
33 namespace abstract {
34 class TestSpecializeGraph : public UT::Common {
35 public:
36 void SetUp();
37 void TearDown();
38 // f(x) call g(x)
39 FuncGraphPtr graph_f_;
40 FuncGraphPtr graph_g_;
41 // alpha(x) return beta(x) closure;
42 FuncGraphPtr graph_alpha_;
43 FuncGraphPtr graph_beta_;
44 std::shared_ptr<AnalysisEngine> engine_;
45 std::shared_ptr<ProgramSpecializer> special_;
46 };
47
SetUp()48 void TestSpecializeGraph::SetUp() {
49 UT::InitPythonPath();
50 // init resource
51 engine_ = SetupAnalysisEngine();
52
53 special_ = std::make_shared<ProgramSpecializer>(engine_);
54
55 /*
56 * def g(y):
57 * return y;
58 */
59 graph_g_ = std::make_shared<FuncGraph>();
60 ParameterPtr y = graph_g_->add_parameter();
61 auto prim_return = std::make_shared<Primitive>("Return");
62 std::vector<AnfNodePtr> inputs;
63 inputs.push_back(NewValueNode(prim_return));
64 inputs.push_back(y);
65 CNodePtr cnode_g_ret = graph_g_->NewCNode(inputs);
66 graph_g_->set_return(cnode_g_ret);
67
68 /*
69 * def f(x):
70 * return g(x)
71 */
72 graph_f_ = std::make_shared<FuncGraph>();
73 ParameterPtr x = graph_f_->add_parameter();
74 inputs.clear();
75 inputs.push_back(NewValueNode(graph_g_));
76 inputs.push_back(x);
77 CNodePtr cnode_f = graph_f_->NewCNode(inputs);
78 inputs.clear();
79 inputs.push_back(NewValueNode(prim_return));
80 inputs.push_back(cnode_f);
81 CNodePtr cnode_f_ret = graph_f_->NewCNode(inputs);
82 graph_f_->set_return(cnode_f_ret);
83
84 /* build a closure func_graph */
85 /*
86 *def alpha(x, y):
87 * def beta(x1):
88 * return x1 + y
89 * return beta(x)
90 */
91 graph_alpha_ = std::make_shared<FuncGraph>();
92 graph_beta_ = std::make_shared<FuncGraph>();
93 x = graph_alpha_->add_parameter();
94 y = graph_alpha_->add_parameter();
95
96 // build func_graph beta
97 ParameterPtr x1 = graph_beta_->add_parameter();
98 inputs.clear();
99 inputs.push_back(NewValueNode(std::make_shared<Primitive>(prim::kScalarAdd)));
100 inputs.push_back(x1);
101 inputs.push_back(y);
102 CNodePtr cnode_add = graph_beta_->NewCNode(inputs);
103 inputs.clear();
104 inputs.push_back(NewValueNode(std::make_shared<Primitive>("Return")));
105 inputs.push_back(cnode_add);
106 CNodePtr cnode_return = graph_beta_->NewCNode(inputs);
107 graph_beta_->set_return(cnode_return);
108
109 // build func_graph alpha
110 inputs.clear();
111 inputs.push_back(NewValueNode(graph_beta_));
112 inputs.push_back(x);
113 CNodePtr cnode_graph_beta_ = graph_alpha_->NewCNode(inputs);
114
115 inputs.clear();
116 inputs.push_back(NewValueNode(prim_return));
117 inputs.push_back(cnode_graph_beta_);
118 cnode_return = graph_alpha_->NewCNode(inputs);
119 graph_alpha_->set_return(cnode_return);
120 }
121
TearDown()122 void TestSpecializeGraph::TearDown() {}
123
TEST_F(TestSpecializeGraph,test_specialize)124 TEST_F(TestSpecializeGraph, test_specialize) {
125 AbstractBasePtrList args_spec_list;
126 MS_LOG(INFO) << "Begin TestSpecializeGraph call other graph.";
127 MS_LOG(INFO) << "" << graph_f_->get_return()->ToString();
128 AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), false);
129 args_spec_list.push_back(abstract_v1);
130
131 AnalysisResult result = engine_->Run(graph_f_, args_spec_list);
132 FuncGraphPtr new_graph = special_->Run(graph_f_, result.context);
133 }
134
TEST_F(TestSpecializeGraph,test_specialize1)135 TEST_F(TestSpecializeGraph, test_specialize1) {
136 AbstractBasePtrList args_spec_list;
137 AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), true);
138 AbstractBasePtr abstract_v2 = FromValue(static_cast<int64_t>(2), true);
139 args_spec_list.push_back(abstract_v1);
140 args_spec_list.push_back(abstract_v2);
141 AnalysisResult result = engine_->Run(graph_alpha_, args_spec_list);
142 FuncGraphPtr new_graph = special_->Run(graph_alpha_, result.context);
143 }
144
145 class TestSpecializeMetaFuncGraph : public UT::Common {
146 public:
147 void SetUp();
148 void TearDown();
149 FuncGraphPtr graph_;
150 std::shared_ptr<AnalysisEngine> engine_;
151 std::shared_ptr<ProgramSpecializer> special_;
152 };
153
154 class MetaScalarAdd : public MetaFuncGraph {
155 public:
MetaScalarAdd(std::string name)156 explicit MetaScalarAdd(std::string name) : MetaFuncGraph(name) {}
157
~MetaScalarAdd()158 ~MetaScalarAdd() {}
159 /*
160 * Generate a Graph for the given abstract arguments.
161 */
GenerateFromTypes(const TypePtrList & types)162 FuncGraphPtr GenerateFromTypes(const TypePtrList& types) override {
163 FuncGraphPtr graph_g = std::make_shared<FuncGraph>();
164 ParameterPtr x = graph_g->add_parameter();
165 ParameterPtr y = graph_g->add_parameter();
166 auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
167 std::vector<AnfNodePtr> inputs;
168 inputs.push_back(NewValueNode(prim_scalar_add));
169 inputs.push_back(x);
170 inputs.push_back(y);
171 CNodePtr cnode_add = graph_g->NewCNode(inputs);
172 auto prim_return = std::make_shared<Primitive>("Return");
173 inputs.clear();
174 inputs.push_back(NewValueNode(prim_return));
175 inputs.push_back(cnode_add);
176 CNodePtr cnode_return = graph_g->NewCNode(inputs);
177 graph_g->set_return(cnode_return);
178 return graph_g;
179 }
180 };
181
SetUp()182 void TestSpecializeMetaFuncGraph::SetUp() {
183 UT::InitPythonPath();
184 // init resource
185 engine_ = SetupAnalysisEngine();
186 special_ = std::make_shared<ProgramSpecializer>(engine_);
187
188 /*
189 * def f(x, y):
190 * return mata_scalar_add(x, y)
191 */
192 graph_ = std::make_shared<FuncGraph>();
193 ParameterPtr x = graph_->add_parameter();
194 ParameterPtr y = graph_->add_parameter();
195 std::shared_ptr<MetaFuncGraph> meta_scalar_add = std::make_shared<MetaScalarAdd>("meta_scalar_add");
196 std::vector<AnfNodePtr> inputs;
197 inputs.push_back(NewValueNode(meta_scalar_add));
198 inputs.push_back(x);
199 inputs.push_back(y);
200 CNodePtr cnode_add = graph_->NewCNode(inputs);
201 auto prim_return = std::make_shared<Primitive>("Return");
202 inputs.clear();
203 inputs.push_back(NewValueNode(prim_return));
204 inputs.push_back(cnode_add);
205 CNodePtr cnode_return = graph_->NewCNode(inputs);
206 graph_->set_return(cnode_return);
207 }
208
TearDown()209 void TestSpecializeMetaFuncGraph::TearDown() {}
210
TEST_F(TestSpecializeMetaFuncGraph,test_specialize)211 TEST_F(TestSpecializeMetaFuncGraph, test_specialize) {
212 AbstractBasePtrList args_spec_list;
213 std::cout << graph_->get_return()->ToString() << std::endl;
214 AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), true);
215 AbstractBasePtr abstract_v2 = FromValue(static_cast<int64_t>(2), true);
216 args_spec_list.push_back(abstract_v1);
217 args_spec_list.push_back(abstract_v2);
218 AnalysisResult result = engine_->Run(graph_, args_spec_list);
219 FuncGraphPtr new_graph = special_->Run(graph_, result.context);
220 }
221
222 } // namespace abstract
223 } // namespace mindspore
224