• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "backend/optimizer/common/visit.h"
20 
21 #include <vector>
22 #include <memory>
23 #include <algorithm>
24 #include "backend/optimizer/common/pattern_engine.h"
25 #include "utils/any.h"
26 #include "ir/anf.h"
27 #include "ir/func_graph.h"
28 #include "utils/log_adapter.h"
29 
30 /* namespace to support utils definition */
31 namespace mindspore {
CheckIfNeedExpand(const std::vector<BaseRef> & list)32 bool CheckIfNeedExpand(const std::vector<BaseRef> &list) {
33   return std::any_of(list.begin(), list.end(), [](const BaseRef &any) { return utils::isa<Seq>(any); });
34 }
35 
ExpandList(const std::vector<BaseRef> & list)36 std::shared_ptr<VectorRef> ExpandList(const std::vector<BaseRef> &list) {
37   std::shared_ptr<VectorRef> new_list = std::make_shared<VectorRef>();
38   for (auto &item : list) {
39     if (utils::isa<Seq>(item)) {
40       const Seq &seq = utils::cast<Seq>(item);
41       new_list->insert(new_list->end(), seq.begin(), seq.end());
42     } else {
43       new_list->push_back(item);
44     }
45   }
46   return new_list;
47 }
48 
GetVar(const BaseRef & x)49 static BaseRef GetVar(const BaseRef &x) {
50   if (utils::isa<AnfNodePtr>(x)) {
51     auto node = utils::cast<AnfNodePtr>(x);
52     MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]";
53     if (node->isa<VarNode>()) {
54       MS_LOG(DEBUG) << "IsVarNode " + node->cast<VarNodePtr>()->var_->ToString();
55       return node->cast<VarNodePtr>()->var_;
56     }
57   }
58   return x;
59 }
60 
Visit(const VectorRef & v_any,VectorRef * const values_ref,BaseRef * const visit_out) const61 bool Visitor::Visit(const VectorRef &v_any, VectorRef *const values_ref, BaseRef *const visit_out) const {
62   std::vector<BaseRef> out;
63   for (const auto &element : v_any) {
64     out.push_back(element);
65     values_ref->push_back(GetVar(element));
66   }
67   if (visit_out != nullptr) {
68     *visit_out = ExpandList(out);
69   }
70   return true;
71 }
72 
Visit(const BaseRef & any,VectorRef * const values_ref,BaseRef * const visit_out) const73 bool Visitor::Visit(const BaseRef &any, VectorRef *const values_ref, BaseRef *const visit_out) const {
74   if (utils::isa<Seq>(any)) {
75     return Visit(utils::cast<Seq>(any), values_ref, visit_out);
76   } else if (utils::isa<AnfNodePtr>(any)) {
77     auto nodeptr = utils::cast<AnfNodePtr>(any);
78     AnfNodePtr output;
79     AnfNodePtr *p_output = &output;
80     if (visit_out == nullptr) {
81       p_output = nullptr;
82     }
83     Visit(nodeptr, values_ref, p_output);
84     if (visit_out != nullptr) {
85       *visit_out = output;
86     }
87     return true;
88   }
89   MS_LOG(DEBUG) << "VisitError, not support type to Visit: " + any.ToString();
90   return false;
91 }
92 
Visit(const AnfNodePtr & node,VectorRef * const values_ref,AnfNodePtr * output) const93 void Visitor::Visit(const AnfNodePtr &node, VectorRef *const values_ref, AnfNodePtr *output) const {
94   if (node->isa<CNode>()) {
95     Visit(node->cast<CNodePtr>(), values_ref, output);
96     return;
97   }
98 
99   if (node->isa<ValueNode>()) {
100     Visit(node->cast<ValueNodePtr>(), values_ref, output);
101     return;
102   }
103 
104   if (output != nullptr) {
105     *output = node;
106   }
107 }
108 
Visit(const CNodePtr & cnode,VectorRef * const values_ref,AnfNodePtr * output) const109 void Visitor::Visit(const CNodePtr &cnode, VectorRef *const values_ref, AnfNodePtr *output) const {
110   // if output is nullptr, it's not required to make the new CNode node.
111   if (output == nullptr) {
112     for (auto &inp : cnode->inputs()) {
113       auto var = GetVar(inp);
114       values_ref->push_back(var);
115     }
116     if (cnode->func_graph() != nullptr) {
117       values_ref->push_back(GetVar(cnode->func_graph()));
118     } else {
119       values_ref->push_back(GetVar(cnode->func_graph_as_var()));
120     }
121     return;
122   }
123 
124   std::vector<AnfNodePtr> new_inputs;
125   std::vector<BaseRef> after_cnode_fn;
126   std::shared_ptr<VectorRef> out;
127   for (auto &input : cnode->inputs()) {
128     after_cnode_fn.push_back(input);
129     values_ref->push_back(GetVar(input));
130   }
131   if (CheckIfNeedExpand(after_cnode_fn)) {
132     out = ExpandList(after_cnode_fn);
133   }
134 
135   std::vector<BaseRef> &outs = after_cnode_fn;
136   if (out != nullptr) {
137     outs = out->elements();
138   }
139 
140   for (auto &any_item : outs) {
141     if (!utils::isa<AnfNodePtr>(any_item)) {
142       MS_LOG(EXCEPTION) << "VisitError, fn not return the same type AnfNodePtr";
143     }
144     new_inputs.push_back(utils::cast<AnfNodePtr>(any_item));
145   }
146 
147   BaseRef any_fg;
148   AnfNodePtr new_cnode = nullptr;
149   if (cnode->func_graph() != nullptr) {
150     any_fg = cnode->func_graph();
151     values_ref->push_back(GetVar(any_fg));
152     if (!utils::isa<FuncGraphPtr>(any_fg)) {
153       MS_LOG(EXCEPTION) << "VisitError, fn not return the same type FuncGraphPtr";
154     }
155     new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<FuncGraphPtr>(any_fg));
156   } else {
157     any_fg = cnode->func_graph_as_var();
158     values_ref->push_back(GetVar(any_fg));
159     if (utils::isa<VarPtr>(any_fg)) {
160       new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<VarPtr>(any_fg));
161     } else if (utils::isa<FuncGraphPtr>(any_fg)) {
162       new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<FuncGraphPtr>(any_fg));
163     } else {
164       MS_LOG(EXCEPTION) << "VisitError, fn not return VarPtr or FuncGraphPtr";
165     }
166   }
167   new_cnode->set_abstract(cnode->abstract());
168   *output = new_cnode;
169 }
170 
Visit(const ValueNodePtr & vnode,VectorRef * const values_ref,AnfNodePtr * output) const171 void Visitor::Visit(const ValueNodePtr &vnode, VectorRef *const values_ref, AnfNodePtr *output) const {
172   values_ref->push_back(GetVar(vnode->value()));
173   const BaseRef &value = utils::cast<ValuePtr>(vnode->value());
174   if (utils::isa<ValuePtr>(value)) {
175     if (output != nullptr) {
176       auto ct = NewValueNode(utils::cast<ValuePtr>(value));
177       ct->set_abstract(vnode->abstract());
178       *output = ct;
179     }
180     return;
181   }
182   MS_LOG(EXCEPTION) << "Visit result is not ValuePtr.";
183 }
184 }  // namespace mindspore
185