1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/graph_kernel/model/lite_graph.h"
17
18 #include <set>
19 #include <utility>
20
21 #include "backend/optimizer/graph_kernel/model/node.h"
22 #include "backend/optimizer/graph_kernel/model/op_node.h"
23 #include "backend/optimizer/graph_kernel/model/op_register.h"
24
25 namespace mindspore {
26 namespace opt {
27 namespace graphkernel {
Dump() const28 std::string LiteGraph::Dump() const {
29 std::ostringstream os;
30 os << name_ << "(";
31 for (size_t i = 0; i < inputs_.size(); i++) {
32 os << inputs_[i]->name();
33 if (i != inputs_.size() - 1) os << ", ";
34 }
35 os << ") -> ";
36 auto &outputs = GetOutputs();
37 for (size_t i = 0; i < outputs.size(); i++) {
38 os << outputs[i]->name();
39 if (i != outputs.size() - 1) os << ", ";
40 }
41 os << " {\n";
42 for (NodePtr op : ops_) {
43 os << " " << *op << "\n";
44 }
45 os << "}";
46 return os.str();
47 }
48
GetOrderedNodes()49 const NodePtrList &LiteGraph::GetOrderedNodes() {
50 std::unordered_map<NodePtr, size_t> outdegrees;
51 std::function<void(NodePtr)> dfs;
52 std::set<NodePtr> visited;
53 dfs = [&dfs, &outdegrees, &visited](const NodePtr &node) {
54 (void)visited.insert(node);
55 for (auto &input : node->inputs()) {
56 if (input->NodeType() == NType::Primitive) {
57 ++outdegrees[input];
58 if (visited.count(input) == 0) {
59 dfs(input);
60 }
61 }
62 }
63 };
64 dfs(output_);
65 NodePtrList res;
66 NodePtrList stack;
67 stack.push_back(output_);
68 while (!stack.empty()) {
69 auto cur = stack.back();
70 stack.pop_back();
71 res.push_back(cur);
72 for (auto &input : cur->inputs()) {
73 if (input->NodeType() != NType::Primitive) continue;
74 --outdegrees[input];
75 if (outdegrees[input] == 0) {
76 stack.push_back(input);
77 (void)outdegrees.erase(input);
78 }
79 }
80 }
81 if (!outdegrees.empty()) {
82 MS_LOG(ERROR) << "Circle was found:";
83 for (auto &node : outdegrees) {
84 MS_LOG(ERROR) << " " << *(node.first);
85 }
86 MS_LOG(EXCEPTION) << "Circle size: " << outdegrees.size();
87 }
88 std::reverse(res.begin(), res.end());
89 res.pop_back(); // erase the output node
90 ops_ = std::move(res);
91 return ops_;
92 }
93
Emit(const std::string & op,const NodePtrList & inputs,const DAttrs & attrs,std::string node_name)94 NodePtr LiteGraph::GraphBuilder::Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs,
95 std::string node_name) {
96 if (node_name.empty()) node_name = NewName();
97 PrimOpPtr op_ptr = CreateOp(op, node_name);
98 auto baseinfo = op_ptr->Infer(inputs, attrs);
99 op_ptr->SetInputs(inputs);
100 op_ptr->SetAttrs(attrs);
101 op_ptr->SetBaseInfo(baseinfo);
102 return graph_->Add(op_ptr);
103 }
104
Op(const std::string & op,const NodeBase & baseinfo,const NodePtrList & inputs,const DAttrs & attrs,std::string node_name)105 NodePtr LiteGraph::GraphBuilder::Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs,
106 const DAttrs &attrs, std::string node_name) {
107 if (node_name.empty()) node_name = NewName();
108 PrimOpPtr op_ptr = CreateOp(op, node_name);
109 op_ptr->SetInputs(inputs);
110 op_ptr->SetAttrs(attrs);
111 op_ptr->SetBaseInfo(baseinfo);
112 return graph_->Add(op_ptr);
113 }
114
CreateOp(const std::string & op,const std::string & node_name)115 PrimOpPtr LiteGraph::GraphBuilder::CreateOp(const std::string &op, const std::string &node_name) {
116 return OpRegistry::Instance().NewOp(op, node_name);
117 }
118 } // namespace graphkernel
119 } // namespace opt
120 } // namespace mindspore
121