• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_COMMON_DRAW_GRAPHVIZ_GRAPH_H_
18 #define MINDSPORE_LITE_SRC_COMMON_DRAW_GRAPHVIZ_GRAPH_H_
19 
20 #include <utility>
21 #include <vector>
22 #include <string>
23 #include <unordered_map>
24 #include "src/common/log_adapter.h"
25 #include "src/tensor.h"
26 #include "include/errorcode.h"
27 
28 namespace mindspore::lite {
29 constexpr int kNodeTypeCNode = 0;
30 constexpr int kNodeTypeInput = 1;
31 constexpr int kNodeTypeOutput = 2;
32 constexpr int kNodeTypeWeight = 3;
33 class GVNode;
34 
35 class Edge {
36  public:
Edge(std::string name,const GVNode * from,const size_t & from_port,std::string info)37   Edge(std::string name, const GVNode *from, const size_t &from_port, std::string info)
38       : name_(std::move(name)), from_(from), from_port_(from_port), info_(std::move(info)) {}
39 
40   std::string From() const;
41   std::string name() const;
42   void AppendOutput(const GVNode *to, size_t port);
43   std::string Code() const;
44 
45  private:
46   std::string name_;
47   const GVNode *from_{nullptr};
48   const size_t from_port_{};
49   std::vector<const GVNode *> tos_{};
50   std::vector<size_t> to_ports_{};
51   std::string info_;
52 };
53 
54 class GVNode {
55  public:
56   static GVNode *CreateCNode(const std::string &id, const std::string &label, size_t input_size,
57                              const std::vector<std::string> &output_names, const std::vector<std::string> &output_infos,
58                              bool highlight = false);
59   static GVNode *CreateInput(const std::string &id, const std::vector<std::string> &output_names,
60                              const std::vector<std::string> &output_infos, bool highlight = false);
61   static GVNode *CreateOutput(const std::string &id, size_t input_size, bool highlight = false);
62   static GVNode *CreateWeight(const std::string &id, const std::string &label,
63                               const std::vector<std::string> &output_names,
64                               const std::vector<std::string> &output_infos, bool highlight = false);
65   virtual ~GVNode();
66 
type()67   int type() const { return this->type_; }
prefix()68   std::string prefix() const { return this->prefix_; }
name()69   std::string name() const { return this->id_; }
input_size()70   size_t input_size() const { return input_size_; }
output_size()71   size_t output_size() const { return output_size_; }
inputs()72   const std::vector<Edge *> &inputs() const { return inputs_; }
outputs()73   const std::vector<Edge *> &outputs() const { return outputs_; }
AppendInput(Edge * edge)74   void AppendInput(Edge *edge) { this->inputs_.emplace_back(edge); }
75   std::string Code() const;
76 
77  protected:
78   GVNode(std::string id, std::string label, int type, size_t input_size, size_t output_size, bool highlight = false)
id_(std::move (id))79       : id_(std::move(id)),
80         label_(std::move(label)),
81         type_(type),
82         input_size_(input_size),
83         output_size_(output_size),
84         highlight_(highlight) {}
85   void Init(const std::vector<std::string> &output_names, const std::vector<std::string> &output_infos);
86   size_t FindCols() const;
87 
88  private:
89   std::string id_;
90   std::string label_;
91   int type_;
92   std::string prefix_;
93   std::string color_ = "white";
94   size_t input_size_{0};
95   size_t output_size_{0};
96   std::string shape_;
97   bool highlight_{false};
98   std::vector<Edge *> inputs_{};
99   std::vector<Edge *> outputs_{};
100 };
101 
102 class GVGraph {
103  public:
GVGraph(std::string name)104   explicit GVGraph(std::string name) : name_{std::move(name)} {};
105   virtual ~GVGraph();
106 
107   void AppendNode(GVNode *node);
108   int Link(const std::string &from_name, size_t from_port, const std::string &to_name, size_t to_port);
109   std::string Code() const;
110 
111  private:
112   std::string name_;
113   std::vector<GVNode *> nodes_;
114   std::unordered_map<std::string, GVNode *> node_map_;
115 };
116 }  // namespace mindspore::lite
117 
118 #endif  // MINDSPORE_LITE_SRC_COMMON_DRAW_GRAPHVIZ_GRAPH_H_
119