1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include "backend/common/graph_kernel/kernel_packet/kernel_packet_engine.h" 17 #include "utils/anf_utils.h" 18 19 namespace mindspore { 20 namespace graphkernel { 21 namespace packet { CloneAllAbstracts(const FuncGraphPtr & func_graph)22void CloneAllAbstracts(const FuncGraphPtr &func_graph) { 23 auto nodes = TopoSort(func_graph->get_return(), SuccDeeperSimple, AlwaysInclude); 24 for (auto &node : nodes) { 25 auto old_abs = node->abstract(); 26 if (old_abs == nullptr) { 27 continue; 28 } 29 auto new_abs = old_abs->Clone(); 30 new_abs->SetSymbolicShape(nullptr); 31 new_abs->SetSymbolicValue(nullptr); 32 node->set_abstract(new_abs); 33 } 34 } 35 SetBaseNodeDepend(const CNodePtr & basenode)36void KernelPacketEngine::SetBaseNodeDepend(const CNodePtr &basenode) { 37 depend_status_map_[basenode].shape = true; 38 for (size_t i = 1; i < basenode->size(); i++) { 39 if (basenode->input(i)->isa<CNode>()) { 40 depend_status_map_[basenode->input(i)].value = true; 41 } 42 } 43 } 44 Build(const FuncGraphPtr & func_graph)45KernelPacketEnginePtr KernelPacketEngine::Build(const FuncGraphPtr &func_graph) { 46 CloneAllAbstracts(func_graph); 47 auto engine = std::make_shared<KernelPacketEngine>(func_graph); 48 func_graph->set_symbol_engine(engine); 49 auto basenode = func_graph->output()->cast<CNodePtr>(); 50 MS_EXCEPTION_IF_NULL(basenode); 51 engine->SetBaseNodeDepend(basenode); 52 engine->PreBuild(); 53 engine->BuildImpl(); 54 return engine; 55 } 56 } // namespace packet 57 } // namespace graphkernel 58 } // namespace mindspore 59