• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either AnfNodePtress or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OP_COMBINE_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OP_COMBINE_H_
18 
19 #include <map>
20 #include <unordered_set>
21 #include <unordered_map>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 #include "include/backend/optimizer/pass.h"
26 #include "ir/func_graph.h"
27 #include "backend/common/graph_kernel/graph_kernel_helper.h"
28 #include "ops/array_op_name.h"
29 
30 namespace mindspore::graphkernel {
31 struct Branch {
BranchBranch32   Branch(AnfNodePtrList lst, int pos) : ops(lst), target_op_pos(pos) {}
33   AnfNodePtrList ops;
34   int target_op_pos;  // -1 means no target op in this branch
35   AnfNodePtr root_data{nullptr};
sizeBranch36   size_t size() { return ops.size(); }
GetTargetOpBranch37   AnfNodePtr GetTargetOp() { return GetOp(target_op_pos); }
GetOpBranch38   AnfNodePtr GetOp(int depth) {
39     if (depth < 0 || depth >= static_cast<int>(ops.size())) {
40       return nullptr;
41     }
42     return ops[depth];
43   }
44 
GetRootDataBranch45   AnfNodePtr GetRootData() { return root_data; }
SetDataRootBranch46   void SetDataRoot(AnfNodePtr data) { root_data = data; }
ToStringBranch47   std::string ToString() {
48     std::string res;
49     res += "RootData: ";
50     res += root_data->fullname_with_scope();
51     res += "; Ops: [";
52     for (size_t i = 0; i < ops.size(); ++i) {
53       auto op = ops[i];
54       res += op->fullname_with_scope();
55       if (static_cast<int>(i) == target_op_pos) {
56         res += "(LEAD OP)";
57       }
58       res += ", ";
59     }
60     res += "]";
61     return res;
62   }
63 };
64 using Group = std::vector<Branch>;
65 using FIsSupportedOp = std::function<bool(const AnfNodePtr &n)>;
66 using FAreCompatibleOps = std::function<bool(const AnfNodePtr &a, const AnfNodePtr &b)>;
67 using AnfNodePtrSubstMap = std::unordered_map<AnfNodePtr, AnfNodePtr>;
68 using AnfNodePtrSet = std::unordered_set<AnfNodePtr>;
69 class BranchGroupFinder {
70  public:
71   BranchGroupFinder(const std::string &op_name, FIsSupportedOp fis_supported_op, FAreCompatibleOps fare_compatible_ops);
72   std::vector<Group> Find(const AnfNodePtr &start_node, const FuncGraphPtr &func_graph = nullptr);
73   std::unordered_map<AnfNodePtr, AnfNodePtrSet> children_map_;
74 
75  private:
76   std::string op_name_;
77   AnfNodePtrSet op_roots_;
78   FIsSupportedOp fis_supported_op_;
79   FAreCompatibleOps fare_compatible_ops_;
80   Branch CreateBranch(AnfNodePtr lead_op);
81   AnfNodeIndexSet GetConsumers(FuncGraphManagerPtr mng, const AnfNodePtr &producer);
82 };
83 
84 class ParallelOpCombiner {
85  public:
86   explicit ParallelOpCombiner(const std::string &op_name, uint64_t min_num_branches, const std::string &layout);
87   AnfNodePtr Combine(const AnfNodePtr &root, const FuncGraphPtr &func_graph = nullptr);
88   virtual ~ParallelOpCombiner() = default;
89 
90  protected:
91   virtual bool IsSupportedOp(const AnfNodePtr n) = 0;
92   virtual bool CanOpsBeCombined(const AnfNodePtr a, const AnfNodePtr b) = 0;
93   virtual AnfNodePtr MakeCombinedOp(const Group &branches) = 0;
94   virtual bool IsArgCompatible(const AnfNodePtr a, const AnfNodePtr b) = 0;
95   virtual AnfNodePtr MakeCombinedAnfNodePtrFromFollowingOps(const AnfNodePtr &data, const Group &branches,
96                                                             size_t depth) = 0;
97   virtual void UpdateGroupOutput(const AnfNodePtr &data, const Group &branches, size_t depth) = 0;
98   bool AutoUpdateInfo(const CNodePtr &to_update);
99 
100   std::map<size_t, AnfNodePtrList> GetUniqueInputs(const Group &branches, size_t depth) const;
101 
102   FuncGraphPtr main_graph_;
103   AnfNodePtr combined_;
104   std::unordered_map<AnfNodePtr, AnfNodePtrSet> children_map_;
105   std::unordered_set<std::string> unsupported_ops_{mindspore::kTransposeOpName, mindspore::kReshapeOpName};
106 
107  private:
108   void CombineBranches(const Group &branches);
109   bool CheckLevel(const Group &branches, size_t depth);
110 
111   std::string op_name_;
112   uint64_t min_num_branches_{2};
113   std::string layout_;
114 };
115 
116 class GraphBuilder {
117  public:
118   static CNodePtr NewTupleNode(const FuncGraphPtr &func_graph, AnfNodePtrList shared_inputs);
119   static CNodePtr NewSplitNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, size_t split_dim,
120                                size_t split_num);
121   static CNodePtr NewConcatNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &input_node, size_t concat_dim,
122                                 size_t input_num);
123   static CNodePtr NewElemwiseNoAttrNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &matmul_inputs);
124   static CNodePtr NewReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &matmul_inputs,
125                                  const AnfNodePtr &orig_node);
126   static CNodePtr NewTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &matmul_inputs);
127 
128   static ShapeVector InferReshapeOut(const ShapeVector &orig_reshape_in, const ShapeVector &orig_reshape_out,
129                                      const ShapeVector &new_reshape_in);
130   static ShapeVector InferConcatReshapeOut(const ShapeVector &orig_reshape_in, const ShapeVector &orig_reshape_out,
131                                            const ShapeVector &new_reshape_in);
132   static ShapeVector InferTransposeOut(const ShapeVector &in_shape, const std::vector<int64_t> &perm);
133 };
134 }  // namespace mindspore::graphkernel
135 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OP_COMBINE_H_
136