1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ 17 #define TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ 18 19 #include <iostream> 20 #include <unordered_map> 21 #include <vector> 22 23 #include "tensorflow/core/platform/types.h" 24 25 #if GOOGLE_CUDA 26 #if GOOGLE_TENSORRT 27 #include "tensorrt/include/NvInfer.h" 28 29 namespace tensorflow { 30 namespace tensorrt { 31 32 // A wrapper class for TensorRT plugin 33 // User application should inherit from this class to write custom kernels. 34 // Allows user to insert custom op in TensorRT engine 35 // To register plugin in converter, user should also register custom 36 // PluginDeserializeFunc & PluginConstructFunc through PluginFactoryTensorRT 37 class PluginTensorRT : public nvinfer1::IPlugin { 38 public: PluginTensorRT()39 PluginTensorRT() {} 40 PluginTensorRT(const void* serialized_data, size_t length); 41 42 virtual const string& GetPluginName() const = 0; 43 44 virtual bool Finalize() = 0; 45 46 virtual bool SetAttribute(const string& key, const void* ptr, 47 const size_t size) = 0; 48 virtual bool GetAttribute(const string& key, const void** ptr, 49 size_t* size) const = 0; 50 51 void configure(const nvinfer1::Dims* inputs, int num_inputs, 52 const nvinfer1::Dims* outputs, int num_outputs, 53 int max_batch_size) override; 54 55 virtual bool StoreAttribute(const string& key, const void* ptr, 56 const size_t size); 57 58 size_t getSerializationSize() override; 59 60 void serialize(void* buffer) override; 61 62 protected: 63 std::unordered_map<string, std::vector<char> > attr_map_; 64 65 std::vector<nvinfer1::Dims> input_dim_list_; 66 }; 67 68 } // namespace tensorrt 69 } // namespace tensorflow 70 71 #endif // GOOGLE_TENSORRT 72 #endif // GOOGLE_CUDA 73 74 #endif // TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ 75