• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifdef ENABLE_DRAW
18 #ifndef MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_SUB_GRAPH_KERNEL_ADAPTER_GRAPH_H_
19 #define MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_SUB_GRAPH_KERNEL_ADAPTER_GRAPH_H_
20 
21 #include <utility>
22 #include <vector>
23 #include <memory>
24 #include <functional>
25 #include <string>
26 #include <unordered_map>
27 #include "src/common/log_adapter.h"
28 #include "src/common/draw/adapter_graph.h"
29 #include "src/common/draw/graphviz_graph_builder.h"
30 #include "include/errorcode.h"
31 #include "src/litert/kernel_exec_util.h"
32 #include "src/executor/kernel_exec.h"
33 #include "src/executor/sub_graph_kernel.h"
34 #include "src/common/draw/adapter_graphs/drawer_mark_filter.h"
35 
36 namespace mindspore::lite {
37 class KernelExecAdapterNode : public AdapterNode {
38  public:
39   explicit KernelExecAdapterNode(const kernel::KernelExec *kernel, MarkFilter mark_filter = nullptr)
kernel_(kernel)40       : kernel_(kernel), filter_(std::move(mark_filter)) {}
41 
GetName()42   std::string GetName() const override { return kernel_->name(); }
GetInputs()43   std::vector<Tensor *> GetInputs() const override { return kernel_->in_tensors(); }
GetInput(const size_t & index)44   Tensor *GetInput(const size_t &index) const override {
45     if (index >= InputSize()) {
46       return nullptr;
47     }
48     return kernel_->in_tensors()[index];
49   }
InputSize()50   size_t InputSize() const override { return kernel_->in_tensors().size(); }
GetOutputs()51   std::vector<Tensor *> GetOutputs() const override { return kernel_->out_tensors(); }
GetOutput(const size_t & index)52   Tensor *GetOutput(const size_t &index) const override {
53     if (index >= OutputSize()) {
54       return nullptr;
55     }
56     return kernel_->out_tensors()[index];
57   }
OutputSize()58   size_t OutputSize() const override { return kernel_->out_tensors().size(); }
59 
IsHighlight()60   bool IsHighlight() const override {
61     if (filter_ == nullptr) {
62       return false;
63     }
64     return filter_(*kernel_);
65   }
66 
67  private:
68   const kernel::KernelExec *kernel_;
69   const MarkFilter filter_;
70 };
71 
72 class SubGraphKernelAdapterGraph : public AdapterGraph {
73  public:
74   static std::shared_ptr<SubGraphKernelAdapterGraph> Create(const kernel::SubGraphKernel *graph,
75                                                             const MarkFilter &mark_filter = nullptr) {
76     auto adapter_graph = std::make_shared<SubGraphKernelAdapterGraph>(graph);
77     auto nodes = graph->immutable_nodes();
78     auto ret = kernel::KernelExecUtil::TopologicalSortNodes(&nodes, graph->in_nodes());
79     if (ret != RET_OK) {
80       MS_LOG(ERROR) << "TopologicalSortNodes failed";
81       return nullptr;
82     }
83     for (auto node : nodes) {
84       adapter_graph->nodes_.emplace_back(new KernelExecAdapterNode(node, mark_filter));
85     }
86     return adapter_graph;
87   }
88 
SubGraphKernelAdapterGraph(const kernel::SubGraphKernel * graph)89   explicit SubGraphKernelAdapterGraph(const kernel::SubGraphKernel *graph) : graph_(graph) {}
~SubGraphKernelAdapterGraph()90   ~SubGraphKernelAdapterGraph() override {
91     for (auto node : nodes_) {
92       delete node;
93     }
94     nodes_.clear();
95   }
GetName()96   std::string GetName() const override { return graph_->name(); }
GetNodes()97   std::vector<AdapterNode *> GetNodes() const override { return nodes_; }
GetInputs()98   std::vector<Tensor *> GetInputs() const override { return graph_->in_tensors(); }
InputSize()99   size_t InputSize() const override { return graph_->in_tensors().size(); }
GetOutputs()100   std::vector<Tensor *> GetOutputs() const override { return graph_->out_tensors(); }
OutputSize()101   size_t OutputSize() const override { return graph_->out_tensors().size(); }
102 
103  private:
104   const kernel::SubGraphKernel *graph_;
105   std::vector<AdapterNode *> nodes_;
106 };
107 
108 std::shared_ptr<GVGraph> CreateGVGraph(const kernel::SubGraphKernel *graph, const MarkFilter &mark_filter = nullptr) {
109   auto adapter_graph = SubGraphKernelAdapterGraph::Create(graph, mark_filter);
110   if (adapter_graph == nullptr) {
111     MS_LOG(ERROR) << "Create SubGraphKernelAdapterGraph failed.";
112     return nullptr;
113   }
114   GVGraphBuilder builder;
115   auto gv_graph = builder.Build(adapter_graph);
116   if (gv_graph == nullptr) {
117     MS_LOG(ERROR) << "Build gv_graph failed.";
118     return nullptr;
119   }
120   return gv_graph;
121 }
122 }  // namespace mindspore::lite
123 
124 #endif  // MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_SUB_GRAPH_KERNEL_ADAPTER_GRAPH_H_
125 #endif
126