• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_
18 
19 #include <list>
20 #include <unordered_map>
21 
22 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
23 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
24 #include "tensorflow/core/framework/resource_mgr.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 
27 #if GOOGLE_CUDA && GOOGLE_TENSORRT
28 #include "tensorrt/include/NvInfer.h"
29 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
30 
31 namespace tensorflow {
32 namespace tensorrt {
33 
34 template <class Key, class Value, class HashFunction>
35 class LRUCache {
36  public:
37   typedef Value value_type;
38   typedef Key key_type;
39   typedef HashFunction hasher;
40   typedef typename std::unordered_map<key_type, value_type, hasher> map_type;
41   typedef typename map_type::iterator iterator;
42   typedef typename map_type::const_iterator const_iterator;
43 
LRUCache()44   LRUCache() : capacity_(0) {}
LRUCache(size_t capacity)45   explicit LRUCache(size_t capacity) : capacity_(capacity) {}
46 
capacity()47   size_t capacity() const { return capacity_; }
48 
reserve(size_t capacity)49   void reserve(size_t capacity) {
50     capacity_ = capacity;
51     DiscardOld();
52   }
53 
size()54   size_t size() const { return objects_.size(); }
55 
count(const key_type & key)56   size_t count(const key_type& key) const { return objects_.count(key); }
57 
at(const key_type & key)58   value_type& at(const key_type& key) { return Touch(key); }
59 
begin()60   const_iterator begin() const { return objects_.begin(); }
end()61   const_iterator end() const { return objects_.end(); }
62 
begin()63   iterator begin() { return objects_.begin(); }
end()64   iterator end() { return objects_.end(); }
65 
66   template <typename... Args>
emplace(Args &&...args)67   std::pair<iterator, bool> emplace(Args&&... args) {
68     DiscardOld(1);
69     std::pair<iterator, bool> result =
70         objects_.emplace(std::forward<Args>(args)...);
71     key_type key = result.first->first;
72     if (result.second) {
73       keys_.push_front(key);
74     } else {
75       TouchNoCheck(key);  // The key must exist in this case.
76     }
77     return result;
78   }
79 
80  private:
81   std::unordered_map<key_type, value_type, hasher> objects_;
82   std::list<key_type> keys_;
83   size_t capacity_;
84   value_type not_found_value_;
85 
Touch(const key_type & key)86   value_type& Touch(const key_type& key) {
87     // Check that the key exists, and let it return std::out_of_range error if
88     // not.
89     value_type& value = objects_.at(key);
90     TouchNoCheck(key);
91     return value;
92   }
93 
TouchNoCheck(const key_type & key)94   void TouchNoCheck(const key_type& key) {
95     auto rank = std::find(keys_.begin(), keys_.end(), key);
96     if (rank != keys_.begin()) {
97       keys_.erase(rank);
98       keys_.push_front(key);
99     }
100   }
101 
102   // Creates n free positions in cache
103   Status DiscardOld(size_t n = 0) {
104     if (n > capacity_) {
105       return errors::Internal("Insufficient capacity in cache (capacity = ",
106                               capacity_, ", requested ", n, ")");
107     }
108     while (objects_.size() > (capacity_ - n)) {
109       key_type discard_key = keys_.back();
110       keys_.pop_back();
111       objects_.erase(discard_key);
112     }
113     return Status::OK();
114   }
115 };
116 
117 // Define a hash function for vector<TensorShape> because it is used as the key
118 // for the engine cache.
119 struct VectorTensorShapeHasher {
operatorVectorTensorShapeHasher120   std::size_t operator()(const std::vector<TensorShape>& key) const {
121     return std::hash<std::string>()(TensorShapeUtils::ShapeListString(key));
122   }
123 };
124 
125 #if GOOGLE_CUDA
126 #if GOOGLE_TENSORRT
127 
128 struct EngineContext {
EngineContextEngineContext129   EngineContext() {}  // Creates an empty context.
EngineContextEngineContext130   EngineContext(
131       TrtUniquePtrType<nvinfer1::ICudaEngine>&& input_cuda_engine,
132       TrtUniquePtrType<nvinfer1::IExecutionContext>&& input_execution_context)
133       : cuda_engine(std::move(input_cuda_engine)),
134         execution_context(std::move(input_execution_context)) {}
135 
136   mutex mu;
137   TrtUniquePtrType<nvinfer1::ICudaEngine> cuda_engine;
138   TrtUniquePtrType<nvinfer1::IExecutionContext> execution_context
139       GUARDED_BY(mu);
140 };
141 
142 class TRTEngineCacheResource : public ResourceBase {
143  public:
TRTEngineCacheResource(OpKernelContext * ctx,size_t capacity)144   TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity)
145       : cache_(capacity) {
146     auto device = ctx->device();
147     auto alloc = device->GetAllocator(AllocatorAttributes());
148     if (!alloc) {
149       LOG(ERROR) << "Can't find device allocator for gpu device "
150                  << device->name();
151       allocator_ = nullptr;
152     } else {
153       allocator_.reset(new TRTDeviceAllocator(alloc));
154     }
155   }
156 
DebugString()157   string DebugString() const override {
158     std::stringstream oss;
159     using std::dec;
160     using std::endl;
161     using std::hex;
162     oss << "TRTEngineCacheResource: ";
163     oss << "TRTBaseAllocator = " << hex << allocator_.get() << dec << ", ";
164     oss << "LRUCache = " << hex << &cache_ << dec << endl;
165     oss << "Containing " << cache_.size() << " entries: " << endl;
166     for (const auto& item : cache_) {
167       oss << TensorShapeUtils::ShapeListString(item.first) << ": " << hex
168           << "ICudaEngine: " << item.second.get()->cuda_engine.get() << ", "
169           << "IExecutionContext: " << item.second.get()->execution_context.get()
170           << dec << endl;
171     }
172     return oss.str();
173   }
174 
175   // Keep device allocator for TRT.
176   std::unique_ptr<TRTBaseAllocator> allocator_;
177 
178   // Declare cache after allocator so that it is destroyed before allocator is.
179   LRUCache<std::vector<TensorShape>, std::unique_ptr<EngineContext>,
180            VectorTensorShapeHasher>
181       cache_;
182 };
183 
184 #endif  // GOOGLE_TENSORRT
185 #endif  // GOOGLE_CUDA
186 
187 }  // namespace tensorrt
188 }  // namespace tensorflow
189 
190 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_
191