• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifdef ENABLE_DRAW
18 #ifndef MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_COMPILE_RESULT_ADAPTER_GRAPH_H_
19 #define MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_COMPILE_RESULT_ADAPTER_GRAPH_H_
20 
21 #include <utility>
22 #include <vector>
23 #include <memory>
24 #include <functional>
25 #include <string>
26 #include <unordered_map>
27 #include "src/common/log_adapter.h"
28 #include "src/common/draw/adapter_graph.h"
29 #include "src/common/draw/graphviz_graph_builder.h"
30 #include "include/errorcode.h"
31 #include "src/extendrt/graph_compiler/compile_result.h"
32 
33 namespace mindspore::lite {
34 class CompileNodeAdapterNode : public AdapterNode {
35  public:
CompileNodeAdapterNode(CompileNodePtr node)36   explicit CompileNodeAdapterNode(CompileNodePtr node) : node_(std::move(node)) {}
37 
GetName()38   std::string GetName() const override { return node_->GetName(); }
GetInputs()39   std::vector<Tensor *> GetInputs() const override { return node_->GetInputs(); }
GetInput(const size_t & index)40   Tensor *GetInput(const size_t &index) const override {
41     if (index >= InputSize()) {
42       return nullptr;
43     }
44     return node_->GetInput(index);
45   }
InputSize()46   size_t InputSize() const override { return node_->InputSize(); }
GetOutputs()47   std::vector<Tensor *> GetOutputs() const override { return node_->GetOutputs(); }
GetOutput(const size_t & index)48   Tensor *GetOutput(const size_t &index) const override {
49     if (index >= OutputSize()) {
50       return nullptr;
51     }
52     return node_->GetOutput(index);
53   }
OutputSize()54   size_t OutputSize() const override { return node_->OutputSize(); }
55 
56  private:
57   const CompileNodePtr node_;
58 };
59 
60 class CompileResultAdapterGraph : public AdapterGraph {
61  public:
Create(const CompileResult * graph)62   static std::shared_ptr<CompileResultAdapterGraph> Create(const CompileResult *graph) {
63     auto adapter_graph = std::make_shared<CompileResultAdapterGraph>(graph);
64     for (const auto &node : graph->GetNodes()) {
65       adapter_graph->nodes_.emplace_back(new CompileNodeAdapterNode(node));
66     }
67     return adapter_graph;
68   }
69 
CompileResultAdapterGraph(const CompileResult * graph)70   explicit CompileResultAdapterGraph(const CompileResult *graph) : graph_(graph) {}
~CompileResultAdapterGraph()71   ~CompileResultAdapterGraph() override {
72     for (auto node : nodes_) {
73       delete node;
74     }
75     nodes_.clear();
76   }
GetName()77   std::string GetName() const override { return "CompileResult"; }
GetNodes()78   std::vector<AdapterNode *> GetNodes() const override { return nodes_; }
GetInputs()79   std::vector<Tensor *> GetInputs() const override { return graph_->GetInputs(); }
InputSize()80   size_t InputSize() const override { return graph_->InputSize(); }
GetOutputs()81   std::vector<Tensor *> GetOutputs() const override { return graph_->GetOutputs(); }
OutputSize()82   size_t OutputSize() const override { return graph_->OutputSize(); }
83 
84  private:
85   const CompileResult *graph_;
86   std::vector<AdapterNode *> nodes_;
87 };
88 
CreateGVGraph(const CompileResult * graph)89 std::shared_ptr<GVGraph> CreateGVGraph(const CompileResult *graph) {
90   auto adapter_graph = CompileResultAdapterGraph::Create(graph);
91   if (adapter_graph == nullptr) {
92     MS_LOG(ERROR) << "Create CompileResultAdapterGraph failed.";
93     return nullptr;
94   }
95   GVGraphBuilder builder;
96   auto gv_graph = builder.Build(adapter_graph);
97   if (gv_graph == nullptr) {
98     MS_LOG(ERROR) << "Build gv_graph failed.";
99     return nullptr;
100   }
101   return gv_graph;
102 }
103 }  // namespace mindspore::lite
104 
105 #endif  // MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_COMPILE_RESULT_ADAPTER_GRAPH_H_
106 #endif
107