• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_
18 
19 #include <list>
20 #include <thread>
21 #include <unordered_map>
22 
23 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
24 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
25 #include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h"
26 #include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h"
27 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
28 #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
29 #include "tensorflow/core/framework/resource_mgr.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 
32 #if GOOGLE_CUDA && GOOGLE_TENSORRT
33 #include "third_party/tensorrt/NvInfer.h"
34 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
35 
36 namespace tensorflow {
37 namespace tensorrt {
38 
39 template <class Key, class Value, class HashFunction>
40 class LRUCache {
41  public:
42   typedef Value value_type;
43   typedef Key key_type;
44   typedef HashFunction hasher;
45   typedef typename std::unordered_map<key_type, value_type, hasher> map_type;
46   typedef typename map_type::iterator iterator;
47   typedef typename map_type::const_iterator const_iterator;
48 
LRUCache()49   LRUCache() : capacity_(0) {}
LRUCache(size_t capacity)50   explicit LRUCache(size_t capacity) : capacity_(capacity) {}
51 
capacity()52   size_t capacity() const { return capacity_; }
53 
reserve(size_t capacity)54   void reserve(size_t capacity) {
55     capacity_ = capacity;
56     DiscardOld();
57   }
58 
size()59   size_t size() const { return objects_.size(); }
60 
count(const key_type & key)61   size_t count(const key_type& key) const { return objects_.count(key); }
62 
at(const key_type & key)63   value_type& at(const key_type& key) { return Touch(key); }
64 
begin()65   const_iterator begin() const { return objects_.begin(); }
end()66   const_iterator end() const { return objects_.end(); }
67 
begin()68   iterator begin() { return objects_.begin(); }
end()69   iterator end() { return objects_.end(); }
70 
71   template <typename... Args>
emplace(Args &&...args)72   std::pair<iterator, bool> emplace(Args&&... args) {
73     DiscardOld(1);
74     std::pair<iterator, bool> result =
75         objects_.emplace(std::forward<Args>(args)...);
76     key_type key = result.first->first;
77     if (result.second) {
78       keys_.push_front(key);
79     } else {
80       TouchNoCheck(key);  // The key must exist in this case.
81     }
82     return result;
83   }
84 
85  private:
86   std::unordered_map<key_type, value_type, hasher> objects_;
87   std::list<key_type> keys_;
88   size_t capacity_;
89   value_type not_found_value_;
90 
Touch(const key_type & key)91   value_type& Touch(const key_type& key) {
92     // Check that the key exists, and let it return std::out_of_range error if
93     // not.
94     value_type& value = objects_.at(key);
95     TouchNoCheck(key);
96     return value;
97   }
98 
TouchNoCheck(const key_type & key)99   void TouchNoCheck(const key_type& key) {
100     auto rank = std::find(keys_.begin(), keys_.end(), key);
101     if (rank != keys_.begin()) {
102       keys_.erase(rank);
103       keys_.push_front(key);
104     }
105   }
106 
107   // Creates n free positions in cache
108   void DiscardOld(size_t n = 0) {
109     DCHECK(capacity_ >= n) << "Insufficient capacity in cache (capacity = "
110                            << capacity_ << ", requested " << n << ")";
111     while (objects_.size() > (capacity_ - n)) {
112       key_type discard_key = keys_.back();
113       keys_.pop_back();
114       objects_.erase(discard_key);
115     }
116   }
117 };
118 
119 #if GOOGLE_CUDA && GOOGLE_TENSORRT
120 
121 struct EngineContext {
EngineContextEngineContext122   EngineContext() {}  // Creates an empty context.
EngineContextEngineContext123   EngineContext(TrtUniquePtrType<nvinfer1::ICudaEngine>&& input_cuda_engine,
124                 ExecutionContext&& input_execution_context)
125       : cuda_engine(std::move(input_cuda_engine)) {
126     execution_context.push_back(std::move(input_execution_context));
127   }
EngineContextEngineContext128   EngineContext(TrtUniquePtrType<nvinfer1::ICudaEngine>&& input_cuda_engine,
129                 std::vector<ExecutionContext>&& input_execution_context)
130       : cuda_engine(std::move(input_cuda_engine)),
131         execution_context(std::move(input_execution_context)) {}
132 
133   mutex mu;
134   TrtUniquePtrType<nvinfer1::ICudaEngine> cuda_engine;
135 
GetExecutionContextEngineContext136   Status GetExecutionContext(int idx, nvinfer1::IExecutionContext** exec_ctx)
137       TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
138     if (idx >= execution_context.size()) {
139       return errors::Internal("Requested engine context with index ", idx,
140                               ", but only ", execution_context.size(),
141                               "contexts are present.");
142     }
143     *exec_ctx = execution_context[idx];
144     return Status::OK();
145   }
146 
147   // In explicit batch mode, we maintain a vector of contexts for each engine,
148   // where each context is created for a specific profile. This is because it is
149   // either not possible or non-trivial to change the profile of a context for
150   // the following reasons:
151   // - In TRT 6 it is not possible to switch a profile after it is set
152   //   https://docs.nvidia.com/deeplearning/tensorrt/archives/tensorrt-601/tensorrt-api/c_api/classnvinfer1_1_1_i_execution_context.html#aba0731b9fbc926c477010df818650b0a
153   // - To switch profiles (from TRT 7), one must first ensure that all inference
154   //   calls in that context are finished. This would require an additional
155   //   synchronization before we call setOptimizationProfile. To avoid this
156   //   extra sync call, we mantain separate execution context for each profile.
157   // IExecutionContext object is not thread safe: only one thread should use it
158   // for inference at a time therefore we need a mutex. More details at
159   // https://docs.nvidia.com/deeplearning/sdk/tensorrt-best-practices/index.html#thread-safety
160   // Additional discussion about execution context management and thread safety
161   // at https://github.com/tensorflow/tensorflow/issues/36959
162   std::vector<ExecutionContext> execution_context TF_GUARDED_BY(mu);
163 };
164 
165 // Contains the context required to build the calibration data.
166 class CalibrationContext {
167  public:
168   string TerminateCalibration();
169 
170   // Lookup table for temporary staging areas of input tensors for calibration.
171   std::unordered_map<string, std::pair<void*, size_t>> device_buffers_;
172 
173   // Temporary staging areas for calibration inputs.
174   std::vector<PersistentTensor> device_tensors_;
175 
176   std::unique_ptr<TRTInt8Calibrator> calibrator_;
177   TrtUniquePtrType<nvinfer1::IBuilder> builder_;
178   TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
179   // TODO(sami): Use threadpool threads!
180   std::unique_ptr<std::thread> thr_;
181 
182  private:
183   mutex mu_;
184   bool terminated_ TF_GUARDED_BY(mu_) = false;
185   std::string calibration_table_ TF_GUARDED_BY(mu_);
186 };
187 
188 ABSL_CONST_INIT extern const absl::string_view kTfTrtContainerName;
189 
190 class TRTEngineCacheResource : public ResourceBase {
191  public:
192   // According to the TensorRT API, the logger is considered a singleton by the
193   // TensorRT library, and multiple instances of IRuntime and/or IBuilder must
194   // all use the same logger. So here we make it a singleton.
195   //
196   // TODO(laigd): use this logger in all places where conversion happens.
197   static Logger& GetLogger();
198 
199   TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity);
200 
201   ~TRTEngineCacheResource() override;
202 
203   string DebugString() const override;
204 
205   // Returns the EngineContext that is compatible with input_shapes.
206   // Returns nullptr if no compatible EngineContexts is found in cache.
207   EngineContext* GetEngineContext(const std::vector<TensorShape>& input_shapes);
208 
209   // Returns the EngineContext that is compatible with profile_id.
210   // This function should be only called in explicit batch mode where
211   // cache size is expected to be at most one.
212   // Returns nullptr if no compatible EngineContexts is found in cache.
213   EngineContext* GetEngineContext(const int profile_id);
214 
215   // Keep device allocator for TRT.
216   std::unique_ptr<TRTBaseAllocator> allocator_;
217 
218   // Declare cache after allocator so that it is destroyed before allocator is.
219   LRUCache<std::vector<TensorShape>, std::unique_ptr<EngineContext>,
220            VectorTensorShapeHasher>
221       cache_;
222 
223   // TODO(hinsu): Use different calibration context for the available shapes and
224   // attach it to each item of the cache.
225   std::unique_ptr<CalibrationContext> calib_ctx_;
226 
227   // This object maintains all the optimization profiles during profile
228   // generation and engine build. During runtime the list of profiles is used to
229   // look up a matching profile for the input data.
230   TrtShapeOptimizationProfile profiles_;
231 };
232 
233 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
234 
235 }  // namespace tensorrt
236 }  // namespace tensorflow
237 
238 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_
239