1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_ 19 20 #include <unordered_set> 21 #include "frontend/optimizer/irpass.h" 22 #include "frontend/optimizer/optimizer.h" 23 #include "frontend/optimizer/anf_visitor.h" 24 #include "ir/func_graph.h" 25 26 namespace mindspore { 27 namespace opt { 28 namespace irpass { 29 class SetCellOutputNoRecompute : public AnfVisitor { 30 public: operator()31 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 32 if (!IsValueNode<FuncGraph>(node)) { 33 return nullptr; 34 } 35 36 auto fg = GetValueNode<FuncGraphPtr>(node); 37 if (fg == nullptr || !fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) { 38 return nullptr; 39 } 40 auto output = fg->output(); 41 if (output == nullptr) { 42 return nullptr; 43 } 44 if (output->isa<CNode>()) { 45 std::unordered_set<CNodePtr> real_outputs; 46 GetRealOutputNodes(output, &real_outputs); 47 for (const auto &real_output : real_outputs) { 48 // Set the attr of cnode in case of shared primitives. 49 real_output->AddAttr(kAttrRecompute, MakeValue(false)); 50 } 51 } 52 fg->erase_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE); 53 return nullptr; 54 } 55 GetRealOutputNodes(const AnfNodePtr & output,std::unordered_set<CNodePtr> * real_outputs)56 void GetRealOutputNodes(const AnfNodePtr &output, std::unordered_set<CNodePtr> *real_outputs) { 57 MS_EXCEPTION_IF_NULL(output); 58 MS_EXCEPTION_IF_NULL(real_outputs); 59 if (!output->isa<CNode>()) { 60 return; 61 } 62 auto output_cnode = output->cast<CNodePtr>(); 63 if (IsPrimitiveCNode(output_cnode, prim::kPrimDepend) || IsPrimitiveCNode(output_cnode, prim::kPrimTupleGetItem)) { 64 GetRealOutputNodes(output_cnode->input(kRealInputIndexInDepend), real_outputs); 65 } else if (IsPrimitiveCNode(output_cnode, prim::kPrimMakeTuple)) { 66 auto &inputs = output_cnode->inputs(); 67 for (size_t i = 1; i < inputs.size(); ++i) { 68 GetRealOutputNodes(output_cnode->input(i), real_outputs); 69 } 70 } else { 71 real_outputs->insert(output_cnode); 72 } 73 } 74 }; 75 } // namespace irpass 76 } // namespace opt 77 } // namespace mindspore 78 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_ 79