• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_
18 
19 #include <memory>
20 #include <unordered_map>
21 
22 #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h"
23 #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/macros.h"
26 #include "tensorflow/core/platform/mutex.h"
27 
28 #if GOOGLE_CUDA
29 #if GOOGLE_TENSORRT
30 #include "tensorrt/include/NvInfer.h"
31 
32 namespace tensorflow {
33 namespace tensorrt {
34 
35 class PluginFactoryTensorRT : public nvinfer1::IPluginFactory {
36  public:
37   // TODO(aaroey): this static method has to be inlined to make the singleton a
38   // unique global symbol. Find a way to fix it.
GetInstance()39   static PluginFactoryTensorRT* GetInstance() {
40     static PluginFactoryTensorRT* factory_instance =
41         new PluginFactoryTensorRT();
42     return factory_instance;
43   }
44 
45   // Deserialization method
46   PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data,
47                                size_t serial_length) override;
48 
49   // Plugin construction, PluginFactoryTensorRT owns the plugin.
50   PluginTensorRT* CreatePlugin(const string& op_name);
51 
52   bool RegisterPlugin(const string& op_name,
53                       PluginDeserializeFunc deserialize_func,
54                       PluginConstructFunc construct_func);
55 
IsPlugin(const string & op_name)56   bool IsPlugin(const string& op_name) {
57     return plugin_registry_.find(op_name) != plugin_registry_.end();
58   }
59 
CountOwnedPlugins()60   size_t CountOwnedPlugins() { return owned_plugins_.size(); }
61 
62   void DestroyPlugins();
63 
64  protected:
65   std::unordered_map<string,
66                      std::pair<PluginDeserializeFunc, PluginConstructFunc>>
67       plugin_registry_;
68 
69   // TODO(jie): Owned plugin should be associated with different sessions;
70   //            should really hand ownership of plugins to resource management;
71   std::vector<std::unique_ptr<PluginTensorRT>> owned_plugins_;
72   mutex instance_m_;
73 };
74 
75 class TrtPluginRegistrar {
76  public:
TrtPluginRegistrar(const string & name,PluginDeserializeFunc deserialize_func,PluginConstructFunc construct_func)77   TrtPluginRegistrar(const string& name, PluginDeserializeFunc deserialize_func,
78                      PluginConstructFunc construct_func) {
79     auto factory = PluginFactoryTensorRT::GetInstance();
80     QCHECK(factory->RegisterPlugin(name, deserialize_func, construct_func))
81         << "Failed to register plugin: " << name;
82   }
83 };
84 
85 #define REGISTER_TRT_PLUGIN(name, deserialize_func, construct_func)    \
86   REGISTER_TRT_PLUGIN_UNIQ_HELPER(__COUNTER__, name, deserialize_func, \
87                                   construct_func)
88 #define REGISTER_TRT_PLUGIN_UNIQ_HELPER(ctr, name, deserialize_func, \
89                                         construct_func)              \
90   REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func)
91 #define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) \
92   static TrtPluginRegistrar trt_plugin_registrar##ctr TF_ATTRIBUTE_UNUSED =   \
93       TrtPluginRegistrar(name, deserialize_func, construct_func)
94 
95 }  // namespace tensorrt
96 }  // namespace tensorflow
97 
98 #endif  // GOOGLE_TENSORRT
99 #endif  // GOOGLE_CUDA
100 
101 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_
102