• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_
19 
20 #include <algorithm>
21 #include <memory>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <vector>
25 
26 #include "ir/func_graph.h"
27 #include "ir/func_graph_cloner.h"
28 #include "frontend/optimizer/optimizer_caller.h"
29 #include "frontend/optimizer/anf_visitor.h"
30 #include "frontend/operator/ops.h"
31 #include "frontend/optimizer/irpass.h"
32 #include "frontend/optimizer/optimizer.h"
33 #include "frontend/optimizer/graph_transform.h"
34 
35 namespace mindspore {
36 namespace opt {
37 namespace irpass {
38 // {G, Xs}-->transform graph call tuple inputs to flat inputs.
39 class GraphCallTupleTransform : public AnfVisitor {
40  public:
GraphCallTupleTransform(GraphTupleParamTransform & transformer)41   explicit GraphCallTupleTransform(GraphTupleParamTransform &transformer) : graph_transform_(transformer) {}
42   ~GraphCallTupleTransform() override = default;
operator()43   AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
44     if (!node->isa<CNode>() || node->func_graph() == nullptr) {
45       return nullptr;
46     }
47 
48     auto cnode = node->cast<CNodePtr>();
49     MS_EXCEPTION_IF_NULL(cnode);
50     auto &inputs = cnode->inputs();
51     auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
52     if (fg == nullptr) {
53       return nullptr;
54     }
55     if (!CNodeHasTupleInput(cnode)) {
56       return nullptr;
57     }
58     FuncGraphPtr transformed_fg = graph_transform_(fg, optimizer->manager());
59     auto new_node = TransformCallGraph(transformed_fg, cnode);
60     return new_node;
61   }
62 
63  private:
64   GraphTupleParamTransform &graph_transform_;
65 };
66 
67 // {{switch, cond, true_branch, false_branch}, Xs} -->transform switch graph call tuple inputs to flat inputs.
68 class SwitchCallTupleTransform : public AnfVisitor {
69  public:
SwitchCallTupleTransform(GraphTupleParamTransform & transformer)70   explicit SwitchCallTupleTransform(GraphTupleParamTransform &transformer) : graph_transform_(transformer) {}
71   ~SwitchCallTupleTransform() override = default;
operator()72   AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
73     if (!node->isa<CNode>() || node->func_graph() == nullptr) {
74       return nullptr;
75     }
76     auto switch_call_cnode = node->cast<CNodePtr>();
77     auto call_inputs = switch_call_cnode->inputs();
78     if (call_inputs.size() < 1) {
79       return nullptr;
80     }
81     if (!IsPrimitiveCNode(call_inputs[0], prim::kPrimSwitch)) {
82       return nullptr;
83     }
84     auto swich_cnode = call_inputs[0]->cast<CNodePtr>();
85     auto switch_inputs = swich_cnode->inputs();
86     if (switch_inputs.size() != 4) {
87       return nullptr;
88     }
89 
90     AnfNodePtr transformed = nullptr;
91     bool true_br_changed = TransformBranchNode(switch_inputs[2], optimizer->manager(), &transformed);
92     if (true_br_changed) {
93       switch_inputs[2] = transformed;
94     }
95     bool false_br_changed = TransformBranchNode(switch_inputs[3], optimizer->manager(), &transformed);
96     if (false_br_changed) {
97       switch_inputs[3] = transformed;
98     }
99     if (true_br_changed || false_br_changed) {
100       call_inputs[0] = swich_cnode->func_graph()->NewCNode(switch_inputs);
101     }
102     if (CNodeHasTupleInput(switch_call_cnode)) {
103       return TransformSwitchCall(call_inputs[0], switch_call_cnode);
104     }
105     if (true_br_changed || false_br_changed) {
106       return switch_call_cnode->func_graph()->NewCNode(call_inputs);
107     }
108     return nullptr;
109   }
110 
TransformBranchNode(AnfNodePtr node,FuncGraphManagerPtr mng,AnfNodePtr * trans_node)111   bool TransformBranchNode(AnfNodePtr node, FuncGraphManagerPtr mng, AnfNodePtr *trans_node) {
112     if (IsValueNode<FuncGraph>(node)) {
113       FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
114       if (FuncGraphHasTupleInput(fg)) {
115         FuncGraphPtr transformed_fg = graph_transform_(fg, mng);
116         *trans_node = NewValueNode(transformed_fg);
117         return true;
118       }
119       return false;
120     }
121     if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
122       auto partial_inputs = node->cast<CNodePtr>()->inputs();
123       if (IsValueNode<FuncGraph>(partial_inputs[1])) {
124         FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(partial_inputs[1]);
125         if (FuncGraphHasTupleInput(fg)) {
126           fg = graph_transform_(fg, mng);
127         }
128         if (CNodeHasTupleInput(node->cast<CNodePtr>())) {
129           *trans_node = TransformPartial(fg, node->cast<CNodePtr>());
130           return true;
131         }
132       }
133       return false;
134     }
135 
136     MS_LOG(WARNING) << "Got unexpected switch branch node " << node->DebugString();
137     return false;
138   }
139 
140  private:
141   GraphTupleParamTransform &graph_transform_;
142 };
143 
144 // {{switch_layer, index, {make_tuple, br1, br2,...,}}, Xs} ->
145 // transform switch layer graph call tuple inputs to flat inputs.
146 class SwitchLayerCallTupleTransform : public AnfVisitor {
147  public:
SwitchLayerCallTupleTransform(GraphTupleParamTransform & transformer)148   explicit SwitchLayerCallTupleTransform(GraphTupleParamTransform &transformer) : graph_transform_(transformer) {}
149   ~SwitchLayerCallTupleTransform() override = default;
operator()150   AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
151     if (!node->isa<CNode>() || node->func_graph() == nullptr) {
152       return nullptr;
153     }
154     auto switch_layer_call_cnode = node->cast<CNodePtr>();
155     auto call_inputs = switch_layer_call_cnode->inputs();
156     if (call_inputs.size() < 1) {
157       return nullptr;
158     }
159     if (!IsPrimitiveCNode(call_inputs[0], prim::kPrimSwitchLayer)) {
160       return nullptr;
161     }
162     auto swich_layer_cnode = call_inputs[0]->cast<CNodePtr>();
163     auto switch_layer_inputs = swich_layer_cnode->inputs();
164     if (switch_layer_inputs.size() != 3) {
165       return nullptr;
166     }
167 
168     AnfNodePtr transformed = nullptr;
169     bool layer_changed = TransformLayerNode(switch_layer_inputs[2], optimizer->manager(), &transformed);
170     if (layer_changed) {
171       switch_layer_inputs[2] = transformed;
172       call_inputs[0] = switch_layer_call_cnode->func_graph()->NewCNode(switch_layer_inputs);
173     }
174     if (CNodeHasTupleInput(switch_layer_call_cnode)) {
175       return TransformSwitchCall(call_inputs[0], switch_layer_call_cnode);
176     }
177     if (layer_changed) {
178       return switch_layer_call_cnode->func_graph()->NewCNode(call_inputs);
179     }
180     return nullptr;
181   }
182 
TransformLayerNode(AnfNodePtr node,FuncGraphManagerPtr mng,AnfNodePtr * trans_node)183   bool TransformLayerNode(AnfNodePtr node, FuncGraphManagerPtr mng, AnfNodePtr *trans_node) {
184     if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
185       MS_LOG(WARNING) << "SwitchLayer input is not MakeTuple";
186       return false;
187     }
188     auto tuple_inputs = node->cast<CNodePtr>()->inputs();
189     bool changed = false;
190     for (size_t i = 1; i < tuple_inputs.size(); i++) {
191       if (!IsValueNode<FuncGraph>(tuple_inputs[i])) {
192         MS_LOG(WARNING) << "SwitchLayer input is not FuncGraph";
193         return false;
194       }
195       FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(tuple_inputs[i]);
196       if (FuncGraphHasTupleInput(fg)) {
197         FuncGraphPtr transformed_fg = graph_transform_(fg, mng);
198         tuple_inputs[i] = NewValueNode(transformed_fg);
199         changed = true;
200       }
201     }
202     if (changed) {
203       *trans_node = node->func_graph()->NewCNode(tuple_inputs);
204     }
205     return changed;
206   }
207 
208  private:
209   GraphTupleParamTransform &graph_transform_;
210 };
211 
212 class CallGraphTupleTransform : public OptimizerCaller {
213  public:
CallGraphTupleTransform()214   CallGraphTupleTransform()
215       : graph_transformer_(),
216         graph_call_transform_(std::make_shared<GraphCallTupleTransform>(graph_transformer_)),
217         switch_call_transform_(std::make_shared<SwitchCallTupleTransform>(graph_transformer_)),
218         switch_layer_call_transform_(std::make_shared<SwitchLayerCallTupleTransform>(graph_transformer_)) {
219     transformers_.emplace_back(graph_call_transform_);
220     transformers_.emplace_back(switch_call_transform_);
221     transformers_.emplace_back(switch_layer_call_transform_);
222   }
223   ~CallGraphTupleTransform() = default;
224 
operator()225   AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
226     AnfNodePtr new_node;
227     for (auto &transform : transformers_) {
228       new_node = (*transform)(optimizer, node);
229       if (new_node != nullptr) {
230         return new_node;
231       }
232     }
233     return nullptr;
234   }
235 
236  private:
237   GraphTupleParamTransform graph_transformer_;
238   OptimizerCallerPtr graph_call_transform_;
239   OptimizerCallerPtr switch_call_transform_;
240   OptimizerCallerPtr switch_layer_call_transform_;
241   std::vector<OptimizerCallerPtr> transformers_{};
242 };
243 }  // namespace irpass
244 }  // namespace opt
245 }  // namespace mindspore
246 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_
247