1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_PLUGIN_H_ 17 #define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_PLUGIN_H_ 18 #include <string> 19 #include <vector> 20 #include "src/common/log_adapter.h" 21 #include "include/errorcode.h" 22 #include "NvInferRuntimeCommon.h" 23 #include <NvInfer.h> 24 25 namespace mindspore::lite { 26 void SerializeValue(void **buffer, const void *value, size_t cpy_size); 27 void DeserializeValue(void const **buffer, size_t *buffer_size, void *value, size_t cpy_size); 28 class TensorRTPlugin : public nvinfer1::IPluginV2DynamicExt { 29 public: 30 TensorRTPlugin(const std::string &layer_name, const std::string &plugin_name, uint32_t device_id = 0) layer_name_(layer_name)31 : layer_name_(layer_name), plugin_name_(plugin_name), device_id_(device_id) {} 32 33 // It doesn't make sense to make GeluPluginDynamic without arguments, so we delete 34 // default constructor. 35 TensorRTPlugin() = delete; 36 37 // IPluginV2DynamicExt Methods 38 nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, 39 nvinfer1::IExprBuilder &exprBuilder) noexcept override; 40 bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, 41 int nbOutputs) noexcept override; 42 void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, 43 const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept override; 44 size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, 45 const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept override; 46 47 // IPluginV2Ext Methods 48 nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const 49 noexcept override; 50 51 // IPluginV2 Methods 52 const char *getPluginType() const noexcept override; 53 const char *getPluginVersion() const noexcept override; 54 int getNbOutputs() const noexcept override; 55 int initialize() noexcept override; 56 void terminate() noexcept override; 57 size_t getSerializationSize() const noexcept override; 58 void serialize(void *buffer) const noexcept override; 59 void destroy() noexcept override; 60 void setPluginNamespace(const char *pluginNamespace) noexcept override; 61 const char *getPluginNamespace() const noexcept override; 62 63 protected: 64 std::string layer_name_; 65 std::string name_space_; 66 std::string plugin_version_{"1"}; 67 std::string plugin_name_; 68 uint32_t device_id_{0}; 69 }; 70 71 template <class T> 72 class TensorRTPluginCreater : public nvinfer1::IPluginCreator { 73 public: TensorRTPluginCreater(const std::string & plugin_name)74 explicit TensorRTPluginCreater(const std::string &plugin_name) : plugin_name_(plugin_name) { 75 // Fill PluginFieldCollection with PluginField arguments metadata 76 field_collection_.nbFields = fields_.size(); 77 field_collection_.fields = fields_.data(); 78 } 79 getPluginName()80 const char *getPluginName() const noexcept override { return plugin_name_.c_str(); } 81 getPluginVersion()82 const char *getPluginVersion() const noexcept override { return plugin_version_.c_str(); } 83 getFieldNames()84 const nvinfer1::PluginFieldCollection *getFieldNames() noexcept override { return &field_collection_; } 85 setPluginNamespace(const char * pluginNamespace)86 void setPluginNamespace(const char *pluginNamespace) noexcept override { name_space_ = std::string(pluginNamespace); } 87 getPluginNamespace()88 const char *getPluginNamespace() const noexcept override { return name_space_.c_str(); } 89 createPlugin(const char * name,const nvinfer1::PluginFieldCollection * fc)90 nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) noexcept override { 91 return new (std::nothrow) T(name, fc); 92 } 93 deserializePlugin(const char * name,const void * data,size_t len)94 nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *data, size_t len) noexcept override { 95 return new (std::nothrow) T(name, data, len); 96 } 97 98 protected: 99 static nvinfer1::PluginFieldCollection field_collection_; 100 static std::vector<nvinfer1::PluginField> fields_; 101 std::string name_space_; 102 std::string plugin_version_{"1"}; 103 std::string plugin_name_; 104 }; 105 } // namespace mindspore::lite 106 #endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_PLUGIN_H_ 107