• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_
18 
19 #include <iostream>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "tensorflow/core/platform/types.h"
24 
25 #if GOOGLE_CUDA
26 #if GOOGLE_TENSORRT
27 #include "tensorrt/include/NvInfer.h"
28 
29 namespace tensorflow {
30 namespace tensorrt {
31 
32 // A wrapper class for TensorRT plugin
33 // User application should inherit from this class to write custom kernels.
34 // Allows user to insert custom op in TensorRT engine
35 // To register plugin in converter, user should also register custom
36 // PluginDeserializeFunc & PluginConstructFunc through PluginFactoryTensorRT
37 class PluginTensorRT : public nvinfer1::IPlugin {
38  public:
PluginTensorRT()39   PluginTensorRT() {}
40   PluginTensorRT(const void* serialized_data, size_t length);
41 
42   virtual const string& GetPluginName() const = 0;
43 
44   virtual bool Finalize() = 0;
45 
46   virtual bool SetAttribute(const string& key, const void* ptr,
47                             const size_t size) = 0;
48   virtual bool GetAttribute(const string& key, const void** ptr,
49                             size_t* size) const = 0;
50 
51   void configure(const nvinfer1::Dims* inputs, int num_inputs,
52                  const nvinfer1::Dims* outputs, int num_outputs,
53                  int max_batch_size) override;
54 
55   virtual bool StoreAttribute(const string& key, const void* ptr,
56                               const size_t size);
57 
58   size_t getSerializationSize() override;
59 
60   void serialize(void* buffer) override;
61 
62  protected:
63   std::unordered_map<string, std::vector<char> > attr_map_;
64 
65   std::vector<nvinfer1::Dims> input_dim_list_;
66 };
67 
68 }  // namespace tensorrt
69 }  // namespace tensorflow
70 
71 #endif  // GOOGLE_TENSORRT
72 #endif  // GOOGLE_CUDA
73 
74 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_
75