1 2 /** 3 * Copyright 2021 Huawei Technologies Co., Ltd 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_UPDATE_STATE_FORMATTER_H_ 18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_UPDATE_STATE_FORMATTER_H_ 19 20 #include <vector> 21 #include "backend/optimizer/common/pass.h" 22 #include "ir/func_graph.h" 23 24 namespace mindspore { 25 namespace opt { 26 /** 27 * @brief Spread the input tuple of UpdateState 28 * @example 29 * %1 = op1 30 * %2 = op2 31 * %3 = make_tuple(%1, %2) 32 * UpdateState(U, %3) 33 * --> 34 * %1 = op1 35 * %2 = op2 36 * UpdateState(U, %1, %2) 37 */ 38 class SpreadUpdateState : public Pass { 39 public: SpreadUpdateState()40 SpreadUpdateState() : Pass("spread_update_state") {} 41 ~SpreadUpdateState() override = default; 42 AnfNodePtrList ExtendInputsOfUpdateState(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph); 43 bool Run(const FuncGraphPtr &func_graph) override; 44 }; 45 46 /** 47 * @brief Shrink the inputs of UpdateState to a tuple 48 * @example 49 * %1 = op1 50 * %2 = op2 51 * UpdateState(U, %1, %2) 52 * --> 53 * %1 = op1 54 * %2 = op2 55 * %3 = make_tuple(%1, %2) 56 * UpdateState(U, %3) 57 */ 58 class ShrinkUpdateState : public Pass { 59 public: ShrinkUpdateState()60 ShrinkUpdateState() : Pass("shrink_update_state") {} 61 ~ShrinkUpdateState() override = default; 62 bool Run(const FuncGraphPtr &func_graph) override; 63 }; 64 65 /** 66 * @brief Spread the MakeTuple in node list 67 * @param nodes 68 * @param begin_index 69 * @example 70 * input 71 * nodes: [ a, b, MakeTuple[i, j], c, d, MakeTuple[x, MakeTuple[y, z]] ] 72 * begin_index: 1 73 * output 74 * [b, i, j, c, d, x, y, z] 75 * @return std::vector<AnfNodePtr> 76 */ 77 AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index = 0); 78 79 /** 80 * @brief Extend the getitem for UpdateState 81 * @example 82 * In this example, the Cast is an output of GraphKernel and only links to an UpdateState, 83 * it has two users in GraphKernel, Add and Sub, which are all outputs. 84 * after processing, the Cast was eliminate from output list and the Add and Sub was linked to UpdateState. 85 * 86 * graph_kernel: 87 * %1 = Cast(p1) 88 * %2 = Add(%1, p2) // depends on Cast 89 * %3 = Sub(%2, p3) // depends on Cast 90 * %4 = Mul(p1, p2) // not depends on Cast 91 * return make_tuple(%1, %2, %3, %4) 92 * main graph: 93 * %1 = call @graph_kernel(p1, p2) 94 * %2 = tuple_getitem(%1, 0) // The Cast 95 * %3 = UpdateState(U, %2) 96 * --> 97 * graph_kernel: 98 * %1 = Cast(p1) 99 * %2 = Add(%1, p2) // depends on Cast 100 * %3 = Sub(%2, p3) // depends on Cast 101 * %4 = Mul(p1, p2) // not depends on Cast 102 * return make_tuple(%2, %3, %4) // the Cast was eliminated from output list 103 * main graph: 104 * %1 = call @graph_kernel(p1, p2) 105 * %2 = tuple_getitem(%1, 0) // the Add 106 * %3 = tuple_getitem(%1, 1) // the Sub 107 * %4 = UpdateState(U, %2, %3) 108 */ 109 class ExtendOutputForUpdateState : public Pass { 110 public: ExtendOutputForUpdateState()111 ExtendOutputForUpdateState() : Pass("extend_output_for_update_state") {} 112 ~ExtendOutputForUpdateState() = default; 113 bool Run(const FuncGraphPtr &func_graph) override; 114 115 private: 116 // Get the nodes that have external UpdateState user. 117 void FindIndexesToUpdateState(const FuncGraphManagerPtr &mng); 118 void FilterIndexes(const FuncGraphPtr &func_graph); 119 // Find all the func_graph's outputs that depends (directly or indirectly) on the indicated(index) node. 120 std::vector<size_t> FindAllOutputs(const FuncGraphPtr &func_graph, size_t index); 121 bool ProcessIndex(const FuncGraphPtr &func_graph, const FuncGraphPtr &sub_func_graph, size_t index); 122 123 enum ExternalUserType { 124 kNormalOp, // only has normal operators 125 kUpdateState, // only has UpdateState(s) 126 kMix, // UpdateState mix with normal operator 127 }; 128 AnfNodePtrList getitems_; // Users of the GraphKernel nodes. 129 std::vector<size_t> indexes_; // Indexes of GetItem to be processed. 130 std::vector<ExternalUserType> external_user_type_; // The type of getitem's users. 131 }; 132 133 /** 134 * @brief Merge UpdateState's inputs which link to the same node 135 * @example 136 * graph_kernel: 137 * %1 = Cast(p1) 138 * %2 = Add(%1, p2) 139 * %3 = Sub(%2, p3) 140 * %4 = Mul(p1, p2) 141 * return make_tuple(%1, %2, %3, %4) 142 * main graph: 143 * %1 = call @graph_kernel(p1, p2) 144 * %2 = tuple_getitem(%1, 0) 145 * %3 = tuple_getitem(%1, 1) 146 * %4 = tuple_getitem(%1, 2) 147 * %5 = UpdateState(U, %2, %3, %4) // the %2 %3 %4 are all link to %1 148 * --> 149 * main graph: 150 * %1 = call @graph_kernel(p1, p2) 151 * %2 = tuple_getitem(%1, 0) 152 * %3 = tuple_getitem(%1, 1) 153 * %4 = tuple_getitem(%1, 2) 154 * %5 = UpdateState(U, %2) // only keep %2 155 */ 156 class MergeOutputForUpdateState : public Pass { 157 public: MergeOutputForUpdateState()158 MergeOutputForUpdateState() : Pass("merge_output_for_update_state") {} 159 ~MergeOutputForUpdateState() = default; 160 bool Run(const FuncGraphPtr &func_graph) override; 161 }; 162 } // namespace opt 163 } // namespace mindspore 164 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_UPDATE_STATE_FORMATTER_H_ 165