1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/extendrt/graph_compiler/compile_result.h"
18 #include <string>
19 #include <memory>
20 #include <sstream>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 #include "ops/base_operator.h"
25 #include "utils/hash_map.h"
26 #include "include/api/status.h"
27 #include "ir/primitive.h"
28 #include "ops/op_name.h"
29 #include "ops/primitive_c.h"
30 #include "src/common/file_utils.h"
31
32 namespace mindspore {
33 namespace lite {
34 namespace {
35 constexpr char tab[] = " ";
36
GenIndent(int indent)37 inline std::string GenIndent(int indent) {
38 std::ostringstream oss;
39 for (int i = 0; i < indent; i++) {
40 oss << tab;
41 }
42 return oss.str();
43 }
44
DumpIntShape(const std::vector<int> & shape)45 inline std::string DumpIntShape(const std::vector<int> &shape) {
46 std::ostringstream oss;
47 oss << "[";
48 for (size_t i = 0; i < shape.size(); i++) {
49 oss << shape[i];
50 if (i < shape.size() - 1) {
51 oss << ", ";
52 }
53 }
54 oss << "]";
55 return oss.str();
56 }
57
DumpTensor(const InferTensor * tensor,int indent=0)58 inline std::string DumpTensor(const InferTensor *tensor, int indent = 0) {
59 std::ostringstream oss;
60 oss << GenIndent(indent) << "Tensor <name:" << tensor->tensor_name() << ", dtype:" << tensor->data_type()
61 << ", format:" << tensor->format() << ", shape:" << DumpIntShape(tensor->shape()) << ">";
62 return oss.str();
63 }
64 } // namespace
65
GetKernelAttr() const66 kernel::KernelAttr CompileNode::GetKernelAttr() const {
67 kernel::KernelAttr attr;
68 for (auto &input : inputs_) {
69 (void)attr.AddInputAttr(input->data_type(), FormatEnumToString(input->format()));
70 }
71 for (auto &output : outputs_) {
72 (void)attr.AddOutputAttr(output->data_type(), FormatEnumToString(output->format()));
73 }
74 return attr;
75 }
76
Create(CNodePtr cnode)77 CompileNodePtr CompileNode::Create(CNodePtr cnode) {
78 if (cnode == nullptr) {
79 return nullptr;
80 }
81 auto primitive = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0));
82 if (primitive == nullptr) {
83 MS_LOG(ERROR) << "Node has no primitive, first input of cnode(" << cnode->fullname_with_scope()
84 << ") is : " << cnode->input(0);
85 return nullptr;
86 }
87 auto node = std::make_shared<CompileNode>(cnode->fullname_with_scope(), kernel::PrimitiveType(primitive->name()));
88 ops::PrimitiveCPtr primc{nullptr};
89 if (utils::isa<ops::PrimitiveCPtr>(primitive)) {
90 primc = utils::cast<ops::PrimitiveCPtr>(primitive);
91 } else {
92 static auto ops_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
93 auto primc_creator_iter = ops_primc_fns.find(node->type_.TypeName());
94 if (primc_creator_iter == ops_primc_fns.end()) {
95 MS_LOG(ERROR) << "Can not find primitive_c create function for: " << node->type_;
96 return nullptr;
97 }
98 primc = primc_creator_iter->second();
99 if (primc == nullptr) {
100 MS_LOG(ERROR) << "Create primitive_c failed, type: " << node->type_;
101 return nullptr;
102 }
103 (void)primc->SetAttrs(primitive->attrs());
104 }
105 static auto baseops_fns = ops::OperatorRegister::GetInstance().GetOperatorMap();
106 auto baseops_creator_iter = baseops_fns.find(node->type_.TypeName());
107 if (baseops_creator_iter == baseops_fns.end()) {
108 MS_LOG(ERROR) << "Can not find base-operator create function for: " << node->type_;
109 return nullptr;
110 }
111 auto baseops_creator = baseops_creator_iter->second;
112 node->base_operator_ = baseops_creator(primc);
113 if (node->base_operator_ == nullptr) {
114 MS_LOG(ERROR) << "Create base-operator failed, type: " << node->type_;
115 return nullptr;
116 }
117 node->cnode_ = std::move(cnode);
118 return node;
119 }
120
AppendInputTensor(InferTensor * tensor)121 void CompileNode::AppendInputTensor(InferTensor *tensor) { this->inputs_.emplace_back(tensor); }
122
AppendOutputTensor(InferTensor * tensor)123 void CompileNode::AppendOutputTensor(InferTensor *tensor) { this->outputs_.emplace_back(tensor); }
124
Dump(int indent) const125 std::string CompileNode::Dump(int indent) const {
126 constexpr int kNumberTwo = 2;
127 std::ostringstream oss;
128 oss << GenIndent(indent) << "CompileNode <name:" << name_ << ", type:" << type_ << "> {" << std::endl;
129 oss << GenIndent(indent + 1) << "inputs: [" << std::endl;
130 for (auto &input : inputs_) {
131 oss << DumpTensor(input, indent + kNumberTwo) << std::endl;
132 }
133 oss << GenIndent(indent + 1) << "]" << std::endl;
134 oss << GenIndent(indent + 1) << "outputs: [" << std::endl;
135 for (auto &output : outputs_) {
136 oss << DumpTensor(output, indent + kNumberTwo) << std::endl;
137 }
138 oss << GenIndent(indent + 1) << "]" << std::endl;
139 oss << GenIndent(indent) << "}";
140 return oss.str();
141 }
142
ReplaceInputTensor(InferTensor * dst,const InferTensor * src)143 void CompileNode::ReplaceInputTensor(InferTensor *dst, const InferTensor *src) {
144 std::replace_if(
145 inputs_.begin(), inputs_.end(), [&src](InferTensor *ele) { return ele == src; }, dst);
146 }
147
GetNode(const std::string & name)148 CompileNodePtr CompileResult::GetNode(const std::string &name) {
149 auto iter = node_map_.find(name);
150 if (iter == node_map_.end()) {
151 return nullptr;
152 } else {
153 return iter->second;
154 }
155 }
156
GetArgNode(const std::string & name)157 CompileNodePtr CompileResult::GetArgNode(const std::string &name) {
158 auto iter = arg_node_map_.find(name);
159 if (iter == arg_node_map_.end()) {
160 return nullptr;
161 } else {
162 return iter->second;
163 }
164 }
165
GetMutableNodes()166 std::vector<CompileNodePtr> &CompileResult::GetMutableNodes() {
167 if (assembled_) {
168 MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
169 }
170 return nodes_;
171 }
GetMutableInputs()172 std::vector<InferTensor *> &CompileResult::GetMutableInputs() {
173 if (assembled_) {
174 MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
175 }
176 return inputs_;
177 }
178
GetMutableOutputs()179 std::vector<InferTensor *> &CompileResult::GetMutableOutputs() {
180 if (assembled_) {
181 MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
182 }
183 return outputs_;
184 }
185
AppendNode(CompileNodePtr node)186 StatusCode CompileResult::AppendNode(CompileNodePtr node) {
187 if (assembled_) {
188 MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
189 }
190 if (node == nullptr) {
191 MS_LOG(ERROR) << "Input node is nullptr";
192 return kLiteInputParamInvalid;
193 }
194 const std::string &node_name = node->GetName();
195 auto iter = node_map_.find(node_name);
196 if (iter != node_map_.end()) {
197 MS_LOG(ERROR) << "Duplicated node name : " << node_name;
198 return kLiteError;
199 }
200 node_map_[node_name] = node;
201 nodes_.emplace_back(node);
202 return kSuccess;
203 }
204
AppendArgNode(CompileNodePtr node)205 StatusCode CompileResult::AppendArgNode(CompileNodePtr node) {
206 if (assembled_) {
207 MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
208 }
209 if (node == nullptr) {
210 MS_LOG(ERROR) << "Input node is nullptr";
211 return kLiteInputParamInvalid;
212 }
213 const std::string &node_name = node->GetName();
214 auto iter = arg_node_map_.find(node_name);
215 if (iter != arg_node_map_.end()) {
216 MS_LOG(ERROR) << "Duplicated node name : " << node_name;
217 return kLiteError;
218 }
219 arg_node_map_[node_name] = node;
220 arg_nodes_.emplace_back(node);
221 return kSuccess;
222 }
223
AppendTensor(InferTensor * tensor)224 StatusCode CompileResult::AppendTensor(InferTensor *tensor) {
225 if (assembled_) {
226 MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
227 }
228 if (tensor == nullptr) {
229 MS_LOG(ERROR) << "Input tensor is nullptr";
230 return kLiteInputParamInvalid;
231 }
232 tensors_.emplace_back(tensor);
233 return kSuccess;
234 }
235
AppendInputTensor(InferTensor * tensor,bool is_borrow)236 StatusCode CompileResult::AppendInputTensor(InferTensor *tensor, bool is_borrow) {
237 if (assembled_) {
238 MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
239 }
240 if (tensor == nullptr) {
241 MS_LOG(ERROR) << "Input tensor is nullptr";
242 return kLiteInputParamInvalid;
243 }
244 inputs_.emplace_back(tensor);
245 if (!is_borrow) {
246 return AppendTensor(tensor);
247 }
248 return kSuccess;
249 }
250
AppendOutputTensor(InferTensor * tensor,bool is_borrow)251 StatusCode CompileResult::AppendOutputTensor(InferTensor *tensor, bool is_borrow) {
252 if (assembled_) {
253 MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
254 }
255 if (tensor == nullptr) {
256 MS_LOG(ERROR) << "Input tensor is nullptr";
257 return kLiteInputParamInvalid;
258 }
259 if (tensor->tensor_name().empty()) {
260 tensor->set_tensor_name("graph_out_" + std::to_string(this->outputs_.size()));
261 }
262 if (!is_borrow) {
263 return AppendTensor(tensor);
264 }
265 outputs_.emplace_back(tensor);
266 return kSuccess;
267 }
268
AppendNodeInputTensor(const CompileNodePtr & compile_node,InferTensor * tensor,bool is_borrow)269 StatusCode CompileResult::AppendNodeInputTensor(const CompileNodePtr &compile_node, InferTensor *tensor,
270 bool is_borrow) {
271 if (compile_node == nullptr) {
272 MS_LOG(ERROR) << "Input compile_node is nullptr";
273 return kLiteInputParamInvalid;
274 }
275 return AppendNodeInputTensor(compile_node->GetName(), tensor, is_borrow);
276 }
277
AppendNodeInputTensor(const std::string & node_name,InferTensor * input_tensor,bool is_borrow)278 StatusCode CompileResult::AppendNodeInputTensor(const std::string &node_name, InferTensor *input_tensor,
279 bool is_borrow) {
280 if (assembled_) {
281 MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
282 }
283 if (input_tensor == nullptr) {
284 MS_LOG(ERROR) << "`input_tensor` is nullptr";
285 return kLiteInputParamInvalid;
286 }
287
288 auto iter = node_map_.find(node_name);
289 if (iter == node_map_.end()) {
290 MS_LOG(ERROR) << "CompileNode not belong to this graph, node: " << node_name;
291 return kLiteError;
292 }
293 iter->second->AppendInputTensor(input_tensor);
294 if (!is_borrow) {
295 return AppendTensor(input_tensor);
296 }
297 return kSuccess;
298 }
299
AppendNodeOutputTensor(const CompileNodePtr & compile_node,InferTensor * tensor,bool is_borrow)300 StatusCode CompileResult::AppendNodeOutputTensor(const CompileNodePtr &compile_node, InferTensor *tensor,
301 bool is_borrow) {
302 if (compile_node == nullptr) {
303 MS_LOG(ERROR) << "Input compile_node is nullptr";
304 return kLiteInputParamInvalid;
305 }
306 return AppendNodeOutputTensor(compile_node->GetName(), tensor, is_borrow);
307 }
308
AppendNodeOutputTensor(const std::string & node_name,InferTensor * output_tensor,bool is_borrow)309 StatusCode CompileResult::AppendNodeOutputTensor(const std::string &node_name, InferTensor *output_tensor,
310 bool is_borrow) {
311 if (assembled_) {
312 MS_LOG(EXCEPTION) << "CompileResult not mutable after build.";
313 }
314 if (output_tensor == nullptr) {
315 MS_LOG(ERROR) << "`output_tensor` is nullptr";
316 return kLiteInputParamInvalid;
317 }
318
319 auto iter = node_map_.find(node_name);
320 if (iter == node_map_.end()) {
321 MS_LOG(ERROR) << "CompileNode not belong to this graph, node: " << node_name;
322 return kLiteError;
323 }
324 iter->second->AppendOutputTensor(output_tensor);
325 if (!is_borrow) {
326 return AppendTensor(output_tensor);
327 }
328 return kSuccess;
329 }
330
Dump(int indent) const331 std::string CompileResult::Dump(int indent) const {
332 constexpr int kNumTwo = 2;
333 std::ostringstream oss;
334 oss << GenIndent(indent) << "CompileResult {" << std::endl;
335 oss << GenIndent(indent + 1) << "nodes: [" << std::endl;
336 for (auto &node : nodes_) {
337 oss << node->Dump(indent + kNumTwo) << std::endl;
338 }
339 oss << GenIndent(indent + 1) << "]" << std::endl;
340 oss << GenIndent(indent + 1) << "inputs: [" << std::endl;
341 for (auto &input : inputs_) {
342 oss << DumpTensor(input, indent + kNumTwo) << std::endl;
343 }
344 oss << GenIndent(indent + 1) << "]" << std::endl;
345 oss << GenIndent(indent + 1) << "outputs: [" << std::endl;
346 for (auto &output : outputs_) {
347 oss << DumpTensor(output, indent + kNumTwo) << std::endl;
348 }
349 oss << GenIndent(indent + 1) << "]" << std::endl;
350 oss << GenIndent(indent + 1) << "tensors: [" << std::endl;
351 for (auto &tensor : tensors_) {
352 oss << DumpTensor(tensor, indent + kNumTwo) << std::endl;
353 }
354 oss << GenIndent(indent + 1) << "]" << std::endl;
355 oss << GenIndent(indent) << "}" << std::endl;
356 return oss.str();
357 }
358 } // namespace lite
359 } // namespace mindspore
360