1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h"
18 #include <cstring>
19
20 namespace mindspore::lite {
SerializeValue(void ** buffer,const void * value,size_t cpy_size)21 void SerializeValue(void **buffer, const void *value, size_t cpy_size) {
22 memcpy(*buffer, value, cpy_size);
23 *buffer = static_cast<char *>(*buffer) + cpy_size;
24 }
25
DeserializeValue(void const ** buffer,size_t * buffer_size,void * value,size_t cpy_size)26 void DeserializeValue(void const **buffer, size_t *buffer_size, void *value, size_t cpy_size) {
27 if (cpy_size > *buffer_size) {
28 MS_LOG(ERROR) << "invalid desirialize size, buffer size: " << *buffer_size << ", value size: " << cpy_size;
29 return;
30 }
31 memcpy(value, *buffer, cpy_size);
32 *buffer = static_cast<const char *>(*buffer) + cpy_size;
33 *buffer_size -= cpy_size;
34 }
35
getOutputDimensions(int outputIndex,const nvinfer1::DimsExprs * inputs,int nbInputs,nvinfer1::IExprBuilder & exprBuilder)36 nvinfer1::DimsExprs TensorRTPlugin::getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
37 int nbInputs, nvinfer1::IExprBuilder &exprBuilder) noexcept {
38 return inputs[0];
39 }
40
supportsFormatCombination(int pos,const nvinfer1::PluginTensorDesc * tensorsDesc,int nbInputs,int nbOutputs)41 bool TensorRTPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
42 int nbOutputs) noexcept {
43 return true;
44 }
45
configurePlugin(const nvinfer1::DynamicPluginTensorDesc * in,int nbInputs,const nvinfer1::DynamicPluginTensorDesc * out,int nbOutputs)46 void TensorRTPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
47 const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept {}
48
getWorkspaceSize(const nvinfer1::PluginTensorDesc * inputs,int nbInputs,const nvinfer1::PluginTensorDesc * outputs,int nbOutputs) const49 size_t TensorRTPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
50 const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept {
51 return 0;
52 }
53
getOutputDataType(int index,const nvinfer1::DataType * inputTypes,int nbInputs) const54 nvinfer1::DataType TensorRTPlugin::getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
55 int nbInputs) const noexcept {
56 return inputTypes[0];
57 }
58
getPluginType() const59 const char *TensorRTPlugin::getPluginType() const noexcept { return plugin_name_.c_str(); }
60
getPluginVersion() const61 const char *TensorRTPlugin::getPluginVersion() const noexcept { return plugin_version_.c_str(); }
62
getNbOutputs() const63 int TensorRTPlugin::getNbOutputs() const noexcept { return 1; }
64
initialize()65 int TensorRTPlugin::initialize() noexcept { return 0; }
66
terminate()67 void TensorRTPlugin::terminate() noexcept {}
68
getSerializationSize() const69 size_t TensorRTPlugin::getSerializationSize() const noexcept { return 0; }
70
serialize(void * buffer) const71 void TensorRTPlugin::serialize(void *buffer) const noexcept {}
72
destroy()73 void TensorRTPlugin::destroy() noexcept {
74 // This gets called when the network containing plugin is destroyed
75 delete this;
76 }
77
setPluginNamespace(const char * libNamespace)78 void TensorRTPlugin::setPluginNamespace(const char *libNamespace) noexcept { name_space_ = libNamespace; }
79
getPluginNamespace() const80 const char *TensorRTPlugin::getPluginNamespace() const noexcept { return name_space_.c_str(); }
81 } // namespace mindspore::lite
82