• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_DELEGATE_UTILS_H_
17 #define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_DELEGATE_UTILS_H_
18 #include <vector>
19 #include <map>
20 #include <set>
21 #include "src/common/log_adapter.h"
22 #include "include/errorcode.h"
23 #include "core/base/base.h"
24 #include "src/extendrt/delegate/tensorrt/tensor_info.h"
25 
26 namespace mindspore::lite {
27 bool IsSubGraphInputTensor(const std::vector<TensorInfo> &inputs, const TensorInfo &input);
28 
29 template <typename T>
FindPreOps(T * cur_op,std::vector<T * > all_ops)30 std::vector<T *> FindPreOps(T *cur_op, std::vector<T *> all_ops) {
31   std::vector<T *> in_ops;
32   for (auto in_tensor : cur_op->inputs()) {
33     for (auto op : all_ops) {
34       if (std::find(op->outputs().begin(), op->outputs().end(), in_tensor) != op->outputs().end()) {
35         in_ops.push_back(op);
36       }
37     }
38   }
39   return in_ops;
40 }
41 
42 template <typename T>
FindNextOps(T * cur_op,std::vector<T * > all_ops)43 std::vector<T *> FindNextOps(T *cur_op, std::vector<T *> all_ops) {
44   std::vector<T *> out_ops;
45   for (auto out_tensor : cur_op->outputs()) {
46     for (auto op : all_ops) {
47       if (std::find(op->inputs().begin(), op->inputs().end(), out_tensor) != op->inputs().end()) {
48         out_ops.push_back(op);
49       }
50     }
51   }
52   return out_ops;
53 }
54 
55 template <typename T>
FindPreNextOps(std::vector<T * > all_ops)56 void FindPreNextOps(std::vector<T *> all_ops) {
57   std::map<TensorInfo, std::set<T *>> in_tensor_op;
58   std::map<TensorInfo, std::set<T *>> out_tensor_op;
59   for (auto op : all_ops) {
60     for (auto in_tensor : op->inputs()) {
61       in_tensor_op[in_tensor].insert(op);
62     }
63     for (auto out_tensor : op->outputs()) {
64       out_tensor_op[out_tensor].insert(op);
65     }
66   }
67   for (auto op : all_ops) {
68     std::set<T *> in_ops_set;
69     for (auto in_tensor : op->inputs()) {
70       auto in_ops = out_tensor_op[in_tensor];
71       in_ops_set.insert(in_ops.begin(), in_ops.end());
72     }
73     std::vector<T *> in_ops_vec;
74     in_ops_vec.assign(in_ops_set.begin(), in_ops_set.end());
75     op->set_in_ops(in_ops_vec);
76 
77     std::set<T *> out_ops_set;
78     for (auto out_tensor : op->outputs()) {
79       auto out_ops = in_tensor_op[out_tensor];
80       out_ops_set.insert(out_ops.begin(), out_ops.end());
81     }
82     std::vector<T *> out_ops_vec;
83     out_ops_vec.assign(out_ops_set.begin(), out_ops_set.end());
84     op->set_out_ops(out_ops_vec);
85   }
86 }
87 
88 template <typename T>
GetGraphInOutOps(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,std::vector<T * > * in_ops,std::vector<T * > * out_ops,const std::vector<T * > & all_ops)89 int GetGraphInOutOps(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
90                      std::vector<T *> *in_ops, std::vector<T *> *out_ops, const std::vector<T *> &all_ops) {
91   for (auto in_tensor : inputs) {
92     for (auto op : all_ops) {
93       if (std::find(op->inputs().begin(), op->inputs().end(), in_tensor) != op->inputs().end() &&
94           std::find(in_ops->begin(), in_ops->end(), op) == in_ops->end()) {
95         in_ops->push_back(op);
96       }
97     }
98   }
99   if (in_ops->empty()) {
100     MS_LOG(ERROR) << "Can't find the input ops for npu sub graph.";
101     return RET_ERROR;
102   }
103 
104   for (auto out_tensor : outputs) {
105     for (auto op : all_ops) {
106       if (std::find(op->outputs().begin(), op->outputs().end(), out_tensor) != op->outputs().end() &&
107           std::find(out_ops->begin(), out_ops->end(), op) == out_ops->end()) {
108         out_ops->push_back(op);
109       }
110     }
111   }
112   if (out_ops->empty()) {
113     MS_LOG(ERROR) << "Can't find the output ops for npu sub graph.";
114     return RET_ERROR;
115   }
116   return RET_OK;
117 }
118 }  // namespace mindspore::lite
119 
120 #endif  // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_DELEGATE_UTILS_H_
121