• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
17 
18 #include <sstream>
19 
20 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
21 #include "tensorflow/core/framework/device_base.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/platform/mutex.h"
25 
26 #if GOOGLE_CUDA && GOOGLE_TENSORRT
27 #include "third_party/tensorrt/NvInfer.h"
28 
29 namespace tensorflow {
30 namespace tensorrt {
31 
TerminateCalibration()32 string CalibrationContext::TerminateCalibration() {
33   mutex_lock l(mu_);
34   if (terminated_) return calibration_table_;
35 
36   TRTInt8Calibrator* raw_calibrator = calibrator_.get();
37   raw_calibrator->waitAndSetDone();
38   terminated_ = true;
39 
40   // At this point the calibration thread `thr_` is woken up and can
41   // transfer the ownership of `calibrator_` and `engine_` at any time, so
42   // it's not safe to use `calibrator_` below, but we can still access it
43   // using raw pointer.
44   // TODO(laigd): make TRTEngineOp::AllocateCalibrationResources() a member
45   // function of this class instead.
46 
47   thr_->join();
48   calibration_table_ = raw_calibrator->getCalibrationTableAsString();
49   return calibration_table_;
50 }
51 
52 const absl::string_view kTfTrtContainerName = "TF-TRT";
53 
GetLogger()54 Logger& TRTEngineCacheResource::GetLogger() {
55   static Logger* logger = new Logger();
56   return *logger;
57 }
58 
TRTEngineCacheResource(OpKernelContext * ctx,size_t capacity)59 TRTEngineCacheResource::TRTEngineCacheResource(OpKernelContext* ctx,
60                                                size_t capacity)
61     : cache_(capacity) {
62   auto device = ctx->device();
63   auto alloc = device->GetAllocator(AllocatorAttributes());
64   if (!alloc) {
65     LOG(ERROR) << "Can't find device allocator for gpu device "
66                << device->name();
67     allocator_ = nullptr;
68   } else {
69     allocator_.reset(new TRTDeviceAllocator(alloc));
70   }
71 }
72 
~TRTEngineCacheResource()73 TRTEngineCacheResource::~TRTEngineCacheResource() {
74   VLOG(1) << "Destroying TRTEngineCacheResource...";
75 }
76 
DebugString() const77 string TRTEngineCacheResource::DebugString() const {
78   std::stringstream oss;
79   using std::dec;
80   using std::endl;
81   using std::hex;
82   oss << "TRTEngineCacheResource: ";
83   oss << "TRTBaseAllocator = " << hex << allocator_.get() << dec << ", ";
84   oss << "LRUCache = " << hex << &cache_ << dec << endl;
85   oss << "Containing " << cache_.size() << " entries: " << endl;
86   for (const auto& item : cache_) {
87     mutex_lock lock(item.second->mu);
88     oss << TensorShapeUtils::ShapeListString(item.first) << ": " << hex
89         << "ICudaEngine: " << item.second->cuda_engine.get() << ", "
90         << "IExecutionContext: ";
91     for (auto& ctx : item.second->execution_context) {
92       oss << ctx.GetIExecutionContext() << ", ";
93     }
94     oss << dec << endl;
95   }
96   return oss.str();
97 }
98 
GetEngineContext(const std::vector<TensorShape> & input_shapes)99 EngineContext* TRTEngineCacheResource::GetEngineContext(
100     const std::vector<TensorShape>& input_shapes) {
101   EngineContext* engine_context = nullptr;
102   int64 min_matched_batch_size = kint64max;
103   for (const auto& pair : cache_) {
104     const std::vector<TensorShape>& cached_input_shapes = pair.first;
105     // This should not happen, but just for safety.
106     if (input_shapes.size() != cached_input_shapes.size()) {
107       LOG(ERROR) << "Input shape list size mismatch"
108                  << ", cached size: " << cached_input_shapes.size()
109                  << " vs. input size: " << input_shapes.size();
110     }
111     if (AreShapesCompatible(input_shapes, cached_input_shapes)) {
112       const int cached_batch_size = cached_input_shapes[0].dim_size(0);
113       if (min_matched_batch_size > cached_batch_size) {
114         min_matched_batch_size = cached_batch_size;
115         engine_context = pair.second.get();
116       }
117     }
118   }
119   return engine_context;
120 }
121 
GetEngineContext(const int profile_id)122 EngineContext* TRTEngineCacheResource::GetEngineContext(const int profile_id) {
123   if (profiles_.NeedProfiles() && profile_id >= profiles_.GetNumProfiles()) {
124     LOG(ERROR) << "Out of range: profile_id " << profile_id
125                << " is larger than number of profiles "
126                << profiles_.GetNumProfiles();
127     return nullptr;
128   }
129   if (cache_.size() > 1) {
130     LOG(ERROR) << "Cache is expected to have at most "
131                << "1 engine in explicit batch mode where profiles are used.";
132     return nullptr;
133   }
134   if (cache_.size() == 0) {
135     return nullptr;
136   }
137   return cache_.begin()->second.get();
138 }
139 
140 }  // namespace tensorrt
141 }  // namespace tensorflow
142 
143 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
144