• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_OP_
17 #define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_OP_
18 
19 #include <utility>
20 #include <NvInfer.h>
21 #include <string>
22 #include <vector>
23 #include "include/api/kernel.h"
24 #include "src/common/log_adapter.h"
25 #include "include/errorcode.h"
26 #include "src/delegate/tensorrt/tensorrt_utils.h"
27 
28 namespace mindspore::lite {
29 constexpr int INPUT_SIZE2 = 2;
30 constexpr int INPUT_SIZE3 = 3;
31 constexpr int INPUT_SIZE4 = 4;
32 
33 struct ITensorHelper {
34   nvinfer1::ITensor *trt_tensor_{nullptr};
35   mindspore::Format format_;
36 };
37 
38 class TensorRTOp {
39  public:
TensorRTOp(const schema::Primitive * primitive,std::vector<mindspore::MSTensor> in_tensors,std::vector<mindspore::MSTensor> out_tensors,std::string name)40   explicit TensorRTOp(const schema::Primitive *primitive, std::vector<mindspore::MSTensor> in_tensors,
41                       std::vector<mindspore::MSTensor> out_tensors, std::string name)
42       : op_primitive_(primitive),
43         in_tensors_(std::move(in_tensors)),
44         out_tensors_(std::move(out_tensors)),
45         op_name_(std::move(name)) {
46     if (primitive != nullptr) {
47       this->type_ = primitive->value_type();
48     }
49   }
50 
51   virtual ~TensorRTOp() = default;
52 
53   virtual int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
54                         const std::vector<mindspore::MSTensor> &out_tensors) = 0;
55 
56   virtual int AddInnerOp(nvinfer1::INetworkDefinition *network) = 0;
57 
58   const schema::Primitive *GetPrimitive();
59 
60   void AddInnerInTensors(ITensorHelper tensor);
61 
62   void AddInnerOutTensors(ITensorHelper tensor);
63 
64   std::vector<ITensorHelper> &GetInnerOutTensor();
65 
66   std::vector<ITensorHelper> &GetInnerInTensors();
67 
68   std::string GetOpName();
69 
70   std::vector<mindspore::MSTensor> &inputs();
71 
72   std::vector<mindspore::MSTensor> &outputs();
73 
74   schema::PrimitiveType type() const;
75 
76   void set_in_ops(const std::vector<TensorRTOp *> &in_ops);
77 
78   void set_out_ops(const std::vector<TensorRTOp *> &out_ops);
79 
80   const std::vector<TensorRTOp *> &in_ops() const;
81 
82   const std::vector<TensorRTOp *> &out_ops() const;
83 
84  protected:
85   bool IsShapeKnown();
86 
87   std::vector<nvinfer1::ILayer *> layers_;
88 
89   const schema::Primitive *op_primitive_;
90 
91   std::vector<mindspore::MSTensor> in_tensors_;
92 
93   std::vector<mindspore::MSTensor> out_tensors_;
94 
95   std::vector<ITensorHelper> tensorrt_in_tensors_;
96 
97   std::vector<ITensorHelper> tensorrt_out_tensors_;
98 
99   std::vector<TensorRTOp *> in_ops_;
100 
101   std::vector<TensorRTOp *> out_ops_;
102 
103   std::string op_name_;
104 
105   schema::PrimitiveType type_ = schema::PrimitiveType_NONE;
106 };
107 
108 template <class T>
GetTensorRTOp(const schema::Primitive * primitive,const std::vector<mindspore::MSTensor> & in_tensors,const std::vector<mindspore::MSTensor> & out_tensors,const std::string & name)109 TensorRTOp *GetTensorRTOp(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
110                           const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) {
111   auto *op = new (std::nothrow) T(primitive, in_tensors, out_tensors, name);
112   if (op == nullptr) {
113     MS_LOG(ERROR) << "TensorRT is nullptr.";
114     return nullptr;
115   }
116 
117   auto ret = op->IsSupport(primitive, in_tensors, out_tensors);
118   if (ret != RET_OK) {
119     MS_LOG(ERROR) << "TensorRT op is not supported.";
120     delete op;
121     return nullptr;
122   }
123   return op;
124 }
125 }  // namespace mindspore::lite
126 #endif  // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_OP_
127