• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/ps/load_mindir.h"
17 
18 #include <string>
19 #include <set>
20 #include <memory>
21 #include <algorithm>
22 
23 #include "utils/log_adapter.h"
24 #include "abstract/abstract_value.h"
25 #include "pipeline/jit/ps/parse/parse_base.h"
26 #include "utils/check_convert_utils.h"
27 #include "load_mindir/infer_mindir.h"
28 
29 namespace mindspore {
30 namespace pipeline {
InferMindIR(const ResourcePtr & resource)31 bool InferMindIR(const ResourcePtr &resource) {
32   MS_EXCEPTION_IF_NULL(resource);
33   const auto &root = resource->func_graph();
34   InferFuncGraphLoaded(root);
35   return true;
36 }
37 
ArgsNeededToConvert(const PrimitivePtr & prim)38 std::vector<AnfNodePtr> ArgsNeededToConvert(const PrimitivePtr &prim) {
39   auto op_def = mindspore::ops::GetOpDef(prim->name());
40   std::vector<AnfNodePtr> prim_init_arg_nodes;
41   MS_EXCEPTION_IF_NULL(op_def);
42   // Get init args.
43   for (const auto &op_arg : op_def->args_) {
44     if (op_arg.as_init_arg_) {
45       auto arg_name = op_arg.arg_name_;
46       ValuePtr attr;
47       // "data_format" is renamed as "format" for some operator.
48       if (CheckAndConvertUtils::CheckPrimAttrConverted(prim->name()) && arg_name == "data_format" &&
49           prim->HasAttr("format")) {
50         attr = prim->GetAttr("format");
51       } else if (!prim->HasAttr(arg_name)) {
52         attr = parse::GetArgDefaultValue(prim->name(), arg_name);
53         if (attr == nullptr) {
54           MS_LOG(EXCEPTION) << "Cannot find attribute: " << arg_name << " from primitive :" << prim->name();
55         }
56       } else {
57         attr = prim->GetAttr(arg_name);
58       }
59       (void)prim_init_arg_nodes.emplace_back(NewValueNode(attr));
60     }
61   }
62   return prim_init_arg_nodes;
63 }
64 
ModifyOneCNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode)65 void ModifyOneCNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
66   MS_EXCEPTION_IF_NULL(cnode);
67   auto &inputs = cnode->inputs();
68   if (IsValueNode<Primitive>(inputs[0])) {
69     auto prim = GetValueNode<PrimitivePtr>(inputs[0]);
70     if (mindspore::ops::IsPrimitiveFunction(prim->name())) {
71       // Append Primitive arguments to the inputs.
72       std::vector<AnfNodePtr> prim_init_arg_nodes = ArgsNeededToConvert(prim);
73       // Get call args.
74       AnfNodePtrList prim_call_arg_nodes(inputs.begin() + 1, inputs.end());
75       // Create new node.
76       auto new_prim = std::make_shared<Primitive>(*prim);
77       AnfNodePtrList input_nodes{NewValueNode(new_prim)};
78       (void)std::copy(prim_call_arg_nodes.cbegin(), prim_call_arg_nodes.cend(), std::back_inserter(input_nodes));
79       (void)std::copy(prim_init_arg_nodes.cbegin(), prim_init_arg_nodes.cend(), std::back_inserter(input_nodes));
80       auto new_cnode = func_graph->NewCNodeInOrder(input_nodes);
81       MS_LOG(DEBUG) << "Convert primitive args: " << prim->name() << ". node: " << cnode->DebugString()
82                     << ", new_node: " << new_cnode->DebugString();
83       auto manager = func_graph->manager();
84       if (manager == nullptr) {
85         manager = MakeManager();
86         manager->AddFuncGraph(func_graph, true);
87       }
88       (void)manager->Replace(cnode, new_cnode);
89     }
90   }
91 }
92 
ModifyOneFuncGraph(const FuncGraphPtr & func_graph,std::set<FuncGraphPtr> * func_graph_set,std::set<FuncGraphPtr> * func_graph_modified)93 void ModifyOneFuncGraph(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *func_graph_set,
94                         std::set<FuncGraphPtr> *func_graph_modified) {
95   MS_LOG(DEBUG) << "Start modifying: " << func_graph->ToString();
96   std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
97   for (const AnfNodePtr &node : nodes) {
98     MS_EXCEPTION_IF_NULL(node);
99     if (!node->isa<CNode>()) {
100       continue;
101     }
102     auto cnode = node->cast<CNodePtr>();
103     ModifyOneCNode(func_graph, cnode);
104     auto &inputs = cnode->inputs();
105     for (size_t i = 0; i < inputs.size(); ++i) {
106       if (IsValueNode<FuncGraph>(inputs[i])) {
107         FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(inputs[i]);
108         if ((*func_graph_set).find(fg) == (*func_graph_set).end() &&
109             (*func_graph_modified).find(fg) == (*func_graph_modified).end()) {
110           (void)(*func_graph_set).insert(fg);
111         }
112       }
113     }
114   }
115 }
116 
ModifyGraphs(const FuncGraphPtr & func_graph)117 void ModifyGraphs(const FuncGraphPtr &func_graph) {
118   std::set<FuncGraphPtr> func_graph_set{};
119   std::set<FuncGraphPtr> func_graph_modified{};
120   func_graph_set.insert(func_graph);
121   // Check every node in every graph to find nodes needed to convert.
122   while (!func_graph_set.empty()) {
123     FuncGraphPtr fg = *func_graph_set.cbegin();
124     if (!func_graph->has_flag("generated_from_mindir_with_prim_func")) {
125       ModifyOneFuncGraph(fg, &func_graph_set, &func_graph_modified);
126     }
127     (void)func_graph_set.erase(fg);
128     (void)func_graph_modified.insert(fg);
129   }
130 }
131 
ModifyGraphGeneratedByMindIR(const ResourcePtr & resource)132 bool ModifyGraphGeneratedByMindIR(const ResourcePtr &resource) {
133   MS_EXCEPTION_IF_NULL(resource);
134   const auto &func_graph = resource->func_graph();
135   ModifyGraphs(func_graph);
136   return true;
137 }
138 }  // namespace pipeline
139 }  // namespace mindspore
140