• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_PLUGIN_H_
17 #define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_PLUGIN_H_
18 #include <string>
19 #include <vector>
20 #include "src/common/log_adapter.h"
21 #include "include/errorcode.h"
22 #include "NvInferRuntimeCommon.h"
23 #include <NvInfer.h>
24 
25 namespace mindspore::lite {
26 void SerializeValue(void **buffer, const void *value, size_t cpy_size);
27 void DeserializeValue(void const **buffer, size_t *buffer_size, void *value, size_t cpy_size);
28 class TensorRTPlugin : public nvinfer1::IPluginV2DynamicExt {
29  public:
30   TensorRTPlugin(const std::string &layer_name, const std::string &plugin_name, uint32_t device_id = 0)
layer_name_(layer_name)31       : layer_name_(layer_name), plugin_name_(plugin_name), device_id_(device_id) {}
32 
33   // It doesn't make sense to make GeluPluginDynamic without arguments, so we delete
34   // default constructor.
35   TensorRTPlugin() = delete;
36 
37   // IPluginV2DynamicExt Methods
38   nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
39                                           nvinfer1::IExprBuilder &exprBuilder) noexcept override;
40   bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
41                                  int nbOutputs) noexcept override;
42   void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
43                        const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept override;
44   size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
45                           const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept override;
46 
47   // IPluginV2Ext Methods
48   nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const
49     noexcept override;
50 
51   // IPluginV2 Methods
52   const char *getPluginType() const noexcept override;
53   const char *getPluginVersion() const noexcept override;
54   int getNbOutputs() const noexcept override;
55   int initialize() noexcept override;
56   void terminate() noexcept override;
57   size_t getSerializationSize() const noexcept override;
58   void serialize(void *buffer) const noexcept override;
59   void destroy() noexcept override;
60   void setPluginNamespace(const char *pluginNamespace) noexcept override;
61   const char *getPluginNamespace() const noexcept override;
62 
63  protected:
64   std::string layer_name_;
65   std::string name_space_;
66   std::string plugin_version_{"1"};
67   std::string plugin_name_;
68   uint32_t device_id_{0};
69 };
70 
71 template <class T>
72 class TensorRTPluginCreater : public nvinfer1::IPluginCreator {
73  public:
TensorRTPluginCreater(const std::string & plugin_name)74   explicit TensorRTPluginCreater(const std::string &plugin_name) : plugin_name_(plugin_name) {
75     // Fill PluginFieldCollection with PluginField arguments metadata
76     field_collection_.nbFields = fields_.size();
77     field_collection_.fields = fields_.data();
78   }
79 
getPluginName()80   const char *getPluginName() const noexcept override { return plugin_name_.c_str(); }
81 
getPluginVersion()82   const char *getPluginVersion() const noexcept override { return plugin_version_.c_str(); }
83 
getFieldNames()84   const nvinfer1::PluginFieldCollection *getFieldNames() noexcept override { return &field_collection_; }
85 
setPluginNamespace(const char * pluginNamespace)86   void setPluginNamespace(const char *pluginNamespace) noexcept override { name_space_ = std::string(pluginNamespace); }
87 
getPluginNamespace()88   const char *getPluginNamespace() const noexcept override { return name_space_.c_str(); }
89 
createPlugin(const char * name,const nvinfer1::PluginFieldCollection * fc)90   nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) noexcept override {
91     return new (std::nothrow) T(name, fc);
92   }
93 
deserializePlugin(const char * name,const void * data,size_t len)94   nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *data, size_t len) noexcept override {
95     return new (std::nothrow) T(name, data, len);
96   }
97 
98  protected:
99   static nvinfer1::PluginFieldCollection field_collection_;
100   static std::vector<nvinfer1::PluginField> fields_;
101   std::string name_space_;
102   std::string plugin_version_{"1"};
103   std::string plugin_name_;
104 };
105 }  // namespace mindspore::lite
106 #endif  // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_PLUGIN_H_
107