1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
17
18 #include <sstream>
19
20 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
21 #include "tensorflow/core/framework/device_base.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/platform/mutex.h"
25
26 #if GOOGLE_CUDA && GOOGLE_TENSORRT
27 #include "third_party/tensorrt/NvInfer.h"
28
29 namespace tensorflow {
30 namespace tensorrt {
31
TerminateCalibration()32 string CalibrationContext::TerminateCalibration() {
33 mutex_lock l(mu_);
34 if (terminated_) return calibration_table_;
35
36 TRTInt8Calibrator* raw_calibrator = calibrator_.get();
37 raw_calibrator->waitAndSetDone();
38 terminated_ = true;
39
40 // At this point the calibration thread `thr_` is woken up and can
41 // transfer the ownership of `calibrator_` and `engine_` at any time, so
42 // it's not safe to use `calibrator_` below, but we can still access it
43 // using raw pointer.
44 // TODO(laigd): make TRTEngineOp::AllocateCalibrationResources() a member
45 // function of this class instead.
46
47 thr_->join();
48 calibration_table_ = raw_calibrator->getCalibrationTableAsString();
49 return calibration_table_;
50 }
51
52 const absl::string_view kTfTrtContainerName = "TF-TRT";
53
GetLogger()54 Logger& TRTEngineCacheResource::GetLogger() {
55 static Logger* logger = new Logger();
56 return *logger;
57 }
58
TRTEngineCacheResource(OpKernelContext * ctx,size_t capacity)59 TRTEngineCacheResource::TRTEngineCacheResource(OpKernelContext* ctx,
60 size_t capacity)
61 : cache_(capacity) {
62 auto device = ctx->device();
63 auto alloc = device->GetAllocator(AllocatorAttributes());
64 if (!alloc) {
65 LOG(ERROR) << "Can't find device allocator for gpu device "
66 << device->name();
67 allocator_ = nullptr;
68 } else {
69 allocator_.reset(new TRTDeviceAllocator(alloc));
70 }
71 }
72
~TRTEngineCacheResource()73 TRTEngineCacheResource::~TRTEngineCacheResource() {
74 VLOG(1) << "Destroying TRTEngineCacheResource...";
75 }
76
DebugString() const77 string TRTEngineCacheResource::DebugString() const {
78 std::stringstream oss;
79 using std::dec;
80 using std::endl;
81 using std::hex;
82 oss << "TRTEngineCacheResource: ";
83 oss << "TRTBaseAllocator = " << hex << allocator_.get() << dec << ", ";
84 oss << "LRUCache = " << hex << &cache_ << dec << endl;
85 oss << "Containing " << cache_.size() << " entries: " << endl;
86 for (const auto& item : cache_) {
87 mutex_lock lock(item.second->mu);
88 oss << TensorShapeUtils::ShapeListString(item.first) << ": " << hex
89 << "ICudaEngine: " << item.second->cuda_engine.get() << ", "
90 << "IExecutionContext: ";
91 for (auto& ctx : item.second->execution_context) {
92 oss << ctx.GetIExecutionContext() << ", ";
93 }
94 oss << dec << endl;
95 }
96 return oss.str();
97 }
98
GetEngineContext(const std::vector<TensorShape> & input_shapes)99 EngineContext* TRTEngineCacheResource::GetEngineContext(
100 const std::vector<TensorShape>& input_shapes) {
101 EngineContext* engine_context = nullptr;
102 int64 min_matched_batch_size = kint64max;
103 for (const auto& pair : cache_) {
104 const std::vector<TensorShape>& cached_input_shapes = pair.first;
105 // This should not happen, but just for safety.
106 if (input_shapes.size() != cached_input_shapes.size()) {
107 LOG(ERROR) << "Input shape list size mismatch"
108 << ", cached size: " << cached_input_shapes.size()
109 << " vs. input size: " << input_shapes.size();
110 }
111 if (AreShapesCompatible(input_shapes, cached_input_shapes)) {
112 const int cached_batch_size = cached_input_shapes[0].dim_size(0);
113 if (min_matched_batch_size > cached_batch_size) {
114 min_matched_batch_size = cached_batch_size;
115 engine_context = pair.second.get();
116 }
117 }
118 }
119 return engine_context;
120 }
121
GetEngineContext(const int profile_id)122 EngineContext* TRTEngineCacheResource::GetEngineContext(const int profile_id) {
123 if (profiles_.NeedProfiles() && profile_id >= profiles_.GetNumProfiles()) {
124 LOG(ERROR) << "Out of range: profile_id " << profile_id
125 << " is larger than number of profiles "
126 << profiles_.GetNumProfiles();
127 return nullptr;
128 }
129 if (cache_.size() > 1) {
130 LOG(ERROR) << "Cache is expected to have at most "
131 << "1 engine in explicit batch mode where profiles are used.";
132 return nullptr;
133 }
134 if (cache_.size() == 0) {
135 return nullptr;
136 }
137 return cache_.begin()->second.get();
138 }
139
140 } // namespace tensorrt
141 } // namespace tensorflow
142
143 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
144