• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 /**
3  * Copyright 2021 Huawei Technologies Co., Ltd
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_UPDATE_STATE_FORMATTER_H_
18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_UPDATE_STATE_FORMATTER_H_
19 
20 #include <vector>
21 #include "backend/optimizer/common/pass.h"
22 #include "ir/func_graph.h"
23 
24 namespace mindspore {
25 namespace opt {
26 /**
27  * @brief Spread the input tuple of UpdateState
28  * @example
29  *   %1 = op1
30  *   %2 = op2
31  *   %3 = make_tuple(%1, %2)
32  *   UpdateState(U, %3)
33  *   -->
34  *   %1 = op1
35  *   %2 = op2
36  *   UpdateState(U, %1, %2)
37  */
38 class SpreadUpdateState : public Pass {
39  public:
SpreadUpdateState()40   SpreadUpdateState() : Pass("spread_update_state") {}
41   ~SpreadUpdateState() override = default;
42   AnfNodePtrList ExtendInputsOfUpdateState(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph);
43   bool Run(const FuncGraphPtr &func_graph) override;
44 };
45 
46 /**
47  * @brief Shrink the inputs of UpdateState to a tuple
48  * @example
49  *   %1 = op1
50  *   %2 = op2
51  *   UpdateState(U, %1, %2)
52  *   -->
53  *   %1 = op1
54  *   %2 = op2
55  *   %3 = make_tuple(%1, %2)
56  *   UpdateState(U, %3)
57  */
58 class ShrinkUpdateState : public Pass {
59  public:
ShrinkUpdateState()60   ShrinkUpdateState() : Pass("shrink_update_state") {}
61   ~ShrinkUpdateState() override = default;
62   bool Run(const FuncGraphPtr &func_graph) override;
63 };
64 
65 /**
66  * @brief Spread the MakeTuple in node list
67  * @param nodes
68  * @param begin_index
69  * @example
70  *   input
71  *     nodes: [ a, b, MakeTuple[i, j], c, d, MakeTuple[x, MakeTuple[y, z]] ]
72  *     begin_index: 1
73  *   output
74  *     [b, i, j, c, d, x, y, z]
75  * @return std::vector<AnfNodePtr>
76  */
77 AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index = 0);
78 
79 /**
80  * @brief Extend the getitem for UpdateState
81  * @example
82  *   In this example, the Cast is an output of GraphKernel and only links to an UpdateState,
83  *   it has two users in GraphKernel, Add and Sub, which are all outputs.
84  *   after processing, the Cast was eliminate from output list and the Add and Sub was linked to UpdateState.
85  *
86  *   graph_kernel:
87  *      %1 = Cast(p1)
88  *      %2 = Add(%1, p2)   // depends on Cast
89  *      %3 = Sub(%2, p3)   // depends on Cast
90  *      %4 = Mul(p1, p2)   // not depends on Cast
91  *      return make_tuple(%1, %2, %3, %4)
92  *   main graph:
93  *      %1 = call @graph_kernel(p1, p2)
94  *      %2 = tuple_getitem(%1, 0)  // The Cast
95  *      %3 = UpdateState(U, %2)
96  *  -->
97  *   graph_kernel:
98  *      %1 = Cast(p1)
99  *      %2 = Add(%1, p2)   // depends on Cast
100  *      %3 = Sub(%2, p3)   // depends on Cast
101  *      %4 = Mul(p1, p2)   // not depends on Cast
102  *      return make_tuple(%2, %3, %4)  // the Cast was eliminated from output list
103  *   main graph:
104  *      %1 = call @graph_kernel(p1, p2)
105  *      %2 = tuple_getitem(%1, 0)   // the Add
106  *      %3 = tuple_getitem(%1, 1)   // the Sub
107  *      %4 = UpdateState(U, %2, %3)
108  */
109 class ExtendOutputForUpdateState : public Pass {
110  public:
ExtendOutputForUpdateState()111   ExtendOutputForUpdateState() : Pass("extend_output_for_update_state") {}
112   ~ExtendOutputForUpdateState() = default;
113   bool Run(const FuncGraphPtr &func_graph) override;
114 
115  private:
116   // Get the nodes that have external UpdateState user.
117   void FindIndexesToUpdateState(const FuncGraphManagerPtr &mng);
118   void FilterIndexes(const FuncGraphPtr &func_graph);
119   // Find all the func_graph's outputs that depends (directly or indirectly) on the indicated(index) node.
120   std::vector<size_t> FindAllOutputs(const FuncGraphPtr &func_graph, size_t index);
121   bool ProcessIndex(const FuncGraphPtr &func_graph, const FuncGraphPtr &sub_func_graph, size_t index);
122 
123   enum ExternalUserType {
124     kNormalOp,     // only has normal operators
125     kUpdateState,  // only has UpdateState(s)
126     kMix,          // UpdateState mix with normal operator
127   };
128   AnfNodePtrList getitems_;                           // Users of the GraphKernel nodes.
129   std::vector<size_t> indexes_;                       // Indexes of GetItem to be processed.
130   std::vector<ExternalUserType> external_user_type_;  // The type of getitem's users.
131 };
132 
133 /**
134  * @brief Merge UpdateState's inputs which link to the same node
135  * @example
136  *   graph_kernel:
137  *      %1 = Cast(p1)
138  *      %2 = Add(%1, p2)
139  *      %3 = Sub(%2, p3)
140  *      %4 = Mul(p1, p2)
141  *      return make_tuple(%1, %2, %3, %4)
142  *   main graph:
143  *      %1 = call @graph_kernel(p1, p2)
144  *      %2 = tuple_getitem(%1, 0)
145  *      %3 = tuple_getitem(%1, 1)
146  *      %4 = tuple_getitem(%1, 2)
147  *      %5 = UpdateState(U, %2, %3, %4)  // the %2 %3 %4 are all link to %1
148  * -->
149  *   main graph:
150  *      %1 = call @graph_kernel(p1, p2)
151  *      %2 = tuple_getitem(%1, 0)
152  *      %3 = tuple_getitem(%1, 1)
153  *      %4 = tuple_getitem(%1, 2)
154  *      %5 = UpdateState(U, %2)  // only keep %2
155  */
156 class MergeOutputForUpdateState : public Pass {
157  public:
MergeOutputForUpdateState()158   MergeOutputForUpdateState() : Pass("merge_output_for_update_state") {}
159   ~MergeOutputForUpdateState() = default;
160   bool Run(const FuncGraphPtr &func_graph) override;
161 };
162 }  // namespace opt
163 }  // namespace mindspore
164 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_UPDATE_STATE_FORMATTER_H_
165