• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_MODEL_SPLIT_MODEL_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_MODEL_SPLIT_MODEL_H_
18 
19 #include <vector>
20 #include <list>
21 #include <memory>
22 #include <set>
23 #include <utility>
24 #include "backend/common/graph_kernel/model/lite_graph.h"
25 #include "backend/common/graph_kernel/split_model/area.h"
26 #include "backend/common/graph_kernel/split_model/fuse_pattern.h"
27 
28 namespace mindspore::graphkernel::inner {
29 class ReachTable : public CircleChecker {
30  public:
31   explicit ReachTable(size_t size);
32   virtual ~ReachTable() = default;
33   bool HasCircle(const AreaPtr &a, const AreaPtr &b) const override;
34 
35   // Link area from `from` to `to`.
36   void Link(size_t from, size_t to);
37 
38   // Fuse the area `target` and `other`. After that, the `other` area will be discarded.
39   void FuseArea(size_t target, size_t other);
40 
41  private:
42   // check the reachability from `from` to `to`
Reachable(size_t from,size_t to)43   bool Reachable(size_t from, size_t to) const { return reach_[from][to]; }
44 
45   size_t size_;
46   std::vector<std::vector<bool>> reach_;
47   std::set<size_t> alive_;
48 };
49 
50 class SplitModel {
51  public:
52   void Run(const LiteGraphPtr &litegraph);
areas()53   const std::list<AreaPtr> &areas() const { return areas_; }
54   SplitModel() = default;
55   virtual ~SplitModel() = default;
56 
57  protected:
58   // transform the litegraph to areas, and initialize inner tables.
59   void InitGraph(const LiteGraphPtr &litegraph);
60   // Push leading "1" to shapes to facilitate pattern match.
61   void AlignShape(const LiteGraphPtr &litegraph) const;
62   // initialize fusion pattern list.
63   virtual void InitFusePatterns() = 0;
64   bool RunOnePattern(const FusePatternPtr &pattern);
65   // fuse areas by pattern
66   void RunFusePatterns();
67   // set default area mode when the area has only one node.
SetDefaultAreaMode(const AreaPtr & area)68   void SetDefaultAreaMode(const AreaPtr &area) const { area->SetMode(GetDefaultAreaMode(area->dom())); }
69   // get default area mode of the dominant node
70   virtual AreaMode GetDefaultAreaMode(const PrimOpPtr &node) const = 0;
71   // add new pattern
72   void AddPattern(const std::shared_ptr<FusePattern> &pn, bool enable = true);
73   // fuse areas
74   void FuseAreas(const AreaPtr &dom, const std::vector<AreaPtr> &areas, FuseDirection direction);
75   // create new area
76   AreaPtr NewArea(const PrimOpPtr &op, bool is_output);
77   // limit the area's size
78   void LimitAreaSize(const AreaPtr &dom, std::vector<AreaPtr> *areas) const;
79 
80   std::list<AreaPtr> areas_;  // use std::list to accelerate the "erase"
81   std::shared_ptr<ReachTable> reach_table_{nullptr};
82   HashMap<NodePtr, AreaPtr> node_area_map_;
83 
84  private:
85   size_t cur_area_id_{0};
86   std::vector<std::pair<FusePatternPtr, bool>> patterns_;
87 };
88 using SplitModelPtr = std::shared_ptr<SplitModel>;
89 }  // namespace mindspore::graphkernel::inner
90 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_MODEL_SPLIT_MODEL_H_
91