• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/backend/optimizer/visitor.h"
18 
19 #include <vector>
20 #include <memory>
21 #include <algorithm>
22 #include "include/backend/optimizer/pattern_engine.h"
23 #include "utils/any.h"
24 #include "ir/anf.h"
25 #include "ir/func_graph.h"
26 #include "utils/log_adapter.h"
27 
28 namespace mindspore {
CheckIfNeedExpand(const std::vector<BaseRef> & list)29 bool CheckIfNeedExpand(const std::vector<BaseRef> &list) {
30   return std::any_of(list.begin(), list.end(), [](const BaseRef &any) { return utils::isa<Seq>(any); });
31 }
32 
ExpandList(const std::vector<BaseRef> & list)33 std::shared_ptr<VectorRef> ExpandList(const std::vector<BaseRef> &list) {
34   std::shared_ptr<VectorRef> new_list = std::make_shared<VectorRef>();
35   for (auto &item : list) {
36     if (utils::isa<Seq>(item)) {
37       const Seq &seq = utils::cast<Seq>(item);
38       new_list->insert(new_list->end(), seq.begin(), seq.end());
39     } else {
40       new_list->push_back(item);
41     }
42   }
43   return new_list;
44 }
45 
GetVar(const BaseRef & x)46 static BaseRef GetVar(const BaseRef &x) {
47   if (utils::isa<AnfNodePtr>(x)) {
48     auto node = utils::cast<AnfNodePtr>(x);
49     MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]";
50     if (node->isa<VarNode>()) {
51       MS_LOG(DEBUG) << "IsVarNode " + node->cast<VarNodePtr>()->var_->ToString();
52       return node->cast<VarNodePtr>()->var_;
53     }
54   }
55   return x;
56 }
57 
Visit(const VectorRef & v_any,VectorRef * const values_ref,BaseRef * const visit_out) const58 bool Visitor::Visit(const VectorRef &v_any, VectorRef *const values_ref, BaseRef *const visit_out) const {
59   std::vector<BaseRef> out;
60   for (const auto &element : v_any) {
61     out.push_back(element);
62     values_ref->push_back(GetVar(element));
63   }
64   if (visit_out != nullptr) {
65     *visit_out = ExpandList(out);
66   }
67   return true;
68 }
69 
Visit(const BaseRef & any,VectorRef * const values_ref,BaseRef * const visit_out) const70 bool Visitor::Visit(const BaseRef &any, VectorRef *const values_ref, BaseRef *const visit_out) const {
71   if (utils::isa<Seq>(any)) {
72     return Visit(utils::cast<Seq>(any), values_ref, visit_out);
73   } else if (utils::isa<AnfNodePtr>(any)) {
74     auto nodeptr = utils::cast<AnfNodePtr>(any);
75     AnfNodePtr output;
76     AnfNodePtr *p_output = &output;
77     if (visit_out == nullptr) {
78       p_output = nullptr;
79     }
80     Visit(nodeptr, values_ref, p_output);
81     if (visit_out != nullptr) {
82       *visit_out = output;
83     }
84     return true;
85   }
86   MS_LOG(DEBUG) << "VisitError, not support type to Visit: " + any.ToString();
87   return false;
88 }
89 
Visit(const AnfNodePtr & node,VectorRef * const values_ref,AnfNodePtr * output) const90 void Visitor::Visit(const AnfNodePtr &node, VectorRef *const values_ref, AnfNodePtr *output) const {
91   if (node->isa<CNode>()) {
92     Visit(node->cast<CNodePtr>(), values_ref, output);
93     return;
94   }
95 
96   if (node->isa<ValueNode>()) {
97     Visit(node->cast<ValueNodePtr>(), values_ref, output);
98     return;
99   }
100 
101   if (output != nullptr) {
102     *output = node;
103   }
104 }
105 
Visit(const CNodePtr & cnode,VectorRef * const values_ref,AnfNodePtr * output) const106 void Visitor::Visit(const CNodePtr &cnode, VectorRef *const values_ref, AnfNodePtr *output) const {
107   // if output is nullptr, it's not required to make the new CNode node.
108   if (output == nullptr) {
109     for (auto &inp : cnode->inputs()) {
110       auto var = GetVar(inp);
111       values_ref->push_back(var);
112     }
113     if (cnode->func_graph() != nullptr) {
114       values_ref->push_back(GetVar(cnode->func_graph()));
115     } else {
116       values_ref->push_back(GetVar(cnode->func_graph_as_var()));
117     }
118     return;
119   }
120 
121   std::vector<AnfNodePtr> new_inputs;
122   std::vector<BaseRef> after_cnode_fn;
123   std::shared_ptr<VectorRef> out;
124   for (auto &input : cnode->inputs()) {
125     after_cnode_fn.push_back(input);
126     values_ref->push_back(GetVar(input));
127   }
128   if (CheckIfNeedExpand(after_cnode_fn)) {
129     out = ExpandList(after_cnode_fn);
130   }
131 
132   std::vector<BaseRef> &outs = after_cnode_fn;
133   if (out != nullptr) {
134     outs = out->elements();
135   }
136 
137   for (auto &any_item : outs) {
138     if (!utils::isa<AnfNodePtr>(any_item)) {
139       MS_LOG(INTERNAL_EXCEPTION) << "VisitError, fn not return the same type AnfNodePtr";
140     }
141     new_inputs.push_back(utils::cast<AnfNodePtr>(any_item));
142   }
143 
144   BaseRef any_fg;
145   AnfNodePtr new_cnode = nullptr;
146   if (cnode->func_graph() != nullptr) {
147     any_fg = cnode->func_graph();
148     values_ref->push_back(GetVar(any_fg));
149     if (!utils::isa<FuncGraphPtr>(any_fg)) {
150       MS_LOG(INTERNAL_EXCEPTION) << "VisitError, fn not return the same type FuncGraphPtr";
151     }
152     new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<FuncGraphPtr>(any_fg));
153   } else {
154     any_fg = cnode->func_graph_as_var();
155     values_ref->push_back(GetVar(any_fg));
156     if (utils::isa<VarPtr>(any_fg)) {
157       new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<VarPtr>(any_fg));
158     } else if (utils::isa<FuncGraphPtr>(any_fg)) {
159       new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<FuncGraphPtr>(any_fg));
160     } else {
161       MS_LOG(INTERNAL_EXCEPTION) << "VisitError, fn not return VarPtr or FuncGraphPtr";
162     }
163   }
164   new_cnode->set_abstract(cnode->abstract());
165   *output = new_cnode;
166 }
167 
Visit(const ValueNodePtr & vnode,VectorRef * const values_ref,AnfNodePtr * output) const168 void Visitor::Visit(const ValueNodePtr &vnode, VectorRef *const values_ref, AnfNodePtr *output) const {
169   values_ref->push_back(GetVar(vnode->value()));
170   const BaseRef &value = utils::cast<ValuePtr>(vnode->value());
171   if (utils::isa<ValuePtr>(value)) {
172     if (output != nullptr) {
173       auto ct = NewValueNode(utils::cast<ValuePtr>(value));
174       ct->set_abstract(vnode->abstract());
175       *output = ct;
176     }
177     return;
178   }
179   MS_LOG(INTERNAL_EXCEPTION) << "Visit result is not ValuePtr.";
180 }
181 }  // namespace mindspore
182