• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/jit/parse/resolve.h"
18 
19 #include <string>
20 #include <memory>
21 #include <vector>
22 
23 #include "ir/param_info.h"
24 #include "pipeline/jit/parse/data_converter.h"
25 #include "pipeline/jit/parse/parse.h"
26 #include "pipeline/jit/parse/python_adapter.h"
27 #include "utils/any.h"
28 #include "frontend/operator/ops.h"
29 #include "frontend/optimizer/opt.h"
30 #include "frontend/optimizer/irpass.h"
31 #include "frontend/optimizer/irpass/symbol_resolver.h"
32 
33 namespace mindspore {
34 namespace parse {
ToAbstract()35 abstract::AbstractBasePtr ClassObject::ToAbstract() {
36   ClassPtr cls_ptr = ParseDataClass(obj());
37   auto abs_scalar = std::make_shared<abstract::AbstractScalar>();
38   abs_scalar->set_type(std::make_shared<TypeType>());
39   abs_scalar->set_value(cls_ptr);
40 
41   AbstractBasePtrList args_spec_list = {abs_scalar};
42   auto func_ptr = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeRecord);
43   return std::make_shared<abstract::PartialAbstractClosure>(func_ptr, args_spec_list);
44 }
45 
IsSupportedCreateInstanceType(const py::object & obj)46 static inline bool IsSupportedCreateInstanceType(const py::object &obj) {
47   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
48   auto res = python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE, obj);
49   if (!py::isinstance<py::bool_>(res)) {
50     MS_LOG(ERROR) << "Expect a bool type, but got " << py::str(res);
51     return false;
52   }
53   return res.cast<bool>();
54 }
55 
ToAbstract()56 abstract::AbstractBasePtr ClassType::ToAbstract() {
57   auto abs_scalar =
58     std::make_shared<abstract::AbstractScalar>(shared_from_base<ClassType>(), std::make_shared<TypeType>());
59 
60   // The fallback feature is enabled in default.
61   // Not support change the flag during the process is alive.
62   static const auto support_fallback = common::GetEnv("ENV_SUPPORT_FALLBACK");
63   static const auto use_fallback = (support_fallback == "1");
64   if (use_fallback && !IsSupportedCreateInstanceType(obj())) {
65     return abs_scalar;
66   }
67   AbstractBasePtrList args_spec_list = {abs_scalar};
68 
69   auto func_ptr = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimCreateInstance);
70   auto ret_val = std::make_shared<abstract::PartialAbstractClosure>(func_ptr, args_spec_list);
71   ret_val->set_value_desc(ToString());
72   return ret_val;
73 }
74 
75 // call python PYTHON_MOD_RESOLVE_FUNCTION interface to resolve the symbol in corresponding namespace
Resolve()76 bool SymbolResolver::Resolve() {
77   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
78 
79   py::object obj = namespace_->obj();
80   std::string symbol = symbol_->symbol();
81   if (py::isinstance<py::none>(obj)) {
82     MS_EXCEPTION(NameError) << "The name \'" << symbol << "\' is not defined.";
83   }
84   result_ = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_FUNCTION, obj, common::SafeCStr(symbol));
85   return true;
86 }
87 
88 namespace {
89 // If any mixed precision flag add a cast node after the parameter node.
90 // argument obj should be python Parameter object
91 // it will be converted to Parameter node here
ResolveParameterObj(const FuncGraphPtr & func_graph,const py::object & obj)92 AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) {
93   MS_EXCEPTION_IF_NULL(func_graph);
94 
95   // Parameter object should not be none
96   if (py::isinstance<py::none>(obj)) {
97     MS_LOG(EXCEPTION) << "Resolve class Parameter error because obj is null.";
98   }
99 
100   if (!py::hasattr(obj, "name")) {
101     MS_LOG(EXCEPTION) << "Resolve class Parameter error: cannot find name attr for obj";
102   }
103 
104   // Get the parameter name from parameter object
105   auto name_attr = python_adapter::GetPyObjAttr(obj, "name");
106   if (py::isinstance<py::none>(name_attr)) {
107     MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
108   }
109 
110   auto param_name = py::cast<std::string>(name_attr);
111   auto top_func_graph = Parser::GetTopFuncGraph();
112   // If the parameter node has been created , return it.
113   AnfNodePtr para_node = nullptr;
114   for (auto const &param : top_func_graph->parameters()) {
115     auto param_node = dyn_cast<Parameter>(param);
116     if (param_node != nullptr && param_node->name() == param_name) {
117       para_node = param;
118       MS_LOG(DEBUG) << "Found existing parameter for " << func_graph->ToString()
119                     << ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
120       break;
121     }
122   }
123   if (para_node == nullptr) {
124     auto node = top_func_graph->AddWeightParameter(param_name);
125     auto value = py::cast<tensor::MetaTensorPtr>(obj);
126     node->set_default_param(value);
127     // Set abstract for parameter
128     auto abs = value->ToAbstract();
129     node->set_abstract(abs);
130     para_node = node;
131     MS_LOG(DEBUG) << "Created a new weight parameter for " << func_graph->ToString()
132                   << ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
133   }
134   func_graph->add_parameter_obj_node(para_node);
135 
136   return para_node;
137 }
138 
BroadenCNodeAbstract(const FuncGraphPtr & func_graph)139 void BroadenCNodeAbstract(const FuncGraphPtr &func_graph) {
140   std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
141   for (const AnfNodePtr &node : nodes) {
142     if (!node->isa<CNode>()) {
143       continue;
144     }
145     auto abstract = node->abstract();
146     if (abstract != nullptr) {
147       node->set_abstract(abstract->Broaden());
148     }
149   }
150 }
151 
ConvertLoadedGraph(const FuncGraphPtr & func_graph,const ValuePtr & value)152 void ConvertLoadedGraph(const FuncGraphPtr &func_graph, const ValuePtr &value) {
153   if (!value->isa<FuncGraph>()) {
154     return;
155   }
156   auto resolved_graph = value->cast<FuncGraphPtr>();
157   MS_EXCEPTION_IF_NULL(resolved_graph);
158   if (!resolved_graph->has_attr("is_load")) {
159     return;
160   }
161   auto top_graph = Parser::GetTopFuncGraph();
162   std::vector<AnfNodePtr> input_params;
163   for (auto const &param : resolved_graph->parameters()) {
164     auto param_ptr = dyn_cast<Parameter>(param);
165     MS_EXCEPTION_IF_NULL(param_ptr);
166     if (param_ptr->has_default()) {
167       param_ptr->set_func_graph(top_graph);
168       func_graph->add_parameter_obj_node(param_ptr);
169 
170       // Update top_graph
171       top_graph->add_parameter(param_ptr);
172       size_t hyper_param_count = top_graph->hyper_param_count();
173       top_graph->set_hyper_param_count(hyper_param_count + 1);
174     } else {
175       input_params.push_back(param_ptr);
176     }
177   }
178   resolved_graph->set_parameters(input_params);
179   BroadenCNodeAbstract(resolved_graph);
180 }
181 
ResolveObjectToNode(const FuncGraphPtr & func_graph,const py::object & obj,AnfNodePtr * const node)182 bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) {
183   AnfNodePtr output = nullptr;
184   if (py::hasattr(obj, "__parameter__") && py::isinstance<tensor::MetaTensor>(obj)) {
185     auto param = ResolveParameterObj(func_graph, obj);
186     if (param == nullptr) {
187       MS_LOG(ERROR) << "Resolve parameter object failed, got nullptr";
188       return false;
189     }
190     MS_LOG(DEBUG) << "Add param graph:" << func_graph->ToString() << ", " << param->DebugString();
191     output = param;
192   } else if (py::hasattr(obj, "__parameter_tuple__")) {
193     auto tuple = obj.cast<py::tuple>();
194     std::vector<AnfNodePtr> args;
195     args.push_back(NewValueNode(prim::kPrimMakeTuple));
196     for (size_t it = 0; it < tuple.size(); ++it) {
197       AnfNodePtr out = nullptr;
198       bool success = ResolveObjectToNode(func_graph, tuple[it], &out);
199       if (!success) {
200         MS_LOG(ERROR) << "Resolve object to node failed";
201         return false;
202       }
203       args.push_back(out);
204     }
205     output = NewCNode(args, func_graph);
206   } else {
207     ValuePtr convert_result = nullptr;
208     bool converted = ConvertData(obj, &convert_result, parse::python_adapter::UseSignatureInResolve());
209     if (!converted) {
210       MS_LOG(ERROR) << "Convert data failed";
211       return false;
212     }
213     MS_EXCEPTION_IF_NULL(convert_result);
214     ConvertLoadedGraph(func_graph, convert_result);
215     output = NewValueNode(convert_result);
216     if (convert_result->isa<tensor::Tensor>()) {
217       output = GetMixedPrecisionCastHelp(func_graph, output);
218     }
219   }
220   *node = output;
221   return true;
222 }
223 
IsAllFuncInValueSequence(const std::vector<ValuePtr> & value_vec)224 bool IsAllFuncInValueSequence(const std::vector<ValuePtr> &value_vec) {
225   if (value_vec.empty()) {
226     return false;
227   }
228   for (auto &elem : value_vec) {
229     MS_EXCEPTION_IF_NULL(elem);
230     if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) {
231       const auto &vec = GetValue<ValuePtrList>(elem);
232       auto is_graph = IsAllFuncInValueSequence(vec);
233       if (!is_graph) {
234         return false;
235       }
236     } else if (!elem->isa<FuncGraph>() && !elem->isa<Primitive>()) {
237       return false;
238     }
239   }
240   return true;
241 }
242 
TransformToMakeTupleNodes(const FuncGraphManagerPtr & manager,const FuncGraphPtr & func_graph,const std::vector<ValuePtr> & value_vec)243 AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
244                                      const std::vector<ValuePtr> &value_vec) {
245   std::vector<AnfNodePtr> nodes;
246   nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
247   for (auto &elem : value_vec) {
248     MS_EXCEPTION_IF_NULL(elem);
249     AnfNodePtr node = nullptr;
250     if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) {
251       const auto &vec = GetValue<std::vector<ValuePtr>>(elem);
252       node = TransformToMakeTupleNodes(manager, func_graph, vec);
253     } else if (elem->isa<FuncGraph>()) {
254       FuncGraphPtr new_fg = elem->cast<FuncGraphPtr>();
255       manager->AddFuncGraph(new_fg);
256       node = NewValueNode(new_fg);
257     } else if (elem->isa<Primitive>()) {
258       node = NewValueNode(elem);
259     } else {
260       MS_LOG(EXCEPTION) << "TransformToMakeTupleNodes error, expect funcgraph, got " << elem->ToString();
261     }
262     nodes.emplace_back(node);
263   }
264   auto cnode = func_graph->NewCNode(nodes);
265   return cnode;
266 }
267 
268 // Transform the ValueTuple or ValueList of graph/primitive node to make tuple of const graph/primitive node
TransformVectorFuncValueNode(const FuncGraphManagerPtr & manager,const FuncGraphPtr & func_graph,const ValueNodePtr & value_node,AnfNodePtr * const transformed)269 bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
270                                   const ValueNodePtr &value_node, AnfNodePtr *const transformed) {
271   MS_EXCEPTION_IF_NULL(value_node);
272   const auto &value_vec = GetValue<ValuePtrList>(value_node->value());
273   if (!IsAllFuncInValueSequence(value_vec)) {
274     return false;
275   }
276 
277   // (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
278   // So if has graph in list, try to replace the node with make tuple of graph value node.
279   // We do this because the graph manager won't investigate the graph inside valuetuple,
280   // change the vector of graph to be make_tuple of graph value node.
281   // (2) the primitive valuetuple or valuelist may encounter to abstract error, make it all
282   // independent nodes.
283   auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec);
284   // Replace the ret ptr to be make tuple of graph value node
285   *transformed = node_tuple_graphs;
286 
287   return true;
288 }
289 
290 // Resolve the python obj, and if the resovled node is valuenode with graphs, add the graphs to manager.
ResolveObjectAndAddToManager(const FuncGraphManagerPtr & manager,const py::object & obj,const AnfNodePtr & node)291 AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, const py::object &obj,
292                                         const AnfNodePtr &node) {
293   ScopeGuard scope_guard(node->scope());
294   AnfNodePtr resolved_node = nullptr;
295   bool success = ResolveObjectToNode(node->func_graph(), obj, &resolved_node);
296   if (!success) {
297     MS_LOG(EXCEPTION) << "Parse Resolve covert failed NodeInfo.";
298   }
299   if (IsValueNode<FuncGraph>(resolved_node)) {
300     auto new_fg = GetValueNode<FuncGraphPtr>(resolved_node);
301     manager->AddFuncGraph(new_fg);
302   }
303 
304   // If the constant node is constant of vector of graph, add graph to manager.
305   if (IsValueNode<ValueTuple>(resolved_node) || IsValueNode<ValueList>(resolved_node)) {
306     (void)TransformVectorFuncValueNode(manager, node->func_graph(), resolved_node->cast<ValueNodePtr>(),
307                                        &resolved_node);
308   }
309   return resolved_node;
310 }
311 }  // namespace
312 
ResolveSymbol(const FuncGraphManagerPtr & manager,const NameSpacePtr & name_space,const SymbolPtr & symbol,const AnfNodePtr & node)313 AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
314                          const AnfNodePtr &node) {
315   MS_EXCEPTION_IF_NULL(node);
316   TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
317   if (node->func_graph() == nullptr || manager == nullptr) {
318     MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
319   }
320   SymbolResolver symbol_resolver(name_space, symbol, node);
321   symbol_resolver.Resolve();
322   py::object obj = symbol_resolver.result();
323   AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
324   return resolved_node;
325 }
326 
ResolveCellwithAttr(const FuncGraphManagerPtr & manager,const NameSpacePtr & name_space,const SymbolPtr & symbol,const AnfNodePtr & node,const AnfNodePtr & attr)327 AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space,
328                                const SymbolPtr &symbol, const AnfNodePtr &node, const AnfNodePtr &attr) {
329   MS_EXCEPTION_IF_NULL(node);
330   TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
331   if (node->func_graph() == nullptr || manager == nullptr) {
332     MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
333   }
334   SymbolResolver symbol_resolver(name_space, symbol, node);
335   if (!symbol_resolver.Resolve()) {
336     MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo.";
337   }
338 
339   py::object obj = symbol_resolver.result();
340   if (!data_converter::IsCellInstance(obj)) {
341     AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
342     AnfNodePtrList inputs = {NewValueNode(prim::kPrimGetAttr), resolved_node, attr};
343     AnfNodePtr res_node = node->func_graph()->NewCNode(inputs);
344     TraceManager::ClearParseOrResolveDebugInfo();
345     return res_node;
346   }
347 
348   const std::string fn = PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL;
349   const std::string module = "mindspore._extends.parse.parser";
350   py::object namespace_obj = parse::python_adapter::GetPyFn(module, fn)(obj);
351   auto new_namespace = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_obj);
352   std::string attr_as_string = GetValueNode<StringImmPtr>(attr)->value();
353   auto new_symbol = std::make_shared<Symbol>(attr_as_string);
354   MS_LOG(DEBUG) << "name_space: " << new_namespace->ToString() << ", symbol: " << new_symbol->ToString();
355 
356   AnfNodePtrList inputs = {NewValueNode(prim::kPrimResolve), NewValueNode(new_namespace), NewValueNode(new_symbol)};
357   AnfNodePtr resolved_node = node->func_graph()->NewCNode(inputs);
358   TraceManager::ClearParseOrResolveDebugInfo();
359   return resolved_node;
360 }
361 
362 namespace {
GetOptResolvePasses(const opt::irpass::ResolveIRPassLib & irpass)363 opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) {
364   // For resolve and getattr primitive.
365   opt::OptPassGroupMap map({
366     {"resolve",
367      {
368        irpass.resolver_getattr_resolve_,
369      }},
370   });
371   return map;
372 }
373 }  // namespace
374 
ResolveFuncGraph(const FuncGraphPtr & func_graph,const pipeline::ResourceBasePtr & res,bool use_profile)375 bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile) {
376   if (func_graph == nullptr || res == nullptr) {
377     MS_LOG(ERROR) << "func_graph or resource is null";
378     return false;
379   }
380   opt::irpass::ResolveIRPassLib irpass;
381   opt::OptimizerPtr opt_resolve =
382     opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass), false, false, false);
383 
384   (void)parse::python_adapter::set_python_scoped();
385 
386   MS_EXCEPTION_IF_NULL(opt_resolve);
387   (void)opt_resolve->step(func_graph, use_profile);
388   return true;
389 }
390 
ResolveAll(const FuncGraphManagerPtr & manager)391 bool ResolveAll(const FuncGraphManagerPtr &manager) {
392   if (manager == nullptr) {
393     MS_LOG(ERROR) << "func graph manager is null";
394     return false;
395   }
396 
397   if (manager->roots().size() > 1) {
398     MS_LOG(WARNING)
399       << "After call ResolveAll, only one graph will be kept in GraphManager. ResolveAll can resolve graphs"
400          "called from root graph, so it's not necessary to pass all graphs as roots. "
401          "Please ensure your usage.";
402   }
403   // Should not use pipeline::Resource as Resource::Clean will clean some
404   // global variable such as ScopeManager, it will cause JExpandedGraphs::GetBprop
405   // fail as valid scope has been cleaned.
406   auto res = std::make_shared<pipeline::ResourceBase>();
407   res->set_manager(manager);
408 
409   auto roots = manager->roots();
410   for (auto &fg : roots) {
411     bool ret = ResolveFuncGraph(fg, res, false);
412     if (!ret) {
413       MS_EXCEPTION_IF_NULL(fg);
414       MS_LOG(ERROR) << "Resolve fg " << fg->ToString() << " failed";
415     }
416   }
417   return true;
418 }
419 }  // namespace parse
420 }  // namespace mindspore
421