1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either AnfNodePtress or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OP_COMBINE_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OP_COMBINE_H_ 18 19 #include <map> 20 #include <unordered_set> 21 #include <unordered_map> 22 #include <memory> 23 #include <string> 24 #include <vector> 25 #include "include/backend/optimizer/pass.h" 26 #include "ir/func_graph.h" 27 #include "backend/common/graph_kernel/graph_kernel_helper.h" 28 #include "ops/array_op_name.h" 29 30 namespace mindspore::graphkernel { 31 struct Branch { BranchBranch32 Branch(AnfNodePtrList lst, int pos) : ops(lst), target_op_pos(pos) {} 33 AnfNodePtrList ops; 34 int target_op_pos; // -1 means no target op in this branch 35 AnfNodePtr root_data{nullptr}; sizeBranch36 size_t size() { return ops.size(); } GetTargetOpBranch37 AnfNodePtr GetTargetOp() { return GetOp(target_op_pos); } GetOpBranch38 AnfNodePtr GetOp(int depth) { 39 if (depth < 0 || depth >= static_cast<int>(ops.size())) { 40 return nullptr; 41 } 42 return ops[depth]; 43 } 44 GetRootDataBranch45 AnfNodePtr GetRootData() { return root_data; } SetDataRootBranch46 void SetDataRoot(AnfNodePtr data) { root_data = data; } ToStringBranch47 std::string ToString() { 48 std::string res; 49 res += "RootData: "; 50 res += root_data->fullname_with_scope(); 51 res += "; Ops: ["; 52 for (size_t i = 0; i < ops.size(); ++i) { 53 auto op = ops[i]; 54 res += op->fullname_with_scope(); 55 if (static_cast<int>(i) == target_op_pos) { 56 res += "(LEAD OP)"; 57 } 58 res += ", "; 59 } 60 res += "]"; 61 return res; 62 } 63 }; 64 using Group = std::vector<Branch>; 65 using FIsSupportedOp = std::function<bool(const AnfNodePtr &n)>; 66 using FAreCompatibleOps = std::function<bool(const AnfNodePtr &a, const AnfNodePtr &b)>; 67 using AnfNodePtrSubstMap = std::unordered_map<AnfNodePtr, AnfNodePtr>; 68 using AnfNodePtrSet = std::unordered_set<AnfNodePtr>; 69 class BranchGroupFinder { 70 public: 71 BranchGroupFinder(const std::string &op_name, FIsSupportedOp fis_supported_op, FAreCompatibleOps fare_compatible_ops); 72 std::vector<Group> Find(const AnfNodePtr &start_node, const FuncGraphPtr &func_graph = nullptr); 73 std::unordered_map<AnfNodePtr, AnfNodePtrSet> children_map_; 74 75 private: 76 std::string op_name_; 77 AnfNodePtrSet op_roots_; 78 FIsSupportedOp fis_supported_op_; 79 FAreCompatibleOps fare_compatible_ops_; 80 Branch CreateBranch(AnfNodePtr lead_op); 81 AnfNodeIndexSet GetConsumers(FuncGraphManagerPtr mng, const AnfNodePtr &producer); 82 }; 83 84 class ParallelOpCombiner { 85 public: 86 explicit ParallelOpCombiner(const std::string &op_name, uint64_t min_num_branches, const std::string &layout); 87 AnfNodePtr Combine(const AnfNodePtr &root, const FuncGraphPtr &func_graph = nullptr); 88 virtual ~ParallelOpCombiner() = default; 89 90 protected: 91 virtual bool IsSupportedOp(const AnfNodePtr n) = 0; 92 virtual bool CanOpsBeCombined(const AnfNodePtr a, const AnfNodePtr b) = 0; 93 virtual AnfNodePtr MakeCombinedOp(const Group &branches) = 0; 94 virtual bool IsArgCompatible(const AnfNodePtr a, const AnfNodePtr b) = 0; 95 virtual AnfNodePtr MakeCombinedAnfNodePtrFromFollowingOps(const AnfNodePtr &data, const Group &branches, 96 size_t depth) = 0; 97 virtual void UpdateGroupOutput(const AnfNodePtr &data, const Group &branches, size_t depth) = 0; 98 bool AutoUpdateInfo(const CNodePtr &to_update); 99 100 std::map<size_t, AnfNodePtrList> GetUniqueInputs(const Group &branches, size_t depth) const; 101 102 FuncGraphPtr main_graph_; 103 AnfNodePtr combined_; 104 std::unordered_map<AnfNodePtr, AnfNodePtrSet> children_map_; 105 std::unordered_set<std::string> unsupported_ops_{mindspore::kTransposeOpName, mindspore::kReshapeOpName}; 106 107 private: 108 void CombineBranches(const Group &branches); 109 bool CheckLevel(const Group &branches, size_t depth); 110 111 std::string op_name_; 112 uint64_t min_num_branches_{2}; 113 std::string layout_; 114 }; 115 116 class GraphBuilder { 117 public: 118 static CNodePtr NewTupleNode(const FuncGraphPtr &func_graph, AnfNodePtrList shared_inputs); 119 static CNodePtr NewSplitNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, size_t split_dim, 120 size_t split_num); 121 static CNodePtr NewConcatNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &input_node, size_t concat_dim, 122 size_t input_num); 123 static CNodePtr NewElemwiseNoAttrNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &matmul_inputs); 124 static CNodePtr NewReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &matmul_inputs, 125 const AnfNodePtr &orig_node); 126 static CNodePtr NewTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &matmul_inputs); 127 128 static ShapeVector InferReshapeOut(const ShapeVector &orig_reshape_in, const ShapeVector &orig_reshape_out, 129 const ShapeVector &new_reshape_in); 130 static ShapeVector InferConcatReshapeOut(const ShapeVector &orig_reshape_in, const ShapeVector &orig_reshape_out, 131 const ShapeVector &new_reshape_in); 132 static ShapeVector InferTransposeOut(const ShapeVector &in_shape, const std::vector<int64_t> &perm); 133 }; 134 } // namespace mindspore::graphkernel 135 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OP_COMBINE_H_ 136