• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 /**
3  * Copyright 2021-2022 Huawei Technologies Co., Ltd
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_UPDATE_STATE_FORMATTER_H_
18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_UPDATE_STATE_FORMATTER_H_
19 
20 #include <vector>
21 #include "include/backend/optimizer/pass.h"
22 #include "mindspore/core/ops/framework_ops.h"
23 #include "ir/func_graph.h"
24 
25 namespace mindspore::graphkernel {
26 /**
27  * @brief Spread the input tuple of UpdateState
28  * @example
29  *   %1 = op1
30  *   %2 = op2
31  *   %3 = make_tuple(%1, %2)
32  *   UpdateState(U, %3)
33  *   -->
34  *   %1 = op1
35  *   %2 = op2
36  *   UpdateState(U, %1, %2)
37  */
38 class SpreadUpdateState : public opt::Pass {
39  public:
SpreadUpdateState()40   SpreadUpdateState() : Pass("spread_update_state") {}
41   ~SpreadUpdateState() override = default;
42   AnfNodePtrList ExtendInputsOfUpdateState(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph) const;
43   bool Run(const FuncGraphPtr &func_graph) override;
44 };
45 
46 /**
47  * @brief Shrink the inputs of UpdateState to a tuple
48  * @example
49  *   %1 = op1
50  *   %2 = op2
51  *   UpdateState(U, %1, %2)
52  *   -->
53  *   %1 = op1
54  *   %2 = op2
55  *   %3 = make_tuple(%1, %2)
56  *   UpdateState(U, %3)
57  */
58 class ShrinkUpdateState : public opt::Pass {
59  public:
ShrinkUpdateState()60   ShrinkUpdateState() : Pass("shrink_update_state") {}
61   ~ShrinkUpdateState() override = default;
62   bool Run(const FuncGraphPtr &func_graph) override;
63 };
64 
65 /**
66  * @brief Extend the getitem for UpdateState
67  * @example
68  *   In this example, the Cast is an output of GraphKernel and only links to an UpdateState,
69  *   it has two users in GraphKernel, Add and Sub, which are all outputs.
70  *   after processing, the Cast was eliminate from output list and the Add and Sub was linked to UpdateState.
71  *
72  *   graph_kernel:
73  *      %1 = Cast(p1)
74  *      %2 = Add(%1, p2)   // depends on Cast
75  *      %3 = Sub(%2, p3)   // depends on Cast
76  *      %4 = Mul(p1, p2)   // not depends on Cast
77  *      return make_tuple(%1, %2, %3, %4)
78  *   main graph:
79  *      %1 = call @graph_kernel(p1, p2)
80  *      %2 = tuple_getitem(%1, 0)  // The Cast
81  *      %3 = UpdateState(U, %2)
82  *  -->
83  *   graph_kernel:
84  *      %1 = Cast(p1)
85  *      %2 = Add(%1, p2)   // depends on Cast
86  *      %3 = Sub(%2, p3)   // depends on Cast
87  *      %4 = Mul(p1, p2)   // not depends on Cast
88  *      return make_tuple(%2, %3, %4)  // the Cast was eliminated from output list
89  *   main graph:
90  *      %1 = call @graph_kernel(p1, p2)
91  *      %2 = tuple_getitem(%1, 0)   // the Add
92  *      %3 = tuple_getitem(%1, 1)   // the Sub
93  *      %4 = UpdateState(U, %2, %3)
94  */
95 class ExtendOutputForUpdateState : public opt::Pass {
96  public:
ExtendOutputForUpdateState()97   ExtendOutputForUpdateState() : Pass("extend_output_for_update_state") {}
98   ~ExtendOutputForUpdateState() = default;
99   bool Run(const FuncGraphPtr &func_graph) override;
100 
101  private:
102   // Get the nodes that have external UpdateState user.
103   void FindIndexesToUpdateState(const FuncGraphManagerPtr &mng);
104   void FilterIndexes(const FuncGraphPtr &func_graph);
105   // Find all the func_graph's outputs that depends (directly or indirectly) on the indicated(index) node.
106   std::vector<size_t> FindAllOutputs(const FuncGraphPtr &func_graph, size_t index);
107   bool ProcessIndex(const FuncGraphPtr &func_graph, const FuncGraphPtr &sub_func_graph, size_t index);
108 
109   enum ExternalUserType {
110     kNormalOp,     // only has normal operators
111     kUpdateState,  // only has UpdateState(s)
112     kMix,          // UpdateState mix with normal operator
113   };
114   AnfNodePtrList getitems_;                           // Users of the GraphKernel nodes.
115   std::vector<size_t> indexes_;                       // Indexes of GetItem to be processed.
116   std::vector<ExternalUserType> external_user_type_;  // The type of getitem's users.
117 };
118 
119 /**
120  * @brief Merge UpdateState's inputs which link to the same node
121  * @example
122  *   graph_kernel:
123  *      %1 = Cast(p1)
124  *      %2 = Add(%1, p2)
125  *      %3 = Sub(%2, p3)
126  *      %4 = Mul(p1, p2)
127  *      return make_tuple(%1, %2, %3, %4)
128  *   main graph:
129  *      %1 = call @graph_kernel(p1, p2)
130  *      %2 = tuple_getitem(%1, 0)
131  *      %3 = tuple_getitem(%1, 1)
132  *      %4 = tuple_getitem(%1, 2)
133  *      %5 = UpdateState(U, %2, %3, %4)  // the %2 %3 %4 are all link to %1
134  * -->
135  *   main graph:
136  *      %1 = call @graph_kernel(p1, p2)
137  *      %2 = tuple_getitem(%1, 0)
138  *      %3 = tuple_getitem(%1, 1)
139  *      %4 = tuple_getitem(%1, 2)
140  *      %5 = UpdateState(U, %2)  // only keep %2
141  */
142 class MergeOutputForUpdateState : public opt::Pass {
143  public:
MergeOutputForUpdateState()144   MergeOutputForUpdateState() : Pass("merge_output_for_update_state") {}
145   ~MergeOutputForUpdateState() = default;
146   bool Run(const FuncGraphPtr &func_graph) override;
147 };
148 }  // namespace mindspore::graphkernel
149 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_UPDATE_STATE_FORMATTER_H_
150