1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_MODEL_FUSE_PATTERN_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_MODEL_FUSE_PATTERN_H_ 18 19 #include <string> 20 #include <vector> 21 #include <memory> 22 #include "backend/common/graph_kernel/split_model/area.h" 23 #include "ops/math_op_name.h" 24 25 namespace mindspore::graphkernel::inner { 26 class CircleChecker { 27 public: 28 // whether it will form a circle if the two areas are fused. 29 virtual bool HasCircle(const AreaPtr &a, const AreaPtr &b) const = 0; 30 virtual ~CircleChecker() = default; 31 }; 32 using CircleCheckerPtr = std::shared_ptr<CircleChecker>; 33 34 enum class FuseDirection { 35 FORWARD, // fuse with inputs 36 BACKWARD // fuse with outputs 37 }; 38 39 // the base class of fusion patterns 40 class FusePattern { 41 public: FusePattern(const std::string & name)42 explicit FusePattern(const std::string &name) : name_(name) {} 43 virtual ~FusePattern() = default; 44 // Run the pattern Run(const AreaPtr & dom)45 bool Run(const AreaPtr &dom) { 46 Reset(); 47 return Check(dom) && Match(dom); 48 } 49 std::string ToString() const; 50 // Bind the circle checker SetCircleChecker(const CircleCheckerPtr & c)51 void SetCircleChecker(const CircleCheckerPtr &c) { circle_checker_ = c; } 52 name()53 std::string name() const { return name_; } direction()54 FuseDirection direction() const { return direction_; } 55 std::vector<AreaPtr> fused_areas_; 56 57 protected: Reset()58 void Reset() { fused_areas_.clear(); } 59 // Check whether the pattern can handle this area Check(const AreaPtr &)60 virtual bool Check(const AreaPtr &) { return true; } 61 // Match the ADJACENT areas of `dom` 62 virtual bool Match(const AreaPtr &dom) = 0; 63 // whether it will form a circle if the two areas are fused. HasCircle(const AreaPtr & a,const AreaPtr & b)64 bool HasCircle(const AreaPtr &a, const AreaPtr &b) const { 65 MS_EXCEPTION_IF_NULL(circle_checker_); 66 return circle_checker_->HasCircle(a, b); 67 } 68 69 std::string name_; 70 FuseDirection direction_{FuseDirection::FORWARD}; 71 CircleCheckerPtr circle_checker_{nullptr}; 72 }; 73 using FusePatternPtr = std::shared_ptr<FusePattern>; 74 75 /* some common patterns are defined below */ 76 enum class FuseType { kWidth, kDepth }; 77 class FuseReshape : public FusePattern { 78 public: FuseReshape()79 FuseReshape() : FusePattern("reshape") {} 80 ~FuseReshape() = default; 81 82 protected: Check(const AreaPtr & dom)83 bool Check(const AreaPtr &dom) override { return dom->pattern() == NodePattern::RESHAPE; } 84 bool Match(const AreaPtr &dom) override; 85 void KeepMinimumArea(const AreaPtr &a, FuseDirection dir); 86 AreaPtr min_area_; 87 }; 88 89 class FuseIsolateReshape : public FusePattern { 90 public: FuseIsolateReshape()91 FuseIsolateReshape() : FusePattern("isolate_reshape") {} 92 ~FuseIsolateReshape() = default; 93 94 protected: Check(const AreaPtr & dom)95 bool Check(const AreaPtr &dom) override { return dom->pattern() == NodePattern::RESHAPE && dom->size() == 1; } 96 bool Match(const AreaPtr &dom) override; 97 }; 98 99 class FuseElemwiseFwd : public FusePattern { 100 public: FuseElemwiseFwd(FuseType fuse_type)101 explicit FuseElemwiseFwd(FuseType fuse_type) : FusePattern("elemwise_fwd"), fuse_type_(fuse_type) { 102 direction_ = FuseDirection::FORWARD; 103 name_ += (fuse_type == FuseType::kWidth ? "_width" : "_depth"); 104 } 105 ~FuseElemwiseFwd() = default; CreateDepthMatcher()106 static FusePatternPtr CreateDepthMatcher() { return std::make_shared<FuseElemwiseFwd>(FuseType::kDepth); } CreateWidthMatcher()107 static FusePatternPtr CreateWidthMatcher() { return std::make_shared<FuseElemwiseFwd>(FuseType::kWidth); } 108 109 protected: 110 bool Check(const AreaPtr &dom) override; 111 bool Match(const AreaPtr &dom) override; 112 FuseType fuse_type_; 113 }; 114 115 class FuseElemwiseBroadcastFwd : public FusePattern { 116 public: FuseElemwiseBroadcastFwd(FuseType fuse_type)117 explicit FuseElemwiseBroadcastFwd(FuseType fuse_type) : FusePattern("elemwise_broadcast_fwd"), fuse_type_(fuse_type) { 118 direction_ = FuseDirection::FORWARD; 119 name_ += (fuse_type == FuseType::kWidth ? "_width" : "_depth"); 120 } 121 ~FuseElemwiseBroadcastFwd() = default; CreateDepthMatcher()122 static FusePatternPtr CreateDepthMatcher() { return std::make_shared<FuseElemwiseBroadcastFwd>(FuseType::kDepth); } CreateWidthMatcher()123 static FusePatternPtr CreateWidthMatcher() { return std::make_shared<FuseElemwiseBroadcastFwd>(FuseType::kWidth); } 124 125 protected: 126 bool Check(const AreaPtr &dom) override; 127 bool Match(const AreaPtr &dom) override; 128 FuseType fuse_type_; 129 }; 130 131 class FuseDynElemwiseBroadcastFwd : public FusePattern { 132 public: FuseDynElemwiseBroadcastFwd(FuseType fuse_type)133 explicit FuseDynElemwiseBroadcastFwd(FuseType fuse_type) 134 : FusePattern("elemwise_broadcast_fwd"), fuse_type_(fuse_type) { 135 direction_ = FuseDirection::FORWARD; 136 name_ += (fuse_type == FuseType::kWidth ? "_width" : "_depth"); 137 } 138 ~FuseDynElemwiseBroadcastFwd() = default; CreateDepthMatcher()139 static FusePatternPtr CreateDepthMatcher() { return std::make_shared<FuseDynElemwiseBroadcastFwd>(FuseType::kDepth); } CreateWidthMatcher()140 static FusePatternPtr CreateWidthMatcher() { return std::make_shared<FuseDynElemwiseBroadcastFwd>(FuseType::kWidth); } 141 142 protected: 143 bool Check(const AreaPtr &dom) override; 144 bool Match(const AreaPtr &dom) override; 145 FuseType fuse_type_; 146 }; 147 148 class FuseReduceFwd : public FusePattern { 149 public: FuseReduceFwd(FuseType fuse_type,size_t size_limit)150 FuseReduceFwd(FuseType fuse_type, size_t size_limit) 151 : FusePattern("reduce_fwd"), fuse_type_(fuse_type), size_limit_(size_limit) { 152 direction_ = FuseDirection::FORWARD; 153 name_ += (fuse_type == FuseType::kWidth ? "_width" : "_depth"); 154 } 155 ~FuseReduceFwd() = default; CreateDepthMatcher(size_t size_limit)156 static FusePatternPtr CreateDepthMatcher(size_t size_limit) { 157 return std::make_shared<FuseReduceFwd>(FuseType::kDepth, size_limit); 158 } CreateWidthMatcher(size_t size_limit)159 static FusePatternPtr CreateWidthMatcher(size_t size_limit) { 160 return std::make_shared<FuseReduceFwd>(FuseType::kWidth, size_limit); 161 } 162 163 protected: 164 bool Check(const AreaPtr &dom) override; 165 bool Match(const AreaPtr &dom) override; 166 FuseType fuse_type_; 167 size_t size_limit_; 168 }; 169 170 class FuseDynReduceFwd : public FusePattern { 171 public: FuseDynReduceFwd(FuseType fuse_type,size_t size_limit)172 FuseDynReduceFwd(FuseType fuse_type, size_t size_limit) 173 : FusePattern("reduce_fwd"), fuse_type_(fuse_type), size_limit_(size_limit) { 174 direction_ = FuseDirection::FORWARD; 175 name_ += (fuse_type == FuseType::kWidth ? "_width" : "_depth"); 176 } 177 ~FuseDynReduceFwd() = default; CreateDepthMatcher(size_t size_limit)178 static FusePatternPtr CreateDepthMatcher(size_t size_limit) { 179 return std::make_shared<FuseDynReduceFwd>(FuseType::kDepth, size_limit); 180 } CreateWidthMatcher(size_t size_limit)181 static FusePatternPtr CreateWidthMatcher(size_t size_limit) { 182 return std::make_shared<FuseDynReduceFwd>(FuseType::kWidth, size_limit); 183 } 184 185 protected: 186 bool Check(const AreaPtr &dom) override; 187 bool Match(const AreaPtr &dom) override; 188 FuseType fuse_type_; 189 size_t size_limit_; 190 }; 191 192 class FuseElemwiseBroadcastBwd : public FusePattern { 193 public: FuseElemwiseBroadcastBwd(FuseType fuse_type,size_t size_limit)194 FuseElemwiseBroadcastBwd(FuseType fuse_type, size_t size_limit) 195 : FusePattern("elemwise_broadcast_bwd"), fuse_type_(fuse_type), size_limit_(size_limit) { 196 direction_ = FuseDirection::BACKWARD; 197 name_ += (fuse_type == FuseType::kWidth ? "_width" : "_depth"); 198 } 199 ~FuseElemwiseBroadcastBwd() = default; CreateDepthMatcher(size_t size_limit)200 static FusePatternPtr CreateDepthMatcher(size_t size_limit) { 201 return std::make_shared<FuseElemwiseBroadcastBwd>(FuseType::kDepth, size_limit); 202 } CreateWidthMatcher(size_t size_limit)203 static FusePatternPtr CreateWidthMatcher(size_t size_limit) { 204 return std::make_shared<FuseElemwiseBroadcastBwd>(FuseType::kWidth, size_limit); 205 } 206 207 protected: 208 bool Check(const AreaPtr &dom) override; 209 bool Match(const AreaPtr &dom) override; 210 FuseType fuse_type_; 211 size_t size_limit_; 212 }; 213 214 // bind the virtual nodes to their inputs 215 class FuseVirtualNode : public FusePattern { 216 public: FuseVirtualNode()217 FuseVirtualNode() : FusePattern("bind_virtual_node") { direction_ = FuseDirection::FORWARD; } 218 ~FuseVirtualNode() = default; 219 220 protected: Check(const AreaPtr & area)221 bool Check(const AreaPtr &area) override { return area->pattern() == NodePattern::VIRTUAL; } 222 bool Match(const AreaPtr &area) override; 223 }; 224 225 namespace ascend { 226 class FuseMatMul : public FusePattern { 227 public: FuseMatMul()228 FuseMatMul() : FusePattern("matmul_depth") { direction_ = FuseDirection::BACKWARD; } 229 ~FuseMatMul() = default; 230 231 protected: Check(const AreaPtr & dom)232 bool Check(const AreaPtr &dom) override { 233 return dom->dom()->op() == kMatMulOpName || dom->dom()->op() == kBatchMatMulOpName; 234 } 235 bool Match(const AreaPtr &dom) override; 236 }; 237 } // namespace ascend 238 } // namespace mindspore::graphkernel::inner 239 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_MODEL_FUSE_PATTERN_H_ 240