• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <vector>
18 
19 #include "absl/memory/memory.h"
20 #include "absl/strings/string_view.h"
21 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
22 #include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.pb.h"  // NOLINT
23 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
24 #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/resource_mgr.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/refcount.h"
30 #include "tensorflow/core/lib/io/record_reader.h"
31 #include "tensorflow/core/lib/io/record_writer.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/mutex.h"
34 #include "tensorflow/core/platform/thread_annotations.h"
35 
36 #if GOOGLE_CUDA && GOOGLE_TENSORRT
37 #include "third_party/tensorrt/NvInfer.h"
38 
39 namespace tensorflow {
40 namespace tensorrt {
41 using ::nvinfer1::IRuntime;
42 
43 class CreateTRTResourceHandle : public OpKernel {
44  public:
CreateTRTResourceHandle(OpKernelConstruction * ctx)45   explicit CreateTRTResourceHandle(OpKernelConstruction* ctx) : OpKernel(ctx) {
46     OP_REQUIRES_OK(ctx, ctx->GetAttr("resource_name", &resource_name_));
47   }
48 
Compute(OpKernelContext * ctx)49   void Compute(OpKernelContext* ctx) override {
50     {
51       mutex_lock l(mutex_);
52       if (!initialized_) {
53         AllocatorAttributes attr;
54         attr.set_on_host(true);
55         OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
56                                                &handle_, attr));
57 
58         VLOG(1) << "Creating TRT engine cache resource handle for op "
59                 << resource_name_ << " on device " << ctx->device()->name();
60         handle_.scalar<ResourceHandle>()() =
61             MakeResourceHandle<TRTEngineCacheResource>(
62                 ctx, std::string(kTfTrtContainerName), resource_name_);
63         initialized_ = true;
64       }
65     }
66     ctx->set_output(0, handle_);
67   }
68 
69  private:
70   string resource_name_;
71   Tensor handle_;
72   mutex mutex_;
73   bool initialized_ TF_GUARDED_BY(mutex_) = false;
74 
75   TF_DISALLOW_COPY_AND_ASSIGN(CreateTRTResourceHandle);
76 };
77 
78 REGISTER_KERNEL_BUILDER(Name("CreateTRTResourceHandle")
79                             .Device(DEVICE_GPU)
80                             .HostMemory("resource_handle"),
81                         CreateTRTResourceHandle);
82 
83 class InitializeTRTResource : public OpKernel {
84  public:
InitializeTRTResource(OpKernelConstruction * ctx)85   explicit InitializeTRTResource(OpKernelConstruction* ctx) : OpKernel(ctx) {
86     OP_REQUIRES_OK(
87         ctx, ctx->GetAttr("max_cached_engines_count", &max_cached_engines_));
88   }
89 
Compute(OpKernelContext * ctx)90   void Compute(OpKernelContext* ctx) override {
91     ResourceHandle handle = HandleFromInput(ctx, 0);
92     core::RefCountPtr<TRTEngineCacheResource> resource;
93     OP_REQUIRES_OK(
94         ctx, LookupOrCreateResource<TRTEngineCacheResource>(
95                  ctx, handle, &resource,
96                  [this, ctx](TRTEngineCacheResource** resource) -> Status {
97                    *resource = new TRTEngineCacheResource(
98                        ctx, this->max_cached_engines_);
99                    return Status::OK();
100                  }));
101 
102     auto allocator = resource->allocator_.get();
103     OP_REQUIRES(ctx, allocator != nullptr,
104                 errors::Internal("Not able to initialize TRT engine cache when "
105                                  "GPU allocator is empty."));
106     OP_REQUIRES(ctx, resource->cache_.size() == 0,
107                 errors::Internal("Expect engine cache to be empty, but got ",
108                                  resource->cache_.size(), " entries."));
109 
110     // Get the file name.
111     const string& filename = ctx->input(1).scalar<tstring>()();
112     OP_REQUIRES(ctx, !filename.empty(),
113                 errors::InvalidArgument("filename cannot be empty."));
114 
115     // Parse the serialized engines and add them to the cache.
116     std::unique_ptr<RandomAccessFile> file;
117     OP_REQUIRES_OK(ctx, ctx->env()->NewRandomAccessFile(filename, &file));
118     auto reader = std::make_unique<io::RecordReader>(file.get());
119 
120     uint64 offset = 0;
121     int num_loaded_engine = 0;
122     do {
123       tstring record;
124       Status status = reader->ReadRecord(&offset, &record);
125       if (errors::IsOutOfRange(status)) break;
126 
127       TRTEngineInstance engine_instance;
128       engine_instance.ParseFromString(record);
129       std::vector<TensorShape> engine_input_shapes;
130       const auto& input_shapes = engine_instance.input_shapes();
131       engine_input_shapes.reserve(input_shapes.size());
132       for (const TensorShapeProto& shape : input_shapes) {
133         engine_input_shapes.emplace_back(shape);
134       }
135 
136       TrtUniquePtrType<IRuntime> infer(
137           nvinfer1::createInferRuntime(TRTEngineCacheResource::GetLogger()));
138       infer->setGpuAllocator(allocator);
139       TrtUniquePtrType<nvinfer1::ICudaEngine> engine(
140           infer->deserializeCudaEngine(
141               engine_instance.serialized_engine().c_str(),
142               engine_instance.serialized_engine().size(), nullptr));
143       auto raw_engine = engine.get();
144       std::vector<ExecutionContext> ctx_vec;
145       if (num_loaded_engine == 0) {
146         // Restore profiles if there are any. Currently only 1 engine is allowed
147         // in dynamic mode therefore we call this only for the 0th engine.
148         // it is a no-op in implicit batch mode.
149         OP_REQUIRES_OK(ctx, resource->profiles_.RestoreProfiles(
150                                 raw_engine, engine_input_shapes.size()));
151         OP_REQUIRES_OK(ctx, resource->profiles_.CreateExecutionContexts(
152                                 raw_engine, &ctx_vec));
153       } else {
154         // Multiple engines are only available in static mode. For each engine
155         // we have only a single execution context.
156         ctx_vec.push_back(ExecutionContext::Create(raw_engine));
157       }
158       resource->cache_.emplace(engine_input_shapes,
159                                std::make_unique<EngineContext>(
160                                    std::move(engine), std::move(ctx_vec)));
161       ++num_loaded_engine;
162     } while (1);
163     VLOG(1) << "Loaded " << num_loaded_engine << " TRT engines for op "
164             << handle.name() << " on device " << ctx->device()->name()
165             << " from file " << filename;
166   }
167 
168  private:
169   // Maximum number of cached engines
170   int max_cached_engines_;
171 
172   TF_DISALLOW_COPY_AND_ASSIGN(InitializeTRTResource);
173 };
174 
175 REGISTER_KERNEL_BUILDER(Name("InitializeTRTResource")
176                             .Device(DEVICE_GPU)
177                             .HostMemory("resource_handle"),
178                         InitializeTRTResource);
179 
180 class SerializeTRTResource : public OpKernel {
181  public:
SerializeTRTResource(OpKernelConstruction * ctx)182   explicit SerializeTRTResource(OpKernelConstruction* ctx) : OpKernel(ctx) {
183     OP_REQUIRES_OK(ctx, ctx->GetAttr("delete_resource", &delete_resource_));
184     OP_REQUIRES_OK(ctx, ctx->GetAttr("save_gpu_specific_engines",
185                                      &save_gpu_specific_engines_));
186   }
187 
Compute(OpKernelContext * ctx)188   void Compute(OpKernelContext* ctx) override {
189     const string& resource_name = ctx->input(0).scalar<tstring>()();
190     const string& filename = ctx->input(1).scalar<tstring>()();
191     OP_REQUIRES(ctx, !filename.empty(),
192                 errors::InvalidArgument("filename cannot be empty."));
193 
194     // Lookup engine cache resource.
195     TRTEngineCacheResource* resource = nullptr;
196     OP_REQUIRES(
197         ctx,
198         ctx->resource_manager()
199             ->Lookup(std::string(kTfTrtContainerName), resource_name, &resource)
200             .ok(),
201         errors::NotFound("TRTEngineCacheResource not yet created"));
202     core::ScopedUnref unref_me(resource);
203 
204     // Terminate the calibration if any.
205     if (resource->calib_ctx_) resource->calib_ctx_->TerminateCalibration();
206 
207     // Serialize the engines and write them to file.
208     std::unique_ptr<WritableFile> file;
209     OP_REQUIRES_OK(ctx, ctx->env()->NewWritableFile(filename, &file));
210     auto writer = std::make_unique<io::RecordWriter>(file.get());
211 
212     int num_serialized_engines = 0;
213     if (save_gpu_specific_engines_) {
214       // If user requests TRT engines export, recursively create
215       // requisite directories.
216       const char* export_trt_engines_env =
217           getenv("TF_TRT_EXPORT_TRT_ENGINES_PATH");
218       if (export_trt_engines_env) {
219         VLOG(1) << "Exporting TRT engines to directory: "
220                 << export_trt_engines_env;
221         OP_REQUIRES_OK(
222             ctx, ctx->env()->RecursivelyCreateDir(export_trt_engines_env));
223       }
224 
225       for (const auto& pair : resource->cache_) {
226         // Ignore engines that failed to build.
227         const std::unique_ptr<EngineContext>& engine = pair.second;
228         if (!engine || !engine->GetCudaEngine()) continue;
229 
230         TRTEngineInstance engine_instance;
231         // Add input shapes.
232         const std::vector<TensorShape>& engine_input_shapes = pair.first;
233         for (const TensorShape& shape : engine_input_shapes) {
234           shape.AsProto(engine_instance.add_input_shapes());
235         }
236         // Add the serialized engine.
237         TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(
238             engine->GetCudaEngine()->serialize());
239         engine_instance.set_serialized_engine(engine_data->data(),
240                                               engine_data->size());
241 
242         if (export_trt_engines_env) {
243           const std::string engine_filename =
244               std::string(export_trt_engines_env) + "/" + resource_name;
245           std::unique_ptr<WritableFile> engine_file;
246           OP_REQUIRES_OK(
247               ctx, ctx->env()->NewWritableFile(engine_filename, &engine_file));
248           OP_REQUIRES_OK(ctx, engine_file->Append(StringPiece(
249                                   static_cast<char*>(engine_data->data()),
250                                   engine_data->size())));
251 
252           const std::string dims_filename =
253               std::string(export_trt_engines_env) + "/dims-" + resource_name;
254           std::unique_ptr<WritableFile> dims_file;
255           OP_REQUIRES_OK(
256               ctx, ctx->env()->NewWritableFile(dims_filename, &dims_file));
257 
258           for (const TensorShape& shape : engine_input_shapes) {
259             OP_REQUIRES_OK(ctx,
260                            dims_file->Append(StringPiece(shape.DebugString())));
261           }
262         }
263 
264         OP_REQUIRES_OK(
265             ctx, writer->WriteRecord(engine_instance.SerializeAsString()));
266         ++num_serialized_engines;
267       }
268     } else {
269       VLOG(1) << "TRT Engines are not serialized for op: " << resource_name;
270     }
271     VLOG(1) << "Serialized " << num_serialized_engines << " TRT engines for op "
272             << resource_name << " on device " << ctx->device()->name()
273             << " to file " << filename;
274 
275     if (delete_resource_) {
276       VLOG(1) << "Destroying TRT engine cache resource for op " << resource_name
277               << " on device " << ctx->device()->name();
278       OP_REQUIRES_OK(ctx,
279                      ctx->resource_manager()->Delete<TRTEngineCacheResource>(
280                          std::string(kTfTrtContainerName), resource_name));
281     }
282   }
283 
284  private:
285   bool delete_resource_ = false;
286   bool save_gpu_specific_engines_ = true;
287 
288   TF_DISALLOW_COPY_AND_ASSIGN(SerializeTRTResource);
289 };
290 
291 REGISTER_KERNEL_BUILDER(Name("SerializeTRTResource").Device(DEVICE_GPU),
292                         SerializeTRTResource);
293 
294 }  // namespace tensorrt
295 }  // namespace tensorflow
296 
297 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
298