• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_OPTIMIZER_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_OPTIMIZER_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 #include <unordered_map>
23 
24 #include "ir/anf.h"
25 #include "ir/func_graph.h"
26 #include "ir/primitive.h"
27 #include "backend/optimizer/common/pass_manager.h"
28 #include "backend/optimizer/common/pattern_engine.h"
29 #include "ir/graph_utils.h"
30 #include "utils/ms_utils.h"
31 #include "backend/optimizer/common/helper.h"
32 
33 namespace mindspore {
34 namespace opt {
35 using PatternListType = std::initializer_list<BaseRef>;
36 
37 class PatternProcessPass : public NodePass {
38  public:
39   explicit PatternProcessPass(const std::string &name = "", bool multigraph = true);
40   ~PatternProcessPass() override = default;
41   virtual const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0;
42   virtual const BaseRef DefinePattern() const;
43   AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
44 
45  private:
46   void Build();
47 
48   AnfNodePtr pattern_ = nullptr;
49   bool multigraph_ = true;
50   PatternEngine pattern_engine_;
51   PrimitiveVarMapPtr primitive_vars_;
52 };
53 
54 class MultipleOutputPatternProcessPass : public PatternProcessPass {
55  public:
56   explicit MultipleOutputPatternProcessPass(const std::string &name = "", bool multigraph = true)
PatternProcessPass(name,multigraph)57       : PatternProcessPass(name, multigraph),
58         child_pattern_engine_(PatternEngine(std::make_shared<Visitor>())),
59         child_primitive_vars_(std::make_shared<PrimitiveVarMap>()) {}
60   ~MultipleOutputPatternProcessPass() override = default;
61   virtual BaseRef DefineAnotherPattern() const = 0;
62   // check two patterns whether share the same nodes or not
63   virtual bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const = 0;
64 
65  protected:
66   bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const;
67   PatternEngine child_pattern_engine_;
68   PrimitiveVarMapPtr child_primitive_vars_;
69 };
70 
71 class GraphOptimizer {
72  public:
name_(name)73   explicit GraphOptimizer(const std::string &name = "graph_optimizer") : name_(name) {}
74   virtual ~GraphOptimizer() = default;
75 
76   void AddPassManager(const PassManagerPtr &pass_manager);
77   FuncGraphPtr Optimize(const FuncGraphPtr &func_graph, bool run_only_once = true);
78 
79  private:
80   const std::string name_ = "graph_optimizer";
81   std::vector<PassManagerPtr> pass_managers_{};
82   bool run_only_once_ = true;
83 };
84 }  // namespace opt
85 }  // namespace mindspore
86 
87 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_OPTIMIZER_H_
88