• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_MODEL_AREA_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_MODEL_AREA_H_
18 
19 #include <memory>
20 #include <vector>
21 #include <utility>
22 #include <string>
23 #include "utils/hash_map.h"
24 #include "backend/common/graph_kernel/model/op_node.h"
25 
26 namespace mindspore::graphkernel::inner {
27 using NodePattern = PrimOp::ComputeType;
28 // EdgeRelation indicates the pattern of node's input edges.
29 // the INJECTIVE means the input is directly sent into the kernel,
30 // the BROADCAST means the input is implicit broadcasted.
31 //
32 // Note, it should be distinguished from the PrimOp::ComputeType,
33 // which indicates the INNER logic of kernels.
34 enum class EdgeRelation : int { INJECTIVE = 0, BROADCAST = 1 };
35 
36 // AreaMode indicates the finally mode of kernels.
37 // the BASIC means the node(s) of area will be inlined into the main graph
38 // the COMPOSITE means the node(s) of area will be kept as a GraphKernel node.
39 enum class AreaMode { BASIC, COMPOSITE };
40 
41 class Area;
42 using AreaPtr = std::shared_ptr<Area>;
43 using AreaWithRelation = std::pair<AreaPtr, EdgeRelation>;
44 
45 // Area is used to maintain the operator set that was fused.
46 class Area : public std::enable_shared_from_this<Area> {
47   // NodeHandle is used to maintain the input and user edges of areas.
48   // The handle's inputs should be other areas' handle.
49   //
50   // This class is derived from PrimOp, to reuse the compute_type field
51   // and to avoid overriding pure virtual functions (if exists).
52   //
53   // This class is not visible outside the class Area.
54   class NodeHandle : public PrimOp {
55    public:
NodeHandle(Area * area,const PrimOpPtr & p)56     NodeHandle(Area *area, const PrimOpPtr &p) : PrimOp("", p->compute_type()), area_(area) {}
57     ~NodeHandle() = default;
58     using PrimOp::compute_type_;
area()59     AreaPtr area() const { return area_->shared_from_this(); }
60 
61    private:
62     Area *const area_;
63   };  // class Area::NodeHandle
64 
65  public:
66   Area(size_t id, const PrimOpPtr &prim_op, bool is_output, const HashMap<NodePtr, AreaPtr> &node_area_map);
67   ~Area() = default;
68 
id()69   size_t id() const { return unique_id_; }
input(size_t i)70   const AreaPtr &input(size_t i) const { return inputs_with_relation_[i].first; }
71   std::vector<AreaPtr> inputs() const;
input_relation(size_t i)72   EdgeRelation input_relation(size_t i) const { return inputs_with_relation_[i].second; }
inputs_with_relation()73   const std::vector<AreaWithRelation> &inputs_with_relation() const { return inputs_with_relation_; }
input_num()74   size_t input_num() const { return inputs_with_relation_.size(); }
75   // get the number of operators in the area
size()76   size_t size() const { return ops_.size(); }
77   std::vector<AreaPtr> users() const;
78   std::vector<AreaWithRelation> users_with_relation() const;
user_num()79   size_t user_num() const { return hd_->users().size(); }
mode()80   AreaMode mode() const { return mode_; }
81   // get the dominant op node
dom()82   PrimOpPtr dom() const { return IsAlive() ? ops_[0] : nullptr; }
pattern()83   NodePattern pattern() const { return hd_->compute_type(); }
ops()84   const std::vector<PrimOpPtr> &ops() const { return ops_; }
is_output()85   bool is_output() const { return is_output_; }
86   int64_t compute_size() const;
87   bool ComputeSizeEqual(const AreaPtr &other) const;
88 
89   // check whether the area is alive(true) or is fused(false)
IsAlive()90   bool IsAlive() const { return !ops_.empty(); }
91   std::string ToString() const;
SetOps(const std::vector<PrimOpPtr> & ops)92   void SetOps(const std::vector<PrimOpPtr> &ops) { ops_ = ops; }
SetMode(AreaMode mode)93   void SetMode(AreaMode mode) { mode_ = mode; }
94   // fuse `input_area` into `this` area. after that, the `input_area` will be discarded.
95   // the `input_area` node should be in the input list of `this` area.
96   void FuseInput(const AreaPtr &input_area);
97 
98  protected:
99   // Make the inputs unique, and sync the inputs to NodeHandle
100   void MakeUniqueAndSyncInputs();
101   // Relink the `input_area`'s users to `this` area
102   void UpdateUsersRelation(const AreaPtr &input_area);
103 
104   std::shared_ptr<NodeHandle> hd_;
105   const size_t unique_id_;
106   bool is_output_;
107   std::vector<PrimOpPtr> ops_;
108   AreaMode mode_{AreaMode::BASIC};
109   // The `inputs_with_relation_.first` stores the input area of `this` area.
110   // The `hd_->inputs` stores the NodeHandle of `this` area, to maintain the user edges.
111   // They should always be in sync.
112   std::vector<AreaWithRelation> inputs_with_relation_;
113 };
114 }  // namespace mindspore::graphkernel::inner
115 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_MODEL_AREA_H_
116