1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_OP_
17 #define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_OP_
18
19 #include <utility>
20 #include <NvInfer.h>
21 #include <string>
22 #include <vector>
23 #include "include/api/kernel.h"
24 #include "src/common/log_adapter.h"
25 #include "include/errorcode.h"
26 #include "src/delegate/tensorrt/tensorrt_utils.h"
27
28 namespace mindspore::lite {
29 constexpr int INPUT_SIZE2 = 2;
30 constexpr int INPUT_SIZE3 = 3;
31 constexpr int INPUT_SIZE4 = 4;
32
33 struct ITensorHelper {
34 nvinfer1::ITensor *trt_tensor_{nullptr};
35 mindspore::Format format_;
36 };
37
38 class TensorRTOp {
39 public:
TensorRTOp(const schema::Primitive * primitive,std::vector<mindspore::MSTensor> in_tensors,std::vector<mindspore::MSTensor> out_tensors,std::string name)40 explicit TensorRTOp(const schema::Primitive *primitive, std::vector<mindspore::MSTensor> in_tensors,
41 std::vector<mindspore::MSTensor> out_tensors, std::string name)
42 : op_primitive_(primitive),
43 in_tensors_(std::move(in_tensors)),
44 out_tensors_(std::move(out_tensors)),
45 op_name_(std::move(name)) {
46 if (primitive != nullptr) {
47 this->type_ = primitive->value_type();
48 }
49 }
50
51 virtual ~TensorRTOp() = default;
52
53 virtual int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
54 const std::vector<mindspore::MSTensor> &out_tensors) = 0;
55
56 virtual int AddInnerOp(nvinfer1::INetworkDefinition *network) = 0;
57
58 const schema::Primitive *GetPrimitive();
59
60 void AddInnerInTensors(ITensorHelper tensor);
61
62 void AddInnerOutTensors(ITensorHelper tensor);
63
64 std::vector<ITensorHelper> &GetInnerOutTensor();
65
66 std::vector<ITensorHelper> &GetInnerInTensors();
67
68 std::string GetOpName();
69
70 std::vector<mindspore::MSTensor> &inputs();
71
72 std::vector<mindspore::MSTensor> &outputs();
73
74 schema::PrimitiveType type() const;
75
76 void set_in_ops(const std::vector<TensorRTOp *> &in_ops);
77
78 void set_out_ops(const std::vector<TensorRTOp *> &out_ops);
79
80 const std::vector<TensorRTOp *> &in_ops() const;
81
82 const std::vector<TensorRTOp *> &out_ops() const;
83
84 protected:
85 bool IsShapeKnown();
86
87 std::vector<nvinfer1::ILayer *> layers_;
88
89 const schema::Primitive *op_primitive_;
90
91 std::vector<mindspore::MSTensor> in_tensors_;
92
93 std::vector<mindspore::MSTensor> out_tensors_;
94
95 std::vector<ITensorHelper> tensorrt_in_tensors_;
96
97 std::vector<ITensorHelper> tensorrt_out_tensors_;
98
99 std::vector<TensorRTOp *> in_ops_;
100
101 std::vector<TensorRTOp *> out_ops_;
102
103 std::string op_name_;
104
105 schema::PrimitiveType type_ = schema::PrimitiveType_NONE;
106 };
107
108 template <class T>
GetTensorRTOp(const schema::Primitive * primitive,const std::vector<mindspore::MSTensor> & in_tensors,const std::vector<mindspore::MSTensor> & out_tensors,const std::string & name)109 TensorRTOp *GetTensorRTOp(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
110 const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) {
111 auto *op = new (std::nothrow) T(primitive, in_tensors, out_tensors, name);
112 if (op == nullptr) {
113 MS_LOG(ERROR) << "TensorRT is nullptr.";
114 return nullptr;
115 }
116
117 auto ret = op->IsSupport(primitive, in_tensors, out_tensors);
118 if (ret != RET_OK) {
119 MS_LOG(ERROR) << "TensorRT op is not supported.";
120 delete op;
121 return nullptr;
122 }
123 return op;
124 }
125 } // namespace mindspore::lite
126 #endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_OP_
127