1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_MODEL_AREA_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_MODEL_AREA_H_ 18 19 #include <memory> 20 #include <vector> 21 #include <utility> 22 #include <string> 23 #include "utils/hash_map.h" 24 #include "backend/common/graph_kernel/model/op_node.h" 25 26 namespace mindspore::graphkernel::inner { 27 using NodePattern = PrimOp::ComputeType; 28 // EdgeRelation indicates the pattern of node's input edges. 29 // the INJECTIVE means the input is directly sent into the kernel, 30 // the BROADCAST means the input is implicit broadcasted. 31 // 32 // Note, it should be distinguished from the PrimOp::ComputeType, 33 // which indicates the INNER logic of kernels. 34 enum class EdgeRelation : int { INJECTIVE = 0, BROADCAST = 1 }; 35 36 // AreaMode indicates the finally mode of kernels. 37 // the BASIC means the node(s) of area will be inlined into the main graph 38 // the COMPOSITE means the node(s) of area will be kept as a GraphKernel node. 39 enum class AreaMode { BASIC, COMPOSITE }; 40 41 class Area; 42 using AreaPtr = std::shared_ptr<Area>; 43 using AreaWithRelation = std::pair<AreaPtr, EdgeRelation>; 44 45 // Area is used to maintain the operator set that was fused. 46 class Area : public std::enable_shared_from_this<Area> { 47 // NodeHandle is used to maintain the input and user edges of areas. 48 // The handle's inputs should be other areas' handle. 49 // 50 // This class is derived from PrimOp, to reuse the compute_type field 51 // and to avoid overriding pure virtual functions (if exists). 52 // 53 // This class is not visible outside the class Area. 54 class NodeHandle : public PrimOp { 55 public: NodeHandle(Area * area,const PrimOpPtr & p)56 NodeHandle(Area *area, const PrimOpPtr &p) : PrimOp("", p->compute_type()), area_(area) {} 57 ~NodeHandle() = default; 58 using PrimOp::compute_type_; area()59 AreaPtr area() const { return area_->shared_from_this(); } 60 61 private: 62 Area *const area_; 63 }; // class Area::NodeHandle 64 65 public: 66 Area(size_t id, const PrimOpPtr &prim_op, bool is_output, const HashMap<NodePtr, AreaPtr> &node_area_map); 67 ~Area() = default; 68 id()69 size_t id() const { return unique_id_; } input(size_t i)70 const AreaPtr &input(size_t i) const { return inputs_with_relation_[i].first; } 71 std::vector<AreaPtr> inputs() const; input_relation(size_t i)72 EdgeRelation input_relation(size_t i) const { return inputs_with_relation_[i].second; } inputs_with_relation()73 const std::vector<AreaWithRelation> &inputs_with_relation() const { return inputs_with_relation_; } input_num()74 size_t input_num() const { return inputs_with_relation_.size(); } 75 // get the number of operators in the area size()76 size_t size() const { return ops_.size(); } 77 std::vector<AreaPtr> users() const; 78 std::vector<AreaWithRelation> users_with_relation() const; user_num()79 size_t user_num() const { return hd_->users().size(); } mode()80 AreaMode mode() const { return mode_; } 81 // get the dominant op node dom()82 PrimOpPtr dom() const { return IsAlive() ? ops_[0] : nullptr; } pattern()83 NodePattern pattern() const { return hd_->compute_type(); } ops()84 const std::vector<PrimOpPtr> &ops() const { return ops_; } is_output()85 bool is_output() const { return is_output_; } 86 int64_t compute_size() const; 87 bool ComputeSizeEqual(const AreaPtr &other) const; 88 89 // check whether the area is alive(true) or is fused(false) IsAlive()90 bool IsAlive() const { return !ops_.empty(); } 91 std::string ToString() const; SetOps(const std::vector<PrimOpPtr> & ops)92 void SetOps(const std::vector<PrimOpPtr> &ops) { ops_ = ops; } SetMode(AreaMode mode)93 void SetMode(AreaMode mode) { mode_ = mode; } 94 // fuse `input_area` into `this` area. after that, the `input_area` will be discarded. 95 // the `input_area` node should be in the input list of `this` area. 96 void FuseInput(const AreaPtr &input_area); 97 98 protected: 99 // Make the inputs unique, and sync the inputs to NodeHandle 100 void MakeUniqueAndSyncInputs(); 101 // Relink the `input_area`'s users to `this` area 102 void UpdateUsersRelation(const AreaPtr &input_area); 103 104 std::shared_ptr<NodeHandle> hd_; 105 const size_t unique_id_; 106 bool is_output_; 107 std::vector<PrimOpPtr> ops_; 108 AreaMode mode_{AreaMode::BASIC}; 109 // The `inputs_with_relation_.first` stores the input area of `this` area. 110 // The `hd_->inputs` stores the NodeHandle of `this` area, to maintain the user edges. 111 // They should always be in sync. 112 std::vector<AreaWithRelation> inputs_with_relation_; 113 }; 114 } // namespace mindspore::graphkernel::inner 115 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_MODEL_AREA_H_ 116