• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <vector>
18 
19 #include "absl/memory/memory.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
22 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
23 #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
24 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
25 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
26 #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
27 #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/graph_to_functiondef.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/lib/core/refcount.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/platform/stream_executor.h"
38 #include "tensorflow/core/platform/thread_annotations.h"
39 #include "tensorflow/core/platform/types.h"
40 
41 #if GOOGLE_CUDA
42 #if GOOGLE_TENSORRT
43 #include "cuda/include/cuda_runtime_api.h"
44 #include "tensorrt/include/NvInfer.h"
45 
46 namespace tensorflow {
47 namespace tensorrt {
48 static Logger logger;
49 using absl::StrAppend;
50 using absl::StrCat;
51 using ::nvinfer1::IRuntime;
52 
53 // A helper class to call done() when destructed for asynchronous execution.
54 // Helps simultaneous execution of native and TRT engines.
55 class AsyncHelper : public core::RefCounted {
56  public:
AsyncHelper(AsyncOpKernel::DoneCallback done)57   AsyncHelper(AsyncOpKernel::DoneCallback done) { done_ = done; }
~AsyncHelper()58   ~AsyncHelper() override { done_(); }
59 
60  private:
61   AsyncOpKernel::DoneCallback done_;
62 };
63 
64 //  This OP can construct TRTEngine on the fly and if construction of engine
65 //  fails, executes equivalent subgraph as a TensorFlow function.
66 class TRTEngineOp : public AsyncOpKernel {
67  public:
68   explicit TRTEngineOp(OpKernelConstruction* context);
69 
70   void ComputeAsync(OpKernelContext* context,
71                     AsyncOpKernel::DoneCallback done) override;
72 
73  private:
74   // Execute calibration
75   void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper);
76 
77   // Construct a function handle for executing native funcdef graph
78   Status ConstructFunctionHandle(OpKernelContext* ctx);
79 
80   // Execute replaced native segment as function Op.
81   void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
82 
83   // Execute the tensorrt engine. Returns whether we need to retry by running
84   // the native segment.
85   bool ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context);
86 
87   // Allocate necessary resources for calibration
88   Status AllocateCalibrationResources(OpKernelContext* ctx,
89                                       SerializableResourceBase** cr);
90 
91   // Get engine for the input shape
92   EngineContext* GetEngine(const std::vector<TensorShape>& input_shapes,
93                            OpKernelContext* ctx);
94 
95   // Return engine batch in cached_engne_batch_sizes_ which is closest to input
96   // batch.
97   bool GetCompatibleCachedEngine(
98       const std::vector<TensorShape>& actual_input_shapes,
99       std::vector<TensorShape>* engine_input_shapes);
100 
101   std::vector<string> input_nodes_;
102   std::vector<string> output_nodes_;
103 
104   // serialized protobuf segment or trt engine depending on static_engine_ flag.
105   string serialized_segment_;
106 
107   // Name of the function for TF native execution of the segment. If empty, it
108   // means TF native execution is not allowed, and if TRT engine fails to run
109   // an error will be returned.
110   string funcdef_name_;
111 
112   // GraphDef representation of the segment.
113   GraphDef segment_graph_;
114 
115   // Engine Precision mode.
116   TrtPrecisionMode precision_mode_;
117 
118   // Whether engine is constructed during the conversion or needs to be
119   // constructed from protobuf segment.
120   bool static_engine_;
121 
122   // Whether to calibrate INT8 engine.
123   bool calibration_mode_;
124 
125   // Batches of the cached engines
126   std::vector<int> cached_engine_batches_;
127 
128   // Maximum number of cached engines
129   int max_cached_engines_;
130 
131   int64 workspace_size_;
132   mutex engine_mutex_;
133   FunctionLibraryRuntime::Handle native_func_;
134 
135   // The finalized calibrator for inference.
136   std::unique_ptr<TRTInt8Calibrator> calibrator_;
137 
138   // If true, create calibration graph for INT8 mode. Otherwise, we are using
139   // user-provided quantization ranges.
140   bool use_calibration_;
141 };
142 
143 #define TYPECASE(dt, X, Y)                                    \
144   case dt: {                                                  \
145     return (void*)X->flat<EnumToDataType<dt>::Type>().data(); \
146   }
147 
GetTensorAddress(const Tensor * tensor_ptr)148 void* GetTensorAddress(const Tensor* tensor_ptr) {
149   auto tensor_type = tensor_ptr->dtype();
150   switch (tensor_type) {
151     TYPECASE(DT_FLOAT, tensor_ptr, dest_ptr);
152     TYPECASE(DT_HALF, tensor_ptr, dest_ptr);
153     TYPECASE(DT_INT8, tensor_ptr, dest_ptr);
154     default: {
155       LOG(ERROR) << "Unsupported Data type " << DataTypeString(tensor_type);
156       return nullptr;
157     }
158   }
159 }
160 
ConstructFunctionHandle(OpKernelContext * ctx)161 Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) {
162   VLOG(1) << "Constructing function handle";
163   auto lib = ctx->function_library();
164   if (lib == nullptr) {
165     return errors::Internal("Context function library is null");
166   }
167   auto fdef = lib->GetFunctionLibraryDefinition()->Find(funcdef_name_);
168   if (fdef == nullptr) {
169     return errors::Internal("Native FunctionDef ", funcdef_name_,
170                             " can't be found in function library");
171   }
172   FunctionLibraryRuntime::InstantiateOptions inst_ops;
173   inst_ops.overlay_lib = nullptr;
174   inst_ops.state_handle = "";
175   inst_ops.target = ctx->device()->name();
176   native_func_ = 0;
177   auto status = lib->Instantiate(funcdef_name_, AttrSlice(&fdef->attr()),
178                                  inst_ops, &native_func_);
179   if (!status.ok()) {
180     LOG(ERROR) << " Instantiating native function " << funcdef_name_
181                << " failed!";
182   }
183   return status;
184 }
185 
TRTEngineOp(OpKernelConstruction * context)186 TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
187     : AsyncOpKernel(context) {
188   // read serialized_engine
189   OP_REQUIRES_OK(context,
190                  context->GetAttr("serialized_segment", &serialized_segment_));
191   OP_REQUIRES_OK(context,
192                  context->GetAttr("workspace_size_bytes", &workspace_size_));
193   OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
194   if (!static_engine_) {
195     if (!segment_graph_.ParseFromString(serialized_segment_)) {
196       LOG(ERROR) << "Parsing segment graph failed!";
197       context->SetStatus(
198           errors::InvalidArgument("Failed to parse segment graphdef!"));
199       return;
200     }
201     VLOG(1) << "Size of serialized GraphDef: "
202             << serialized_segment_.capacity();
203     string tmp;
204     // Swap with temporary empty string to deallocate the CPU memory.
205     serialized_segment_.swap(tmp);
206   }
207   VLOG(1) << "Constructing " << name();
208   string precision_string;
209   OP_REQUIRES_OK(context,
210                  context->GetAttr("precision_mode", &precision_string));
211   string calibration_data;
212   OP_REQUIRES_OK(context,
213                  context->GetAttr("calibration_data", &calibration_data));
214   OP_REQUIRES_OK(context,
215                  context->GetAttr("segment_funcdef_name", &funcdef_name_));
216   OP_REQUIRES_OK(context,
217                  TrtPrecisionModeFromName(precision_string, &precision_mode_));
218   OP_REQUIRES_OK(context,
219                  context->GetAttr("use_calibration", &use_calibration_));
220   calibration_mode_ =
221       (use_calibration_ && precision_mode_ == TrtPrecisionMode::INT8 &&
222        calibration_data.empty());
223   if (!calibration_data.empty()) {
224     calibrator_.reset(new TRTInt8Calibrator(calibration_data));
225     calibration_data.resize(0);
226   }
227   native_func_ = kInvalidHandle;
228   OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
229                                            &max_cached_engines_));
230   OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches",
231                                            &cached_engine_batches_));
232   std::sort(cached_engine_batches_.begin(), cached_engine_batches_.end());
233   if (VLOG_IS_ON(1)) {
234     string s("Engine Batches= ");
235     for (auto i : cached_engine_batches_) {
236       StrAppend(&s, i, " ");
237     }
238     VLOG(1) << s;
239   }
240 }
241 
ExecuteNativeSegment(OpKernelContext * ctx,AsyncHelper * helper)242 void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
243                                        AsyncHelper* helper) {
244   if (funcdef_name_.empty()) {
245     const string err_msg = StrCat("Fallback path is disabled, for ", name());
246     LOG(WARNING) << err_msg;
247     ctx->SetStatus(errors::Internal(err_msg));
248     return;
249   }
250   std::vector<Tensor> inputs;
251   std::vector<Tensor>* outputs = new std::vector<Tensor>();
252   if (native_func_ == kInvalidHandle) {
253     auto status = ConstructFunctionHandle(ctx);
254     if (!status.ok()) {
255       LOG(ERROR) << "Couldn't construct function handle " << funcdef_name_;
256       ctx->SetStatus(status);
257       return;
258     }
259   }
260   auto lib = ctx->function_library();
261   FunctionLibraryRuntime::Options opts;
262   opts.step_id = ctx->step_id();
263   opts.rendezvous = ctx->rendezvous();
264   opts.cancellation_manager = ctx->cancellation_manager();
265   opts.runner = ctx->runner();
266   inputs.reserve(ctx->num_inputs());
267   for (int i = 0; i < ctx->num_inputs(); i++) {
268     inputs.push_back(ctx->input(i));
269   }
270   helper->Ref();  // Increment count for calculating native graph
271   VLOG(1) << "Executing native segment: " << name();
272   lib->Run(opts, native_func_, inputs, outputs,
273            [this, ctx, outputs, helper](const Status& s) {
274              core::ScopedUnref sc(helper);
275              if (!s.ok()) {
276                LOG(ERROR) << "Failed to execute native segment " << this->name()
277                           << ": " << s;
278                ctx->SetStatus(s);
279                return;
280              }
281              VLOG(1) << "Native Segment completed";
282              for (size_t t = 0; t < outputs->size(); ++t) {
283                ctx->set_output(t, outputs->at(t));
284              }
285              delete outputs;
286            });
287 }
288 
ExecuteCalibration(OpKernelContext * ctx,AsyncHelper * helper)289 void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
290                                      AsyncHelper* helper) {
291   VLOG(1) << "Executing TRT calibration: " << name();
292   helper->Ref();
293   core::ScopedUnref sc(helper);
294   auto res_mgr = ctx->resource_manager();
295   TRTCalibrationResource* calib_res = nullptr;
296   OP_REQUIRES_OK(ctx,
297                  res_mgr->LookupOrCreate(
298                      "TF_TRT_Calibration", name(),
299                      reinterpret_cast<SerializableResourceBase**>(&calib_res),
300                      {[ctx, this](SerializableResourceBase** cr) -> Status {
301                        return this->AllocateCalibrationResources(ctx, cr);
302                      }}));
303   core::ScopedUnref calib_sc(calib_res);
304   int num_inputs = ctx->num_inputs();
305   // Pass input data to calibrator
306   std::unordered_map<string, void*> input_data;
307   for (int i = 0; i < num_inputs; i++) {
308     const Tensor& t = ctx->input(i);
309     void* data_address = GetTensorAddress(&t);
310     if (data_address == nullptr) {
311       ctx->SetStatus(errors::InvalidArgument(
312           "Unsupported data type encountered in input ", i));
313       return;
314     }
315     // Check the allocated buffer is sufficient for input
316     const auto device_tensor =
317         calib_res->device_tensors_.at(i).AccessTensor(ctx);
318     CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
319     input_data.emplace(StrCat(kInputPHName, i), data_address);
320   }
321   VLOG(2) << "Filled map for sending";
322   // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
323   const cudaStream_t* stream = CHECK_NOTNULL(
324       reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
325                                                 ->stream()
326                                                 ->implementation()
327                                                 ->GpuStreamMemberHack()));
328   calib_res->calibrator_->setBatch(input_data, *stream);
329   VLOG(2) << "Passed calibration data";
330   ExecuteNativeSegment(ctx, helper);
331 }
332 
GetCompatibleCachedEngine(const std::vector<TensorShape> & actual_input_shapes,std::vector<TensorShape> * engine_input_shapes)333 bool TRTEngineOp::GetCompatibleCachedEngine(
334     const std::vector<TensorShape>& actual_input_shapes,
335     std::vector<TensorShape>* engine_input_shapes) {
336   const int batch_size = actual_input_shapes[0].dim_size(0);
337   int smallest_batch_size = -1;
338   // Output shape will always be the same as the input but we will overwrite the
339   // batch size.
340   *engine_input_shapes = actual_input_shapes;
341   for (const int cached_batch_size : cached_engine_batches_) {
342     // Check if compatible: batch <= cached batch.
343     //
344     // TODO(laigd): here it only compare the first dim a.k.a the batch size,
345     // we'll need to to support non-batch dimensions as well. This will be done
346     // as part of the offline conversion implementation.
347     if (batch_size <= cached_batch_size) {
348       // First case: first compatible engine found
349       // Second case: smaller batch size engine found
350       if ((smallest_batch_size == -1) ||
351           (cached_batch_size < smallest_batch_size)) {
352         smallest_batch_size = cached_batch_size;
353         // Overwrite batch size for output
354         for (int i = 0; i < engine_input_shapes->size(); i++) {
355           (*engine_input_shapes)[i].set_dim(0, smallest_batch_size);
356         }
357       }
358     }
359   }
360   return (smallest_batch_size != -1);
361 }
362 
ComputeAsync(OpKernelContext * ctx,AsyncOpKernel::DoneCallback done)363 void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
364                                AsyncOpKernel::DoneCallback done) {
365   auto helper = new AsyncHelper(done);
366   core::ScopedUnref sc(helper);
367   if (calibration_mode_) {
368     ExecuteCalibration(ctx, helper);
369     return;
370   }
371   // Get shapes of inputs to engine.
372   std::vector<TensorShape> input_shapes;
373   input_shapes.reserve(ctx->num_inputs());
374   for (int i = 0; i < ctx->num_inputs(); ++i) {
375     input_shapes.push_back(ctx->input(i).shape());
376   }
377   EngineContext* engine_context = GetEngine(input_shapes, ctx);
378   if (!engine_context->cuda_engine) {
379     VLOG(1) << "Engine retrieval for input shapes: "
380             << TensorShapeUtils::ShapeListString(input_shapes)
381             << " failed. Running native segment for " << name();
382     ExecuteNativeSegment(ctx, helper);
383     return;
384   }
385   const bool retry = ExecuteTrtEngine(ctx, engine_context);
386   if (retry) {
387     LOG(WARNING) << "Failed to execute engine, "
388                  << "retrying with native segment for " << name();
389     ExecuteNativeSegment(ctx, helper);
390     return;
391   }
392 }
393 
ExecuteTrtEngine(OpKernelContext * ctx,EngineContext * engine_context)394 bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
395                                    EngineContext* engine_context) {
396   VLOG(1) << "Executing TRT engine: " << name();
397   auto& cuda_engine = engine_context->cuda_engine;
398   const bool kRetry = true;
399   // All inputs must have the same batch size, so just get it from the first
400   // input.
401   const int num_batch = ctx->input(0).shape().dim_size(0);
402   const int num_binding = ctx->num_inputs() + ctx->num_outputs();
403   std::vector<void*> buffers(num_binding);
404   for (int i = 0; i < ctx->num_inputs(); i++) {
405     const string input_name = StrCat(kInputPHName, i);
406     const int binding_index = cuda_engine->getBindingIndex(input_name.c_str());
407     if (binding_index == -1) {
408       const string msg =
409           StrCat("Input node ", input_name, " not found, at ", name());
410       LOG(ERROR) << msg;
411       ctx->SetStatus(errors::NotFound(msg));
412       return !kRetry;
413     }
414 
415     const Tensor& input_tensor = ctx->input(i);
416     const TensorShape& input_shape = input_tensor.shape();
417     if (num_batch != input_shape.dim_size(0)) {
418       LOG(ERROR) << "Input data has inconsistent batch size: " << num_batch
419                  << " vs " << input_shape.dim_size(0);
420       return kRetry;
421     }
422     auto dtype = cuda_engine->getBindingDataType(binding_index);
423     switch (dtype) {
424       case nvinfer1::DataType::kFLOAT:
425         buffers[binding_index] =
426             const_cast<float*>(input_tensor.flat<float>().data());
427         break;
428       case nvinfer1::DataType::kHALF:
429         LOG(ERROR) << "FP16 inputs are not supported yet!";
430         return kRetry;
431       case nvinfer1::DataType::kINT8:
432         LOG(ERROR) << "INT8 inputs are not supported yet!";
433         return kRetry;
434       case nvinfer1::DataType::kINT32:
435         buffers[binding_index] =
436             const_cast<int32*>(input_tensor.flat<int32>().data());
437         break;
438       default:
439         LOG(ERROR) << "Unknown TRT data type: " << static_cast<int>(dtype);
440         return kRetry;
441     }
442   }
443 
444   for (int i = 0; i < ctx->num_outputs(); i++) {
445     // Create an output tensor
446     const string output_name = StrCat(kOutputPHName, i);
447     const int binding_index = cuda_engine->getBindingIndex(output_name.c_str());
448     Tensor* output_tensor = nullptr;
449 
450     TensorShape output_shape;
451     if (binding_index != -1) {
452       auto dims = cuda_engine->getBindingDimensions(binding_index);
453       std::vector<int> trt_shape(dims.nbDims + 1);
454       trt_shape[0] = num_batch;
455       for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j];
456       auto status = TensorShapeUtils::MakeShape(
457           trt_shape.data(), trt_shape.size(), &output_shape);
458       if (!status.ok()) {
459         LOG(ERROR) << "Failed to get output shape: " << status;
460         return kRetry;
461       }
462     } else {
463       const string msg =
464           StrCat("Ouput node ", output_name, " not found, at ", name());
465       LOG(ERROR) << msg;
466       ctx->SetStatus(errors::NotFound(msg));
467       return !kRetry;
468     }
469     auto status = ctx->allocate_output(i, output_shape, &output_tensor);
470     if (!status.ok()) {
471       LOG(ERROR) << "Allocating output failed with " << status;
472       ctx->SetStatus(status);
473       // Do not retry since we cannot allocate the same output twice.
474       // TODO(aaroey): ideally we should retry, fix this.
475       return !kRetry;
476     }
477     auto dtype = cuda_engine->getBindingDataType(binding_index);
478     switch (dtype) {
479       case nvinfer1::DataType::kFLOAT:
480         buffers[binding_index] =
481             const_cast<float*>(output_tensor->flat<float>().data());
482         break;
483       case nvinfer1::DataType::kHALF:
484         LOG(WARNING) << "half size is not supported yet!";
485         return kRetry;
486       case nvinfer1::DataType::kINT8:
487         LOG(WARNING) << "int8 is not supported yet!";
488         return kRetry;
489       case nvinfer1::DataType::kINT32:
490         buffers[binding_index] =
491             const_cast<int32*>(output_tensor->flat<int32>().data());
492         break;
493       default:
494         LOG(WARNING) << "Unknown TRT data type: " << static_cast<int>(dtype);
495         return kRetry;
496     }
497   }
498   // Copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
499   const cudaStream_t* stream = CHECK_NOTNULL(
500       reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
501                                                 ->stream()
502                                                 ->implementation()
503                                                 ->GpuStreamMemberHack()));
504 
505   // nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex
506   // for it.
507   mutex_lock lock(engine_context->mu);
508   // TODO(jie): trt enqueue does not return error
509   auto ret = engine_context->execution_context->enqueue(num_batch, &buffers[0],
510                                                         *stream, nullptr);
511   if (!ret) {
512     LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name();
513     return kRetry;
514   }
515   // Synchronization will be done by TF.
516   return !kRetry;
517 }
518 
GetEngine(const std::vector<TensorShape> & input_shapes,OpKernelContext * ctx)519 EngineContext* TRTEngineOp::GetEngine(
520     const std::vector<TensorShape>& input_shapes, OpKernelContext* ctx) {
521   static EngineContext empty_context;
522   mutex_lock lock(engine_mutex_);
523   // TODO(tmorris): using first input to get batch size - is this reliable?
524   const int batch_size = input_shapes[0].dim_size(0);
525 
526   // Get engine cache
527   TRTEngineCacheResource* cache_res = nullptr;
528   auto status = ctx->resource_manager()->LookupOrCreate(
529       "TRTEngineCache", name(), &cache_res,
530       {[this, ctx](TRTEngineCacheResource** cr) -> Status {
531         *cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_);
532         return Status::OK();
533       }});
534   if (!status.ok()) {
535     ctx->SetStatus(status);
536     return &empty_context;
537   }
538   core::ScopedUnref sc(cache_res);
539   auto& cache = cache_res->cache_;
540   auto allocator = cache_res->allocator_.get();
541   if (allocator == nullptr) {
542     return &empty_context;
543   }
544 
545   // Handle the static engine case. For static engines, the cache will have a
546   // single element containing the only engine.
547   if (static_engine_) {
548     if (cache.size()) {
549       // Batch size of engine must be >= the input batch size
550       // TODO(tmorris): use match compatible function?
551       if (cache.begin()->first[0].dim_size(0) >= batch_size) {
552         return cache.begin()->second.get();
553       }
554       return &empty_context;
555     }
556 
557     TrtUniquePtrType<IRuntime> infer(nvinfer1::createInferRuntime(logger));
558     infer->setGpuAllocator(allocator);
559     TrtUniquePtrType<nvinfer1::ICudaEngine> static_engine(
560         infer->deserializeCudaEngine(serialized_segment_.c_str(),
561                                      serialized_segment_.size(),
562                                      PluginFactoryTensorRT::GetInstance()));
563     auto raw_static_engine = static_engine.get();
564     const auto max_batch_size = raw_static_engine->getMaxBatchSize();
565     // Static engine will have max_batch_size for batch size so that all inputs
566     // will map to this single engine.
567     std::vector<TensorShape> engine_input_shapes(input_shapes);
568     for (int i = 0; i < engine_input_shapes.size(); i++) {
569       // TODO(tmorris): will all inputs have batch size as first dimension??
570       engine_input_shapes[i].set_dim(0, max_batch_size);
571     }
572     // TODO(laigd): here we assume engine_input_shapes matches the actual input
573     // shapes of the engine, we should verify that.
574     cache.emplace(engine_input_shapes,
575                   absl::make_unique<EngineContext>(
576                       std::move(static_engine),
577                       TrtUniquePtrType<nvinfer1::IExecutionContext>(
578                           raw_static_engine->createExecutionContext())));
579     // Runtime is safe to delete after engine creation
580     VLOG(1) << "Size of serialized TRT engine: "
581             << serialized_segment_.capacity();
582     string tmp;
583     // Swap with temporary empty string to deallocate the CPU memory.
584     serialized_segment_.swap(tmp);
585     if (max_batch_size < batch_size) {
586       return &empty_context;
587     }
588     return cache.at(engine_input_shapes).get();
589   }  // static_engine_
590 
591   // Handle the dynamic engine case.
592   // See if there is a compatible engine cached. The batch size should be <= the
593   // cached batch size.
594   std::vector<TensorShape> engine_input_shapes;
595   const bool matched_successfully =
596       GetCompatibleCachedEngine(input_shapes, &engine_input_shapes);
597   // If matched, use that engine. Otherwise, we will look in cache for that
598   // exact shape and possibly create a new engine if it is not in cache.
599   if (!matched_successfully) {
600     engine_input_shapes = input_shapes;
601     if (!cached_engine_batches_.empty()) {
602       // If user has explicitly defined cached_engine_batches, we should
603       // warn them that their input was non-compatible (batch size too high)
604       LOG(WARNING) << "No compatible cached engine was found for batch size: "
605                    << batch_size << ". A new engine will be created.";
606       cached_engine_batches_.push_back(batch_size);
607     }
608   }
609 
610   if (!cache.count(engine_input_shapes)) {
611     TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
612     bool convert_successfully = false;
613     LOG(INFO) << "Building a new TensorRT engine for " << name()
614               << " input shapes: "
615               << TensorShapeUtils::ShapeListString(engine_input_shapes);
616 
617     // Convert to partial shapes
618     std::vector<PartialTensorShape> partial_shapes(engine_input_shapes.begin(),
619                                                    engine_input_shapes.end());
620 
621     // Up to this point, calibrator_ can never be empty, since otherwise it
622     // means calibration_mode_ is true and this path won't get executed.
623     auto status = convert::ConvertGraphDefToEngine(
624         segment_graph_, precision_mode_, batch_size, workspace_size_,
625         partial_shapes, &logger, allocator, calibrator_.get(), &engine,
626         use_calibration_, &convert_successfully);
627     if (!status.ok()) {
628       LOG(WARNING) << "Engine creation for " << name() << " failed. "
629                    << "The native segment will be used instead. "
630                    << "Reason: " << status;
631       // Store an empty engine in the cache for these input shapes so we don't
632       // try to build the same failing engine again.
633       cache.emplace(engine_input_shapes, absl::make_unique<EngineContext>());
634       return &empty_context;
635     }
636     VLOG(1) << "Conversion is done";
637     TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
638         engine->createExecutionContext());
639     cache.emplace(engine_input_shapes,
640                   absl::make_unique<EngineContext>(std::move(engine),
641                                                    std::move(exec_context)));
642   }
643   return cache.at(engine_input_shapes).get();
644 }
645 
AllocateCalibrationResources(OpKernelContext * ctx,SerializableResourceBase ** cr)646 Status TRTEngineOp::AllocateCalibrationResources(
647     OpKernelContext* ctx, SerializableResourceBase** cr) {
648   auto cres = new TRTCalibrationResource();
649   *cr = cres;
650   // Get the allocator.
651   auto alloc = ctx->device()->GetAllocator(AllocatorAttributes());
652   if (!alloc) {
653     LOG(WARNING) << "Can't get device allocator will not be able to "
654                     "allocate memory from TensorFlow memory pool";
655     cres->allocator_.reset(new TRTCudaAllocator);
656   } else {
657     cres->allocator_.reset(new TRTDeviceAllocator(alloc));
658   }
659   // Get the input shapes.
660   const int batch_size = ctx->input(0).dim_size(0);
661   const int num_inputs = ctx->num_inputs();
662   std::vector<PartialTensorShape> shapes;
663   cres->device_tensors_.resize(num_inputs);
664   VLOG(1) << " Constructing calibrator";
665   for (int i = 0; i < num_inputs; i++) {
666     // allocate workspace on device for inputs
667     const Tensor& t = ctx->input(i);
668     shapes.emplace_back(t.shape());
669     Tensor* device_tensor;
670     TF_RETURN_IF_ERROR(ctx->allocate_persistent(
671         t.dtype(), t.shape(), &cres->device_tensors_.at(i), &device_tensor));
672     CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
673     void* device_address = GetTensorAddress(device_tensor);
674     if (device_address == nullptr) {
675       return errors::InvalidArgument(
676           "Unsupported data type encountered in input ", i);
677     }
678     cres->device_buffers_.emplace(
679         StrCat(kInputPHName, i),
680         std::pair<void*, size_t>(device_address, device_tensor->TotalBytes()));
681   }
682   cres->calibrator_.reset(
683       new TRTInt8Calibrator(cres->device_buffers_, batch_size, name()));
684   const string label(name());
685   auto segment_graph = &segment_graph_;
686   const int platform_gpu_id =
687       ctx->device()->tensorflow_gpu_device_info()->gpu_id;
688   if (platform_gpu_id < 0) {
689     LOG(ERROR) << "Can't get gpu_device_info from context->device()";
690     return errors::InvalidArgument(
691         "Context->device doesn't contain device info!");
692   }
693   const int64 workspace_size_bytes = workspace_size_;
694   cres->thr_.reset(new std::thread([cres, label, segment_graph, shapes,
695                                     platform_gpu_id, workspace_size_bytes]() {
696     LOG(INFO) << "Starting calibration thread on device " << platform_gpu_id
697               << ", Calibration Resource @ " << cres;
698     auto err = cudaSetDevice(platform_gpu_id);
699     if (err != cudaSuccess) {
700       // TODO(aaroey): should return error here.
701       LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id
702                  << " in calibration thread";
703     }
704     // ConvertGraphDefToEngine() will try to build the engine. This thread
705     // will loop inside buildCudaEngine() consuming the calibration data
706     // that is set by the TF op, and drive the builder until calibrator returns
707     // false. Engine is discarded after calibration table is generated
708     //
709     // TODO(aaroey): maybe setting the max batch size using the python
710     // calibration wrapper class.
711     auto s = convert::ConvertGraphDefToEngine(
712         *segment_graph, TrtPrecisionMode::INT8,
713         cres->calibrator_->getBatchSize(), workspace_size_bytes, shapes,
714         &cres->logger_, cres->allocator_.get(), cres->calibrator_.get(),
715         &cres->engine_,
716         /*use_calibration=*/true,
717         /*convert_successfully=*/nullptr);
718     if (!s.ok()) {
719       LOG(ERROR) << "Calibration failed: " << s;
720       cres->calibrator_->setDone();  // Ignore further pushes
721     }
722     VLOG(1) << "Calibration loop terminated " << label;
723   }));
724   VLOG(1) << "initialized calibrator resource";
725   return Status::OK();
726 }
727 
728 REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
729 
730 }  // namespace tensorrt
731 }  // namespace tensorflow
732 
733 #endif  // GOOGLE_TENSORRT
734 #endif  // GOOGLE_CUDA
735