• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_
19 
20 #include <unordered_set>
21 #include "frontend/optimizer/irpass.h"
22 #include "frontend/optimizer/optimizer.h"
23 #include "frontend/optimizer/anf_visitor.h"
24 #include "ir/func_graph.h"
25 
26 namespace mindspore {
27 namespace opt {
28 namespace irpass {
29 class SetCellOutputNoRecompute : public AnfVisitor {
30  public:
operator()31   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
32     if (!IsValueNode<FuncGraph>(node)) {
33       return nullptr;
34     }
35 
36     auto fg = GetValueNode<FuncGraphPtr>(node);
37     if (fg == nullptr || !fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
38       return nullptr;
39     }
40     auto output = fg->output();
41     if (output == nullptr) {
42       return nullptr;
43     }
44     if (output->isa<CNode>()) {
45       std::unordered_set<CNodePtr> real_outputs;
46       GetRealOutputNodes(output, &real_outputs);
47       for (const auto &real_output : real_outputs) {
48         // Set the attr of cnode in case of shared primitives.
49         real_output->AddAttr(kAttrRecompute, MakeValue(false));
50       }
51     }
52     fg->erase_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE);
53     return nullptr;
54   }
55 
GetRealOutputNodes(const AnfNodePtr & output,std::unordered_set<CNodePtr> * real_outputs)56   void GetRealOutputNodes(const AnfNodePtr &output, std::unordered_set<CNodePtr> *real_outputs) {
57     MS_EXCEPTION_IF_NULL(output);
58     MS_EXCEPTION_IF_NULL(real_outputs);
59     if (!output->isa<CNode>()) {
60       return;
61     }
62     auto output_cnode = output->cast<CNodePtr>();
63     if (IsPrimitiveCNode(output_cnode, prim::kPrimDepend) || IsPrimitiveCNode(output_cnode, prim::kPrimTupleGetItem)) {
64       GetRealOutputNodes(output_cnode->input(kRealInputIndexInDepend), real_outputs);
65     } else if (IsPrimitiveCNode(output_cnode, prim::kPrimMakeTuple)) {
66       auto &inputs = output_cnode->inputs();
67       for (size_t i = 1; i < inputs.size(); ++i) {
68         GetRealOutputNodes(output_cnode->input(i), real_outputs);
69       }
70     } else {
71       real_outputs->insert(output_cnode);
72     }
73   }
74 };
75 }  // namespace irpass
76 }  // namespace opt
77 }  // namespace mindspore
78 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_
79