1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "frontend/optimizer/py_pass.h"
17 #include <unordered_set>
18 #include <deque>
19 #include <vector>
20
21 #include "ir/func_graph.h"
22 #include "ir/manager.h"
23 #include "pybind_api/ir/primitive_py.h"
24 #include "ir/scalar.h"
25 #include "ir/graph_utils.h"
26 #include "pipeline/jit/parse/parse_base.h"
27 #include "pipeline/jit/resource.h"
28 #include "frontend/optimizer/py_pass_manager.h"
29 #include "utils/info.h"
30
31 namespace mindspore {
32 namespace opt {
33 namespace python_pass {
34 namespace internal {
35 const char PARAMETER_MODULE[] = "mindspore.common.parameter";
36 const char PARAMETER_CLASS[] = "Parameter";
37 const char SET_PARAM[] = "__setattr__";
38 AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph,
39 const FuncGraphPtr &top_graph);
40 AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph,
41 const MatchResultPtr &res);
42 void ReflectParamBackToPython(const AnfNodePtr ¶m, const string ¶m_name, const tensor::TensorPtr &default_input,
43 bool requires_grad, bool layerwise_parallel);
44
IsTraversable(const AnfNodePtr & node)45 bool IsTraversable(const AnfNodePtr &node) {
46 if (node == nullptr) {
47 return false;
48 }
49 if (node->isa<CNode>() || node->isa<Parameter>()) {
50 return true;
51 }
52 if (IsValueNode<FuncGraph>(node) || IsValueNode<RefKey>(node)) {
53 return true;
54 }
55 return false;
56 }
57
BuildPrimitive(const PatternPtr & pattern)58 AnfNodePtr BuildPrimitive(const PatternPtr &pattern) {
59 // Build up AnfNode from primitive
60 auto prim_pattern = pattern->cast<PrimPtr>();
61 MS_EXCEPTION_IF_NULL(prim_pattern);
62 PrimitivePyPtr prim = prim_pattern->matched_primitive();
63 MS_EXCEPTION_IF_NULL(prim);
64 // Make value node out of primitives
65 return std::make_shared<ValueNode>(prim);
66 }
67
BuildNewTensor(const PatternPtr & pattern)68 AnfNodePtr BuildNewTensor(const PatternPtr &pattern) {
69 // Build a ValueNode from TensorPtr
70 auto new_tensor_pattern = pattern->cast<NewTensorPtr>();
71 MS_EXCEPTION_IF_NULL(new_tensor_pattern);
72 auto input_tensor = new_tensor_pattern->input_tensor();
73 MS_EXCEPTION_IF_NULL(input_tensor);
74 return std::make_shared<ValueNode>(input_tensor);
75 }
76
BuildPrimitiveValueNode(const PatternPtr & pattern,const MatchResultPtr & res,const FuncGraphPtr & fg,const FuncGraphPtr & top_graph)77 AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg,
78 const FuncGraphPtr &top_graph) {
79 auto call_pattern = pattern->cast<CallPtr>();
80 MS_EXCEPTION_IF_NULL(call_pattern);
81 auto prim = call_pattern->prim_value();
82 if (prim != nullptr) {
83 return std::make_shared<ValueNode>(prim);
84 }
85 auto prim_pattern = call_pattern->prim_pattern();
86 MS_EXCEPTION_IF_NULL(prim_pattern);
87 return ProcessSinglePattern(prim_pattern, res, fg, top_graph);
88 }
89
BuildNewParameter(const PatternPtr & pattern,const MatchResultPtr & res,const FuncGraphPtr & top_graph)90 AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &top_graph) {
91 auto new_para_pattern = pattern->cast<NewParameterPtr>();
92 MS_EXCEPTION_IF_NULL(new_para_pattern);
93 if (!new_para_pattern->built()) {
94 static int64_t parameter_id = 0;
95 auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name() + std::to_string(parameter_id++);
96 auto para_node = std::make_shared<Parameter>(top_graph);
97 MS_EXCEPTION_IF_NULL(para_node);
98 para_node->set_name(para_name);
99 // Set function graph
100 para_node->set_func_graph(top_graph);
101 // Set Debug Info
102 auto debug_info = std::make_shared<NodeDebugInfo>(para_name);
103 para_node->set_debug_info(debug_info);
104 // Set abstract
105 auto default_value = new_para_pattern->default_tensor();
106 MS_EXCEPTION_IF_NULL(default_value);
107 para_node->set_abstract(default_value->ToAbstract()->Broaden());
108 res->add_entry(pattern, para_node);
109 top_graph->add_parameter(para_node);
110 // Reflect back to Cell._params
111 internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(),
112 new_para_pattern->layerwise_parallel());
113 MS_LOG(WARNING) << "Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name();
114 new_para_pattern->set_built(true);
115 return para_node;
116 } else {
117 // Built, fetch the node
118 auto para_node = res->get_node(pattern);
119 MS_EXCEPTION_IF_NULL(para_node);
120 return para_node;
121 }
122 }
123
BuildImmNode(const PatternPtr & pattern)124 AnfNodePtr BuildImmNode(const PatternPtr &pattern) {
125 auto imm_pattern = pattern->cast<ImmPtr>();
126 MS_EXCEPTION_IF_NULL(imm_pattern);
127 auto value = imm_pattern->value();
128 auto scalar_value_ptr = std::make_shared<Int64Imm>(value);
129 return std::make_shared<ValueNode>(scalar_value_ptr);
130 }
131
ProcessSinglePattern(const PatternPtr & pattern,const MatchResultPtr & res,const FuncGraphPtr & func_graph,const FuncGraphPtr & top_graph)132 AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph,
133 const FuncGraphPtr &top_graph) {
134 auto target_node = res->get_node(pattern);
135 if (target_node != nullptr) {
136 // If pattern is NewParameter, check whether it shouldn't last and is not built
137 auto new_para = pattern->cast<NewParameterPtr>();
138 if (new_para == nullptr || new_para->should_last() || new_para->built()) {
139 return target_node;
140 }
141 }
142 // Build up new node from pattern
143 if (pattern->isa<Prim>()) {
144 return BuildPrimitive(pattern);
145 } else if (pattern->isa<NewTensor>()) {
146 return BuildNewTensor(pattern);
147 } else if (pattern->isa<Call>()) {
148 return BuildPrimitiveValueNode(pattern, res, func_graph, top_graph);
149 } else if (pattern->isa<NewParameter>()) {
150 // Add new parameter to top graph instead of current graph
151 return BuildNewParameter(pattern, res, top_graph);
152 } else if (pattern->isa<Imm>()) {
153 return BuildImmNode(pattern);
154 } else {
155 MS_LOG(EXCEPTION) << "Cannot find or build target node, pattern: " + pattern->unique_name() + "\n";
156 return nullptr;
157 }
158 return nullptr;
159 }
160
ProcessComplexPatternFirstInput(const PatternPtr & pattern,const MatchResultPtr & res,const FuncGraphPtr & func_graph,const FuncGraphPtr & top_graph)161 AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res,
162 const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph) {
163 if (pattern->isa<Call>()) {
164 return BuildPrimitiveValueNode(pattern, res, func_graph, top_graph);
165 }
166 return nullptr;
167 }
168
BuildTarget(const PatternPtr & pattern,const FuncGraphPtr & func_graph,const FuncGraphPtr & top_graph,const MatchResultPtr & res)169 AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph,
170 const MatchResultPtr &res) {
171 auto target_inputs = pattern->inputs();
172 if (target_inputs.size() == 0) {
173 auto new_anf_node = ProcessSinglePattern(pattern, res, func_graph, top_graph);
174 if (new_anf_node != nullptr) {
175 res->add_entry(pattern, new_anf_node);
176 }
177 return new_anf_node;
178 }
179 // Build up the AnfNode in a recursive manner
180 std::vector<AnfNodePtr> new_inputs;
181 auto prim_value_node = ProcessComplexPatternFirstInput(pattern, res, func_graph, top_graph);
182 MS_EXCEPTION_IF_NULL(prim_value_node);
183 new_inputs.push_back(prim_value_node);
184 for (auto &iter : target_inputs) {
185 if (iter == pattern) {
186 MS_LOG(EXCEPTION) << "Circle references. Got pattern: " + pattern->unique_name() + "\n";
187 }
188 auto input_node = BuildTarget(iter, func_graph, top_graph, res);
189 if (input_node == nullptr) {
190 MS_LOG(EXCEPTION) << "Failed to build input node for pattern : " + iter->unique_name() + "\n";
191 }
192 new_inputs.push_back(input_node);
193 }
194 auto new_c_node = func_graph->NewCNode(new_inputs);
195 res->add_entry(pattern, new_c_node);
196 return new_c_node;
197 }
198
ReflectParamBackToPython(const AnfNodePtr & param,const string & param_name,const tensor::TensorPtr & default_input,bool requires_grad,bool layerwise_parallel)199 void ReflectParamBackToPython(const AnfNodePtr ¶m, const string ¶m_name, const tensor::TensorPtr &default_input,
200 bool requires_grad, bool layerwise_parallel) {
201 // 1. Get current cell object
202 auto ppm = opt::python_pass::PyPassManager::GetInstance();
203 auto resource = ppm->GetResource();
204 py::object top_cell = resource->source_input();
205 if (py::isinstance<py::none>(top_cell)) {
206 MS_LOG(EXCEPTION) << "Failed to get top cell from resource.";
207 }
208 // 2. Clone default_input tensor
209 MS_EXCEPTION_IF_NULL(default_input);
210 auto default_tensor = std::make_shared<tensor::Tensor>(default_input->data_type(), default_input->shape_c(),
211 default_input->data_c(), (size_t)default_input->Size());
212 // 3. New a Parameter object with the above-specified args
213 py::object parameter_class = py::module::import(PARAMETER_MODULE).attr(PARAMETER_CLASS);
214 py::object new_parameter = parameter_class(default_tensor, param_name, requires_grad, layerwise_parallel);
215 // 4. Add the new python Parameter object to Cell's _params attributes
216 top_cell.attr(SET_PARAM)(param_name, new_parameter);
217 // 5. Set default_param for param_node
218 ValuePtr param_value = nullptr;
219 bool converted = parse::ConvertData(new_parameter, ¶m_value, false);
220 if (!converted) {
221 MS_LOG(EXCEPTION) << "Failed to convert new parameter to ValuePtr.";
222 }
223 MS_EXCEPTION_IF_NULL(param);
224 auto param_node = param->cast<ParameterPtr>();
225 MS_EXCEPTION_IF_NULL(param_node);
226 param_node->set_default_param(param_value);
227 }
228
Reset(const PatternPtr & pattern)229 void Reset(const PatternPtr &pattern) {
230 if (pattern->isa<Prim>()) {
231 auto prim_pattern = pattern->cast<PrimPtr>();
232 prim_pattern->reset();
233 } else if (pattern->isa<NewParameter>()) {
234 auto new_param_pattern = pattern->cast<NewParameterPtr>();
235 new_param_pattern->reset();
236 } else if (pattern->isa<Call>()) {
237 auto call_with_pattern = pattern->cast<CallPtr>();
238 for (const auto &sub_pattern : call_with_pattern->inputs()) {
239 Reset(sub_pattern);
240 }
241 }
242 }
243 } // namespace internal
244
Run(const FuncGraphPtr & func_graph,const FuncGraphPtr & top_graph,const AnfNodePtr & node,const MatchResultPtr & res)245 AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph, const AnfNodePtr &node,
246 const MatchResultPtr &res) {
247 auto match_res = src_pattern_->match(node);
248 if (match_res != nullptr) {
249 res->merge(match_res);
250 auto new_node = internal::BuildTarget(dst_pattern_, func_graph, top_graph, res);
251 internal::Reset(dst_pattern());
252 return new_node;
253 }
254 internal::Reset(src_pattern());
255 return nullptr;
256 }
257
Run(const FuncGraphPtr & func_graph,const MatchResultPtr & res)258 bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res) {
259 MS_EXCEPTION_IF_NULL(func_graph);
260 MS_EXCEPTION_IF_NULL(dst_pattern_);
261 if (src_pattern_ == nullptr) {
262 // Add NewParameter
263 auto new_para_pattern = dst_pattern_->cast<NewParameterPtr>();
264 if (new_para_pattern == nullptr) {
265 MS_LOG(EXCEPTION) << "Expect NewParameter pattern for target if src pattern is null.";
266 }
267 auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name();
268 auto para_node = std::make_shared<Parameter>(func_graph);
269 MS_EXCEPTION_IF_NULL(para_node);
270 para_node->set_name(para_name);
271 // Set function graph
272 para_node->set_func_graph(func_graph);
273 // Set Debug Info
274 auto debug_info = std::make_shared<NodeDebugInfo>(para_name);
275 para_node->set_debug_info(debug_info);
276 // Set abstract
277 auto default_value = new_para_pattern->default_tensor();
278 MS_EXCEPTION_IF_NULL(default_value);
279 para_node->set_abstract(default_value->ToAbstract()->Broaden());
280 res->add_entry(dst_pattern_, para_node);
281 func_graph->add_parameter(para_node);
282 // Reflect back to Cell._params
283 internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(),
284 new_para_pattern->layerwise_parallel());
285 MS_LOG(WARNING) << "[Gen]Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name();
286 return true;
287 }
288 FuncGraphManagerPtr manager = func_graph->manager();
289 MS_EXCEPTION_IF_NULL(manager);
290 auto func_graphs = manager->func_graphs();
291 bool changes = false;
292 for (auto &fg : func_graphs) {
293 manager->AddFuncGraph(fg);
294 auto graph_nodes_sorted = TopoSort(fg->output());
295 // Traverse once
296 for (auto &node : graph_nodes_sorted) {
297 AnfNodePtr new_node = Run(fg, func_graph, node, res);
298 if (new_node != nullptr && new_node != node) {
299 MS_LOG(WARNING) << "Matched";
300 (void)manager->Replace(node, new_node);
301 changes = true;
302 }
303 }
304 }
305 return changes;
306 }
307 } // namespace python_pass
308 } // namespace opt
309 } // namespace mindspore
310