• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_
19 
20 #include <deque>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 #include <unordered_map>
25 
26 #include "ir/anf.h"
27 #include "ir/func_graph.h"
28 #include "frontend/optimizer/optimizer_caller.h"
29 #include "frontend/operator/ops.h"
30 
31 namespace mindspore {
32 /* namespace to support opt */
33 namespace opt {
34 
35 // Define the interaction mode between an Optimize pass and Renormalize pass
36 // FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed
37 // CHECK_RENORM: check if the new node is un-typed to decide if the next Renormalize will be executted
38 enum RenormAction : int64_t { FORCE_RENORM = 0, CHECK_RENORM };
39 
40 class Substitution {
41  public:
42   OptimizerCallerPtr transform_;
43   std::string name_;
44   PredicateFuncType predicate_{nullptr};
45   // An enum to mark this Substitution relation to renormalize pass.
46   RenormAction renorm_action_;
47   // Determine whether it is a priority substitution, that is, some patterns need to be matched prior to others.
48   bool has_priority_pattern_{false};
49 
Substitution(const OptimizerCallerPtr & transform,const std::string & name,const PredicateFuncType & predicate,const RenormAction & renorm_action,bool has_priority_pattern)50   Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate,
51                const RenormAction &renorm_action, bool has_priority_pattern)
52       : transform_(transform),
53         name_(name),
54         predicate_(predicate),
55         renorm_action_(renorm_action),
56         has_priority_pattern_(has_priority_pattern) {}
57   ~Substitution() = default;
58   AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node);
59 };
60 
61 using SubstitutionPtr = std::shared_ptr<Substitution>;
62 
63 SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
64                                  const RenormAction &action_renorm = CHECK_RENORM, bool has_priority_pattern = false);
65 SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
66                                  const std::vector<PrimitivePtr> &prims,
67                                  const RenormAction &action_renorm = CHECK_RENORM, bool has_priority_pattern = false);
68 SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
69                                  const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM,
70                                  bool has_priority_pattern = false);
71 
72 enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptTraverseFromSubstitutionsToIR };
73 
74 class SubstitutionList {
75  public:
76   explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false,
77                             bool global_sensitive = false)
list_(patterns)78       : list_(patterns), is_once_(is_once), global_sensitive_(global_sensitive) {}
79   ~SubstitutionList() = default;
80 
81   bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const;
82 
83  private:
84   bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
85   bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph,
86                              const SubstitutionPtr &sub) const;
87   bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
88   void DisplayStatusOfSubstitution(const std::unordered_map<std::string, std::vector<bool>> &status,
89                                    const OptimizerPtr &optimizer, size_t space) const;
90 
91   std::vector<SubstitutionPtr> list_;
92   // a flag to mark this list of Substitution can only be executed only once
93   bool is_once_{false};
94   bool global_sensitive_{false};
95 };
96 }  // namespace opt
97 }  // namespace mindspore
98 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_
99