• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_FUNCTIONALIZE_CONTROL_OP_PASS_H_
18 #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_FUNCTIONALIZE_CONTROL_OP_PASS_H_
19 #include <string>
20 #include <set>
21 #include <utility>
22 #include <vector>
23 #include <memory>
24 #include "include/backend/optimizer/pass.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "tools/converter/ops/ops_def.h"
27 #include "tools/optimizer/common/gllo_utils.h"
28 #include "include/registry/converter_context.h"
29 
30 using mindspore::converter::FmkType;
31 namespace mindspore::opt {
32 using AimFunc = std::function<bool(const AnfNodePtr &)>;
33 class FunctionalizeControlOpPass : public Pass {
34  public:
FunctionalizeControlOpPass()35   FunctionalizeControlOpPass() : Pass("functionalize_control_op_pass") {}
36   ~FunctionalizeControlOpPass() override = default;
37   bool Run(const FuncGraphPtr &graph) override;
38   static FuncGraphPtr NewFuncGraph(const std::string &subgraph_name, const FmkType &fmk_type);
IsMerge(const AnfNodePtr & node)39   static bool IsMerge(const AnfNodePtr &node) { return CheckPrimitiveType(node, prim::kPrimMerge); }
IsLoopCond(const AnfNodePtr & node)40   static bool IsLoopCond(const AnfNodePtr &node) {
41     return CheckPrimitiveType(node, std::make_shared<Primitive>(lite::kNameLoopCond));
42   }
IsEnter(const AnfNodePtr & node)43   static bool IsEnter(const AnfNodePtr &node) {
44     return CheckPrimitiveType(node, std::make_shared<Primitive>(lite::kNameEnter));
45   }
IsExit(const AnfNodePtr & node)46   static bool IsExit(const AnfNodePtr &node) {
47     return CheckPrimitiveType(node, std::make_shared<Primitive>(lite::kNameExit));
48   }
IsSwitch(const AnfNodePtr & node)49   static bool IsSwitch(const AnfNodePtr &node) { return CheckPrimitiveType(node, prim::kPrimSwitch); }
IsNextIteration(const AnfNodePtr & node)50   static bool IsNextIteration(const AnfNodePtr &node) {
51     return CheckPrimitiveType(node, std::make_shared<Primitive>(lite::kNameNextIteration));
52   }
IsControlFlowOp(const AnfNodePtr & node)53   static bool IsControlFlowOp(const AnfNodePtr &node) {
54     return IsLoopCond(node) || IsEnter(node) || IsMerge(node) || IsSwitch(node) || IsExit(node) ||
55            IsNextIteration(node);
56   }
57   static CNodePtr BelongToWhichNode(const CNodePtr &node, const AimFunc &aim_func,
58                                     const FilterFunc &filter_func = nullptr);
GetSubgraphIndex()59   static int GetSubgraphIndex() {
60     static int subgraph_index = 1;
61     return subgraph_index++;
62   }
63   // The names of nodes with the same prefix are a cluster.
64   static std::string NodeClusterName(const AnfNodePtr &node);
65   void InitNodeClusters(const FuncGraphPtr &func_graph);
66   // return the position in node_clusters_
67   size_t WhichCluster(const std::string &cluster_name);
68 
69  protected:
70   STATUS BuildWhileSubgraph(const FuncGraphPtr &func_graph);
71   static STATUS BuildIfSubgraph(const FuncGraphPtr &func_graph);
72   std::vector<std::pair<std::string, std::vector<AnfNodePtr>>> node_clusters_{};
73   std::vector<CNodePtr> loop_cond_nodes_{};
74 };
75 }  // namespace mindspore::opt
76 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_FUNCTIONALIZE_CONTROL_OP_PASS_H_
77