• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/graph_kernel/model/lite_graph.h"
17 
18 #include <set>
19 #include <utility>
20 
21 #include "backend/optimizer/graph_kernel/model/node.h"
22 #include "backend/optimizer/graph_kernel/model/op_node.h"
23 #include "backend/optimizer/graph_kernel/model/op_register.h"
24 
25 namespace mindspore {
26 namespace opt {
27 namespace graphkernel {
Dump() const28 std::string LiteGraph::Dump() const {
29   std::ostringstream os;
30   os << name_ << "(";
31   for (size_t i = 0; i < inputs_.size(); i++) {
32     os << inputs_[i]->name();
33     if (i != inputs_.size() - 1) os << ", ";
34   }
35   os << ") -> ";
36   auto &outputs = GetOutputs();
37   for (size_t i = 0; i < outputs.size(); i++) {
38     os << outputs[i]->name();
39     if (i != outputs.size() - 1) os << ", ";
40   }
41   os << " {\n";
42   for (NodePtr op : ops_) {
43     os << "  " << *op << "\n";
44   }
45   os << "}";
46   return os.str();
47 }
48 
GetOrderedNodes()49 const NodePtrList &LiteGraph::GetOrderedNodes() {
50   std::unordered_map<NodePtr, size_t> outdegrees;
51   std::function<void(NodePtr)> dfs;
52   std::set<NodePtr> visited;
53   dfs = [&dfs, &outdegrees, &visited](const NodePtr &node) {
54     (void)visited.insert(node);
55     for (auto &input : node->inputs()) {
56       if (input->NodeType() == NType::Primitive) {
57         ++outdegrees[input];
58         if (visited.count(input) == 0) {
59           dfs(input);
60         }
61       }
62     }
63   };
64   dfs(output_);
65   NodePtrList res;
66   NodePtrList stack;
67   stack.push_back(output_);
68   while (!stack.empty()) {
69     auto cur = stack.back();
70     stack.pop_back();
71     res.push_back(cur);
72     for (auto &input : cur->inputs()) {
73       if (input->NodeType() != NType::Primitive) continue;
74       --outdegrees[input];
75       if (outdegrees[input] == 0) {
76         stack.push_back(input);
77         (void)outdegrees.erase(input);
78       }
79     }
80   }
81   if (!outdegrees.empty()) {
82     MS_LOG(ERROR) << "Circle was found:";
83     for (auto &node : outdegrees) {
84       MS_LOG(ERROR) << "  " << *(node.first);
85     }
86     MS_LOG(EXCEPTION) << "Circle size: " << outdegrees.size();
87   }
88   std::reverse(res.begin(), res.end());
89   res.pop_back();  // erase the output node
90   ops_ = std::move(res);
91   return ops_;
92 }
93 
Emit(const std::string & op,const NodePtrList & inputs,const DAttrs & attrs,std::string node_name)94 NodePtr LiteGraph::GraphBuilder::Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs,
95                                       std::string node_name) {
96   if (node_name.empty()) node_name = NewName();
97   PrimOpPtr op_ptr = CreateOp(op, node_name);
98   auto baseinfo = op_ptr->Infer(inputs, attrs);
99   op_ptr->SetInputs(inputs);
100   op_ptr->SetAttrs(attrs);
101   op_ptr->SetBaseInfo(baseinfo);
102   return graph_->Add(op_ptr);
103 }
104 
Op(const std::string & op,const NodeBase & baseinfo,const NodePtrList & inputs,const DAttrs & attrs,std::string node_name)105 NodePtr LiteGraph::GraphBuilder::Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs,
106                                     const DAttrs &attrs, std::string node_name) {
107   if (node_name.empty()) node_name = NewName();
108   PrimOpPtr op_ptr = CreateOp(op, node_name);
109   op_ptr->SetInputs(inputs);
110   op_ptr->SetAttrs(attrs);
111   op_ptr->SetBaseInfo(baseinfo);
112   return graph_->Add(op_ptr);
113 }
114 
CreateOp(const std::string & op,const std::string & node_name)115 PrimOpPtr LiteGraph::GraphBuilder::CreateOp(const std::string &op, const std::string &node_name) {
116   return OpRegistry::Instance().NewOp(op, node_name);
117 }
118 }  // namespace graphkernel
119 }  // namespace opt
120 }  // namespace mindspore
121