1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifdef ENABLE_DRAW 18 #ifndef MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_SUB_GRAPH_KERNEL_ADAPTER_GRAPH_H_ 19 #define MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_SUB_GRAPH_KERNEL_ADAPTER_GRAPH_H_ 20 21 #include <utility> 22 #include <vector> 23 #include <memory> 24 #include <functional> 25 #include <string> 26 #include <unordered_map> 27 #include "src/common/log_adapter.h" 28 #include "src/common/draw/adapter_graph.h" 29 #include "src/common/draw/graphviz_graph_builder.h" 30 #include "include/errorcode.h" 31 #include "src/litert/kernel_exec_util.h" 32 #include "src/executor/kernel_exec.h" 33 #include "src/executor/sub_graph_kernel.h" 34 #include "src/common/draw/adapter_graphs/drawer_mark_filter.h" 35 36 namespace mindspore::lite { 37 class KernelExecAdapterNode : public AdapterNode { 38 public: 39 explicit KernelExecAdapterNode(const kernel::KernelExec *kernel, MarkFilter mark_filter = nullptr) kernel_(kernel)40 : kernel_(kernel), filter_(std::move(mark_filter)) {} 41 GetName()42 std::string GetName() const override { return kernel_->name(); } GetInputs()43 std::vector<Tensor *> GetInputs() const override { return kernel_->in_tensors(); } GetInput(const size_t & index)44 Tensor *GetInput(const size_t &index) const override { 45 if (index >= InputSize()) { 46 return nullptr; 47 } 48 return kernel_->in_tensors()[index]; 49 } InputSize()50 size_t InputSize() const override { return kernel_->in_tensors().size(); } GetOutputs()51 std::vector<Tensor *> GetOutputs() const override { return kernel_->out_tensors(); } GetOutput(const size_t & index)52 Tensor *GetOutput(const size_t &index) const override { 53 if (index >= OutputSize()) { 54 return nullptr; 55 } 56 return kernel_->out_tensors()[index]; 57 } OutputSize()58 size_t OutputSize() const override { return kernel_->out_tensors().size(); } 59 IsHighlight()60 bool IsHighlight() const override { 61 if (filter_ == nullptr) { 62 return false; 63 } 64 return filter_(*kernel_); 65 } 66 67 private: 68 const kernel::KernelExec *kernel_; 69 const MarkFilter filter_; 70 }; 71 72 class SubGraphKernelAdapterGraph : public AdapterGraph { 73 public: 74 static std::shared_ptr<SubGraphKernelAdapterGraph> Create(const kernel::SubGraphKernel *graph, 75 const MarkFilter &mark_filter = nullptr) { 76 auto adapter_graph = std::make_shared<SubGraphKernelAdapterGraph>(graph); 77 auto nodes = graph->immutable_nodes(); 78 auto ret = kernel::KernelExecUtil::TopologicalSortNodes(&nodes, graph->in_nodes()); 79 if (ret != RET_OK) { 80 MS_LOG(ERROR) << "TopologicalSortNodes failed"; 81 return nullptr; 82 } 83 for (auto node : nodes) { 84 adapter_graph->nodes_.emplace_back(new KernelExecAdapterNode(node, mark_filter)); 85 } 86 return adapter_graph; 87 } 88 SubGraphKernelAdapterGraph(const kernel::SubGraphKernel * graph)89 explicit SubGraphKernelAdapterGraph(const kernel::SubGraphKernel *graph) : graph_(graph) {} ~SubGraphKernelAdapterGraph()90 ~SubGraphKernelAdapterGraph() override { 91 for (auto node : nodes_) { 92 delete node; 93 } 94 nodes_.clear(); 95 } GetName()96 std::string GetName() const override { return graph_->name(); } GetNodes()97 std::vector<AdapterNode *> GetNodes() const override { return nodes_; } GetInputs()98 std::vector<Tensor *> GetInputs() const override { return graph_->in_tensors(); } InputSize()99 size_t InputSize() const override { return graph_->in_tensors().size(); } GetOutputs()100 std::vector<Tensor *> GetOutputs() const override { return graph_->out_tensors(); } OutputSize()101 size_t OutputSize() const override { return graph_->out_tensors().size(); } 102 103 private: 104 const kernel::SubGraphKernel *graph_; 105 std::vector<AdapterNode *> nodes_; 106 }; 107 108 std::shared_ptr<GVGraph> CreateGVGraph(const kernel::SubGraphKernel *graph, const MarkFilter &mark_filter = nullptr) { 109 auto adapter_graph = SubGraphKernelAdapterGraph::Create(graph, mark_filter); 110 if (adapter_graph == nullptr) { 111 MS_LOG(ERROR) << "Create SubGraphKernelAdapterGraph failed."; 112 return nullptr; 113 } 114 GVGraphBuilder builder; 115 auto gv_graph = builder.Build(adapter_graph); 116 if (gv_graph == nullptr) { 117 MS_LOG(ERROR) << "Build gv_graph failed."; 118 return nullptr; 119 } 120 return gv_graph; 121 } 122 } // namespace mindspore::lite 123 124 #endif // MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_SUB_GRAPH_KERNEL_ADAPTER_GRAPH_H_ 125 #endif 126