• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_
17 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_
18 #include <string>
19 #include <memory>
20 #include <vector>
21 #include <unordered_map>
22 
23 #include "base/base.h"
24 #include "ir/anf.h"
25 #include "ir/tensor.h"
26 #include "pybind_api/ir/primitive_py.h"
27 #include "pybind_api/ir/tensor_py.h"
28 
29 namespace mindspore {
30 namespace opt {
31 namespace python_pass {
32 using std::string;
33 using std::vector;
34 
35 class MatchResult;
36 using MatchResultPtr = std::shared_ptr<MatchResult>;
37 class Pattern;
38 using PatternPtr = std::shared_ptr<Pattern>;
39 class Prim;
40 using PrimPtr = std::shared_ptr<Prim>;
41 class Call;
42 using CallPtr = std::shared_ptr<Call>;
43 class NewTensor;
44 using NewTensorPtr = std::shared_ptr<NewTensor>;
45 class NewParameter;
46 using NewParameterPtr = std::shared_ptr<NewParameter>;
47 class Imm;
48 using ImmPtr = std::shared_ptr<Imm>;
49 struct PatternHasher;
50 struct PatternEqual;
51 using PatternNodeMap = std::unordered_map<PatternPtr, AnfNodePtr, PatternHasher, PatternEqual>;
52 
53 class Pattern : public Base {
54  public:
Pattern()55   Pattern() : unique_name_(std::to_string(g_id_++)) {}
56   ~Pattern() = default;
match(const AnfNodePtr & node)57   virtual MatchResultPtr match(const AnfNodePtr &node) { return nullptr; }
58   virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; }
unique_name()59   string unique_name() const { return unique_name_; }
inputs()60   vector<PatternPtr> inputs() { return inputs_; }
reset()61   virtual void reset() {}
reset_gid()62   static void reset_gid() { g_id_ = 0; }
63 
64  protected:
65   static int64_t g_id_;
66   // NOTE: To ensure uniqueness of the name, raise g_id_ by 1 every time a pattern got constructed
67   string unique_name_;
68   vector<PatternPtr> inputs_;
69 };
70 
71 struct PatternEqual {
operatorPatternEqual72   bool operator()(PatternPtr const &p1, PatternPtr const &p2) const {
73     MS_EXCEPTION_IF_NULL(p1);
74     MS_EXCEPTION_IF_NULL(p2);
75     return p1->unique_name() == p2->unique_name();
76   }
77 };
78 
79 struct PatternHasher {
operatorPatternHasher80   std::size_t operator()(PatternPtr const &p) const {
81     MS_EXCEPTION_IF_NULL(p);
82     return std::hash<string>()(p->unique_name());
83   }
84 };
85 
86 class Prim : public Pattern {
87  public:
Prim()88   Prim() { unique_name_ = std::to_string(g_id_++); }
89   ~Prim() = default;
Prim(vector<py::object> prim_objs,string name)90   Prim(vector<py::object> prim_objs, string name) : name_(name) {
91     unique_name_ = std::to_string(g_id_++) + "Prim_" + name;
92     for (auto &prim_obj : prim_objs) {
93       if (py::isinstance<PrimitivePyAdapter>(prim_obj)) {
94         auto prim_adapter = prim_obj.cast<PrimitivePyAdapterPtr>();
95         primitives_.push_back(std::make_shared<PrimitivePy>(prim_obj, prim_adapter));
96       } else if (py::isinstance<py::str>(prim_obj)) {
97         std::string prim_name = prim_obj.cast<py::str>();
98         primitives_.push_back(std::make_shared<PrimitivePy>(prim_name));
99       } else {
100         MS_LOG(EXCEPTION) << "Parameter of Prim::__init__ must be Primitive_ type or Prim name, please check input.";
101       }
102     }
103     // Default using the first prim to build target
104     matched_prim_ = primitives_[0];
105   }
106   MS_DECLARE_PARENT(Prim, Pattern);
107   MatchResultPtr match(const AnfNodePtr &node) override;
matched_primitive()108   PrimitivePyPtr matched_primitive() { return matched_prim_; }
reset()109   void reset() override {
110     // Init before reset
111     MS_EXCEPTION_IF_NULL(matched_prim_);
112     matched_prim_ = primitives_[0];
113   }
114 
115  private:
116   vector<PrimitivePyPtr> primitives_;
117   string name_;
118   PrimitivePyPtr matched_prim_{nullptr};
119 };
120 
121 class Call : public Pattern {
122  public:
Call()123   Call() { unique_name_ = std::to_string(g_id_++); }
124   ~Call() = default;
Call(PatternPtr prim_pattern,vector<PatternPtr> inputs)125   Call(PatternPtr prim_pattern, vector<PatternPtr> inputs) {
126     // NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting
127     prim_pattern_ = prim_pattern;
128     unique_name_ = std::to_string(g_id_++) + "Call_" + prim_pattern->unique_name();
129     inputs_ = inputs;
130   }
Call(py::object prim_obj,vector<PatternPtr> inputs)131   Call(py::object prim_obj, vector<PatternPtr> inputs) {
132     if (py::isinstance<PrimitivePyAdapter>(prim_obj)) {
133       auto prim_adapter = prim_obj.cast<PrimitivePyAdapterPtr>();
134       prim_ = std::make_shared<PrimitivePy>(prim_obj, prim_adapter);
135     } else if (py::isinstance<py::str>(prim_obj)) {
136       std::string prim_name = prim_obj.cast<py::str>();
137       prim_ = std::make_shared<PrimitivePy>(prim_name);
138     } else {
139       MS_LOG(EXCEPTION) << "Parameter of Call::__init__ must be Primitive_ type or Prim name, please check input.";
140     }
141     unique_name_ = std::to_string(g_id_++) + "Call_" + prim_->ToString();
142     inputs_ = inputs;
143   }
144   MS_DECLARE_PARENT(Call, Pattern);
145   MatchResultPtr match(const AnfNodePtr &node) override;
prim_value()146   PrimitivePtr prim_value() { return prim_; }
prim_pattern()147   PatternPtr prim_pattern() { return prim_pattern_; }
148 
149  private:
150   PatternPtr prim_pattern_ = nullptr;
151   PrimitivePtr prim_ = nullptr;
152   vector<string> types_;
153   string name_;
154 };
155 
156 class OneOf : public Pattern {
157  public:
OneOf()158   OneOf() { unique_name_ = std::to_string(g_id_++); }
159   ~OneOf() = default;
OneOf(vector<PatternPtr> patterns)160   explicit OneOf(vector<PatternPtr> patterns) : patterns_(patterns) {
161     unique_name_ = std::to_string(g_id_++) + "OneOf";
162     for (auto &iter : patterns) {
163       unique_name_ = unique_name_ + "_" + iter->unique_name();
164     }
165   }
166   MS_DECLARE_PARENT(OneOf, Pattern);
167   MatchResultPtr match(const AnfNodePtr &node) override;
168 
169  private:
170   vector<PatternPtr> patterns_;
171 };
172 
173 class NoneOf : public Pattern {
174  public:
NoneOf()175   NoneOf() { unique_name_ = std::to_string(g_id_++); }
176   ~NoneOf() = default;
NoneOf(vector<PatternPtr> patterns)177   explicit NoneOf(vector<PatternPtr> patterns) : patterns_(patterns) {
178     unique_name_ = std::to_string(g_id_++) + "NoneOf";
179     for (auto &iter : patterns) {
180       unique_name_ = unique_name_ + "_" + iter->unique_name();
181     }
182   }
183   MS_DECLARE_PARENT(NoneOf, Pattern);
184   MatchResultPtr match(const AnfNodePtr &node) override;
185 
186  private:
187   vector<PatternPtr> patterns_;
188 };
189 
190 class Any : public Pattern {
191  public:
Any()192   Any() { unique_name_ = std::to_string(g_id_++) + "_Any"; }
193   ~Any() = default;
194   MS_DECLARE_PARENT(Any, Pattern);
195   MatchResultPtr match(const AnfNodePtr &node) override;
196 };
197 
198 class NewTensor : public Pattern {
199  public:
NewTensor()200   NewTensor() { unique_name_ = std::to_string(g_id_++); }
201   ~NewTensor() = default;
NewTensor(tensor::TensorPtr input_tensor)202   explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) {
203     unique_name_ = std::to_string(g_id_++) + "NewTensor";
204   }
205   MS_DECLARE_PARENT(NewTensor, Pattern);
match(const AnfNodePtr & node)206   MatchResultPtr match(const AnfNodePtr &node) override {
207     MS_LOG(EXCEPTION) << "Find NewTensor in pattern, NewTensor should only appear in the target.\n";
208   }
input_tensor()209   tensor::TensorPtr input_tensor() { return input_tensor_; }
210 
211  private:
212   tensor::TensorPtr input_tensor_;
213 };
214 
215 class NewParameter : public Pattern {
216  public:
NewParameter()217   NewParameter() { unique_name_ = std::to_string(g_id_++); }
NewParameter(string para_name,tensor::TensorPtr default_tensor,bool requires_grad,bool layerwise_parallel)218   explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel)
219       : para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) {
220     unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name;
221     default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get());
222     built_ = false;
223   }
224   ~NewParameter() = default;
225   MS_DECLARE_PARENT(NewParameter, Pattern);
match(const AnfNodePtr & node)226   MatchResultPtr match(const AnfNodePtr &node) override {
227     MS_LOG(EXCEPTION) << "Find NewParameter in pattern, NewParameter should only appear in the target.\n";
228   }
para_name()229   string para_name() { return para_name_; }
default_tensor()230   tensor::TensorPtr default_tensor() { return default_tensor_; }
requires_grad()231   bool requires_grad() { return requires_grad_; }
layerwise_parallel()232   bool layerwise_parallel() { return layerwise_parallel_; }
built()233   bool built() { return built_; }
set_built(bool built)234   void set_built(bool built) { built_ = built; }
reset()235   void reset() override { built_ = false; }
should_last()236   bool should_last() { return last_across_passes_; }
set_last(bool last)237   void set_last(bool last) { last_across_passes_ = last; }
238 
239  private:
240   string para_name_;
241   bool requires_grad_;
242   bool layerwise_parallel_;
243   bool last_across_passes_{false};
244   bool built_;
245   tensor::TensorPtr default_tensor_;
246 };
247 
248 class Imm : public Pattern {
249  public:
Imm()250   Imm() { unique_name_ = std::to_string(g_id_++); }
Imm(int value)251   explicit Imm(int value) : value_(value) { unique_name_ = std::to_string(g_id_++) + "Imm_" + std::to_string(value); }
252   ~Imm() = default;
253   MS_DECLARE_PARENT(Imm, Pattern);
254   MatchResultPtr match(const AnfNodePtr &node) override;
value()255   int value() { return value_; }
256 
257  private:
258   int64_t value_;
259 };
260 
261 class MatchResult {
262  public:
MatchResult()263   MatchResult() {}
264   ~MatchResult() = default;
add_entry(PatternPtr pattern,AnfNodePtr node)265   void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; }
result()266   const PatternNodeMap &result() { return match_result_; }
267   AnfNodePtr get_node(const PatternPtr &pattern);
268   void merge(const MatchResultPtr &other_result);
clear()269   void clear() { match_result_.clear(); }
dump()270   void dump() {
271     MS_LOG(DEBUG) << "match_result_.size: " + std::to_string(match_result_.size()) + "\n";
272     for (auto &iter : match_result_) {
273       MS_LOG(DEBUG) << "Pattern : " + iter.first->unique_name() + " , node : " + iter.second->ToString() + "\n";
274     }
275   }
276 
277  private:
278   PatternNodeMap match_result_;
279 };
280 }  // namespace python_pass
281 }  // namespace opt
282 }  // namespace mindspore
283 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_
284