1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_ 17 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_ 18 19 #include <deque> 20 #include <memory> 21 #include <string> 22 #include <vector> 23 #include <functional> 24 25 #include "base/base.h" 26 #include "ir/manager.h" 27 #include "utils/hash_map.h" 28 #include "ir/anf.h" 29 #include "ir/func_graph.h" 30 #include "frontend/optimizer/optimizer_caller.h" 31 #include "frontend/operator/ops.h" 32 33 namespace mindspore { 34 /* namespace to support opt */ 35 namespace opt { 36 37 // Define the interaction mode between an Optimize pass and Renormalize pass 38 // FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed 39 // CHECK_RENORM: check if the new node is un-typed to decide if the next Renormalize will be executted 40 enum RenormAction : int64_t { FORCE_RENORM = 0, CHECK_RENORM }; 41 42 class Substitution { 43 public: 44 OptimizerCallerPtr transform_; 45 std::string name_; 46 PredicateFuncType predicate_{nullptr}; 47 // An enum to mark this Substitution relation to renormalize pass. 48 RenormAction renorm_action_; 49 // Determine whether it is a priority substitution, that is, some patterns need to be matched prior to others. 50 bool has_priority_pattern_{false}; 51 Substitution(const OptimizerCallerPtr & transform,const std::string & name,const PredicateFuncType & predicate,const RenormAction & renorm_action,bool has_priority_pattern)52 Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, 53 const RenormAction &renorm_action, bool has_priority_pattern) 54 : transform_(transform), 55 name_(name), 56 predicate_(predicate), 57 renorm_action_(renorm_action), 58 has_priority_pattern_(has_priority_pattern) {} 59 ~Substitution() = default; 60 AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node); 61 }; 62 63 using SubstitutionPtr = std::shared_ptr<Substitution>; 64 65 SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, 66 const RenormAction &renorm_action = CHECK_RENORM, bool has_priority_pattern = false); 67 SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, 68 const std::vector<PrimitivePtr> &prims, 69 const RenormAction &renorm_action = CHECK_RENORM, bool has_priority_pattern = false); 70 SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, 71 const PredicateFuncType &predicate, const RenormAction &renorm_action = CHECK_RENORM, 72 bool has_priority_pattern = false); 73 74 enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptTraverseFromSubstitutionsToIR }; 75 76 class SubstitutionList { 77 public: 78 explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false, 79 bool global_sensitive = false) list_(patterns)80 : list_(patterns), is_once_(is_once), global_sensitive_(global_sensitive) {} 81 ~SubstitutionList() = default; 82 83 bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const; 84 85 private: 86 bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; 87 bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph, 88 const SubstitutionPtr &substitution) const; 89 bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; 90 void DisplayStatusOfSubstitution(const mindspore::HashMap<std::string, std::vector<bool>> &status, 91 const OptimizerPtr &optimizer, size_t space) const; 92 93 std::vector<SubstitutionPtr> list_; 94 // a flag to mark this list of Substitution can only be executed only once 95 bool is_once_{false}; 96 bool global_sensitive_{false}; 97 }; 98 99 // SimpleRewriter simply rewrites a graph according to the node rewriter defined by derived class. 100 class SimpleRewriter { 101 public: SimpleRewriter(const FuncGraphPtr & root_graph,const FuncGraphManagerPtr & manager)102 SimpleRewriter(const FuncGraphPtr &root_graph, const FuncGraphManagerPtr &manager) 103 : root_graph_(root_graph), manager_(manager) {} 104 virtual ~SimpleRewriter() = default; 105 bool Run(); 106 107 protected: 108 virtual AnfNodePtr NodeRewrite(const AnfNodePtr &node) = 0; 109 FuncGraphPtr root_graph_; 110 FuncGraphManagerPtr manager_; 111 }; 112 } // namespace opt 113 } // namespace mindspore 114 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_ 115