1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_ 17 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_ 18 #include <string> 19 #include <memory> 20 #include <vector> 21 #include <unordered_map> 22 23 #include "base/base.h" 24 #include "ir/anf.h" 25 #include "ir/tensor.h" 26 #include "pybind_api/ir/primitive_py.h" 27 #include "pybind_api/ir/tensor_py.h" 28 29 namespace mindspore { 30 namespace opt { 31 namespace python_pass { 32 using std::string; 33 using std::vector; 34 35 class MatchResult; 36 using MatchResultPtr = std::shared_ptr<MatchResult>; 37 class Pattern; 38 using PatternPtr = std::shared_ptr<Pattern>; 39 class Prim; 40 using PrimPtr = std::shared_ptr<Prim>; 41 class Call; 42 using CallPtr = std::shared_ptr<Call>; 43 class NewTensor; 44 using NewTensorPtr = std::shared_ptr<NewTensor>; 45 class NewParameter; 46 using NewParameterPtr = std::shared_ptr<NewParameter>; 47 class Imm; 48 using ImmPtr = std::shared_ptr<Imm>; 49 struct PatternHasher; 50 struct PatternEqual; 51 using PatternNodeMap = std::unordered_map<PatternPtr, AnfNodePtr, PatternHasher, PatternEqual>; 52 53 class Pattern : public Base { 54 public: Pattern()55 Pattern() : unique_name_(std::to_string(g_id_++)) {} 56 ~Pattern() = default; match(const AnfNodePtr & node)57 virtual MatchResultPtr match(const AnfNodePtr &node) { return nullptr; } 58 virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; } unique_name()59 string unique_name() const { return unique_name_; } inputs()60 vector<PatternPtr> inputs() { return inputs_; } reset()61 virtual void reset() {} reset_gid()62 static void reset_gid() { g_id_ = 0; } 63 64 protected: 65 static int64_t g_id_; 66 // NOTE: To ensure uniqueness of the name, raise g_id_ by 1 every time a pattern got constructed 67 string unique_name_; 68 vector<PatternPtr> inputs_; 69 }; 70 71 struct PatternEqual { operatorPatternEqual72 bool operator()(PatternPtr const &p1, PatternPtr const &p2) const { 73 MS_EXCEPTION_IF_NULL(p1); 74 MS_EXCEPTION_IF_NULL(p2); 75 return p1->unique_name() == p2->unique_name(); 76 } 77 }; 78 79 struct PatternHasher { operatorPatternHasher80 std::size_t operator()(PatternPtr const &p) const { 81 MS_EXCEPTION_IF_NULL(p); 82 return std::hash<string>()(p->unique_name()); 83 } 84 }; 85 86 class Prim : public Pattern { 87 public: Prim()88 Prim() { unique_name_ = std::to_string(g_id_++); } 89 ~Prim() = default; Prim(vector<py::object> prim_objs,string name)90 Prim(vector<py::object> prim_objs, string name) : name_(name) { 91 unique_name_ = std::to_string(g_id_++) + "Prim_" + name; 92 for (auto &prim_obj : prim_objs) { 93 if (py::isinstance<PrimitivePyAdapter>(prim_obj)) { 94 auto prim_adapter = prim_obj.cast<PrimitivePyAdapterPtr>(); 95 primitives_.push_back(std::make_shared<PrimitivePy>(prim_obj, prim_adapter)); 96 } else if (py::isinstance<py::str>(prim_obj)) { 97 std::string prim_name = prim_obj.cast<py::str>(); 98 primitives_.push_back(std::make_shared<PrimitivePy>(prim_name)); 99 } else { 100 MS_LOG(EXCEPTION) << "Parameter of Prim::__init__ must be Primitive_ type or Prim name, please check input."; 101 } 102 } 103 // Default using the first prim to build target 104 matched_prim_ = primitives_[0]; 105 } 106 MS_DECLARE_PARENT(Prim, Pattern); 107 MatchResultPtr match(const AnfNodePtr &node) override; matched_primitive()108 PrimitivePyPtr matched_primitive() { return matched_prim_; } reset()109 void reset() override { 110 // Init before reset 111 MS_EXCEPTION_IF_NULL(matched_prim_); 112 matched_prim_ = primitives_[0]; 113 } 114 115 private: 116 vector<PrimitivePyPtr> primitives_; 117 string name_; 118 PrimitivePyPtr matched_prim_{nullptr}; 119 }; 120 121 class Call : public Pattern { 122 public: Call()123 Call() { unique_name_ = std::to_string(g_id_++); } 124 ~Call() = default; Call(PatternPtr prim_pattern,vector<PatternPtr> inputs)125 Call(PatternPtr prim_pattern, vector<PatternPtr> inputs) { 126 // NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting 127 prim_pattern_ = prim_pattern; 128 unique_name_ = std::to_string(g_id_++) + "Call_" + prim_pattern->unique_name(); 129 inputs_ = inputs; 130 } Call(py::object prim_obj,vector<PatternPtr> inputs)131 Call(py::object prim_obj, vector<PatternPtr> inputs) { 132 if (py::isinstance<PrimitivePyAdapter>(prim_obj)) { 133 auto prim_adapter = prim_obj.cast<PrimitivePyAdapterPtr>(); 134 prim_ = std::make_shared<PrimitivePy>(prim_obj, prim_adapter); 135 } else if (py::isinstance<py::str>(prim_obj)) { 136 std::string prim_name = prim_obj.cast<py::str>(); 137 prim_ = std::make_shared<PrimitivePy>(prim_name); 138 } else { 139 MS_LOG(EXCEPTION) << "Parameter of Call::__init__ must be Primitive_ type or Prim name, please check input."; 140 } 141 unique_name_ = std::to_string(g_id_++) + "Call_" + prim_->ToString(); 142 inputs_ = inputs; 143 } 144 MS_DECLARE_PARENT(Call, Pattern); 145 MatchResultPtr match(const AnfNodePtr &node) override; prim_value()146 PrimitivePtr prim_value() { return prim_; } prim_pattern()147 PatternPtr prim_pattern() { return prim_pattern_; } 148 149 private: 150 PatternPtr prim_pattern_ = nullptr; 151 PrimitivePtr prim_ = nullptr; 152 vector<string> types_; 153 string name_; 154 }; 155 156 class OneOf : public Pattern { 157 public: OneOf()158 OneOf() { unique_name_ = std::to_string(g_id_++); } 159 ~OneOf() = default; OneOf(vector<PatternPtr> patterns)160 explicit OneOf(vector<PatternPtr> patterns) : patterns_(patterns) { 161 unique_name_ = std::to_string(g_id_++) + "OneOf"; 162 for (auto &iter : patterns) { 163 unique_name_ = unique_name_ + "_" + iter->unique_name(); 164 } 165 } 166 MS_DECLARE_PARENT(OneOf, Pattern); 167 MatchResultPtr match(const AnfNodePtr &node) override; 168 169 private: 170 vector<PatternPtr> patterns_; 171 }; 172 173 class NoneOf : public Pattern { 174 public: NoneOf()175 NoneOf() { unique_name_ = std::to_string(g_id_++); } 176 ~NoneOf() = default; NoneOf(vector<PatternPtr> patterns)177 explicit NoneOf(vector<PatternPtr> patterns) : patterns_(patterns) { 178 unique_name_ = std::to_string(g_id_++) + "NoneOf"; 179 for (auto &iter : patterns) { 180 unique_name_ = unique_name_ + "_" + iter->unique_name(); 181 } 182 } 183 MS_DECLARE_PARENT(NoneOf, Pattern); 184 MatchResultPtr match(const AnfNodePtr &node) override; 185 186 private: 187 vector<PatternPtr> patterns_; 188 }; 189 190 class Any : public Pattern { 191 public: Any()192 Any() { unique_name_ = std::to_string(g_id_++) + "_Any"; } 193 ~Any() = default; 194 MS_DECLARE_PARENT(Any, Pattern); 195 MatchResultPtr match(const AnfNodePtr &node) override; 196 }; 197 198 class NewTensor : public Pattern { 199 public: NewTensor()200 NewTensor() { unique_name_ = std::to_string(g_id_++); } 201 ~NewTensor() = default; NewTensor(tensor::TensorPtr input_tensor)202 explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { 203 unique_name_ = std::to_string(g_id_++) + "NewTensor"; 204 } 205 MS_DECLARE_PARENT(NewTensor, Pattern); match(const AnfNodePtr & node)206 MatchResultPtr match(const AnfNodePtr &node) override { 207 MS_LOG(EXCEPTION) << "Find NewTensor in pattern, NewTensor should only appear in the target.\n"; 208 } input_tensor()209 tensor::TensorPtr input_tensor() { return input_tensor_; } 210 211 private: 212 tensor::TensorPtr input_tensor_; 213 }; 214 215 class NewParameter : public Pattern { 216 public: NewParameter()217 NewParameter() { unique_name_ = std::to_string(g_id_++); } NewParameter(string para_name,tensor::TensorPtr default_tensor,bool requires_grad,bool layerwise_parallel)218 explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel) 219 : para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) { 220 unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name; 221 default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get()); 222 built_ = false; 223 } 224 ~NewParameter() = default; 225 MS_DECLARE_PARENT(NewParameter, Pattern); match(const AnfNodePtr & node)226 MatchResultPtr match(const AnfNodePtr &node) override { 227 MS_LOG(EXCEPTION) << "Find NewParameter in pattern, NewParameter should only appear in the target.\n"; 228 } para_name()229 string para_name() { return para_name_; } default_tensor()230 tensor::TensorPtr default_tensor() { return default_tensor_; } requires_grad()231 bool requires_grad() { return requires_grad_; } layerwise_parallel()232 bool layerwise_parallel() { return layerwise_parallel_; } built()233 bool built() { return built_; } set_built(bool built)234 void set_built(bool built) { built_ = built; } reset()235 void reset() override { built_ = false; } should_last()236 bool should_last() { return last_across_passes_; } set_last(bool last)237 void set_last(bool last) { last_across_passes_ = last; } 238 239 private: 240 string para_name_; 241 bool requires_grad_; 242 bool layerwise_parallel_; 243 bool last_across_passes_{false}; 244 bool built_; 245 tensor::TensorPtr default_tensor_; 246 }; 247 248 class Imm : public Pattern { 249 public: Imm()250 Imm() { unique_name_ = std::to_string(g_id_++); } Imm(int value)251 explicit Imm(int value) : value_(value) { unique_name_ = std::to_string(g_id_++) + "Imm_" + std::to_string(value); } 252 ~Imm() = default; 253 MS_DECLARE_PARENT(Imm, Pattern); 254 MatchResultPtr match(const AnfNodePtr &node) override; value()255 int value() { return value_; } 256 257 private: 258 int64_t value_; 259 }; 260 261 class MatchResult { 262 public: MatchResult()263 MatchResult() {} 264 ~MatchResult() = default; add_entry(PatternPtr pattern,AnfNodePtr node)265 void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; } result()266 const PatternNodeMap &result() { return match_result_; } 267 AnfNodePtr get_node(const PatternPtr &pattern); 268 void merge(const MatchResultPtr &other_result); clear()269 void clear() { match_result_.clear(); } dump()270 void dump() { 271 MS_LOG(DEBUG) << "match_result_.size: " + std::to_string(match_result_.size()) + "\n"; 272 for (auto &iter : match_result_) { 273 MS_LOG(DEBUG) << "Pattern : " + iter.first->unique_name() + " , node : " + iter.second->ToString() + "\n"; 274 } 275 } 276 277 private: 278 PatternNodeMap match_result_; 279 }; 280 } // namespace python_pass 281 } // namespace opt 282 } // namespace mindspore 283 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_ 284