1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/backend/optimizer/visitor.h"
18
19 #include <vector>
20 #include <memory>
21 #include <algorithm>
22 #include "include/backend/optimizer/pattern_engine.h"
23 #include "utils/any.h"
24 #include "ir/anf.h"
25 #include "ir/func_graph.h"
26 #include "utils/log_adapter.h"
27
28 namespace mindspore {
CheckIfNeedExpand(const std::vector<BaseRef> & list)29 bool CheckIfNeedExpand(const std::vector<BaseRef> &list) {
30 return std::any_of(list.begin(), list.end(), [](const BaseRef &any) { return utils::isa<Seq>(any); });
31 }
32
ExpandList(const std::vector<BaseRef> & list)33 std::shared_ptr<VectorRef> ExpandList(const std::vector<BaseRef> &list) {
34 std::shared_ptr<VectorRef> new_list = std::make_shared<VectorRef>();
35 for (auto &item : list) {
36 if (utils::isa<Seq>(item)) {
37 const Seq &seq = utils::cast<Seq>(item);
38 new_list->insert(new_list->end(), seq.begin(), seq.end());
39 } else {
40 new_list->push_back(item);
41 }
42 }
43 return new_list;
44 }
45
GetVar(const BaseRef & x)46 static BaseRef GetVar(const BaseRef &x) {
47 if (utils::isa<AnfNodePtr>(x)) {
48 auto node = utils::cast<AnfNodePtr>(x);
49 MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]";
50 if (node->isa<VarNode>()) {
51 MS_LOG(DEBUG) << "IsVarNode " + node->cast<VarNodePtr>()->var_->ToString();
52 return node->cast<VarNodePtr>()->var_;
53 }
54 }
55 return x;
56 }
57
Visit(const VectorRef & v_any,VectorRef * const values_ref,BaseRef * const visit_out) const58 bool Visitor::Visit(const VectorRef &v_any, VectorRef *const values_ref, BaseRef *const visit_out) const {
59 std::vector<BaseRef> out;
60 for (const auto &element : v_any) {
61 out.push_back(element);
62 values_ref->push_back(GetVar(element));
63 }
64 if (visit_out != nullptr) {
65 *visit_out = ExpandList(out);
66 }
67 return true;
68 }
69
Visit(const BaseRef & any,VectorRef * const values_ref,BaseRef * const visit_out) const70 bool Visitor::Visit(const BaseRef &any, VectorRef *const values_ref, BaseRef *const visit_out) const {
71 if (utils::isa<Seq>(any)) {
72 return Visit(utils::cast<Seq>(any), values_ref, visit_out);
73 } else if (utils::isa<AnfNodePtr>(any)) {
74 auto nodeptr = utils::cast<AnfNodePtr>(any);
75 AnfNodePtr output;
76 AnfNodePtr *p_output = &output;
77 if (visit_out == nullptr) {
78 p_output = nullptr;
79 }
80 Visit(nodeptr, values_ref, p_output);
81 if (visit_out != nullptr) {
82 *visit_out = output;
83 }
84 return true;
85 }
86 MS_LOG(DEBUG) << "VisitError, not support type to Visit: " + any.ToString();
87 return false;
88 }
89
Visit(const AnfNodePtr & node,VectorRef * const values_ref,AnfNodePtr * output) const90 void Visitor::Visit(const AnfNodePtr &node, VectorRef *const values_ref, AnfNodePtr *output) const {
91 if (node->isa<CNode>()) {
92 Visit(node->cast<CNodePtr>(), values_ref, output);
93 return;
94 }
95
96 if (node->isa<ValueNode>()) {
97 Visit(node->cast<ValueNodePtr>(), values_ref, output);
98 return;
99 }
100
101 if (output != nullptr) {
102 *output = node;
103 }
104 }
105
Visit(const CNodePtr & cnode,VectorRef * const values_ref,AnfNodePtr * output) const106 void Visitor::Visit(const CNodePtr &cnode, VectorRef *const values_ref, AnfNodePtr *output) const {
107 // if output is nullptr, it's not required to make the new CNode node.
108 if (output == nullptr) {
109 for (auto &inp : cnode->inputs()) {
110 auto var = GetVar(inp);
111 values_ref->push_back(var);
112 }
113 if (cnode->func_graph() != nullptr) {
114 values_ref->push_back(GetVar(cnode->func_graph()));
115 } else {
116 values_ref->push_back(GetVar(cnode->func_graph_as_var()));
117 }
118 return;
119 }
120
121 std::vector<AnfNodePtr> new_inputs;
122 std::vector<BaseRef> after_cnode_fn;
123 std::shared_ptr<VectorRef> out;
124 for (auto &input : cnode->inputs()) {
125 after_cnode_fn.push_back(input);
126 values_ref->push_back(GetVar(input));
127 }
128 if (CheckIfNeedExpand(after_cnode_fn)) {
129 out = ExpandList(after_cnode_fn);
130 }
131
132 std::vector<BaseRef> &outs = after_cnode_fn;
133 if (out != nullptr) {
134 outs = out->elements();
135 }
136
137 for (auto &any_item : outs) {
138 if (!utils::isa<AnfNodePtr>(any_item)) {
139 MS_LOG(INTERNAL_EXCEPTION) << "VisitError, fn not return the same type AnfNodePtr";
140 }
141 new_inputs.push_back(utils::cast<AnfNodePtr>(any_item));
142 }
143
144 BaseRef any_fg;
145 AnfNodePtr new_cnode = nullptr;
146 if (cnode->func_graph() != nullptr) {
147 any_fg = cnode->func_graph();
148 values_ref->push_back(GetVar(any_fg));
149 if (!utils::isa<FuncGraphPtr>(any_fg)) {
150 MS_LOG(INTERNAL_EXCEPTION) << "VisitError, fn not return the same type FuncGraphPtr";
151 }
152 new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<FuncGraphPtr>(any_fg));
153 } else {
154 any_fg = cnode->func_graph_as_var();
155 values_ref->push_back(GetVar(any_fg));
156 if (utils::isa<VarPtr>(any_fg)) {
157 new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<VarPtr>(any_fg));
158 } else if (utils::isa<FuncGraphPtr>(any_fg)) {
159 new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<FuncGraphPtr>(any_fg));
160 } else {
161 MS_LOG(INTERNAL_EXCEPTION) << "VisitError, fn not return VarPtr or FuncGraphPtr";
162 }
163 }
164 new_cnode->set_abstract(cnode->abstract());
165 *output = new_cnode;
166 }
167
Visit(const ValueNodePtr & vnode,VectorRef * const values_ref,AnfNodePtr * output) const168 void Visitor::Visit(const ValueNodePtr &vnode, VectorRef *const values_ref, AnfNodePtr *output) const {
169 values_ref->push_back(GetVar(vnode->value()));
170 const BaseRef &value = utils::cast<ValuePtr>(vnode->value());
171 if (utils::isa<ValuePtr>(value)) {
172 if (output != nullptr) {
173 auto ct = NewValueNode(utils::cast<ValuePtr>(value));
174 ct->set_abstract(vnode->abstract());
175 *output = ct;
176 }
177 return;
178 }
179 MS_LOG(INTERNAL_EXCEPTION) << "Visit result is not ValuePtr.";
180 }
181 } // namespace mindspore
182