• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_
19 
20 #include "utils/hash_set.h"
21 #include "mindspore/core/ops/sequence_ops.h"
22 #include "mindspore/core/ops/framework_ops.h"
23 #include "frontend/optimizer/irpass.h"
24 #include "frontend/optimizer/optimizer.h"
25 #include "frontend/optimizer/anf_visitor.h"
26 #include "include/common/utils/parallel_context.h"
27 #include "ir/func_graph.h"
28 
29 namespace mindspore {
30 namespace opt {
31 namespace irpass {
32 class SetCellOutputNoRecompute : public AnfVisitor {
33  public:
operator()34   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
35     auto context = MsContext::GetInstance();
36     MS_EXCEPTION_IF_NULL(context);
37     const auto no_cell_reuse = context->CellReuseLevel() == CellReuseLevel::kNoCellReuse;
38     if (!IsValueNode<FuncGraph>(node)) {
39       return nullptr;
40     }
41 
42     auto fg = GetValueNode<FuncGraphPtr>(node);
43     if (fg == nullptr || !fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
44       return nullptr;
45     }
46     auto output = fg->output();
47     if (output == nullptr) {
48       return nullptr;
49     }
50     if (output->isa<CNode>()) {
51       mindspore::HashSet<CNodePtr> real_outputs;
52       GetRealOutputNodes(output, &real_outputs);
53       if (OutputAllNodes(real_outputs)) {
54         MS_LOG(WARNING)
55           << "All nodes in the graph " << fg->ToString()
56           << " are the output nodes, which are set to not be recomputed. If you want to set these nodes to "
57              "be recomputed, use the api recompute() of Primitive.";
58       }
59       for (const auto &real_output : real_outputs) {
60         // Set the attr of cnode in case of shared primitives.
61         if (no_cell_reuse) {
62           real_output->AddAttr(kAttrRecompute, MakeValue(false));
63         }
64 
65         if (parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kSemiAutoParallel ||
66             parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kAutoParallel) {
67           auto prim = GetCNodePrimitive(real_output);
68           if (prim->HasAttr(kAttrSliceActivation) && GetValue<bool>(prim->GetAttr(kAttrSliceActivation))) {
69             real_output->AddAttr(kAttrSliceActivation, MakeValue(true));
70           }
71         }
72       }
73     }
74     if (no_cell_reuse) {
75       fg->erase_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE);
76     }
77     return nullptr;
78   }
79 
GetRealOutputNodes(const AnfNodePtr & output,mindspore::HashSet<CNodePtr> * real_outputs)80   void GetRealOutputNodes(const AnfNodePtr &output, mindspore::HashSet<CNodePtr> *real_outputs) {
81     MS_EXCEPTION_IF_NULL(output);
82     MS_EXCEPTION_IF_NULL(real_outputs);
83     auto output_cnode = output->cast<CNodePtr>();
84     if (output_cnode == nullptr) {
85       return;
86     }
87     auto input0 = output_cnode->input(0);
88     MS_EXCEPTION_IF_NULL(input0);
89     if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimTupleGetItem)) {
90       GetRealOutputNodes(output_cnode->input(kRealInputIndexInDepend), real_outputs);
91     } else if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
92       auto &inputs = output_cnode->inputs();
93       for (size_t i = 1; i < inputs.size(); ++i) {
94         GetRealOutputNodes(output_cnode->input(i), real_outputs);
95       }
96     } else if (IsValueNode<FuncGraph>(input0)) {
97       auto fg = GetValueNode<FuncGraphPtr>(input0);
98       GetRealOutputNodes(fg->output(), real_outputs);
99     } else if (input0->isa<CNode>()) {
100       auto abs = input0->abstract();
101       if (abs == nullptr || !abs->isa<abstract::AbstractFunction>()) {
102         return;
103       }
104       auto abs_func = abs->cast<abstract::AbstractFunctionPtr>();
105       if (abs_func->isa<abstract::AbstractFuncUnion>()) {
106         auto visit_fn = [this, &real_outputs](const abstract::AbstractFuncAtomPtr &poss) {
107           auto abs_fg = GetAbstractFuncGraph(poss);
108           if (abs_fg != nullptr) {
109             GetRealOutputNodes(abs_fg->output(), real_outputs);
110           }
111         };
112         abs_func->Visit(visit_fn);
113         return;
114       }
115       auto fg = GetAbstractFuncGraph(abs_func);
116       if (fg != nullptr) {
117         GetRealOutputNodes(fg->output(), real_outputs);
118       }
119     } else {
120       real_outputs->insert(output_cnode);
121     }
122   }
123 
GetAbstractFuncGraph(const abstract::AbstractFunctionPtr & abs)124   FuncGraphPtr GetAbstractFuncGraph(const abstract::AbstractFunctionPtr &abs) const {
125     if (abs->isa<abstract::FuncGraphAbstractClosure>()) {
126       auto abstract_func_graph = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
127       return abstract_func_graph->func_graph();
128     }
129     if (abs->isa<abstract::PartialAbstractClosure>()) {
130       auto abstract_partial_func = abs->cast<abstract::PartialAbstractClosurePtr>();
131       auto abstract_fn = abstract_partial_func->fn();
132       if (abstract_fn != nullptr && abstract_fn->isa<abstract::FuncGraphAbstractClosure>()) {
133         auto abstract_func_graph = abstract_fn->cast<abstract::FuncGraphAbstractClosurePtr>();
134         return abstract_func_graph->func_graph();
135       }
136     }
137     return nullptr;
138   }
139 
OutputAllNodes(const mindspore::HashSet<CNodePtr> & real_outputs)140   bool OutputAllNodes(const mindspore::HashSet<CNodePtr> &real_outputs) const {
141     for (const auto &cnode : real_outputs) {
142       const auto &inputs = cnode->inputs();
143       for (const auto &input : inputs) {
144         auto input_cnode = input->cast<CNodePtr>();
145         if (input_cnode == nullptr || IsPrimitiveCNode(input_cnode, prim::kPrimLoad)) {
146           continue;
147         }
148         if (real_outputs.find(input_cnode) == real_outputs.end()) {
149           return false;
150         }
151       }
152     }
153     return true;
154   }
155 };
156 }  // namespace irpass
157 }  // namespace opt
158 }  // namespace mindspore
159 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_
160