1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ 17 #define TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ 18 19 #include <memory> 20 #include <unordered_map> 21 22 #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" 23 #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h" 24 #include "tensorflow/core/platform/logging.h" 25 #include "tensorflow/core/platform/macros.h" 26 #include "tensorflow/core/platform/mutex.h" 27 28 #if GOOGLE_CUDA 29 #if GOOGLE_TENSORRT 30 #include "tensorrt/include/NvInfer.h" 31 32 namespace tensorflow { 33 namespace tensorrt { 34 35 class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { 36 public: 37 // TODO(aaroey): this static method has to be inlined to make the singleton a 38 // unique global symbol. Find a way to fix it. GetInstance()39 static PluginFactoryTensorRT* GetInstance() { 40 static PluginFactoryTensorRT* factory_instance = 41 new PluginFactoryTensorRT(); 42 return factory_instance; 43 } 44 45 // Deserialization method 46 PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data, 47 size_t serial_length) override; 48 49 // Plugin construction, PluginFactoryTensorRT owns the plugin. 50 PluginTensorRT* CreatePlugin(const string& op_name); 51 52 bool RegisterPlugin(const string& op_name, 53 PluginDeserializeFunc deserialize_func, 54 PluginConstructFunc construct_func); 55 IsPlugin(const string & op_name)56 bool IsPlugin(const string& op_name) { 57 return plugin_registry_.find(op_name) != plugin_registry_.end(); 58 } 59 CountOwnedPlugins()60 size_t CountOwnedPlugins() { return owned_plugins_.size(); } 61 62 void DestroyPlugins(); 63 64 protected: 65 std::unordered_map<string, 66 std::pair<PluginDeserializeFunc, PluginConstructFunc>> 67 plugin_registry_; 68 69 // TODO(jie): Owned plugin should be associated with different sessions; 70 // should really hand ownership of plugins to resource management; 71 std::vector<std::unique_ptr<PluginTensorRT>> owned_plugins_; 72 mutex instance_m_; 73 }; 74 75 class TrtPluginRegistrar { 76 public: TrtPluginRegistrar(const string & name,PluginDeserializeFunc deserialize_func,PluginConstructFunc construct_func)77 TrtPluginRegistrar(const string& name, PluginDeserializeFunc deserialize_func, 78 PluginConstructFunc construct_func) { 79 auto factory = PluginFactoryTensorRT::GetInstance(); 80 QCHECK(factory->RegisterPlugin(name, deserialize_func, construct_func)) 81 << "Failed to register plugin: " << name; 82 } 83 }; 84 85 #define REGISTER_TRT_PLUGIN(name, deserialize_func, construct_func) \ 86 REGISTER_TRT_PLUGIN_UNIQ_HELPER(__COUNTER__, name, deserialize_func, \ 87 construct_func) 88 #define REGISTER_TRT_PLUGIN_UNIQ_HELPER(ctr, name, deserialize_func, \ 89 construct_func) \ 90 REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) 91 #define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) \ 92 static TrtPluginRegistrar trt_plugin_registrar##ctr TF_ATTRIBUTE_UNUSED = \ 93 TrtPluginRegistrar(name, deserialize_func, construct_func) 94 95 } // namespace tensorrt 96 } // namespace tensorflow 97 98 #endif // GOOGLE_TENSORRT 99 #endif // GOOGLE_CUDA 100 101 #endif // TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ 102