1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifdef ENABLE_DRAW 18 #ifndef MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_COMPILE_RESULT_ADAPTER_GRAPH_H_ 19 #define MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_COMPILE_RESULT_ADAPTER_GRAPH_H_ 20 21 #include <utility> 22 #include <vector> 23 #include <memory> 24 #include <functional> 25 #include <string> 26 #include <unordered_map> 27 #include "src/common/log_adapter.h" 28 #include "src/common/draw/adapter_graph.h" 29 #include "src/common/draw/graphviz_graph_builder.h" 30 #include "include/errorcode.h" 31 #include "src/extendrt/graph_compiler/compile_result.h" 32 33 namespace mindspore::lite { 34 class CompileNodeAdapterNode : public AdapterNode { 35 public: CompileNodeAdapterNode(CompileNodePtr node)36 explicit CompileNodeAdapterNode(CompileNodePtr node) : node_(std::move(node)) {} 37 GetName()38 std::string GetName() const override { return node_->GetName(); } GetInputs()39 std::vector<Tensor *> GetInputs() const override { return node_->GetInputs(); } GetInput(const size_t & index)40 Tensor *GetInput(const size_t &index) const override { 41 if (index >= InputSize()) { 42 return nullptr; 43 } 44 return node_->GetInput(index); 45 } InputSize()46 size_t InputSize() const override { return node_->InputSize(); } GetOutputs()47 std::vector<Tensor *> GetOutputs() const override { return node_->GetOutputs(); } GetOutput(const size_t & index)48 Tensor *GetOutput(const size_t &index) const override { 49 if (index >= OutputSize()) { 50 return nullptr; 51 } 52 return node_->GetOutput(index); 53 } OutputSize()54 size_t OutputSize() const override { return node_->OutputSize(); } 55 56 private: 57 const CompileNodePtr node_; 58 }; 59 60 class CompileResultAdapterGraph : public AdapterGraph { 61 public: Create(const CompileResult * graph)62 static std::shared_ptr<CompileResultAdapterGraph> Create(const CompileResult *graph) { 63 auto adapter_graph = std::make_shared<CompileResultAdapterGraph>(graph); 64 for (const auto &node : graph->GetNodes()) { 65 adapter_graph->nodes_.emplace_back(new CompileNodeAdapterNode(node)); 66 } 67 return adapter_graph; 68 } 69 CompileResultAdapterGraph(const CompileResult * graph)70 explicit CompileResultAdapterGraph(const CompileResult *graph) : graph_(graph) {} ~CompileResultAdapterGraph()71 ~CompileResultAdapterGraph() override { 72 for (auto node : nodes_) { 73 delete node; 74 } 75 nodes_.clear(); 76 } GetName()77 std::string GetName() const override { return "CompileResult"; } GetNodes()78 std::vector<AdapterNode *> GetNodes() const override { return nodes_; } GetInputs()79 std::vector<Tensor *> GetInputs() const override { return graph_->GetInputs(); } InputSize()80 size_t InputSize() const override { return graph_->InputSize(); } GetOutputs()81 std::vector<Tensor *> GetOutputs() const override { return graph_->GetOutputs(); } OutputSize()82 size_t OutputSize() const override { return graph_->OutputSize(); } 83 84 private: 85 const CompileResult *graph_; 86 std::vector<AdapterNode *> nodes_; 87 }; 88 CreateGVGraph(const CompileResult * graph)89std::shared_ptr<GVGraph> CreateGVGraph(const CompileResult *graph) { 90 auto adapter_graph = CompileResultAdapterGraph::Create(graph); 91 if (adapter_graph == nullptr) { 92 MS_LOG(ERROR) << "Create CompileResultAdapterGraph failed."; 93 return nullptr; 94 } 95 GVGraphBuilder builder; 96 auto gv_graph = builder.Build(adapter_graph); 97 if (gv_graph == nullptr) { 98 MS_LOG(ERROR) << "Build gv_graph failed."; 99 return nullptr; 100 } 101 return gv_graph; 102 } 103 } // namespace mindspore::lite 104 105 #endif // MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_COMPILE_RESULT_ADAPTER_GRAPH_H_ 106 #endif 107