• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
17 
18 #if GOOGLE_CUDA && GOOGLE_TENSORRT
19 #include "absl/base/call_once.h"
20 #include "absl/strings/str_join.h"
21 #include "third_party/tensorrt/NvInferPlugin.h"
22 #endif
23 
24 namespace tensorflow {
25 namespace tensorrt {
26 
GetLinkedTensorRTVersion()27 std::tuple<int, int, int> GetLinkedTensorRTVersion() {
28 #if GOOGLE_CUDA && GOOGLE_TENSORRT
29   return std::tuple<int, int, int>{NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR,
30                                    NV_TENSORRT_PATCH};
31 #else
32   return std::tuple<int, int, int>{0, 0, 0};
33 #endif
34 }
35 
GetLoadedTensorRTVersion()36 std::tuple<int, int, int> GetLoadedTensorRTVersion() {
37 #if GOOGLE_CUDA && GOOGLE_TENSORRT
38   int ver = getInferLibVersion();
39   int major = ver / 1000;
40   ver = ver - major * 1000;
41   int minor = ver / 100;
42   int patch = ver - minor * 100;
43   return std::tuple<int, int, int>{major, minor, patch};
44 #else
45   return std::tuple<int, int, int>{0, 0, 0};
46 #endif
47 }
48 
49 }  // namespace tensorrt
50 }  // namespace tensorflow
51 
52 #if GOOGLE_CUDA && GOOGLE_TENSORRT
53 namespace tensorflow {
54 namespace tensorrt {
55 namespace {
56 
InitializeTrtPlugins(nvinfer1::ILogger * trt_logger)57 void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) {
58   LOG(INFO) << "Linked TensorRT version: "
59             << absl::StrJoin(GetLinkedTensorRTVersion(), ".");
60   LOG(INFO) << "Loaded TensorRT version: "
61             << absl::StrJoin(GetLoadedTensorRTVersion(), ".");
62 
63   bool plugin_initialized = initLibNvInferPlugins(trt_logger, "");
64   if (!plugin_initialized) {
65     LOG(ERROR) << "Failed to initialize TensorRT plugins, and conversion may "
66                   "fail later.";
67   }
68 
69   int num_trt_plugins = 0;
70   nvinfer1::IPluginCreator* const* trt_plugin_creator_list =
71       getPluginRegistry()->getPluginCreatorList(&num_trt_plugins);
72   if (!trt_plugin_creator_list) {
73     LOG_WARNING_WITH_PREFIX << "Can not find any TensorRT plugins in registry.";
74   } else {
75     VLOG(1) << "Found the following " << num_trt_plugins
76             << " TensorRT plugins in registry:";
77     for (int i = 0; i < num_trt_plugins; ++i) {
78       if (!trt_plugin_creator_list[i]) {
79         LOG_WARNING_WITH_PREFIX
80             << "TensorRT plugin at index " << i
81             << " is not accessible (null pointer returned by "
82                "getPluginCreatorList for this plugin)";
83       } else {
84         VLOG(1) << "  " << trt_plugin_creator_list[i]->getPluginName();
85       }
86     }
87   }
88 }
89 
90 }  // namespace
91 
MaybeInitializeTrtPlugins(nvinfer1::ILogger * trt_logger)92 void MaybeInitializeTrtPlugins(nvinfer1::ILogger* trt_logger) {
93   static absl::once_flag once;
94   absl::call_once(once, InitializeTrtPlugins, trt_logger);
95 }
96 
97 }  // namespace tensorrt
98 }  // namespace tensorflow
99 #endif
100