• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_MODEL_FUSE_PATTERN_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_MODEL_FUSE_PATTERN_H_
18 
19 #include <string>
20 #include <vector>
21 #include <memory>
22 #include "backend/common/graph_kernel/split_model/area.h"
23 #include "ops/math_op_name.h"
24 
25 namespace mindspore::graphkernel::inner {
26 class CircleChecker {
27  public:
28   // whether it will form a circle if the two areas are fused.
29   virtual bool HasCircle(const AreaPtr &a, const AreaPtr &b) const = 0;
30   virtual ~CircleChecker() = default;
31 };
32 using CircleCheckerPtr = std::shared_ptr<CircleChecker>;
33 
34 enum class FuseDirection {
35   FORWARD,  // fuse with inputs
36   BACKWARD  // fuse with outputs
37 };
38 
39 // the base class of fusion patterns
40 class FusePattern {
41  public:
FusePattern(const std::string & name)42   explicit FusePattern(const std::string &name) : name_(name) {}
43   virtual ~FusePattern() = default;
44   // Run the pattern
Run(const AreaPtr & dom)45   bool Run(const AreaPtr &dom) {
46     Reset();
47     return Check(dom) && Match(dom);
48   }
49   std::string ToString() const;
50   // Bind the circle checker
SetCircleChecker(const CircleCheckerPtr & c)51   void SetCircleChecker(const CircleCheckerPtr &c) { circle_checker_ = c; }
52 
name()53   std::string name() const { return name_; }
direction()54   FuseDirection direction() const { return direction_; }
55   std::vector<AreaPtr> fused_areas_;
56 
57  protected:
Reset()58   void Reset() { fused_areas_.clear(); }
59   // Check whether the pattern can handle this area
Check(const AreaPtr &)60   virtual bool Check(const AreaPtr &) { return true; }
61   // Match the ADJACENT areas of `dom`
62   virtual bool Match(const AreaPtr &dom) = 0;
63   // whether it will form a circle if the two areas are fused.
HasCircle(const AreaPtr & a,const AreaPtr & b)64   bool HasCircle(const AreaPtr &a, const AreaPtr &b) const {
65     MS_EXCEPTION_IF_NULL(circle_checker_);
66     return circle_checker_->HasCircle(a, b);
67   }
68 
69   std::string name_;
70   FuseDirection direction_{FuseDirection::FORWARD};
71   CircleCheckerPtr circle_checker_{nullptr};
72 };
73 using FusePatternPtr = std::shared_ptr<FusePattern>;
74 
75 /* some common patterns are defined below */
76 enum class FuseType { kWidth, kDepth };
77 class FuseReshape : public FusePattern {
78  public:
FuseReshape()79   FuseReshape() : FusePattern("reshape") {}
80   ~FuseReshape() = default;
81 
82  protected:
Check(const AreaPtr & dom)83   bool Check(const AreaPtr &dom) override { return dom->pattern() == NodePattern::RESHAPE; }
84   bool Match(const AreaPtr &dom) override;
85   void KeepMinimumArea(const AreaPtr &a, FuseDirection dir);
86   AreaPtr min_area_;
87 };
88 
89 class FuseIsolateReshape : public FusePattern {
90  public:
FuseIsolateReshape()91   FuseIsolateReshape() : FusePattern("isolate_reshape") {}
92   ~FuseIsolateReshape() = default;
93 
94  protected:
Check(const AreaPtr & dom)95   bool Check(const AreaPtr &dom) override { return dom->pattern() == NodePattern::RESHAPE && dom->size() == 1; }
96   bool Match(const AreaPtr &dom) override;
97 };
98 
99 class FuseElemwiseFwd : public FusePattern {
100  public:
FuseElemwiseFwd(FuseType fuse_type)101   explicit FuseElemwiseFwd(FuseType fuse_type) : FusePattern("elemwise_fwd"), fuse_type_(fuse_type) {
102     direction_ = FuseDirection::FORWARD;
103     name_ += (fuse_type == FuseType::kWidth ? "_width" : "_depth");
104   }
105   ~FuseElemwiseFwd() = default;
CreateDepthMatcher()106   static FusePatternPtr CreateDepthMatcher() { return std::make_shared<FuseElemwiseFwd>(FuseType::kDepth); }
CreateWidthMatcher()107   static FusePatternPtr CreateWidthMatcher() { return std::make_shared<FuseElemwiseFwd>(FuseType::kWidth); }
108 
109  protected:
110   bool Check(const AreaPtr &dom) override;
111   bool Match(const AreaPtr &dom) override;
112   FuseType fuse_type_;
113 };
114 
115 class FuseElemwiseBroadcastFwd : public FusePattern {
116  public:
FuseElemwiseBroadcastFwd(FuseType fuse_type)117   explicit FuseElemwiseBroadcastFwd(FuseType fuse_type) : FusePattern("elemwise_broadcast_fwd"), fuse_type_(fuse_type) {
118     direction_ = FuseDirection::FORWARD;
119     name_ += (fuse_type == FuseType::kWidth ? "_width" : "_depth");
120   }
121   ~FuseElemwiseBroadcastFwd() = default;
CreateDepthMatcher()122   static FusePatternPtr CreateDepthMatcher() { return std::make_shared<FuseElemwiseBroadcastFwd>(FuseType::kDepth); }
CreateWidthMatcher()123   static FusePatternPtr CreateWidthMatcher() { return std::make_shared<FuseElemwiseBroadcastFwd>(FuseType::kWidth); }
124 
125  protected:
126   bool Check(const AreaPtr &dom) override;
127   bool Match(const AreaPtr &dom) override;
128   FuseType fuse_type_;
129 };
130 
131 class FuseDynElemwiseBroadcastFwd : public FusePattern {
132  public:
FuseDynElemwiseBroadcastFwd(FuseType fuse_type)133   explicit FuseDynElemwiseBroadcastFwd(FuseType fuse_type)
134       : FusePattern("elemwise_broadcast_fwd"), fuse_type_(fuse_type) {
135     direction_ = FuseDirection::FORWARD;
136     name_ += (fuse_type == FuseType::kWidth ? "_width" : "_depth");
137   }
138   ~FuseDynElemwiseBroadcastFwd() = default;
CreateDepthMatcher()139   static FusePatternPtr CreateDepthMatcher() { return std::make_shared<FuseDynElemwiseBroadcastFwd>(FuseType::kDepth); }
CreateWidthMatcher()140   static FusePatternPtr CreateWidthMatcher() { return std::make_shared<FuseDynElemwiseBroadcastFwd>(FuseType::kWidth); }
141 
142  protected:
143   bool Check(const AreaPtr &dom) override;
144   bool Match(const AreaPtr &dom) override;
145   FuseType fuse_type_;
146 };
147 
148 class FuseReduceFwd : public FusePattern {
149  public:
FuseReduceFwd(FuseType fuse_type,size_t size_limit)150   FuseReduceFwd(FuseType fuse_type, size_t size_limit)
151       : FusePattern("reduce_fwd"), fuse_type_(fuse_type), size_limit_(size_limit) {
152     direction_ = FuseDirection::FORWARD;
153     name_ += (fuse_type == FuseType::kWidth ? "_width" : "_depth");
154   }
155   ~FuseReduceFwd() = default;
CreateDepthMatcher(size_t size_limit)156   static FusePatternPtr CreateDepthMatcher(size_t size_limit) {
157     return std::make_shared<FuseReduceFwd>(FuseType::kDepth, size_limit);
158   }
CreateWidthMatcher(size_t size_limit)159   static FusePatternPtr CreateWidthMatcher(size_t size_limit) {
160     return std::make_shared<FuseReduceFwd>(FuseType::kWidth, size_limit);
161   }
162 
163  protected:
164   bool Check(const AreaPtr &dom) override;
165   bool Match(const AreaPtr &dom) override;
166   FuseType fuse_type_;
167   size_t size_limit_;
168 };
169 
170 class FuseDynReduceFwd : public FusePattern {
171  public:
FuseDynReduceFwd(FuseType fuse_type,size_t size_limit)172   FuseDynReduceFwd(FuseType fuse_type, size_t size_limit)
173       : FusePattern("reduce_fwd"), fuse_type_(fuse_type), size_limit_(size_limit) {
174     direction_ = FuseDirection::FORWARD;
175     name_ += (fuse_type == FuseType::kWidth ? "_width" : "_depth");
176   }
177   ~FuseDynReduceFwd() = default;
CreateDepthMatcher(size_t size_limit)178   static FusePatternPtr CreateDepthMatcher(size_t size_limit) {
179     return std::make_shared<FuseDynReduceFwd>(FuseType::kDepth, size_limit);
180   }
CreateWidthMatcher(size_t size_limit)181   static FusePatternPtr CreateWidthMatcher(size_t size_limit) {
182     return std::make_shared<FuseDynReduceFwd>(FuseType::kWidth, size_limit);
183   }
184 
185  protected:
186   bool Check(const AreaPtr &dom) override;
187   bool Match(const AreaPtr &dom) override;
188   FuseType fuse_type_;
189   size_t size_limit_;
190 };
191 
192 class FuseElemwiseBroadcastBwd : public FusePattern {
193  public:
FuseElemwiseBroadcastBwd(FuseType fuse_type,size_t size_limit)194   FuseElemwiseBroadcastBwd(FuseType fuse_type, size_t size_limit)
195       : FusePattern("elemwise_broadcast_bwd"), fuse_type_(fuse_type), size_limit_(size_limit) {
196     direction_ = FuseDirection::BACKWARD;
197     name_ += (fuse_type == FuseType::kWidth ? "_width" : "_depth");
198   }
199   ~FuseElemwiseBroadcastBwd() = default;
CreateDepthMatcher(size_t size_limit)200   static FusePatternPtr CreateDepthMatcher(size_t size_limit) {
201     return std::make_shared<FuseElemwiseBroadcastBwd>(FuseType::kDepth, size_limit);
202   }
CreateWidthMatcher(size_t size_limit)203   static FusePatternPtr CreateWidthMatcher(size_t size_limit) {
204     return std::make_shared<FuseElemwiseBroadcastBwd>(FuseType::kWidth, size_limit);
205   }
206 
207  protected:
208   bool Check(const AreaPtr &dom) override;
209   bool Match(const AreaPtr &dom) override;
210   FuseType fuse_type_;
211   size_t size_limit_;
212 };
213 
214 // bind the virtual nodes to their inputs
215 class FuseVirtualNode : public FusePattern {
216  public:
FuseVirtualNode()217   FuseVirtualNode() : FusePattern("bind_virtual_node") { direction_ = FuseDirection::FORWARD; }
218   ~FuseVirtualNode() = default;
219 
220  protected:
Check(const AreaPtr & area)221   bool Check(const AreaPtr &area) override { return area->pattern() == NodePattern::VIRTUAL; }
222   bool Match(const AreaPtr &area) override;
223 };
224 
225 namespace ascend {
226 class FuseMatMul : public FusePattern {
227  public:
FuseMatMul()228   FuseMatMul() : FusePattern("matmul_depth") { direction_ = FuseDirection::BACKWARD; }
229   ~FuseMatMul() = default;
230 
231  protected:
Check(const AreaPtr & dom)232   bool Check(const AreaPtr &dom) override {
233     return dom->dom()->op() == kMatMulOpName || dom->dom()->op() == kBatchMatMulOpName;
234   }
235   bool Match(const AreaPtr &dom) override;
236 };
237 }  // namespace ascend
238 }  // namespace mindspore::graphkernel::inner
239 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_MODEL_FUSE_PATTERN_H_
240