1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "backend/optimizer/common/visit.h"
20
21 #include <vector>
22 #include <memory>
23 #include <algorithm>
24 #include "backend/optimizer/common/pattern_engine.h"
25 #include "utils/any.h"
26 #include "ir/anf.h"
27 #include "ir/func_graph.h"
28 #include "utils/log_adapter.h"
29
30 /* namespace to support utils definition */
31 namespace mindspore {
CheckIfNeedExpand(const std::vector<BaseRef> & list)32 bool CheckIfNeedExpand(const std::vector<BaseRef> &list) {
33 return std::any_of(list.begin(), list.end(), [](const BaseRef &any) { return utils::isa<Seq>(any); });
34 }
35
ExpandList(const std::vector<BaseRef> & list)36 std::shared_ptr<VectorRef> ExpandList(const std::vector<BaseRef> &list) {
37 std::shared_ptr<VectorRef> new_list = std::make_shared<VectorRef>();
38 for (auto &item : list) {
39 if (utils::isa<Seq>(item)) {
40 const Seq &seq = utils::cast<Seq>(item);
41 new_list->insert(new_list->end(), seq.begin(), seq.end());
42 } else {
43 new_list->push_back(item);
44 }
45 }
46 return new_list;
47 }
48
GetVar(const BaseRef & x)49 static BaseRef GetVar(const BaseRef &x) {
50 if (utils::isa<AnfNodePtr>(x)) {
51 auto node = utils::cast<AnfNodePtr>(x);
52 MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]";
53 if (node->isa<VarNode>()) {
54 MS_LOG(DEBUG) << "IsVarNode " + node->cast<VarNodePtr>()->var_->ToString();
55 return node->cast<VarNodePtr>()->var_;
56 }
57 }
58 return x;
59 }
60
Visit(const VectorRef & v_any,VectorRef * const values_ref,BaseRef * const visit_out) const61 bool Visitor::Visit(const VectorRef &v_any, VectorRef *const values_ref, BaseRef *const visit_out) const {
62 std::vector<BaseRef> out;
63 for (const auto &element : v_any) {
64 out.push_back(element);
65 values_ref->push_back(GetVar(element));
66 }
67 if (visit_out != nullptr) {
68 *visit_out = ExpandList(out);
69 }
70 return true;
71 }
72
Visit(const BaseRef & any,VectorRef * const values_ref,BaseRef * const visit_out) const73 bool Visitor::Visit(const BaseRef &any, VectorRef *const values_ref, BaseRef *const visit_out) const {
74 if (utils::isa<Seq>(any)) {
75 return Visit(utils::cast<Seq>(any), values_ref, visit_out);
76 } else if (utils::isa<AnfNodePtr>(any)) {
77 auto nodeptr = utils::cast<AnfNodePtr>(any);
78 AnfNodePtr output;
79 AnfNodePtr *p_output = &output;
80 if (visit_out == nullptr) {
81 p_output = nullptr;
82 }
83 Visit(nodeptr, values_ref, p_output);
84 if (visit_out != nullptr) {
85 *visit_out = output;
86 }
87 return true;
88 }
89 MS_LOG(DEBUG) << "VisitError, not support type to Visit: " + any.ToString();
90 return false;
91 }
92
Visit(const AnfNodePtr & node,VectorRef * const values_ref,AnfNodePtr * output) const93 void Visitor::Visit(const AnfNodePtr &node, VectorRef *const values_ref, AnfNodePtr *output) const {
94 if (node->isa<CNode>()) {
95 Visit(node->cast<CNodePtr>(), values_ref, output);
96 return;
97 }
98
99 if (node->isa<ValueNode>()) {
100 Visit(node->cast<ValueNodePtr>(), values_ref, output);
101 return;
102 }
103
104 if (output != nullptr) {
105 *output = node;
106 }
107 }
108
Visit(const CNodePtr & cnode,VectorRef * const values_ref,AnfNodePtr * output) const109 void Visitor::Visit(const CNodePtr &cnode, VectorRef *const values_ref, AnfNodePtr *output) const {
110 // if output is nullptr, it's not required to make the new CNode node.
111 if (output == nullptr) {
112 for (auto &inp : cnode->inputs()) {
113 auto var = GetVar(inp);
114 values_ref->push_back(var);
115 }
116 if (cnode->func_graph() != nullptr) {
117 values_ref->push_back(GetVar(cnode->func_graph()));
118 } else {
119 values_ref->push_back(GetVar(cnode->func_graph_as_var()));
120 }
121 return;
122 }
123
124 std::vector<AnfNodePtr> new_inputs;
125 std::vector<BaseRef> after_cnode_fn;
126 std::shared_ptr<VectorRef> out;
127 for (auto &input : cnode->inputs()) {
128 after_cnode_fn.push_back(input);
129 values_ref->push_back(GetVar(input));
130 }
131 if (CheckIfNeedExpand(after_cnode_fn)) {
132 out = ExpandList(after_cnode_fn);
133 }
134
135 std::vector<BaseRef> &outs = after_cnode_fn;
136 if (out != nullptr) {
137 outs = out->elements();
138 }
139
140 for (auto &any_item : outs) {
141 if (!utils::isa<AnfNodePtr>(any_item)) {
142 MS_LOG(EXCEPTION) << "VisitError, fn not return the same type AnfNodePtr";
143 }
144 new_inputs.push_back(utils::cast<AnfNodePtr>(any_item));
145 }
146
147 BaseRef any_fg;
148 AnfNodePtr new_cnode = nullptr;
149 if (cnode->func_graph() != nullptr) {
150 any_fg = cnode->func_graph();
151 values_ref->push_back(GetVar(any_fg));
152 if (!utils::isa<FuncGraphPtr>(any_fg)) {
153 MS_LOG(EXCEPTION) << "VisitError, fn not return the same type FuncGraphPtr";
154 }
155 new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<FuncGraphPtr>(any_fg));
156 } else {
157 any_fg = cnode->func_graph_as_var();
158 values_ref->push_back(GetVar(any_fg));
159 if (utils::isa<VarPtr>(any_fg)) {
160 new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<VarPtr>(any_fg));
161 } else if (utils::isa<FuncGraphPtr>(any_fg)) {
162 new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<FuncGraphPtr>(any_fg));
163 } else {
164 MS_LOG(EXCEPTION) << "VisitError, fn not return VarPtr or FuncGraphPtr";
165 }
166 }
167 new_cnode->set_abstract(cnode->abstract());
168 *output = new_cnode;
169 }
170
Visit(const ValueNodePtr & vnode,VectorRef * const values_ref,AnfNodePtr * output) const171 void Visitor::Visit(const ValueNodePtr &vnode, VectorRef *const values_ref, AnfNodePtr *output) const {
172 values_ref->push_back(GetVar(vnode->value()));
173 const BaseRef &value = utils::cast<ValuePtr>(vnode->value());
174 if (utils::isa<ValuePtr>(value)) {
175 if (output != nullptr) {
176 auto ct = NewValueNode(utils::cast<ValuePtr>(value));
177 ct->set_abstract(vnode->abstract());
178 *output = ct;
179 }
180 return;
181 }
182 MS_LOG(EXCEPTION) << "Visit result is not ValuePtr.";
183 }
184 } // namespace mindspore
185