1 2 /** 3 * Copyright 2021-2022 Huawei Technologies Co., Ltd 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_UPDATE_STATE_FORMATTER_H_ 18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_UPDATE_STATE_FORMATTER_H_ 19 20 #include <vector> 21 #include "include/backend/optimizer/pass.h" 22 #include "mindspore/core/ops/framework_ops.h" 23 #include "ir/func_graph.h" 24 25 namespace mindspore::graphkernel { 26 /** 27 * @brief Spread the input tuple of UpdateState 28 * @example 29 * %1 = op1 30 * %2 = op2 31 * %3 = make_tuple(%1, %2) 32 * UpdateState(U, %3) 33 * --> 34 * %1 = op1 35 * %2 = op2 36 * UpdateState(U, %1, %2) 37 */ 38 class SpreadUpdateState : public opt::Pass { 39 public: SpreadUpdateState()40 SpreadUpdateState() : Pass("spread_update_state") {} 41 ~SpreadUpdateState() override = default; 42 AnfNodePtrList ExtendInputsOfUpdateState(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph) const; 43 bool Run(const FuncGraphPtr &func_graph) override; 44 }; 45 46 /** 47 * @brief Shrink the inputs of UpdateState to a tuple 48 * @example 49 * %1 = op1 50 * %2 = op2 51 * UpdateState(U, %1, %2) 52 * --> 53 * %1 = op1 54 * %2 = op2 55 * %3 = make_tuple(%1, %2) 56 * UpdateState(U, %3) 57 */ 58 class ShrinkUpdateState : public opt::Pass { 59 public: ShrinkUpdateState()60 ShrinkUpdateState() : Pass("shrink_update_state") {} 61 ~ShrinkUpdateState() override = default; 62 bool Run(const FuncGraphPtr &func_graph) override; 63 }; 64 65 /** 66 * @brief Extend the getitem for UpdateState 67 * @example 68 * In this example, the Cast is an output of GraphKernel and only links to an UpdateState, 69 * it has two users in GraphKernel, Add and Sub, which are all outputs. 70 * after processing, the Cast was eliminate from output list and the Add and Sub was linked to UpdateState. 71 * 72 * graph_kernel: 73 * %1 = Cast(p1) 74 * %2 = Add(%1, p2) // depends on Cast 75 * %3 = Sub(%2, p3) // depends on Cast 76 * %4 = Mul(p1, p2) // not depends on Cast 77 * return make_tuple(%1, %2, %3, %4) 78 * main graph: 79 * %1 = call @graph_kernel(p1, p2) 80 * %2 = tuple_getitem(%1, 0) // The Cast 81 * %3 = UpdateState(U, %2) 82 * --> 83 * graph_kernel: 84 * %1 = Cast(p1) 85 * %2 = Add(%1, p2) // depends on Cast 86 * %3 = Sub(%2, p3) // depends on Cast 87 * %4 = Mul(p1, p2) // not depends on Cast 88 * return make_tuple(%2, %3, %4) // the Cast was eliminated from output list 89 * main graph: 90 * %1 = call @graph_kernel(p1, p2) 91 * %2 = tuple_getitem(%1, 0) // the Add 92 * %3 = tuple_getitem(%1, 1) // the Sub 93 * %4 = UpdateState(U, %2, %3) 94 */ 95 class ExtendOutputForUpdateState : public opt::Pass { 96 public: ExtendOutputForUpdateState()97 ExtendOutputForUpdateState() : Pass("extend_output_for_update_state") {} 98 ~ExtendOutputForUpdateState() = default; 99 bool Run(const FuncGraphPtr &func_graph) override; 100 101 private: 102 // Get the nodes that have external UpdateState user. 103 void FindIndexesToUpdateState(const FuncGraphManagerPtr &mng); 104 void FilterIndexes(const FuncGraphPtr &func_graph); 105 // Find all the func_graph's outputs that depends (directly or indirectly) on the indicated(index) node. 106 std::vector<size_t> FindAllOutputs(const FuncGraphPtr &func_graph, size_t index); 107 bool ProcessIndex(const FuncGraphPtr &func_graph, const FuncGraphPtr &sub_func_graph, size_t index); 108 109 enum ExternalUserType { 110 kNormalOp, // only has normal operators 111 kUpdateState, // only has UpdateState(s) 112 kMix, // UpdateState mix with normal operator 113 }; 114 AnfNodePtrList getitems_; // Users of the GraphKernel nodes. 115 std::vector<size_t> indexes_; // Indexes of GetItem to be processed. 116 std::vector<ExternalUserType> external_user_type_; // The type of getitem's users. 117 }; 118 119 /** 120 * @brief Merge UpdateState's inputs which link to the same node 121 * @example 122 * graph_kernel: 123 * %1 = Cast(p1) 124 * %2 = Add(%1, p2) 125 * %3 = Sub(%2, p3) 126 * %4 = Mul(p1, p2) 127 * return make_tuple(%1, %2, %3, %4) 128 * main graph: 129 * %1 = call @graph_kernel(p1, p2) 130 * %2 = tuple_getitem(%1, 0) 131 * %3 = tuple_getitem(%1, 1) 132 * %4 = tuple_getitem(%1, 2) 133 * %5 = UpdateState(U, %2, %3, %4) // the %2 %3 %4 are all link to %1 134 * --> 135 * main graph: 136 * %1 = call @graph_kernel(p1, p2) 137 * %2 = tuple_getitem(%1, 0) 138 * %3 = tuple_getitem(%1, 1) 139 * %4 = tuple_getitem(%1, 2) 140 * %5 = UpdateState(U, %2) // only keep %2 141 */ 142 class MergeOutputForUpdateState : public opt::Pass { 143 public: MergeOutputForUpdateState()144 MergeOutputForUpdateState() : Pass("merge_output_for_update_state") {} 145 ~MergeOutputForUpdateState() = default; 146 bool Run(const FuncGraphPtr &func_graph) override; 147 }; 148 } // namespace mindspore::graphkernel 149 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_UPDATE_STATE_FORMATTER_H_ 150