• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <iostream>
17 #include <memory>
18 
19 #include "common/common_test.h"
20 #include "common/py_func_graph_fetcher.h"
21 
22 #include "ir/manager.h"
23 #include "pipeline/jit/static_analysis/prim.h"
24 #include "pipeline/jit/static_analysis/program_specialize.h"
25 #include "pipeline/static_analysis/helper.h"
26 #include "utils/log_adapter.h"
27 #include "ir/graph_utils.h"
28 #include "utils/misc.h"
29 #include "debug/draw.h"
30 #include "base/core_ops.h"
31 
32 namespace mindspore {
33 namespace abstract {
34 class TestSpecializeGraph : public UT::Common {
35  public:
36   void SetUp();
37   void TearDown();
38   // f(x) call g(x)
39   FuncGraphPtr graph_f_;
40   FuncGraphPtr graph_g_;
41   // alpha(x) return beta(x) closure;
42   FuncGraphPtr graph_alpha_;
43   FuncGraphPtr graph_beta_;
44   std::shared_ptr<AnalysisEngine> engine_;
45   std::shared_ptr<ProgramSpecializer> special_;
46 };
47 
SetUp()48 void TestSpecializeGraph::SetUp() {
49   UT::InitPythonPath();
50   // init resource
51   engine_ = SetupAnalysisEngine();
52 
53   special_ = std::make_shared<ProgramSpecializer>(engine_);
54 
55   /*
56    * def g(y):
57    *   return y;
58    */
59   graph_g_ = std::make_shared<FuncGraph>();
60   ParameterPtr y = graph_g_->add_parameter();
61   auto prim_return = std::make_shared<Primitive>("Return");
62   std::vector<AnfNodePtr> inputs;
63   inputs.push_back(NewValueNode(prim_return));
64   inputs.push_back(y);
65   CNodePtr cnode_g_ret = graph_g_->NewCNode(inputs);
66   graph_g_->set_return(cnode_g_ret);
67 
68   /*
69    * def f(x):
70    *   return g(x)
71    */
72   graph_f_ = std::make_shared<FuncGraph>();
73   ParameterPtr x = graph_f_->add_parameter();
74   inputs.clear();
75   inputs.push_back(NewValueNode(graph_g_));
76   inputs.push_back(x);
77   CNodePtr cnode_f = graph_f_->NewCNode(inputs);
78   inputs.clear();
79   inputs.push_back(NewValueNode(prim_return));
80   inputs.push_back(cnode_f);
81   CNodePtr cnode_f_ret = graph_f_->NewCNode(inputs);
82   graph_f_->set_return(cnode_f_ret);
83 
84   /* build a closure func_graph */
85   /*
86    *def alpha(x, y):
87    *    def beta(x1):
88    *         return x1 + y
89    *    return beta(x)
90    */
91   graph_alpha_ = std::make_shared<FuncGraph>();
92   graph_beta_ = std::make_shared<FuncGraph>();
93   x = graph_alpha_->add_parameter();
94   y = graph_alpha_->add_parameter();
95 
96   // build func_graph beta
97   ParameterPtr x1 = graph_beta_->add_parameter();
98   inputs.clear();
99   inputs.push_back(NewValueNode(std::make_shared<Primitive>(prim::kScalarAdd)));
100   inputs.push_back(x1);
101   inputs.push_back(y);
102   CNodePtr cnode_add = graph_beta_->NewCNode(inputs);
103   inputs.clear();
104   inputs.push_back(NewValueNode(std::make_shared<Primitive>("Return")));
105   inputs.push_back(cnode_add);
106   CNodePtr cnode_return = graph_beta_->NewCNode(inputs);
107   graph_beta_->set_return(cnode_return);
108 
109   // build func_graph alpha
110   inputs.clear();
111   inputs.push_back(NewValueNode(graph_beta_));
112   inputs.push_back(x);
113   CNodePtr cnode_graph_beta_ = graph_alpha_->NewCNode(inputs);
114 
115   inputs.clear();
116   inputs.push_back(NewValueNode(prim_return));
117   inputs.push_back(cnode_graph_beta_);
118   cnode_return = graph_alpha_->NewCNode(inputs);
119   graph_alpha_->set_return(cnode_return);
120 }
121 
TearDown()122 void TestSpecializeGraph::TearDown() {}
123 
TEST_F(TestSpecializeGraph,test_specialize)124 TEST_F(TestSpecializeGraph, test_specialize) {
125   AbstractBasePtrList args_spec_list;
126   MS_LOG(INFO) << "Begin TestSpecializeGraph call other graph.";
127   MS_LOG(INFO) << "" << graph_f_->get_return()->ToString();
128   AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), false);
129   args_spec_list.push_back(abstract_v1);
130 
131   AnalysisResult result = engine_->Run(graph_f_, args_spec_list);
132   FuncGraphPtr new_graph = special_->Run(graph_f_, result.context);
133 }
134 
TEST_F(TestSpecializeGraph,test_specialize1)135 TEST_F(TestSpecializeGraph, test_specialize1) {
136   AbstractBasePtrList args_spec_list;
137   AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), true);
138   AbstractBasePtr abstract_v2 = FromValue(static_cast<int64_t>(2), true);
139   args_spec_list.push_back(abstract_v1);
140   args_spec_list.push_back(abstract_v2);
141   AnalysisResult result = engine_->Run(graph_alpha_, args_spec_list);
142   FuncGraphPtr new_graph = special_->Run(graph_alpha_, result.context);
143 }
144 
145 class TestSpecializeMetaFuncGraph : public UT::Common {
146  public:
147   void SetUp();
148   void TearDown();
149   FuncGraphPtr graph_;
150   std::shared_ptr<AnalysisEngine> engine_;
151   std::shared_ptr<ProgramSpecializer> special_;
152 };
153 
154 class MetaScalarAdd : public MetaFuncGraph {
155  public:
MetaScalarAdd(std::string name)156   explicit MetaScalarAdd(std::string name) : MetaFuncGraph(name) {}
157 
~MetaScalarAdd()158   ~MetaScalarAdd() {}
159   /*
160    * Generate a Graph for the given abstract arguments.
161    */
GenerateFromTypes(const TypePtrList & types)162   FuncGraphPtr GenerateFromTypes(const TypePtrList& types) override {
163     FuncGraphPtr graph_g = std::make_shared<FuncGraph>();
164     ParameterPtr x = graph_g->add_parameter();
165     ParameterPtr y = graph_g->add_parameter();
166     auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
167     std::vector<AnfNodePtr> inputs;
168     inputs.push_back(NewValueNode(prim_scalar_add));
169     inputs.push_back(x);
170     inputs.push_back(y);
171     CNodePtr cnode_add = graph_g->NewCNode(inputs);
172     auto prim_return = std::make_shared<Primitive>("Return");
173     inputs.clear();
174     inputs.push_back(NewValueNode(prim_return));
175     inputs.push_back(cnode_add);
176     CNodePtr cnode_return = graph_g->NewCNode(inputs);
177     graph_g->set_return(cnode_return);
178     return graph_g;
179   }
180 };
181 
SetUp()182 void TestSpecializeMetaFuncGraph::SetUp() {
183   UT::InitPythonPath();
184   // init resource
185   engine_ = SetupAnalysisEngine();
186   special_ = std::make_shared<ProgramSpecializer>(engine_);
187 
188   /*
189    * def f(x, y):
190    *   return mata_scalar_add(x, y)
191    */
192   graph_ = std::make_shared<FuncGraph>();
193   ParameterPtr x = graph_->add_parameter();
194   ParameterPtr y = graph_->add_parameter();
195   std::shared_ptr<MetaFuncGraph> meta_scalar_add = std::make_shared<MetaScalarAdd>("meta_scalar_add");
196   std::vector<AnfNodePtr> inputs;
197   inputs.push_back(NewValueNode(meta_scalar_add));
198   inputs.push_back(x);
199   inputs.push_back(y);
200   CNodePtr cnode_add = graph_->NewCNode(inputs);
201   auto prim_return = std::make_shared<Primitive>("Return");
202   inputs.clear();
203   inputs.push_back(NewValueNode(prim_return));
204   inputs.push_back(cnode_add);
205   CNodePtr cnode_return = graph_->NewCNode(inputs);
206   graph_->set_return(cnode_return);
207 }
208 
TearDown()209 void TestSpecializeMetaFuncGraph::TearDown() {}
210 
TEST_F(TestSpecializeMetaFuncGraph,test_specialize)211 TEST_F(TestSpecializeMetaFuncGraph, test_specialize) {
212   AbstractBasePtrList args_spec_list;
213   std::cout << graph_->get_return()->ToString() << std::endl;
214   AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), true);
215   AbstractBasePtr abstract_v2 = FromValue(static_cast<int64_t>(2), true);
216   args_spec_list.push_back(abstract_v1);
217   args_spec_list.push_back(abstract_v2);
218   AnalysisResult result = engine_->Run(graph_, args_spec_list);
219   FuncGraphPtr new_graph = special_->Run(graph_, result.context);
220 }
221 
222 }  // namespace abstract
223 }  // namespace mindspore
224