1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_ 19 20 #include "utils/hash_set.h" 21 #include "mindspore/core/ops/sequence_ops.h" 22 #include "mindspore/core/ops/framework_ops.h" 23 #include "frontend/optimizer/irpass.h" 24 #include "frontend/optimizer/optimizer.h" 25 #include "frontend/optimizer/anf_visitor.h" 26 #include "include/common/utils/parallel_context.h" 27 #include "ir/func_graph.h" 28 29 namespace mindspore { 30 namespace opt { 31 namespace irpass { 32 class SetCellOutputNoRecompute : public AnfVisitor { 33 public: operator()34 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 35 auto context = MsContext::GetInstance(); 36 MS_EXCEPTION_IF_NULL(context); 37 const auto no_cell_reuse = context->CellReuseLevel() == CellReuseLevel::kNoCellReuse; 38 if (!IsValueNode<FuncGraph>(node)) { 39 return nullptr; 40 } 41 42 auto fg = GetValueNode<FuncGraphPtr>(node); 43 if (fg == nullptr || !fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) { 44 return nullptr; 45 } 46 auto output = fg->output(); 47 if (output == nullptr) { 48 return nullptr; 49 } 50 if (output->isa<CNode>()) { 51 mindspore::HashSet<CNodePtr> real_outputs; 52 GetRealOutputNodes(output, &real_outputs); 53 if (OutputAllNodes(real_outputs)) { 54 MS_LOG(WARNING) 55 << "All nodes in the graph " << fg->ToString() 56 << " are the output nodes, which are set to not be recomputed. If you want to set these nodes to " 57 "be recomputed, use the api recompute() of Primitive."; 58 } 59 for (const auto &real_output : real_outputs) { 60 // Set the attr of cnode in case of shared primitives. 61 if (no_cell_reuse) { 62 real_output->AddAttr(kAttrRecompute, MakeValue(false)); 63 } 64 65 if (parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kSemiAutoParallel || 66 parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kAutoParallel) { 67 auto prim = GetCNodePrimitive(real_output); 68 if (prim->HasAttr(kAttrSliceActivation) && GetValue<bool>(prim->GetAttr(kAttrSliceActivation))) { 69 real_output->AddAttr(kAttrSliceActivation, MakeValue(true)); 70 } 71 } 72 } 73 } 74 if (no_cell_reuse) { 75 fg->erase_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE); 76 } 77 return nullptr; 78 } 79 GetRealOutputNodes(const AnfNodePtr & output,mindspore::HashSet<CNodePtr> * real_outputs)80 void GetRealOutputNodes(const AnfNodePtr &output, mindspore::HashSet<CNodePtr> *real_outputs) { 81 MS_EXCEPTION_IF_NULL(output); 82 MS_EXCEPTION_IF_NULL(real_outputs); 83 auto output_cnode = output->cast<CNodePtr>(); 84 if (output_cnode == nullptr) { 85 return; 86 } 87 auto input0 = output_cnode->input(0); 88 MS_EXCEPTION_IF_NULL(input0); 89 if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimTupleGetItem)) { 90 GetRealOutputNodes(output_cnode->input(kRealInputIndexInDepend), real_outputs); 91 } else if (IsPrimitive(input0, prim::kPrimMakeTuple)) { 92 auto &inputs = output_cnode->inputs(); 93 for (size_t i = 1; i < inputs.size(); ++i) { 94 GetRealOutputNodes(output_cnode->input(i), real_outputs); 95 } 96 } else if (IsValueNode<FuncGraph>(input0)) { 97 auto fg = GetValueNode<FuncGraphPtr>(input0); 98 GetRealOutputNodes(fg->output(), real_outputs); 99 } else if (input0->isa<CNode>()) { 100 auto abs = input0->abstract(); 101 if (abs == nullptr || !abs->isa<abstract::AbstractFunction>()) { 102 return; 103 } 104 auto abs_func = abs->cast<abstract::AbstractFunctionPtr>(); 105 if (abs_func->isa<abstract::AbstractFuncUnion>()) { 106 auto visit_fn = [this, &real_outputs](const abstract::AbstractFuncAtomPtr &poss) { 107 auto abs_fg = GetAbstractFuncGraph(poss); 108 if (abs_fg != nullptr) { 109 GetRealOutputNodes(abs_fg->output(), real_outputs); 110 } 111 }; 112 abs_func->Visit(visit_fn); 113 return; 114 } 115 auto fg = GetAbstractFuncGraph(abs_func); 116 if (fg != nullptr) { 117 GetRealOutputNodes(fg->output(), real_outputs); 118 } 119 } else { 120 real_outputs->insert(output_cnode); 121 } 122 } 123 GetAbstractFuncGraph(const abstract::AbstractFunctionPtr & abs)124 FuncGraphPtr GetAbstractFuncGraph(const abstract::AbstractFunctionPtr &abs) const { 125 if (abs->isa<abstract::FuncGraphAbstractClosure>()) { 126 auto abstract_func_graph = abs->cast<abstract::FuncGraphAbstractClosurePtr>(); 127 return abstract_func_graph->func_graph(); 128 } 129 if (abs->isa<abstract::PartialAbstractClosure>()) { 130 auto abstract_partial_func = abs->cast<abstract::PartialAbstractClosurePtr>(); 131 auto abstract_fn = abstract_partial_func->fn(); 132 if (abstract_fn != nullptr && abstract_fn->isa<abstract::FuncGraphAbstractClosure>()) { 133 auto abstract_func_graph = abstract_fn->cast<abstract::FuncGraphAbstractClosurePtr>(); 134 return abstract_func_graph->func_graph(); 135 } 136 } 137 return nullptr; 138 } 139 OutputAllNodes(const mindspore::HashSet<CNodePtr> & real_outputs)140 bool OutputAllNodes(const mindspore::HashSet<CNodePtr> &real_outputs) const { 141 for (const auto &cnode : real_outputs) { 142 const auto &inputs = cnode->inputs(); 143 for (const auto &input : inputs) { 144 auto input_cnode = input->cast<CNodePtr>(); 145 if (input_cnode == nullptr || IsPrimitiveCNode(input_cnode, prim::kPrimLoad)) { 146 continue; 147 } 148 if (real_outputs.find(input_cnode) == real_outputs.end()) { 149 return false; 150 } 151 } 152 } 153 return true; 154 } 155 }; 156 } // namespace irpass 157 } // namespace opt 158 } // namespace mindspore 159 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_ 160