1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <vector>
18
19 #include "absl/memory/memory.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
22 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
23 #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
24 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
25 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
26 #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
27 #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/graph_to_functiondef.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/lib/core/refcount.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/platform/stream_executor.h"
38 #include "tensorflow/core/platform/thread_annotations.h"
39 #include "tensorflow/core/platform/types.h"
40
41 #if GOOGLE_CUDA
42 #if GOOGLE_TENSORRT
43 #include "cuda/include/cuda_runtime_api.h"
44 #include "tensorrt/include/NvInfer.h"
45
46 namespace tensorflow {
47 namespace tensorrt {
48 static Logger logger;
49 using absl::StrAppend;
50 using absl::StrCat;
51 using ::nvinfer1::IRuntime;
52
53 // A helper class to call done() when destructed for asynchronous execution.
54 // Helps simultaneous execution of native and TRT engines.
55 class AsyncHelper : public core::RefCounted {
56 public:
AsyncHelper(AsyncOpKernel::DoneCallback done)57 AsyncHelper(AsyncOpKernel::DoneCallback done) { done_ = done; }
~AsyncHelper()58 ~AsyncHelper() override { done_(); }
59
60 private:
61 AsyncOpKernel::DoneCallback done_;
62 };
63
64 // This OP can construct TRTEngine on the fly and if construction of engine
65 // fails, executes equivalent subgraph as a TensorFlow function.
66 class TRTEngineOp : public AsyncOpKernel {
67 public:
68 explicit TRTEngineOp(OpKernelConstruction* context);
69
70 void ComputeAsync(OpKernelContext* context,
71 AsyncOpKernel::DoneCallback done) override;
72
73 private:
74 // Execute calibration
75 void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper);
76
77 // Construct a function handle for executing native funcdef graph
78 Status ConstructFunctionHandle(OpKernelContext* ctx);
79
80 // Execute replaced native segment as function Op.
81 void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
82
83 // Execute the tensorrt engine. Returns whether we need to retry by running
84 // the native segment.
85 bool ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context);
86
87 // Allocate necessary resources for calibration
88 Status AllocateCalibrationResources(OpKernelContext* ctx,
89 SerializableResourceBase** cr);
90
91 // Get engine for the input shape
92 EngineContext* GetEngine(const std::vector<TensorShape>& input_shapes,
93 OpKernelContext* ctx);
94
95 // Return engine batch in cached_engne_batch_sizes_ which is closest to input
96 // batch.
97 bool GetCompatibleCachedEngine(
98 const std::vector<TensorShape>& actual_input_shapes,
99 std::vector<TensorShape>* engine_input_shapes);
100
101 std::vector<string> input_nodes_;
102 std::vector<string> output_nodes_;
103
104 // serialized protobuf segment or trt engine depending on static_engine_ flag.
105 string serialized_segment_;
106
107 // Name of the function for TF native execution of the segment. If empty, it
108 // means TF native execution is not allowed, and if TRT engine fails to run
109 // an error will be returned.
110 string funcdef_name_;
111
112 // GraphDef representation of the segment.
113 GraphDef segment_graph_;
114
115 // Engine Precision mode.
116 TrtPrecisionMode precision_mode_;
117
118 // Whether engine is constructed during the conversion or needs to be
119 // constructed from protobuf segment.
120 bool static_engine_;
121
122 // Whether to calibrate INT8 engine.
123 bool calibration_mode_;
124
125 // Batches of the cached engines
126 std::vector<int> cached_engine_batches_;
127
128 // Maximum number of cached engines
129 int max_cached_engines_;
130
131 int64 workspace_size_;
132 mutex engine_mutex_;
133 FunctionLibraryRuntime::Handle native_func_;
134
135 // The finalized calibrator for inference.
136 std::unique_ptr<TRTInt8Calibrator> calibrator_;
137
138 // If true, create calibration graph for INT8 mode. Otherwise, we are using
139 // user-provided quantization ranges.
140 bool use_calibration_;
141 };
142
143 #define TYPECASE(dt, X, Y) \
144 case dt: { \
145 return (void*)X->flat<EnumToDataType<dt>::Type>().data(); \
146 }
147
GetTensorAddress(const Tensor * tensor_ptr)148 void* GetTensorAddress(const Tensor* tensor_ptr) {
149 auto tensor_type = tensor_ptr->dtype();
150 switch (tensor_type) {
151 TYPECASE(DT_FLOAT, tensor_ptr, dest_ptr);
152 TYPECASE(DT_HALF, tensor_ptr, dest_ptr);
153 TYPECASE(DT_INT8, tensor_ptr, dest_ptr);
154 default: {
155 LOG(ERROR) << "Unsupported Data type " << DataTypeString(tensor_type);
156 return nullptr;
157 }
158 }
159 }
160
ConstructFunctionHandle(OpKernelContext * ctx)161 Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) {
162 VLOG(1) << "Constructing function handle";
163 auto lib = ctx->function_library();
164 if (lib == nullptr) {
165 return errors::Internal("Context function library is null");
166 }
167 auto fdef = lib->GetFunctionLibraryDefinition()->Find(funcdef_name_);
168 if (fdef == nullptr) {
169 return errors::Internal("Native FunctionDef ", funcdef_name_,
170 " can't be found in function library");
171 }
172 FunctionLibraryRuntime::InstantiateOptions inst_ops;
173 inst_ops.overlay_lib = nullptr;
174 inst_ops.state_handle = "";
175 inst_ops.target = ctx->device()->name();
176 native_func_ = 0;
177 auto status = lib->Instantiate(funcdef_name_, AttrSlice(&fdef->attr()),
178 inst_ops, &native_func_);
179 if (!status.ok()) {
180 LOG(ERROR) << " Instantiating native function " << funcdef_name_
181 << " failed!";
182 }
183 return status;
184 }
185
TRTEngineOp(OpKernelConstruction * context)186 TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
187 : AsyncOpKernel(context) {
188 // read serialized_engine
189 OP_REQUIRES_OK(context,
190 context->GetAttr("serialized_segment", &serialized_segment_));
191 OP_REQUIRES_OK(context,
192 context->GetAttr("workspace_size_bytes", &workspace_size_));
193 OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
194 if (!static_engine_) {
195 if (!segment_graph_.ParseFromString(serialized_segment_)) {
196 LOG(ERROR) << "Parsing segment graph failed!";
197 context->SetStatus(
198 errors::InvalidArgument("Failed to parse segment graphdef!"));
199 return;
200 }
201 VLOG(1) << "Size of serialized GraphDef: "
202 << serialized_segment_.capacity();
203 string tmp;
204 // Swap with temporary empty string to deallocate the CPU memory.
205 serialized_segment_.swap(tmp);
206 }
207 VLOG(1) << "Constructing " << name();
208 string precision_string;
209 OP_REQUIRES_OK(context,
210 context->GetAttr("precision_mode", &precision_string));
211 string calibration_data;
212 OP_REQUIRES_OK(context,
213 context->GetAttr("calibration_data", &calibration_data));
214 OP_REQUIRES_OK(context,
215 context->GetAttr("segment_funcdef_name", &funcdef_name_));
216 OP_REQUIRES_OK(context,
217 TrtPrecisionModeFromName(precision_string, &precision_mode_));
218 OP_REQUIRES_OK(context,
219 context->GetAttr("use_calibration", &use_calibration_));
220 calibration_mode_ =
221 (use_calibration_ && precision_mode_ == TrtPrecisionMode::INT8 &&
222 calibration_data.empty());
223 if (!calibration_data.empty()) {
224 calibrator_.reset(new TRTInt8Calibrator(calibration_data));
225 calibration_data.resize(0);
226 }
227 native_func_ = kInvalidHandle;
228 OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
229 &max_cached_engines_));
230 OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches",
231 &cached_engine_batches_));
232 std::sort(cached_engine_batches_.begin(), cached_engine_batches_.end());
233 if (VLOG_IS_ON(1)) {
234 string s("Engine Batches= ");
235 for (auto i : cached_engine_batches_) {
236 StrAppend(&s, i, " ");
237 }
238 VLOG(1) << s;
239 }
240 }
241
ExecuteNativeSegment(OpKernelContext * ctx,AsyncHelper * helper)242 void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
243 AsyncHelper* helper) {
244 if (funcdef_name_.empty()) {
245 const string err_msg = StrCat("Fallback path is disabled, for ", name());
246 LOG(WARNING) << err_msg;
247 ctx->SetStatus(errors::Internal(err_msg));
248 return;
249 }
250 std::vector<Tensor> inputs;
251 std::vector<Tensor>* outputs = new std::vector<Tensor>();
252 if (native_func_ == kInvalidHandle) {
253 auto status = ConstructFunctionHandle(ctx);
254 if (!status.ok()) {
255 LOG(ERROR) << "Couldn't construct function handle " << funcdef_name_;
256 ctx->SetStatus(status);
257 return;
258 }
259 }
260 auto lib = ctx->function_library();
261 FunctionLibraryRuntime::Options opts;
262 opts.step_id = ctx->step_id();
263 opts.rendezvous = ctx->rendezvous();
264 opts.cancellation_manager = ctx->cancellation_manager();
265 opts.runner = ctx->runner();
266 inputs.reserve(ctx->num_inputs());
267 for (int i = 0; i < ctx->num_inputs(); i++) {
268 inputs.push_back(ctx->input(i));
269 }
270 helper->Ref(); // Increment count for calculating native graph
271 VLOG(1) << "Executing native segment: " << name();
272 lib->Run(opts, native_func_, inputs, outputs,
273 [this, ctx, outputs, helper](const Status& s) {
274 core::ScopedUnref sc(helper);
275 if (!s.ok()) {
276 LOG(ERROR) << "Failed to execute native segment " << this->name()
277 << ": " << s;
278 ctx->SetStatus(s);
279 return;
280 }
281 VLOG(1) << "Native Segment completed";
282 for (size_t t = 0; t < outputs->size(); ++t) {
283 ctx->set_output(t, outputs->at(t));
284 }
285 delete outputs;
286 });
287 }
288
ExecuteCalibration(OpKernelContext * ctx,AsyncHelper * helper)289 void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
290 AsyncHelper* helper) {
291 VLOG(1) << "Executing TRT calibration: " << name();
292 helper->Ref();
293 core::ScopedUnref sc(helper);
294 auto res_mgr = ctx->resource_manager();
295 TRTCalibrationResource* calib_res = nullptr;
296 OP_REQUIRES_OK(ctx,
297 res_mgr->LookupOrCreate(
298 "TF_TRT_Calibration", name(),
299 reinterpret_cast<SerializableResourceBase**>(&calib_res),
300 {[ctx, this](SerializableResourceBase** cr) -> Status {
301 return this->AllocateCalibrationResources(ctx, cr);
302 }}));
303 core::ScopedUnref calib_sc(calib_res);
304 int num_inputs = ctx->num_inputs();
305 // Pass input data to calibrator
306 std::unordered_map<string, void*> input_data;
307 for (int i = 0; i < num_inputs; i++) {
308 const Tensor& t = ctx->input(i);
309 void* data_address = GetTensorAddress(&t);
310 if (data_address == nullptr) {
311 ctx->SetStatus(errors::InvalidArgument(
312 "Unsupported data type encountered in input ", i));
313 return;
314 }
315 // Check the allocated buffer is sufficient for input
316 const auto device_tensor =
317 calib_res->device_tensors_.at(i).AccessTensor(ctx);
318 CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
319 input_data.emplace(StrCat(kInputPHName, i), data_address);
320 }
321 VLOG(2) << "Filled map for sending";
322 // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
323 const cudaStream_t* stream = CHECK_NOTNULL(
324 reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
325 ->stream()
326 ->implementation()
327 ->GpuStreamMemberHack()));
328 calib_res->calibrator_->setBatch(input_data, *stream);
329 VLOG(2) << "Passed calibration data";
330 ExecuteNativeSegment(ctx, helper);
331 }
332
GetCompatibleCachedEngine(const std::vector<TensorShape> & actual_input_shapes,std::vector<TensorShape> * engine_input_shapes)333 bool TRTEngineOp::GetCompatibleCachedEngine(
334 const std::vector<TensorShape>& actual_input_shapes,
335 std::vector<TensorShape>* engine_input_shapes) {
336 const int batch_size = actual_input_shapes[0].dim_size(0);
337 int smallest_batch_size = -1;
338 // Output shape will always be the same as the input but we will overwrite the
339 // batch size.
340 *engine_input_shapes = actual_input_shapes;
341 for (const int cached_batch_size : cached_engine_batches_) {
342 // Check if compatible: batch <= cached batch.
343 //
344 // TODO(laigd): here it only compare the first dim a.k.a the batch size,
345 // we'll need to to support non-batch dimensions as well. This will be done
346 // as part of the offline conversion implementation.
347 if (batch_size <= cached_batch_size) {
348 // First case: first compatible engine found
349 // Second case: smaller batch size engine found
350 if ((smallest_batch_size == -1) ||
351 (cached_batch_size < smallest_batch_size)) {
352 smallest_batch_size = cached_batch_size;
353 // Overwrite batch size for output
354 for (int i = 0; i < engine_input_shapes->size(); i++) {
355 (*engine_input_shapes)[i].set_dim(0, smallest_batch_size);
356 }
357 }
358 }
359 }
360 return (smallest_batch_size != -1);
361 }
362
ComputeAsync(OpKernelContext * ctx,AsyncOpKernel::DoneCallback done)363 void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
364 AsyncOpKernel::DoneCallback done) {
365 auto helper = new AsyncHelper(done);
366 core::ScopedUnref sc(helper);
367 if (calibration_mode_) {
368 ExecuteCalibration(ctx, helper);
369 return;
370 }
371 // Get shapes of inputs to engine.
372 std::vector<TensorShape> input_shapes;
373 input_shapes.reserve(ctx->num_inputs());
374 for (int i = 0; i < ctx->num_inputs(); ++i) {
375 input_shapes.push_back(ctx->input(i).shape());
376 }
377 EngineContext* engine_context = GetEngine(input_shapes, ctx);
378 if (!engine_context->cuda_engine) {
379 VLOG(1) << "Engine retrieval for input shapes: "
380 << TensorShapeUtils::ShapeListString(input_shapes)
381 << " failed. Running native segment for " << name();
382 ExecuteNativeSegment(ctx, helper);
383 return;
384 }
385 const bool retry = ExecuteTrtEngine(ctx, engine_context);
386 if (retry) {
387 LOG(WARNING) << "Failed to execute engine, "
388 << "retrying with native segment for " << name();
389 ExecuteNativeSegment(ctx, helper);
390 return;
391 }
392 }
393
ExecuteTrtEngine(OpKernelContext * ctx,EngineContext * engine_context)394 bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
395 EngineContext* engine_context) {
396 VLOG(1) << "Executing TRT engine: " << name();
397 auto& cuda_engine = engine_context->cuda_engine;
398 const bool kRetry = true;
399 // All inputs must have the same batch size, so just get it from the first
400 // input.
401 const int num_batch = ctx->input(0).shape().dim_size(0);
402 const int num_binding = ctx->num_inputs() + ctx->num_outputs();
403 std::vector<void*> buffers(num_binding);
404 for (int i = 0; i < ctx->num_inputs(); i++) {
405 const string input_name = StrCat(kInputPHName, i);
406 const int binding_index = cuda_engine->getBindingIndex(input_name.c_str());
407 if (binding_index == -1) {
408 const string msg =
409 StrCat("Input node ", input_name, " not found, at ", name());
410 LOG(ERROR) << msg;
411 ctx->SetStatus(errors::NotFound(msg));
412 return !kRetry;
413 }
414
415 const Tensor& input_tensor = ctx->input(i);
416 const TensorShape& input_shape = input_tensor.shape();
417 if (num_batch != input_shape.dim_size(0)) {
418 LOG(ERROR) << "Input data has inconsistent batch size: " << num_batch
419 << " vs " << input_shape.dim_size(0);
420 return kRetry;
421 }
422 auto dtype = cuda_engine->getBindingDataType(binding_index);
423 switch (dtype) {
424 case nvinfer1::DataType::kFLOAT:
425 buffers[binding_index] =
426 const_cast<float*>(input_tensor.flat<float>().data());
427 break;
428 case nvinfer1::DataType::kHALF:
429 LOG(ERROR) << "FP16 inputs are not supported yet!";
430 return kRetry;
431 case nvinfer1::DataType::kINT8:
432 LOG(ERROR) << "INT8 inputs are not supported yet!";
433 return kRetry;
434 case nvinfer1::DataType::kINT32:
435 buffers[binding_index] =
436 const_cast<int32*>(input_tensor.flat<int32>().data());
437 break;
438 default:
439 LOG(ERROR) << "Unknown TRT data type: " << static_cast<int>(dtype);
440 return kRetry;
441 }
442 }
443
444 for (int i = 0; i < ctx->num_outputs(); i++) {
445 // Create an output tensor
446 const string output_name = StrCat(kOutputPHName, i);
447 const int binding_index = cuda_engine->getBindingIndex(output_name.c_str());
448 Tensor* output_tensor = nullptr;
449
450 TensorShape output_shape;
451 if (binding_index != -1) {
452 auto dims = cuda_engine->getBindingDimensions(binding_index);
453 std::vector<int> trt_shape(dims.nbDims + 1);
454 trt_shape[0] = num_batch;
455 for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j];
456 auto status = TensorShapeUtils::MakeShape(
457 trt_shape.data(), trt_shape.size(), &output_shape);
458 if (!status.ok()) {
459 LOG(ERROR) << "Failed to get output shape: " << status;
460 return kRetry;
461 }
462 } else {
463 const string msg =
464 StrCat("Ouput node ", output_name, " not found, at ", name());
465 LOG(ERROR) << msg;
466 ctx->SetStatus(errors::NotFound(msg));
467 return !kRetry;
468 }
469 auto status = ctx->allocate_output(i, output_shape, &output_tensor);
470 if (!status.ok()) {
471 LOG(ERROR) << "Allocating output failed with " << status;
472 ctx->SetStatus(status);
473 // Do not retry since we cannot allocate the same output twice.
474 // TODO(aaroey): ideally we should retry, fix this.
475 return !kRetry;
476 }
477 auto dtype = cuda_engine->getBindingDataType(binding_index);
478 switch (dtype) {
479 case nvinfer1::DataType::kFLOAT:
480 buffers[binding_index] =
481 const_cast<float*>(output_tensor->flat<float>().data());
482 break;
483 case nvinfer1::DataType::kHALF:
484 LOG(WARNING) << "half size is not supported yet!";
485 return kRetry;
486 case nvinfer1::DataType::kINT8:
487 LOG(WARNING) << "int8 is not supported yet!";
488 return kRetry;
489 case nvinfer1::DataType::kINT32:
490 buffers[binding_index] =
491 const_cast<int32*>(output_tensor->flat<int32>().data());
492 break;
493 default:
494 LOG(WARNING) << "Unknown TRT data type: " << static_cast<int>(dtype);
495 return kRetry;
496 }
497 }
498 // Copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
499 const cudaStream_t* stream = CHECK_NOTNULL(
500 reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
501 ->stream()
502 ->implementation()
503 ->GpuStreamMemberHack()));
504
505 // nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex
506 // for it.
507 mutex_lock lock(engine_context->mu);
508 // TODO(jie): trt enqueue does not return error
509 auto ret = engine_context->execution_context->enqueue(num_batch, &buffers[0],
510 *stream, nullptr);
511 if (!ret) {
512 LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name();
513 return kRetry;
514 }
515 // Synchronization will be done by TF.
516 return !kRetry;
517 }
518
GetEngine(const std::vector<TensorShape> & input_shapes,OpKernelContext * ctx)519 EngineContext* TRTEngineOp::GetEngine(
520 const std::vector<TensorShape>& input_shapes, OpKernelContext* ctx) {
521 static EngineContext empty_context;
522 mutex_lock lock(engine_mutex_);
523 // TODO(tmorris): using first input to get batch size - is this reliable?
524 const int batch_size = input_shapes[0].dim_size(0);
525
526 // Get engine cache
527 TRTEngineCacheResource* cache_res = nullptr;
528 auto status = ctx->resource_manager()->LookupOrCreate(
529 "TRTEngineCache", name(), &cache_res,
530 {[this, ctx](TRTEngineCacheResource** cr) -> Status {
531 *cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_);
532 return Status::OK();
533 }});
534 if (!status.ok()) {
535 ctx->SetStatus(status);
536 return &empty_context;
537 }
538 core::ScopedUnref sc(cache_res);
539 auto& cache = cache_res->cache_;
540 auto allocator = cache_res->allocator_.get();
541 if (allocator == nullptr) {
542 return &empty_context;
543 }
544
545 // Handle the static engine case. For static engines, the cache will have a
546 // single element containing the only engine.
547 if (static_engine_) {
548 if (cache.size()) {
549 // Batch size of engine must be >= the input batch size
550 // TODO(tmorris): use match compatible function?
551 if (cache.begin()->first[0].dim_size(0) >= batch_size) {
552 return cache.begin()->second.get();
553 }
554 return &empty_context;
555 }
556
557 TrtUniquePtrType<IRuntime> infer(nvinfer1::createInferRuntime(logger));
558 infer->setGpuAllocator(allocator);
559 TrtUniquePtrType<nvinfer1::ICudaEngine> static_engine(
560 infer->deserializeCudaEngine(serialized_segment_.c_str(),
561 serialized_segment_.size(),
562 PluginFactoryTensorRT::GetInstance()));
563 auto raw_static_engine = static_engine.get();
564 const auto max_batch_size = raw_static_engine->getMaxBatchSize();
565 // Static engine will have max_batch_size for batch size so that all inputs
566 // will map to this single engine.
567 std::vector<TensorShape> engine_input_shapes(input_shapes);
568 for (int i = 0; i < engine_input_shapes.size(); i++) {
569 // TODO(tmorris): will all inputs have batch size as first dimension??
570 engine_input_shapes[i].set_dim(0, max_batch_size);
571 }
572 // TODO(laigd): here we assume engine_input_shapes matches the actual input
573 // shapes of the engine, we should verify that.
574 cache.emplace(engine_input_shapes,
575 absl::make_unique<EngineContext>(
576 std::move(static_engine),
577 TrtUniquePtrType<nvinfer1::IExecutionContext>(
578 raw_static_engine->createExecutionContext())));
579 // Runtime is safe to delete after engine creation
580 VLOG(1) << "Size of serialized TRT engine: "
581 << serialized_segment_.capacity();
582 string tmp;
583 // Swap with temporary empty string to deallocate the CPU memory.
584 serialized_segment_.swap(tmp);
585 if (max_batch_size < batch_size) {
586 return &empty_context;
587 }
588 return cache.at(engine_input_shapes).get();
589 } // static_engine_
590
591 // Handle the dynamic engine case.
592 // See if there is a compatible engine cached. The batch size should be <= the
593 // cached batch size.
594 std::vector<TensorShape> engine_input_shapes;
595 const bool matched_successfully =
596 GetCompatibleCachedEngine(input_shapes, &engine_input_shapes);
597 // If matched, use that engine. Otherwise, we will look in cache for that
598 // exact shape and possibly create a new engine if it is not in cache.
599 if (!matched_successfully) {
600 engine_input_shapes = input_shapes;
601 if (!cached_engine_batches_.empty()) {
602 // If user has explicitly defined cached_engine_batches, we should
603 // warn them that their input was non-compatible (batch size too high)
604 LOG(WARNING) << "No compatible cached engine was found for batch size: "
605 << batch_size << ". A new engine will be created.";
606 cached_engine_batches_.push_back(batch_size);
607 }
608 }
609
610 if (!cache.count(engine_input_shapes)) {
611 TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
612 bool convert_successfully = false;
613 LOG(INFO) << "Building a new TensorRT engine for " << name()
614 << " input shapes: "
615 << TensorShapeUtils::ShapeListString(engine_input_shapes);
616
617 // Convert to partial shapes
618 std::vector<PartialTensorShape> partial_shapes(engine_input_shapes.begin(),
619 engine_input_shapes.end());
620
621 // Up to this point, calibrator_ can never be empty, since otherwise it
622 // means calibration_mode_ is true and this path won't get executed.
623 auto status = convert::ConvertGraphDefToEngine(
624 segment_graph_, precision_mode_, batch_size, workspace_size_,
625 partial_shapes, &logger, allocator, calibrator_.get(), &engine,
626 use_calibration_, &convert_successfully);
627 if (!status.ok()) {
628 LOG(WARNING) << "Engine creation for " << name() << " failed. "
629 << "The native segment will be used instead. "
630 << "Reason: " << status;
631 // Store an empty engine in the cache for these input shapes so we don't
632 // try to build the same failing engine again.
633 cache.emplace(engine_input_shapes, absl::make_unique<EngineContext>());
634 return &empty_context;
635 }
636 VLOG(1) << "Conversion is done";
637 TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
638 engine->createExecutionContext());
639 cache.emplace(engine_input_shapes,
640 absl::make_unique<EngineContext>(std::move(engine),
641 std::move(exec_context)));
642 }
643 return cache.at(engine_input_shapes).get();
644 }
645
AllocateCalibrationResources(OpKernelContext * ctx,SerializableResourceBase ** cr)646 Status TRTEngineOp::AllocateCalibrationResources(
647 OpKernelContext* ctx, SerializableResourceBase** cr) {
648 auto cres = new TRTCalibrationResource();
649 *cr = cres;
650 // Get the allocator.
651 auto alloc = ctx->device()->GetAllocator(AllocatorAttributes());
652 if (!alloc) {
653 LOG(WARNING) << "Can't get device allocator will not be able to "
654 "allocate memory from TensorFlow memory pool";
655 cres->allocator_.reset(new TRTCudaAllocator);
656 } else {
657 cres->allocator_.reset(new TRTDeviceAllocator(alloc));
658 }
659 // Get the input shapes.
660 const int batch_size = ctx->input(0).dim_size(0);
661 const int num_inputs = ctx->num_inputs();
662 std::vector<PartialTensorShape> shapes;
663 cres->device_tensors_.resize(num_inputs);
664 VLOG(1) << " Constructing calibrator";
665 for (int i = 0; i < num_inputs; i++) {
666 // allocate workspace on device for inputs
667 const Tensor& t = ctx->input(i);
668 shapes.emplace_back(t.shape());
669 Tensor* device_tensor;
670 TF_RETURN_IF_ERROR(ctx->allocate_persistent(
671 t.dtype(), t.shape(), &cres->device_tensors_.at(i), &device_tensor));
672 CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
673 void* device_address = GetTensorAddress(device_tensor);
674 if (device_address == nullptr) {
675 return errors::InvalidArgument(
676 "Unsupported data type encountered in input ", i);
677 }
678 cres->device_buffers_.emplace(
679 StrCat(kInputPHName, i),
680 std::pair<void*, size_t>(device_address, device_tensor->TotalBytes()));
681 }
682 cres->calibrator_.reset(
683 new TRTInt8Calibrator(cres->device_buffers_, batch_size, name()));
684 const string label(name());
685 auto segment_graph = &segment_graph_;
686 const int platform_gpu_id =
687 ctx->device()->tensorflow_gpu_device_info()->gpu_id;
688 if (platform_gpu_id < 0) {
689 LOG(ERROR) << "Can't get gpu_device_info from context->device()";
690 return errors::InvalidArgument(
691 "Context->device doesn't contain device info!");
692 }
693 const int64 workspace_size_bytes = workspace_size_;
694 cres->thr_.reset(new std::thread([cres, label, segment_graph, shapes,
695 platform_gpu_id, workspace_size_bytes]() {
696 LOG(INFO) << "Starting calibration thread on device " << platform_gpu_id
697 << ", Calibration Resource @ " << cres;
698 auto err = cudaSetDevice(platform_gpu_id);
699 if (err != cudaSuccess) {
700 // TODO(aaroey): should return error here.
701 LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id
702 << " in calibration thread";
703 }
704 // ConvertGraphDefToEngine() will try to build the engine. This thread
705 // will loop inside buildCudaEngine() consuming the calibration data
706 // that is set by the TF op, and drive the builder until calibrator returns
707 // false. Engine is discarded after calibration table is generated
708 //
709 // TODO(aaroey): maybe setting the max batch size using the python
710 // calibration wrapper class.
711 auto s = convert::ConvertGraphDefToEngine(
712 *segment_graph, TrtPrecisionMode::INT8,
713 cres->calibrator_->getBatchSize(), workspace_size_bytes, shapes,
714 &cres->logger_, cres->allocator_.get(), cres->calibrator_.get(),
715 &cres->engine_,
716 /*use_calibration=*/true,
717 /*convert_successfully=*/nullptr);
718 if (!s.ok()) {
719 LOG(ERROR) << "Calibration failed: " << s;
720 cres->calibrator_->setDone(); // Ignore further pushes
721 }
722 VLOG(1) << "Calibration loop terminated " << label;
723 }));
724 VLOG(1) << "initialized calibrator resource";
725 return Status::OK();
726 }
727
728 REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
729
730 } // namespace tensorrt
731 } // namespace tensorflow
732
733 #endif // GOOGLE_TENSORRT
734 #endif // GOOGLE_CUDA
735