1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_FUNCTIONALIZE_CONTROL_OP_PASS_H_ 18 #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_FUNCTIONALIZE_CONTROL_OP_PASS_H_ 19 #include <string> 20 #include <set> 21 #include <utility> 22 #include <vector> 23 #include <memory> 24 #include "include/backend/optimizer/pass.h" 25 #include "mindspore/core/ops/framework_ops.h" 26 #include "tools/converter/ops/ops_def.h" 27 #include "tools/optimizer/common/gllo_utils.h" 28 #include "include/registry/converter_context.h" 29 30 using mindspore::converter::FmkType; 31 namespace mindspore::opt { 32 using AimFunc = std::function<bool(const AnfNodePtr &)>; 33 class FunctionalizeControlOpPass : public Pass { 34 public: FunctionalizeControlOpPass()35 FunctionalizeControlOpPass() : Pass("functionalize_control_op_pass") {} 36 ~FunctionalizeControlOpPass() override = default; 37 bool Run(const FuncGraphPtr &graph) override; 38 static FuncGraphPtr NewFuncGraph(const std::string &subgraph_name, const FmkType &fmk_type); IsMerge(const AnfNodePtr & node)39 static bool IsMerge(const AnfNodePtr &node) { return CheckPrimitiveType(node, prim::kPrimMerge); } IsLoopCond(const AnfNodePtr & node)40 static bool IsLoopCond(const AnfNodePtr &node) { 41 return CheckPrimitiveType(node, std::make_shared<Primitive>(lite::kNameLoopCond)); 42 } IsEnter(const AnfNodePtr & node)43 static bool IsEnter(const AnfNodePtr &node) { 44 return CheckPrimitiveType(node, std::make_shared<Primitive>(lite::kNameEnter)); 45 } IsExit(const AnfNodePtr & node)46 static bool IsExit(const AnfNodePtr &node) { 47 return CheckPrimitiveType(node, std::make_shared<Primitive>(lite::kNameExit)); 48 } IsSwitch(const AnfNodePtr & node)49 static bool IsSwitch(const AnfNodePtr &node) { return CheckPrimitiveType(node, prim::kPrimSwitch); } IsNextIteration(const AnfNodePtr & node)50 static bool IsNextIteration(const AnfNodePtr &node) { 51 return CheckPrimitiveType(node, std::make_shared<Primitive>(lite::kNameNextIteration)); 52 } IsControlFlowOp(const AnfNodePtr & node)53 static bool IsControlFlowOp(const AnfNodePtr &node) { 54 return IsLoopCond(node) || IsEnter(node) || IsMerge(node) || IsSwitch(node) || IsExit(node) || 55 IsNextIteration(node); 56 } 57 static CNodePtr BelongToWhichNode(const CNodePtr &node, const AimFunc &aim_func, 58 const FilterFunc &filter_func = nullptr); GetSubgraphIndex()59 static int GetSubgraphIndex() { 60 static int subgraph_index = 1; 61 return subgraph_index++; 62 } 63 // The names of nodes with the same prefix are a cluster. 64 static std::string NodeClusterName(const AnfNodePtr &node); 65 void InitNodeClusters(const FuncGraphPtr &func_graph); 66 // return the position in node_clusters_ 67 size_t WhichCluster(const std::string &cluster_name); 68 69 protected: 70 STATUS BuildWhileSubgraph(const FuncGraphPtr &func_graph); 71 static STATUS BuildIfSubgraph(const FuncGraphPtr &func_graph); 72 std::vector<std::pair<std::string, std::vector<AnfNodePtr>>> node_clusters_{}; 73 std::vector<CNodePtr> loop_cond_nodes_{}; 74 }; 75 } // namespace mindspore::opt 76 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_FUNCTIONALIZE_CONTROL_OP_PASS_H_ 77