1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_ 19 20 #include <algorithm> 21 #include <memory> 22 #include <unordered_map> 23 #include <unordered_set> 24 #include <vector> 25 26 #include "ir/func_graph.h" 27 #include "ir/func_graph_cloner.h" 28 #include "frontend/optimizer/optimizer_caller.h" 29 #include "frontend/optimizer/anf_visitor.h" 30 #include "frontend/operator/ops.h" 31 #include "frontend/optimizer/irpass.h" 32 #include "frontend/optimizer/optimizer.h" 33 #include "frontend/optimizer/graph_transform.h" 34 35 namespace mindspore { 36 namespace opt { 37 namespace irpass { 38 // {G, Xs}-->transform graph call tuple inputs to flat inputs. 39 class GraphCallTupleTransform : public AnfVisitor { 40 public: GraphCallTupleTransform(GraphTupleParamTransform & transformer)41 explicit GraphCallTupleTransform(GraphTupleParamTransform &transformer) : graph_transform_(transformer) {} 42 ~GraphCallTupleTransform() override = default; operator()43 AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { 44 if (!node->isa<CNode>() || node->func_graph() == nullptr) { 45 return nullptr; 46 } 47 48 auto cnode = node->cast<CNodePtr>(); 49 MS_EXCEPTION_IF_NULL(cnode); 50 auto &inputs = cnode->inputs(); 51 auto fg = GetValueNode<FuncGraphPtr>(inputs[0]); 52 if (fg == nullptr) { 53 return nullptr; 54 } 55 if (!CNodeHasTupleInput(cnode)) { 56 return nullptr; 57 } 58 FuncGraphPtr transformed_fg = graph_transform_(fg, optimizer->manager()); 59 auto new_node = TransformCallGraph(transformed_fg, cnode); 60 return new_node; 61 } 62 63 private: 64 GraphTupleParamTransform &graph_transform_; 65 }; 66 67 // {{switch, cond, true_branch, false_branch}, Xs} -->transform switch graph call tuple inputs to flat inputs. 68 class SwitchCallTupleTransform : public AnfVisitor { 69 public: SwitchCallTupleTransform(GraphTupleParamTransform & transformer)70 explicit SwitchCallTupleTransform(GraphTupleParamTransform &transformer) : graph_transform_(transformer) {} 71 ~SwitchCallTupleTransform() override = default; operator()72 AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { 73 if (!node->isa<CNode>() || node->func_graph() == nullptr) { 74 return nullptr; 75 } 76 auto switch_call_cnode = node->cast<CNodePtr>(); 77 auto call_inputs = switch_call_cnode->inputs(); 78 if (call_inputs.size() < 1) { 79 return nullptr; 80 } 81 if (!IsPrimitiveCNode(call_inputs[0], prim::kPrimSwitch)) { 82 return nullptr; 83 } 84 auto swich_cnode = call_inputs[0]->cast<CNodePtr>(); 85 auto switch_inputs = swich_cnode->inputs(); 86 if (switch_inputs.size() != 4) { 87 return nullptr; 88 } 89 90 AnfNodePtr transformed = nullptr; 91 bool true_br_changed = TransformBranchNode(switch_inputs[2], optimizer->manager(), &transformed); 92 if (true_br_changed) { 93 switch_inputs[2] = transformed; 94 } 95 bool false_br_changed = TransformBranchNode(switch_inputs[3], optimizer->manager(), &transformed); 96 if (false_br_changed) { 97 switch_inputs[3] = transformed; 98 } 99 if (true_br_changed || false_br_changed) { 100 call_inputs[0] = swich_cnode->func_graph()->NewCNode(switch_inputs); 101 } 102 if (CNodeHasTupleInput(switch_call_cnode)) { 103 return TransformSwitchCall(call_inputs[0], switch_call_cnode); 104 } 105 if (true_br_changed || false_br_changed) { 106 return switch_call_cnode->func_graph()->NewCNode(call_inputs); 107 } 108 return nullptr; 109 } 110 TransformBranchNode(AnfNodePtr node,FuncGraphManagerPtr mng,AnfNodePtr * trans_node)111 bool TransformBranchNode(AnfNodePtr node, FuncGraphManagerPtr mng, AnfNodePtr *trans_node) { 112 if (IsValueNode<FuncGraph>(node)) { 113 FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node); 114 if (FuncGraphHasTupleInput(fg)) { 115 FuncGraphPtr transformed_fg = graph_transform_(fg, mng); 116 *trans_node = NewValueNode(transformed_fg); 117 return true; 118 } 119 return false; 120 } 121 if (IsPrimitiveCNode(node, prim::kPrimPartial)) { 122 auto partial_inputs = node->cast<CNodePtr>()->inputs(); 123 if (IsValueNode<FuncGraph>(partial_inputs[1])) { 124 FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(partial_inputs[1]); 125 if (FuncGraphHasTupleInput(fg)) { 126 fg = graph_transform_(fg, mng); 127 } 128 if (CNodeHasTupleInput(node->cast<CNodePtr>())) { 129 *trans_node = TransformPartial(fg, node->cast<CNodePtr>()); 130 return true; 131 } 132 } 133 return false; 134 } 135 136 MS_LOG(WARNING) << "Got unexpected switch branch node " << node->DebugString(); 137 return false; 138 } 139 140 private: 141 GraphTupleParamTransform &graph_transform_; 142 }; 143 144 // {{switch_layer, index, {make_tuple, br1, br2,...,}}, Xs} -> 145 // transform switch layer graph call tuple inputs to flat inputs. 146 class SwitchLayerCallTupleTransform : public AnfVisitor { 147 public: SwitchLayerCallTupleTransform(GraphTupleParamTransform & transformer)148 explicit SwitchLayerCallTupleTransform(GraphTupleParamTransform &transformer) : graph_transform_(transformer) {} 149 ~SwitchLayerCallTupleTransform() override = default; operator()150 AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { 151 if (!node->isa<CNode>() || node->func_graph() == nullptr) { 152 return nullptr; 153 } 154 auto switch_layer_call_cnode = node->cast<CNodePtr>(); 155 auto call_inputs = switch_layer_call_cnode->inputs(); 156 if (call_inputs.size() < 1) { 157 return nullptr; 158 } 159 if (!IsPrimitiveCNode(call_inputs[0], prim::kPrimSwitchLayer)) { 160 return nullptr; 161 } 162 auto swich_layer_cnode = call_inputs[0]->cast<CNodePtr>(); 163 auto switch_layer_inputs = swich_layer_cnode->inputs(); 164 if (switch_layer_inputs.size() != 3) { 165 return nullptr; 166 } 167 168 AnfNodePtr transformed = nullptr; 169 bool layer_changed = TransformLayerNode(switch_layer_inputs[2], optimizer->manager(), &transformed); 170 if (layer_changed) { 171 switch_layer_inputs[2] = transformed; 172 call_inputs[0] = switch_layer_call_cnode->func_graph()->NewCNode(switch_layer_inputs); 173 } 174 if (CNodeHasTupleInput(switch_layer_call_cnode)) { 175 return TransformSwitchCall(call_inputs[0], switch_layer_call_cnode); 176 } 177 if (layer_changed) { 178 return switch_layer_call_cnode->func_graph()->NewCNode(call_inputs); 179 } 180 return nullptr; 181 } 182 TransformLayerNode(AnfNodePtr node,FuncGraphManagerPtr mng,AnfNodePtr * trans_node)183 bool TransformLayerNode(AnfNodePtr node, FuncGraphManagerPtr mng, AnfNodePtr *trans_node) { 184 if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { 185 MS_LOG(WARNING) << "SwitchLayer input is not MakeTuple"; 186 return false; 187 } 188 auto tuple_inputs = node->cast<CNodePtr>()->inputs(); 189 bool changed = false; 190 for (size_t i = 1; i < tuple_inputs.size(); i++) { 191 if (!IsValueNode<FuncGraph>(tuple_inputs[i])) { 192 MS_LOG(WARNING) << "SwitchLayer input is not FuncGraph"; 193 return false; 194 } 195 FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(tuple_inputs[i]); 196 if (FuncGraphHasTupleInput(fg)) { 197 FuncGraphPtr transformed_fg = graph_transform_(fg, mng); 198 tuple_inputs[i] = NewValueNode(transformed_fg); 199 changed = true; 200 } 201 } 202 if (changed) { 203 *trans_node = node->func_graph()->NewCNode(tuple_inputs); 204 } 205 return changed; 206 } 207 208 private: 209 GraphTupleParamTransform &graph_transform_; 210 }; 211 212 class CallGraphTupleTransform : public OptimizerCaller { 213 public: CallGraphTupleTransform()214 CallGraphTupleTransform() 215 : graph_transformer_(), 216 graph_call_transform_(std::make_shared<GraphCallTupleTransform>(graph_transformer_)), 217 switch_call_transform_(std::make_shared<SwitchCallTupleTransform>(graph_transformer_)), 218 switch_layer_call_transform_(std::make_shared<SwitchLayerCallTupleTransform>(graph_transformer_)) { 219 transformers_.emplace_back(graph_call_transform_); 220 transformers_.emplace_back(switch_call_transform_); 221 transformers_.emplace_back(switch_layer_call_transform_); 222 } 223 ~CallGraphTupleTransform() = default; 224 operator()225 AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { 226 AnfNodePtr new_node; 227 for (auto &transform : transformers_) { 228 new_node = (*transform)(optimizer, node); 229 if (new_node != nullptr) { 230 return new_node; 231 } 232 } 233 return nullptr; 234 } 235 236 private: 237 GraphTupleParamTransform graph_transformer_; 238 OptimizerCallerPtr graph_call_transform_; 239 OptimizerCallerPtr switch_call_transform_; 240 OptimizerCallerPtr switch_layer_call_transform_; 241 std::vector<OptimizerCallerPtr> transformers_{}; 242 }; 243 } // namespace irpass 244 } // namespace opt 245 } // namespace mindspore 246 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_ 247