• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h"
18 #include <cstring>
19 
20 namespace mindspore::lite {
SerializeValue(void ** buffer,const void * value,size_t cpy_size)21 void SerializeValue(void **buffer, const void *value, size_t cpy_size) {
22   memcpy(*buffer, value, cpy_size);
23   *buffer = static_cast<char *>(*buffer) + cpy_size;
24 }
25 
DeserializeValue(void const ** buffer,size_t * buffer_size,void * value,size_t cpy_size)26 void DeserializeValue(void const **buffer, size_t *buffer_size, void *value, size_t cpy_size) {
27   if (cpy_size > *buffer_size) {
28     MS_LOG(ERROR) << "invalid desirialize size, buffer size: " << *buffer_size << ", value size: " << cpy_size;
29     return;
30   }
31   memcpy(value, *buffer, cpy_size);
32   *buffer = static_cast<const char *>(*buffer) + cpy_size;
33   *buffer_size -= cpy_size;
34 }
35 
getOutputDimensions(int outputIndex,const nvinfer1::DimsExprs * inputs,int nbInputs,nvinfer1::IExprBuilder & exprBuilder)36 nvinfer1::DimsExprs TensorRTPlugin::getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
37                                                         int nbInputs, nvinfer1::IExprBuilder &exprBuilder) noexcept {
38   return inputs[0];
39 }
40 
supportsFormatCombination(int pos,const nvinfer1::PluginTensorDesc * tensorsDesc,int nbInputs,int nbOutputs)41 bool TensorRTPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
42                                                int nbOutputs) noexcept {
43   return true;
44 }
45 
configurePlugin(const nvinfer1::DynamicPluginTensorDesc * in,int nbInputs,const nvinfer1::DynamicPluginTensorDesc * out,int nbOutputs)46 void TensorRTPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
47                                      const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept {}
48 
getWorkspaceSize(const nvinfer1::PluginTensorDesc * inputs,int nbInputs,const nvinfer1::PluginTensorDesc * outputs,int nbOutputs) const49 size_t TensorRTPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
50                                         const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept {
51   return 0;
52 }
53 
getOutputDataType(int index,const nvinfer1::DataType * inputTypes,int nbInputs) const54 nvinfer1::DataType TensorRTPlugin::getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
55                                                      int nbInputs) const noexcept {
56   return inputTypes[0];
57 }
58 
getPluginType() const59 const char *TensorRTPlugin::getPluginType() const noexcept { return plugin_name_.c_str(); }
60 
getPluginVersion() const61 const char *TensorRTPlugin::getPluginVersion() const noexcept { return plugin_version_.c_str(); }
62 
getNbOutputs() const63 int TensorRTPlugin::getNbOutputs() const noexcept { return 1; }
64 
initialize()65 int TensorRTPlugin::initialize() noexcept { return 0; }
66 
terminate()67 void TensorRTPlugin::terminate() noexcept {}
68 
getSerializationSize() const69 size_t TensorRTPlugin::getSerializationSize() const noexcept { return 0; }
70 
serialize(void * buffer) const71 void TensorRTPlugin::serialize(void *buffer) const noexcept {}
72 
destroy()73 void TensorRTPlugin::destroy() noexcept {
74   // This gets called when the network containing plugin is destroyed
75   delete this;
76 }
77 
setPluginNamespace(const char * libNamespace)78 void TensorRTPlugin::setPluginNamespace(const char *libNamespace) noexcept { name_space_ = libNamespace; }
79 
getPluginNamespace() const80 const char *TensorRTPlugin::getPluginNamespace() const noexcept { return name_space_.c_str(); }
81 }  // namespace mindspore::lite
82