• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "frontend/optimizer/py_pass.h"
17 #include <unordered_set>
18 #include <deque>
19 #include <vector>
20 
21 #include "ir/func_graph.h"
22 #include "ir/manager.h"
23 #include "pybind_api/ir/primitive_py.h"
24 #include "ir/scalar.h"
25 #include "ir/graph_utils.h"
26 #include "pipeline/jit/parse/parse_base.h"
27 #include "pipeline/jit/resource.h"
28 #include "frontend/optimizer/py_pass_manager.h"
29 #include "utils/info.h"
30 
31 namespace mindspore {
32 namespace opt {
33 namespace python_pass {
34 namespace internal {
35 const char PARAMETER_MODULE[] = "mindspore.common.parameter";
36 const char PARAMETER_CLASS[] = "Parameter";
37 const char SET_PARAM[] = "__setattr__";
38 AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph,
39                                 const FuncGraphPtr &top_graph);
40 AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph,
41                        const MatchResultPtr &res);
42 void ReflectParamBackToPython(const AnfNodePtr &param, const string &param_name, const tensor::TensorPtr &default_input,
43                               bool requires_grad, bool layerwise_parallel);
44 
IsTraversable(const AnfNodePtr & node)45 bool IsTraversable(const AnfNodePtr &node) {
46   if (node == nullptr) {
47     return false;
48   }
49   if (node->isa<CNode>() || node->isa<Parameter>()) {
50     return true;
51   }
52   if (IsValueNode<FuncGraph>(node) || IsValueNode<RefKey>(node)) {
53     return true;
54   }
55   return false;
56 }
57 
BuildPrimitive(const PatternPtr & pattern)58 AnfNodePtr BuildPrimitive(const PatternPtr &pattern) {
59   // Build up AnfNode from primitive
60   auto prim_pattern = pattern->cast<PrimPtr>();
61   MS_EXCEPTION_IF_NULL(prim_pattern);
62   PrimitivePyPtr prim = prim_pattern->matched_primitive();
63   MS_EXCEPTION_IF_NULL(prim);
64   // Make value node out of primitives
65   return std::make_shared<ValueNode>(prim);
66 }
67 
BuildNewTensor(const PatternPtr & pattern)68 AnfNodePtr BuildNewTensor(const PatternPtr &pattern) {
69   // Build a ValueNode from TensorPtr
70   auto new_tensor_pattern = pattern->cast<NewTensorPtr>();
71   MS_EXCEPTION_IF_NULL(new_tensor_pattern);
72   auto input_tensor = new_tensor_pattern->input_tensor();
73   MS_EXCEPTION_IF_NULL(input_tensor);
74   return std::make_shared<ValueNode>(input_tensor);
75 }
76 
BuildPrimitiveValueNode(const PatternPtr & pattern,const MatchResultPtr & res,const FuncGraphPtr & fg,const FuncGraphPtr & top_graph)77 AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg,
78                                    const FuncGraphPtr &top_graph) {
79   auto call_pattern = pattern->cast<CallPtr>();
80   MS_EXCEPTION_IF_NULL(call_pattern);
81   auto prim = call_pattern->prim_value();
82   if (prim != nullptr) {
83     return std::make_shared<ValueNode>(prim);
84   }
85   auto prim_pattern = call_pattern->prim_pattern();
86   MS_EXCEPTION_IF_NULL(prim_pattern);
87   return ProcessSinglePattern(prim_pattern, res, fg, top_graph);
88 }
89 
BuildNewParameter(const PatternPtr & pattern,const MatchResultPtr & res,const FuncGraphPtr & top_graph)90 AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &top_graph) {
91   auto new_para_pattern = pattern->cast<NewParameterPtr>();
92   MS_EXCEPTION_IF_NULL(new_para_pattern);
93   if (!new_para_pattern->built()) {
94     static int64_t parameter_id = 0;
95     auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name() + std::to_string(parameter_id++);
96     auto para_node = std::make_shared<Parameter>(top_graph);
97     MS_EXCEPTION_IF_NULL(para_node);
98     para_node->set_name(para_name);
99     // Set function graph
100     para_node->set_func_graph(top_graph);
101     // Set Debug Info
102     auto debug_info = std::make_shared<NodeDebugInfo>(para_name);
103     para_node->set_debug_info(debug_info);
104     // Set abstract
105     auto default_value = new_para_pattern->default_tensor();
106     MS_EXCEPTION_IF_NULL(default_value);
107     para_node->set_abstract(default_value->ToAbstract()->Broaden());
108     res->add_entry(pattern, para_node);
109     top_graph->add_parameter(para_node);
110     // Reflect back to Cell._params
111     internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(),
112                                        new_para_pattern->layerwise_parallel());
113     MS_LOG(WARNING) << "Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name();
114     new_para_pattern->set_built(true);
115     return para_node;
116   } else {
117     // Built, fetch the node
118     auto para_node = res->get_node(pattern);
119     MS_EXCEPTION_IF_NULL(para_node);
120     return para_node;
121   }
122 }
123 
BuildImmNode(const PatternPtr & pattern)124 AnfNodePtr BuildImmNode(const PatternPtr &pattern) {
125   auto imm_pattern = pattern->cast<ImmPtr>();
126   MS_EXCEPTION_IF_NULL(imm_pattern);
127   auto value = imm_pattern->value();
128   auto scalar_value_ptr = std::make_shared<Int64Imm>(value);
129   return std::make_shared<ValueNode>(scalar_value_ptr);
130 }
131 
ProcessSinglePattern(const PatternPtr & pattern,const MatchResultPtr & res,const FuncGraphPtr & func_graph,const FuncGraphPtr & top_graph)132 AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph,
133                                 const FuncGraphPtr &top_graph) {
134   auto target_node = res->get_node(pattern);
135   if (target_node != nullptr) {
136     // If pattern is NewParameter, check whether it shouldn't last and is not built
137     auto new_para = pattern->cast<NewParameterPtr>();
138     if (new_para == nullptr || new_para->should_last() || new_para->built()) {
139       return target_node;
140     }
141   }
142   // Build up new node from pattern
143   if (pattern->isa<Prim>()) {
144     return BuildPrimitive(pattern);
145   } else if (pattern->isa<NewTensor>()) {
146     return BuildNewTensor(pattern);
147   } else if (pattern->isa<Call>()) {
148     return BuildPrimitiveValueNode(pattern, res, func_graph, top_graph);
149   } else if (pattern->isa<NewParameter>()) {
150     // Add new parameter to top graph instead of current graph
151     return BuildNewParameter(pattern, res, top_graph);
152   } else if (pattern->isa<Imm>()) {
153     return BuildImmNode(pattern);
154   } else {
155     MS_LOG(EXCEPTION) << "Cannot find or build target node, pattern: " + pattern->unique_name() + "\n";
156     return nullptr;
157   }
158   return nullptr;
159 }
160 
ProcessComplexPatternFirstInput(const PatternPtr & pattern,const MatchResultPtr & res,const FuncGraphPtr & func_graph,const FuncGraphPtr & top_graph)161 AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res,
162                                            const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph) {
163   if (pattern->isa<Call>()) {
164     return BuildPrimitiveValueNode(pattern, res, func_graph, top_graph);
165   }
166   return nullptr;
167 }
168 
BuildTarget(const PatternPtr & pattern,const FuncGraphPtr & func_graph,const FuncGraphPtr & top_graph,const MatchResultPtr & res)169 AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph,
170                        const MatchResultPtr &res) {
171   auto target_inputs = pattern->inputs();
172   if (target_inputs.size() == 0) {
173     auto new_anf_node = ProcessSinglePattern(pattern, res, func_graph, top_graph);
174     if (new_anf_node != nullptr) {
175       res->add_entry(pattern, new_anf_node);
176     }
177     return new_anf_node;
178   }
179   // Build up the AnfNode in a recursive manner
180   std::vector<AnfNodePtr> new_inputs;
181   auto prim_value_node = ProcessComplexPatternFirstInput(pattern, res, func_graph, top_graph);
182   MS_EXCEPTION_IF_NULL(prim_value_node);
183   new_inputs.push_back(prim_value_node);
184   for (auto &iter : target_inputs) {
185     if (iter == pattern) {
186       MS_LOG(EXCEPTION) << "Circle references. Got pattern: " + pattern->unique_name() + "\n";
187     }
188     auto input_node = BuildTarget(iter, func_graph, top_graph, res);
189     if (input_node == nullptr) {
190       MS_LOG(EXCEPTION) << "Failed to build input node for pattern : " + iter->unique_name() + "\n";
191     }
192     new_inputs.push_back(input_node);
193   }
194   auto new_c_node = func_graph->NewCNode(new_inputs);
195   res->add_entry(pattern, new_c_node);
196   return new_c_node;
197 }
198 
ReflectParamBackToPython(const AnfNodePtr & param,const string & param_name,const tensor::TensorPtr & default_input,bool requires_grad,bool layerwise_parallel)199 void ReflectParamBackToPython(const AnfNodePtr &param, const string &param_name, const tensor::TensorPtr &default_input,
200                               bool requires_grad, bool layerwise_parallel) {
201   // 1. Get current cell object
202   auto ppm = opt::python_pass::PyPassManager::GetInstance();
203   auto resource = ppm->GetResource();
204   py::object top_cell = resource->source_input();
205   if (py::isinstance<py::none>(top_cell)) {
206     MS_LOG(EXCEPTION) << "Failed to get top cell from resource.";
207   }
208   // 2. Clone default_input tensor
209   MS_EXCEPTION_IF_NULL(default_input);
210   auto default_tensor = std::make_shared<tensor::Tensor>(default_input->data_type(), default_input->shape_c(),
211                                                          default_input->data_c(), (size_t)default_input->Size());
212   // 3. New a Parameter object with the above-specified args
213   py::object parameter_class = py::module::import(PARAMETER_MODULE).attr(PARAMETER_CLASS);
214   py::object new_parameter = parameter_class(default_tensor, param_name, requires_grad, layerwise_parallel);
215   // 4. Add the new python Parameter object to Cell's _params attributes
216   top_cell.attr(SET_PARAM)(param_name, new_parameter);
217   // 5. Set default_param for param_node
218   ValuePtr param_value = nullptr;
219   bool converted = parse::ConvertData(new_parameter, &param_value, false);
220   if (!converted) {
221     MS_LOG(EXCEPTION) << "Failed to convert new parameter to ValuePtr.";
222   }
223   MS_EXCEPTION_IF_NULL(param);
224   auto param_node = param->cast<ParameterPtr>();
225   MS_EXCEPTION_IF_NULL(param_node);
226   param_node->set_default_param(param_value);
227 }
228 
Reset(const PatternPtr & pattern)229 void Reset(const PatternPtr &pattern) {
230   if (pattern->isa<Prim>()) {
231     auto prim_pattern = pattern->cast<PrimPtr>();
232     prim_pattern->reset();
233   } else if (pattern->isa<NewParameter>()) {
234     auto new_param_pattern = pattern->cast<NewParameterPtr>();
235     new_param_pattern->reset();
236   } else if (pattern->isa<Call>()) {
237     auto call_with_pattern = pattern->cast<CallPtr>();
238     for (const auto &sub_pattern : call_with_pattern->inputs()) {
239       Reset(sub_pattern);
240     }
241   }
242 }
243 }  // namespace internal
244 
Run(const FuncGraphPtr & func_graph,const FuncGraphPtr & top_graph,const AnfNodePtr & node,const MatchResultPtr & res)245 AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph, const AnfNodePtr &node,
246                            const MatchResultPtr &res) {
247   auto match_res = src_pattern_->match(node);
248   if (match_res != nullptr) {
249     res->merge(match_res);
250     auto new_node = internal::BuildTarget(dst_pattern_, func_graph, top_graph, res);
251     internal::Reset(dst_pattern());
252     return new_node;
253   }
254   internal::Reset(src_pattern());
255   return nullptr;
256 }
257 
Run(const FuncGraphPtr & func_graph,const MatchResultPtr & res)258 bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res) {
259   MS_EXCEPTION_IF_NULL(func_graph);
260   MS_EXCEPTION_IF_NULL(dst_pattern_);
261   if (src_pattern_ == nullptr) {
262     // Add NewParameter
263     auto new_para_pattern = dst_pattern_->cast<NewParameterPtr>();
264     if (new_para_pattern == nullptr) {
265       MS_LOG(EXCEPTION) << "Expect NewParameter pattern for target if src pattern is null.";
266     }
267     auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name();
268     auto para_node = std::make_shared<Parameter>(func_graph);
269     MS_EXCEPTION_IF_NULL(para_node);
270     para_node->set_name(para_name);
271     // Set function graph
272     para_node->set_func_graph(func_graph);
273     // Set Debug Info
274     auto debug_info = std::make_shared<NodeDebugInfo>(para_name);
275     para_node->set_debug_info(debug_info);
276     // Set abstract
277     auto default_value = new_para_pattern->default_tensor();
278     MS_EXCEPTION_IF_NULL(default_value);
279     para_node->set_abstract(default_value->ToAbstract()->Broaden());
280     res->add_entry(dst_pattern_, para_node);
281     func_graph->add_parameter(para_node);
282     // Reflect back to Cell._params
283     internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(),
284                                        new_para_pattern->layerwise_parallel());
285     MS_LOG(WARNING) << "[Gen]Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name();
286     return true;
287   }
288   FuncGraphManagerPtr manager = func_graph->manager();
289   MS_EXCEPTION_IF_NULL(manager);
290   auto func_graphs = manager->func_graphs();
291   bool changes = false;
292   for (auto &fg : func_graphs) {
293     manager->AddFuncGraph(fg);
294     auto graph_nodes_sorted = TopoSort(fg->output());
295     // Traverse once
296     for (auto &node : graph_nodes_sorted) {
297       AnfNodePtr new_node = Run(fg, func_graph, node, res);
298       if (new_node != nullptr && new_node != node) {
299         MS_LOG(WARNING) << "Matched";
300         (void)manager->Replace(node, new_node);
301         changes = true;
302       }
303     }
304   }
305   return changes;
306 }
307 }  // namespace python_pass
308 }  // namespace opt
309 }  // namespace mindspore
310