• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_
17 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_
18 
19 #include <deque>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include <functional>
24 
25 #include "base/base.h"
26 #include "ir/manager.h"
27 #include "utils/hash_map.h"
28 #include "ir/anf.h"
29 #include "ir/func_graph.h"
30 #include "frontend/optimizer/optimizer_caller.h"
31 #include "frontend/operator/ops.h"
32 
33 namespace mindspore {
34 /* namespace to support opt */
35 namespace opt {
36 
37 // Define the interaction mode between an Optimize pass and Renormalize pass
38 // FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed
39 // CHECK_RENORM: check if the new node is un-typed to decide if the next Renormalize will be executted
40 enum RenormAction : int64_t { FORCE_RENORM = 0, CHECK_RENORM };
41 
42 class Substitution {
43  public:
44   OptimizerCallerPtr transform_;
45   std::string name_;
46   PredicateFuncType predicate_{nullptr};
47   // An enum to mark this Substitution relation to renormalize pass.
48   RenormAction renorm_action_;
49   // Determine whether it is a priority substitution, that is, some patterns need to be matched prior to others.
50   bool has_priority_pattern_{false};
51 
Substitution(const OptimizerCallerPtr & transform,const std::string & name,const PredicateFuncType & predicate,const RenormAction & renorm_action,bool has_priority_pattern)52   Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate,
53                const RenormAction &renorm_action, bool has_priority_pattern)
54       : transform_(transform),
55         name_(name),
56         predicate_(predicate),
57         renorm_action_(renorm_action),
58         has_priority_pattern_(has_priority_pattern) {}
59   ~Substitution() = default;
60   AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node);
61 };
62 
63 using SubstitutionPtr = std::shared_ptr<Substitution>;
64 
65 SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
66                                  const RenormAction &renorm_action = CHECK_RENORM, bool has_priority_pattern = false);
67 SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
68                                  const std::vector<PrimitivePtr> &prims,
69                                  const RenormAction &renorm_action = CHECK_RENORM, bool has_priority_pattern = false);
70 SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
71                                  const PredicateFuncType &predicate, const RenormAction &renorm_action = CHECK_RENORM,
72                                  bool has_priority_pattern = false);
73 
74 enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptTraverseFromSubstitutionsToIR };
75 
76 class SubstitutionList {
77  public:
78   explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false,
79                             bool global_sensitive = false)
list_(patterns)80       : list_(patterns), is_once_(is_once), global_sensitive_(global_sensitive) {}
81   ~SubstitutionList() = default;
82 
83   bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const;
84 
85  private:
86   bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
87   bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph,
88                              const SubstitutionPtr &substitution) const;
89   bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
90   void DisplayStatusOfSubstitution(const mindspore::HashMap<std::string, std::vector<bool>> &status,
91                                    const OptimizerPtr &optimizer, size_t space) const;
92 
93   std::vector<SubstitutionPtr> list_;
94   // a flag to mark this list of Substitution can only be executed only once
95   bool is_once_{false};
96   bool global_sensitive_{false};
97 };
98 
99 // SimpleRewriter simply rewrites a graph according to the node rewriter defined by derived class.
100 class SimpleRewriter {
101  public:
SimpleRewriter(const FuncGraphPtr & root_graph,const FuncGraphManagerPtr & manager)102   SimpleRewriter(const FuncGraphPtr &root_graph, const FuncGraphManagerPtr &manager)
103       : root_graph_(root_graph), manager_(manager) {}
104   virtual ~SimpleRewriter() = default;
105   bool Run();
106 
107  protected:
108   virtual AnfNodePtr NodeRewrite(const AnfNodePtr &node) = 0;
109   FuncGraphPtr root_graph_;
110   FuncGraphManagerPtr manager_;
111 };
112 }  // namespace opt
113 }  // namespace mindspore
114 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_
115