• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/kernel_packet/kernel_packet_engine.h"
17 #include "utils/anf_utils.h"
18 
19 namespace mindspore {
20 namespace graphkernel {
21 namespace packet {
CloneAllAbstracts(const FuncGraphPtr & func_graph)22 void CloneAllAbstracts(const FuncGraphPtr &func_graph) {
23   auto nodes = TopoSort(func_graph->get_return(), SuccDeeperSimple, AlwaysInclude);
24   for (auto &node : nodes) {
25     auto old_abs = node->abstract();
26     if (old_abs == nullptr) {
27       continue;
28     }
29     auto new_abs = old_abs->Clone();
30     new_abs->SetSymbolicShape(nullptr);
31     new_abs->SetSymbolicValue(nullptr);
32     node->set_abstract(new_abs);
33   }
34 }
35 
SetBaseNodeDepend(const CNodePtr & basenode)36 void KernelPacketEngine::SetBaseNodeDepend(const CNodePtr &basenode) {
37   depend_status_map_[basenode].shape = true;
38   for (size_t i = 1; i < basenode->size(); i++) {
39     if (basenode->input(i)->isa<CNode>()) {
40       depend_status_map_[basenode->input(i)].value = true;
41     }
42   }
43 }
44 
Build(const FuncGraphPtr & func_graph)45 KernelPacketEnginePtr KernelPacketEngine::Build(const FuncGraphPtr &func_graph) {
46   CloneAllAbstracts(func_graph);
47   auto engine = std::make_shared<KernelPacketEngine>(func_graph);
48   func_graph->set_symbol_engine(engine);
49   auto basenode = func_graph->output()->cast<CNodePtr>();
50   MS_EXCEPTION_IF_NULL(basenode);
51   engine->SetBaseNodeDepend(basenode);
52   engine->PreBuild();
53   engine->BuildImpl();
54   return engine;
55 }
56 }  // namespace packet
57 }  // namespace graphkernel
58 }  // namespace mindspore
59