• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h"
17 
18 #include <cassert>
19 #include <cstring>
20 
21 #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h"
22 
23 #if GOOGLE_CUDA
24 #if GOOGLE_TENSORRT
25 
26 namespace tensorflow {
27 namespace tensorrt {
28 
PluginTensorRT(const void * serialized_data,size_t length)29 PluginTensorRT::PluginTensorRT(const void* serialized_data, size_t length) {
30   const char* buffer = static_cast<const char*>(serialized_data);
31   size_t op_name_char_count = *reinterpret_cast<const size_t*>(buffer);
32   buffer += sizeof(size_t);
33   buffer += op_name_char_count;
34 
35   size_t count = *reinterpret_cast<const size_t*>(buffer);
36   buffer += sizeof(size_t);
37 
38   for (int i = 0; i < count; i++) {
39     nvinfer1::Dims dim;
40     std::memcpy(&(dim.nbDims), buffer, sizeof(dim.nbDims));
41     buffer += sizeof(dim.nbDims);
42     std::memcpy(dim.d, buffer, sizeof(dim.d));
43     buffer += sizeof(dim.d);
44     std::memcpy(dim.type, buffer, sizeof(dim.type));
45     buffer += sizeof(dim.type);
46     input_dim_list_.emplace_back(dim);
47   }
48 }
49 
configure(const nvinfer1::Dims * inputs,int num_inputs,const nvinfer1::Dims * outputs,int num_outputs,int max_batch_size)50 void PluginTensorRT::configure(const nvinfer1::Dims* inputs, int num_inputs,
51                                const nvinfer1::Dims* outputs, int num_outputs,
52                                int max_batch_size) {
53   for (int index = 0; index < num_inputs; index++) {
54     nvinfer1::Dims dim;
55     dim.nbDims = inputs[index].nbDims;
56     for (int i = 0; i < dim.nbDims; i++) {
57       dim.d[i] = inputs[index].d[i];
58       dim.type[i] = inputs[index].type[i];
59     }
60     input_dim_list_.emplace_back(dim);
61   }
62 }
63 
getSerializationSize()64 size_t PluginTensorRT::getSerializationSize() {
65   nvinfer1::Dims dim;
66   return sizeof(size_t) + GetPluginName().size() +
67          sizeof(input_dim_list_.size()) + sizeof(dim.nbDims) + sizeof(dim.d) +
68          sizeof(dim.type);
69 }
70 
serialize(void * serialized_data)71 void PluginTensorRT::serialize(void* serialized_data) {
72   size_t op_name_size = GetPluginName().size();
73   char* buffer = static_cast<char*>(serialized_data);
74   std::memcpy(buffer, &op_name_size, sizeof(size_t));
75   buffer += sizeof(size_t);
76 
77   std::memcpy(buffer, GetPluginName().data(), op_name_size);
78   buffer += op_name_size;
79 
80   auto list_size = input_dim_list_.size();
81   std::memcpy(buffer, &list_size, sizeof(input_dim_list_.size()));
82   buffer += sizeof(input_dim_list_.size());
83 
84   for (int i = 0; i < input_dim_list_.size(); i++) {
85     auto dim = input_dim_list_[i];
86     std::memcpy(buffer, &(dim.nbDims), sizeof(dim.nbDims));
87     buffer += sizeof(dim.nbDims);
88     std::memcpy(buffer, dim.d, sizeof(dim.d));
89     buffer += sizeof(dim.d);
90     std::memcpy(buffer, dim.type, sizeof(dim.type));
91     buffer += sizeof(dim.type);
92   }
93 }
94 
StoreAttribute(const string & key,const void * ptr,const size_t size)95 bool PluginTensorRT::StoreAttribute(const string& key, const void* ptr,
96                                     const size_t size) {
97   if (attr_map_.count(key) != 0) return false;
98 
99   attr_map_.emplace(key, std::vector<char>(size));
100   std::memcpy(attr_map_[key].data(), ptr, size);
101   return true;
102 }
103 
104 }  // namespace tensorrt
105 }  // namespace tensorflow
106 
107 #endif  // GOOGLE_CUDA
108 #endif  // GOOGLE_TENSORRT
109