• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <vector>
18 
19 #include "absl/memory/memory.h"
20 #include "absl/strings/ascii.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
24 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
25 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
26 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
27 #include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h"
28 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
29 #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
30 #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
31 #include "tensorflow/core/common_runtime/function.h"
32 #include "tensorflow/core/common_runtime/graph_constructor.h"
33 #include "tensorflow/core/common_runtime/graph_optimizer.h"
34 #include "tensorflow/core/framework/function.h"
35 #include "tensorflow/core/framework/graph_to_functiondef.h"
36 #include "tensorflow/core/framework/node_def_builder.h"
37 #include "tensorflow/core/framework/op.h"
38 #include "tensorflow/core/framework/op_kernel.h"
39 #include "tensorflow/core/graph/algorithm.h"
40 #include "tensorflow/core/lib/core/refcount.h"
41 #include "tensorflow/core/lib/strings/str_util.h"
42 #include "tensorflow/core/lib/strings/strcat.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/platform/mutex.h"
45 #include "tensorflow/core/platform/stream_executor.h"
46 #include "tensorflow/core/platform/thread_annotations.h"
47 #include "tensorflow/core/platform/types.h"
48 #include "tensorflow/core/profiler/lib/traceme.h"
49 #include "tensorflow/core/util/env_var.h"
50 
51 #if GOOGLE_CUDA && GOOGLE_TENSORRT
52 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
53 #include "third_party/tensorrt/NvInfer.h"
54 
55 namespace tensorflow {
56 namespace tensorrt {
57 static Logger logger;
58 using absl::StrAppend;
59 using absl::StrCat;
60 using ::nvinfer1::IRuntime;
61 
62 #define LOG_FIRST_FEW_WARNING_WITH_PREFIX \
63   LOG_FIRST_N(WARNING, 5) << "TF-TRT Warning: "
64 
65 // A helper class to call done() when destructed for asynchronous execution.
66 // Helps simultaneous execution of native and TRT engines.
67 
68 class AsyncHelper : public core::RefCounted {
69  public:
AsyncHelper(AsyncOpKernel::DoneCallback done)70   AsyncHelper(AsyncOpKernel::DoneCallback done) : done_(done) {}
71 
~AsyncHelper()72   ~AsyncHelper() override { this->operator()(); }
73 
operator ()()74   void operator()() {
75     if (!called_) {
76       done_();
77       called_ = true;
78     }
79   }
80 
81  private:
82   AsyncOpKernel::DoneCallback done_;
83   bool called_ = false;  // Has `done_` been called?
84 };
85 
86 //  This OP can construct TRTEngine on the fly and if construction of engine
87 //  fails, executes equivalent subgraph as a TensorFlow function.
88 class TRTEngineOp : public AsyncOpKernel {
89  public:
90   explicit TRTEngineOp(OpKernelConstruction* context);
91 
92   void ComputeAsync(OpKernelContext* context,
93                     AsyncOpKernel::DoneCallback done) override;
94 
95  private:
96   using CacheType =
97       LRUCache<std::vector<TensorShape>, std::unique_ptr<EngineContext>,
98                VectorTensorShapeHasher>;
99 
100   // Executes calibration.
101   void ExecuteCalibration(OpKernelContext* ctx,
102                           TRTEngineCacheResource* cache_res,
103                           AsyncHelper* helper);
104 
105   // Constructs a function handle for the segment of the TRTEngineOp.
106   StatusOr<FunctionLibraryRuntime::Handle> ConstructFunctionHandle(
107       FunctionLibraryRuntime* lib, const string& device_name,
108       bool allow_soft_placement = false, size_t num_inputs = 0,
109       size_t num_outputs = 0);
110 
111   // Imports the GraphDef for the segment of the TRTEngineOp to
112   // segment_graph_def_.
113   Status ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
114                                const string& device_name);
115 
116   // Executes replaced native segment as function Op.
117   void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
118 
119   // Executes the tensorrt engine. Returns whether we need to retry by running
120   // the native segment.
121   Status ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context,
122                           int trt_context_idx);
123 
124   // Allocates necessary resources for calibration.
125   Status AllocateCalibrationResources(OpKernelContext* ctx,
126                                       TRTEngineCacheResource* cache_res);
127 
128   Status GetEngineCacheResource(OpKernelContext* ctx,
129                                 TRTEngineCacheResource** cache_res);
130 
131   // Returns a pair of 1) An EngineContext object that is compatible with the
132   // input and 2) The index of the IExecutionContext compatible with the input.
133   // If a cuda engine for the given input shapes can't be found, returns
134   // (nullptr, 0) to allow native engine execution. Returns an error code for
135   // any problem that would prevent both TensorRT engine exceution and native
136   // segment execution.
137   StatusOr<std::pair<EngineContext*, int>> GetEngine(
138       const std::vector<TensorShape>& input_concrete_shapes,
139       OpKernelContext* ctx, TRTEngineCacheResource* cache_resource);
140 
141   // Builds and returns a cuda engine for the input shapes. If building the
142   // engine fails, enters a dummy entry into the cache_resource cache so we
143   // don't continually try to build the same failing engine.
144   StatusOr<TrtUniquePtrType<nvinfer1::ICudaEngine>> BuildEngine(
145       const std::vector<TensorShape>& input_concrete_shapes, int batch_size,
146       bool use_calibration, TRTInt8Calibrator* calibrator,
147       TRTEngineCacheResource* cache_resource);
148 
149   // Verify that the input shapes are consistent and can be handled by this op.
150   Status VerifyInputShapes(const std::vector<TensorShape>& shapes);
151 
152   std::vector<string> input_nodes_;
153   std::vector<string> output_nodes_;
154 
155   // serialized protobuf segment or trt engine depending on static_engine_ flag.
156   string serialized_segment_;
157 
158   // The function for TF native execution of the segment.
159   NameAttrList func_;
160 
161   // GraphDef representation of the segment.
162   GraphDef segment_graph_def_;
163 
164   // Engine Precision mode.
165   TrtPrecisionMode precision_mode_;
166 
167   // Whether engine is constructed during the conversion or needs to be
168   // constructed from protobuf segment.
169   bool static_engine_;
170 
171   // Whether to calibrate INT8 engine.
172   bool calibration_mode_;
173 
174   // Whether to use implicit batch dimension for TensorRT.
175   bool use_implicit_batch_;
176 
177   // Whether to collect optimization profiles for TensorRT, only used when
178   // use_implicit_batch_=false.
179   bool profile_generation_mode_;
180 
181   // Whether the TRTEngineOp has any input with unknown dimensions.
182   bool has_dynamic_shape_input_;
183 
184   // Whether to build TensorRT engines at runtime.
185   bool allow_build_at_runtime_;
186 
187   // Whether to allow soft placement when the graph is executed with native
188   // TensorFlow.
189   bool allow_soft_placement_;
190 
191   // Maximum number of cached engines.
192   int max_cached_engines_;
193 
194   int64 workspace_size_;
195   mutex engine_mutex_;
196   FunctionLibraryRuntime::Handle native_execution_func_handle_;
197 
198   // The finalized calibrator for inference.
199   std::unique_ptr<TRTInt8Calibrator> calibrator_;
200 
201   // If true, create calibration graph for INT8 mode. Otherwise, we are using
202   // user-provided quantization ranges.
203   bool use_calibration_;
204 
205   // Array of all input shapes, collected from the input_shapes attribute when
206   // constructing the TRTEngineOp. The input_shapes attribute is set during
207   // graph conversion time. This data is used to retrieve which input dimensions
208   // could be unknown. During inference time this information is not available
209   // otherwise (all shapes are known (concrete) shapes when we run inference).
210   std::vector<PartialTensorShape> input_partial_shapes_;
211 };
212 
213 #define TYPECASE(dt, X, Y)                                    \
214   case dt: {                                                  \
215     return (void*)X->flat<EnumToDataType<dt>::Type>().data(); \
216   }
217 
GetTensorAddress(const Tensor * tensor_ptr)218 void* GetTensorAddress(const Tensor* tensor_ptr) {
219   auto tensor_type = tensor_ptr->dtype();
220   switch (tensor_type) {
221     TYPECASE(DT_FLOAT, tensor_ptr, dest_ptr);
222     TYPECASE(DT_HALF, tensor_ptr, dest_ptr);
223     TYPECASE(DT_INT8, tensor_ptr, dest_ptr);
224     TYPECASE(DT_INT32, tensor_ptr, dest_ptr);
225     default: {
226       LOG(ERROR) << "Unsupported Data type " << DataTypeString(tensor_type);
227       return nullptr;
228     }
229   }
230 }
231 
FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,FunctionLibraryRuntime * flib_runtime,GraphDef * graph_def)232 static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,
233                                     FunctionLibraryRuntime* flib_runtime,
234                                     GraphDef* graph_def) {
235   const FunctionLibraryDefinition* flib_def =
236       flib_runtime->GetFunctionLibraryDefinition();
237   const FunctionBody* fbody;
238   fbody = flib_runtime->GetFunctionBody(handle);
239   if (!fbody) {
240     return errors::Internal(
241         "Function body is null when converting from FuncDef to GraphDef.");
242   }
243   std::unique_ptr<Graph> graph(new Graph(flib_def));
244   CopyGraph(*fbody->graph, graph.get());
245 
246   auto replace_name = [](const char* const prefix, string* name) {
247     if (absl::StartsWith(*name, absl::AsciiStrToLower(prefix))) {
248       name->replace(0, strlen(prefix), prefix);
249       return true;
250     }
251     return false;
252   };
253   graph->ToGraphDef(graph_def);
254   // GraphToFunctionDef() will convert all the node names to lowercase.
255   for (auto& node : *graph_def->mutable_node()) {
256     if (!replace_name(IONamePrefixes::kInputPHName, node.mutable_name())) {
257       if (replace_name(IONamePrefixes::kOutputPHName, node.mutable_name())) {
258         // Instantiation of the function will append _RetVal to the node name,
259         // need to remove it for backward compatibility.
260         const char* const suffix_to_remove = "_RetVal";
261         if (absl::EndsWith(node.name(), suffix_to_remove)) {
262           node.mutable_name()->erase(node.name().size() -
263                                      strlen(suffix_to_remove));
264         }
265       }
266     }
267     for (auto& input : *node.mutable_input()) {
268       if (!replace_name(IONamePrefixes::kInputPHName, &input)) {
269         replace_name(IONamePrefixes::kOutputPHName, &input);
270       }
271     }
272   }
273   return Status::OK();
274 }
275 
ConstructFunctionHandle(FunctionLibraryRuntime * lib,const string & device_name,bool allow_soft_placement,size_t num_inputs,size_t num_outputs)276 StatusOr<FunctionLibraryRuntime::Handle> TRTEngineOp::ConstructFunctionHandle(
277     FunctionLibraryRuntime* lib, const string& device_name,
278     bool allow_soft_placement, size_t num_inputs, size_t num_outputs) {
279   VLOG(1) << "Constructing function handle";
280   if (lib == nullptr) {
281     return errors::Internal("Context function library is null");
282   }
283   FunctionLibraryRuntime::InstantiateOptions inst_ops;
284   inst_ops.state_handle = "";
285   inst_ops.target = device_name;
286   if (allow_soft_placement) {
287     const FunctionDef* fdef =
288         lib->GetFunctionLibraryDefinition()->Find(func_.name());
289     if (!fdef) {
290       return errors::Internal(
291           StrCat("Cann't find FunctionDef for", func_.name()));
292     }
293     bool ints_on_device =
294         fdef->attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
295         fdef->attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
296     // kIntsOnDeviceAttr is not compatible with is_multi_device_function which
297     // is needed to support allow_soft_placement.
298     if (ints_on_device) {
299       LOG_FIRST_FEW_WARNING_WITH_PREFIX
300           << "Function " << name()
301           << " has attribute kIntsOnDeviceAttr=true "
302              "and will be executed natively with allow_soft_placement=false. "
303              "If this is a problem, please re-generate your SavedModel with "
304              "the TF-TRT runtime you are using.";
305     } else {
306       inst_ops.is_multi_device_function = true;
307       inst_ops.input_devices.resize(num_inputs, device_name);
308       inst_ops.output_devices.resize(num_outputs, device_name);
309       inst_ops.config_proto.set_allow_soft_placement(true);
310     }
311   }
312   FunctionLibraryRuntime::Handle func_handle;
313   Status status = lib->Instantiate(func_.name(), AttrSlice(&func_.attr()),
314                                    inst_ops, &func_handle);
315   if (status.ok()) {
316     return func_handle;
317   }
318   return status;
319 }
320 
ImportSegmentGraphDef(FunctionLibraryRuntime * lib,const string & device_name)321 Status TRTEngineOp::ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
322                                           const string& device_name) {
323   TF_ASSIGN_OR_RETURN(FunctionLibraryRuntime::Handle func_handle,
324                       ConstructFunctionHandle(lib, device_name));
325   return FunctionDefToGraphDef(func_handle, lib, &segment_graph_def_);
326 }
327 
TRTEngineOp(OpKernelConstruction * context)328 TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
329     : AsyncOpKernel(context) {
330   // read serialized_engine
331   OP_REQUIRES_OK(context,
332                  context->GetAttr("serialized_segment", &serialized_segment_));
333   OP_REQUIRES_OK(context,
334                  context->GetAttr("workspace_size_bytes", &workspace_size_));
335   OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
336 
337   VLOG(1) << "Constructing " << name();
338   string precision_string;
339   OP_REQUIRES_OK(context,
340                  context->GetAttr("precision_mode", &precision_string));
341   string calibration_data;
342   OP_REQUIRES_OK(context,
343                  context->GetAttr("calibration_data", &calibration_data));
344   OP_REQUIRES_OK(context, context->GetAttr("segment_func", &func_));
345   OP_REQUIRES(context, !func_.name().empty(),
346               errors::InvalidArgument(
347                   "The TF function for the TRT segment could not be empty"));
348   OP_REQUIRES_OK(context,
349                  TrtPrecisionModeFromName(precision_string, &precision_mode_));
350   OP_REQUIRES_OK(context,
351                  context->GetAttr("use_calibration", &use_calibration_));
352   OP_REQUIRES_OK(context,
353                  context->GetAttr("input_shapes", &input_partial_shapes_));
354   auto status =
355       context->GetAttr("_allow_build_at_runtime", &allow_build_at_runtime_);
356   if (status.code() == tensorflow::error::NOT_FOUND) {
357     VLOG(2) << "Not found _allow_build_at_runtime in "
358             << context->device()->name()
359             << ", thus setting _allow_build_at_runtime=true";
360     allow_build_at_runtime_ = true;
361   } else {
362     OP_REQUIRES_OK(context, status);
363   }
364 
365   status = context->GetAttr("_allow_soft_placement", &allow_soft_placement_);
366   if (status.code() == tensorflow::error::NOT_FOUND) {
367     allow_soft_placement_ = true;
368   } else {
369     OP_REQUIRES_OK(context, status);
370   }
371 
372   native_execution_func_handle_ = kInvalidHandle;
373   if (!static_engine_) {
374     OP_REQUIRES_OK(context, ImportSegmentGraphDef(context->function_library(),
375                                                   context->device()->name()));
376   }
377   // TODO(laigd): calibration_data is used in TF v1.x and we keep it only for
378   // backward compatibility reasons. Remove it once all known users switch to
379   // 2.0.
380   calibration_mode_ =
381       (use_calibration_ && precision_mode_ == TrtPrecisionMode::INT8 &&
382        calibration_data.empty());
383   if (!calibration_data.empty()) {
384     calibrator_.reset(new TRTInt8Calibrator(calibration_data));
385     calibration_data.resize(0);
386   }
387   OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
388                                            &max_cached_engines_));
389 
390   status = context->GetAttr("_use_implicit_batch", &use_implicit_batch_);
391   if (status.code() == tensorflow::error::NOT_FOUND) {
392     VLOG(2) << "Not found _use_implicit_batch in " << context->device()->name()
393             << ", thus setting _use_implicit_batch=true";
394     use_implicit_batch_ = true;
395   }
396 #if !IS_TRT_VERSION_GE(6, 0, 0, 0)
397   if (!use_implicit_batch_) {
398     VLOG(2) << "Need at least TensorRT 6.0 for explicit batch mode. Setting "
399             << "_use_implicit_batch=true";
400     use_implicit_batch_ = true;
401   }
402 #endif
403   status =
404       context->GetAttr("_profile_generation_mode", &profile_generation_mode_);
405   if (status.code() == tensorflow::error::NOT_FOUND) {
406     VLOG(2) << "Not found _profile_generation_mode in "
407             << context->device()->name()
408             << ", thus setting _profile_generation_mode=false";
409     profile_generation_mode_ = false;
410   }
411   if (use_implicit_batch_) {
412     OP_REQUIRES(context, !profile_generation_mode_,
413                 errors::InvalidArgument(
414                     "profile_generation_mode_=true is only supported if "
415                     "use_implicit_batch=false"));
416     if (input_partial_shapes_.empty()) {
417       VLOG(1) << "Attribute input_shapes is not set. This happens probably "
418               << "because you are using a model that is already converted "
419               << "to TensorRT with a previous version of TF-TRT (i.e. includes "
420               << "TRTEngineOp in graph). This is not an error. If you convert "
421               << "the original model again to TensorRT, the attributes "
422               << "input_shapes will be set automatically.";
423     }
424   } else {
425     OP_REQUIRES(
426         context, !input_partial_shapes_.empty(),
427         errors::InvalidArgument(
428             "Explicit batch mode requires attribute input_shapes to be set."
429             "If you are using a model that was converted to TensorRT by a "
430             "previous version of TF-TRT, (i.e. includes TRTEngineOp in graph "
431             "without the input_shapes attribute), then you need to convert the "
432             "original model again to TensorRT in order to set the attribute "
433             "input_shapes."));
434     OP_REQUIRES(context, !calibration_mode_,
435                 errors::InvalidArgument(
436                     "Explicit batch mode does not support calibration"));
437   }
438   has_dynamic_shape_input_ = absl::c_any_of(
439       input_partial_shapes_,
440       [](PartialTensorShape shape) { return !shape.IsFullyDefined(); });
441   VLOG(2) << "TRTEngineOp has_dynamic_shape_input_: "
442           << has_dynamic_shape_input_;
443 }
444 
ExecuteNativeSegment(OpKernelContext * ctx,AsyncHelper * helper)445 void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
446                                        AsyncHelper* helper) {
447   tensorflow::profiler::TraceMe activity(
448       "TRTEngineOp::ExecuteNativeSegment",
449       tensorflow::profiler::TraceMeLevel::kInfo);
450   std::vector<Tensor> inputs;
451   std::vector<Tensor>* outputs = new std::vector<Tensor>();
452   if (native_execution_func_handle_ == kInvalidHandle) {
453     StatusOr<FunctionLibraryRuntime::Handle> status_or_handle =
454         ConstructFunctionHandle(ctx->function_library(), ctx->device()->name(),
455                                 allow_soft_placement_, ctx->num_inputs(),
456                                 ctx->num_outputs());
457     OP_REQUIRES_OK_ASYNC(ctx, status_or_handle.status(), *helper);
458     native_execution_func_handle_ = status_or_handle.ValueOrDie();
459   }
460   auto lib = ctx->function_library();
461   FunctionLibraryRuntime::Options opts;
462   opts.rendezvous = ctx->rendezvous();
463   opts.cancellation_manager = ctx->cancellation_manager();
464   opts.runner = ctx->runner();
465   inputs.reserve(ctx->num_inputs());
466   for (int i = 0; i < ctx->num_inputs(); i++) {
467     inputs.push_back(ctx->input(i));
468   }
469   helper->Ref();  // Increment count for calculating native graph
470   VLOG(1) << "Executing native segment: " << name();
471   lib->Run(opts, native_execution_func_handle_, inputs, outputs,
472            [this, ctx, outputs, helper](const Status& s) {
473              core::ScopedUnref sc(helper);
474              std::unique_ptr<std::vector<Tensor>> outputs_wrapper(outputs);
475              OP_REQUIRES_OK_ASYNC(ctx, s, *helper);
476              VLOG(1) << "Native Segment completed";
477              for (size_t t = 0; t < outputs->size(); ++t) {
478                ctx->set_output(t, outputs->at(t));
479              }
480            });
481 }
482 
ExecuteCalibration(OpKernelContext * ctx,TRTEngineCacheResource * cache_res,AsyncHelper * helper)483 void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
484                                      TRTEngineCacheResource* cache_res,
485                                      AsyncHelper* helper) {
486   tensorflow::profiler::TraceMe activity(
487       "TRTEngineOp::ExecuteCalibration",
488       tensorflow::profiler::TraceMeLevel::kInfo);
489   VLOG(1) << "Executing TRT calibration: " << name();
490   helper->Ref();
491   core::ScopedUnref sc(helper);
492 
493   CalibrationContext* calib_ctx = cache_res->calib_ctx_.get();
494   const int num_inputs = ctx->num_inputs();
495   // TODO(laigd): need to check that input shape matches.
496   // Pass input data to calibrator
497   std::unordered_map<string, void*> input_data;
498   for (int i = 0; i < num_inputs; i++) {
499     const Tensor& t = ctx->input(i);
500     void* data_address = GetTensorAddress(&t);
501     OP_REQUIRES_ASYNC(ctx, data_address,
502                       errors::InvalidArgument(
503                           "Unsupported data type encountered in input ", i),
504                       *helper);
505     // Check the allocated buffer is sufficient for input
506     const auto device_tensor =
507         calib_ctx->device_tensors_.at(i).AccessTensor(ctx);
508     CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
509     input_data.emplace(StrCat(IONamePrefixes::kInputPHName, i), data_address);
510   }
511   VLOG(2) << "Filled map for sending";
512   // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
513   const cudaStream_t* stream = CHECK_NOTNULL(
514       reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
515                                                 ->stream()
516                                                 ->implementation()
517                                                 ->GpuStreamMemberHack()));
518   // If calibrator is terminated before, it means an error has occurred.
519   //
520   // Note: setBatch() will wait until TRTInt8Calibrator::getBatch() is called
521   // the first time before proceeding, so if buildCudaEngine() returns an error,
522   // it means getBatch() is never called, and the setBatch() here will hang
523   // until setDone() is called later by the calibration thread in
524   // AllocateCalibrationResources(). In that case, this setBatch() will always
525   // be able to detect the error and return false.
526   OP_REQUIRES_ASYNC(ctx, calib_ctx->calibrator_->setBatch(input_data, *stream),
527                     errors::Internal("Failed to feed calibration data"),
528                     *helper);
529   VLOG(2) << "Passed calibration data";
530   ExecuteNativeSegment(ctx, helper);
531 }
532 
VerifyInputShapes(const std::vector<TensorShape> & input_concrete_shapes)533 Status TRTEngineOp::VerifyInputShapes(
534     const std::vector<TensorShape>& input_concrete_shapes) {
535   if (input_concrete_shapes.empty()) {
536     return errors::InvalidArgument("Input shapes are empty, for ", name());
537   }
538 
539   if (input_partial_shapes_.empty()) {
540     if (!use_implicit_batch_) {
541       return errors::InvalidArgument(
542           "Explicit batch mode requires input_partial_shapes_ ",
543           "to contain the dynamic input shapes to TRTEngineOp");
544     }
545     // If the graph was converted with an earlier version of TF-TRT, it can
546     // happen that the input_partial_shapes_ vector is not set (see
547     // input_shapes attribute handling in the TRTEngineOp constructor).
548     // In implicit batch mode it is allowed to have empty input_partial_shapes_,
549     // since it is only required in explicit batch mode (see the input_shapes
550     // attribute of ConvertGraphDefToEngine in TRTEngineOp::GetEngine.
551   } else {
552     // Additional consistency checks if input_partial_shapes_ is present.
553     const string error_msg = StrCat(
554         "Input shapes do not match input partial shapes stored in graph, for ",
555         name(), ": ", DebugString(input_concrete_shapes),
556         " != ", DebugString(input_partial_shapes_));
557     if (input_concrete_shapes.size() != input_partial_shapes_.size()) {
558       return errors::InvalidArgument(error_msg);
559     }
560     for (int i = 0; i < input_concrete_shapes.size(); i++) {
561       if (input_concrete_shapes[i].dims() != input_partial_shapes_[i].dims()) {
562         return errors::InvalidArgument(error_msg);
563       }
564     }
565     for (int i = 0; i < input_concrete_shapes.size(); i++) {
566       for (int d = 0; d < input_concrete_shapes[i].dims(); d++) {
567         if (input_partial_shapes_[i].dim_size(d) != -1) {
568           if (input_concrete_shapes[i].dim_size(d) !=
569               input_partial_shapes_[i].dim_size(d)) {
570             return errors::InvalidArgument(error_msg);
571           }
572         }
573       }
574     }
575   }
576 
577   if (use_implicit_batch_) {
578     if (input_concrete_shapes[0].dims() < 1) {
579       return errors::InvalidArgument(
580           "Input shapes contain scalar, for ", name(), ": ",
581           TensorShapeUtils::ShapeListString(input_concrete_shapes));
582     }
583     const int batch_size = input_concrete_shapes[0].dim_size(0);
584     if (batch_size < 1) {
585       return errors::InvalidArgument(
586           "Incorrect batch dimension, for ", name(), ": ",
587           TensorShapeUtils::ShapeListString(input_concrete_shapes));
588     }
589     for (const TensorShape& shape : input_concrete_shapes) {
590       if (batch_size != shape.dim_size(0)) {
591         return errors::InvalidArgument(
592             "Input shapes are inconsistent on the batch dimension, for ",
593             name(), ": ",
594             TensorShapeUtils::ShapeListString(input_concrete_shapes));
595       }
596     }
597   }
598   return Status::OK();
599 }
600 
AllowEngineNativeSegmentExecution()601 static bool AllowEngineNativeSegmentExecution() {
602   bool value;
603   Status status =
604       ReadBoolFromEnvVar("TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION",
605                          /*default_value=*/true, &value);
606   if (!status.ok()) {
607     LOG(ERROR) << status;
608   }
609   return value;
610 }
611 
ComputeAsync(OpKernelContext * ctx,AsyncOpKernel::DoneCallback done)612 void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
613                                AsyncOpKernel::DoneCallback done) {
614   tensorflow::profiler::TraceMe activity(
615       "TRTEngineOp::ComputeAsync", tensorflow::profiler::TraceMeLevel::kInfo);
616   auto helper = new AsyncHelper(done);
617   core::ScopedUnref sc(helper);
618 
619   // Get TRT resource.
620   TRTEngineCacheResource* cache_res = nullptr;
621   OP_REQUIRES_OK_ASYNC(ctx, GetEngineCacheResource(ctx, &cache_res), *helper);
622   core::ScopedUnref unref_cache_res(cache_res);
623 
624   // Run calibration if in int8+calibration mode.
625   // * Logic in TF 1.x:
626   //   - During conversion: calibration_mode_ is true and cache size is 0, so it
627   //     will run calibration.
628   //   - During inference: calibration_data will be set, so calibration_mode_ is
629   //     false and it won't trigger calibration.
630   // * Logic in TF 2.0:
631   //   - During conversion: similar to 1.x.
632   //   - During inference: calibration_data will still be empty, but cache will
633   //     contain the the calibrated engine, so it won't trigger calibration.
634   //
635   // TODO(laigd): consider the following alternatives:
636   // 1. Serialize the state (calibration or inference) using
637   //    TRTEngineInstance proto (or a new proto), so we know which mode we're
638   //    in and don't run calibration during inference (which is invalid).
639   // 2. Reuse the calibration_data attribute or use a new attribute in the
640   //    NodeDef to indicate whether it's in calibration mode.
641   if (calibration_mode_ && cache_res->cache_.size() == 0) {
642     if (!cache_res->calib_ctx_) {
643       // TODO(laigd): better encapsulation.
644       mutex_lock lock(engine_mutex_);
645       if (!cache_res->calib_ctx_) {
646         OP_REQUIRES_OK_ASYNC(ctx, AllocateCalibrationResources(ctx, cache_res),
647                              *helper);
648       }
649     }
650     // TODO(laigd): check that the input shapes match the shapes of the
651     // persistent tensor in the calibration resource.
652     ExecuteCalibration(ctx, cache_res, helper);
653     return;
654   }
655 
656   // Get shapes of inputs to engine.
657   std::vector<TensorShape> input_concrete_shapes;
658   input_concrete_shapes.reserve(ctx->num_inputs());
659   for (int i = 0; i < ctx->num_inputs(); ++i) {
660     input_concrete_shapes.push_back(ctx->input(i).shape());
661   }
662 
663   Status verify_input_shape_status = VerifyInputShapes(input_concrete_shapes);
664   // TODO(bixia): Fix the segmentation.
665   if (!verify_input_shape_status.ok()) {
666     LOG_FIRST_FEW_WARNING_WITH_PREFIX
667         << "Running native segment for" << name()
668         << " due to failure in verifying input shapes: "
669         << verify_input_shape_status.error_message();
670     ExecuteNativeSegment(ctx, helper);
671     return;
672   }
673 
674   if (!use_implicit_batch_ && has_dynamic_shape_input_) {
675     if (profile_generation_mode_) {
676       // Collecting new shapes for profiles can be only done once. After the
677       // shapes are converted to TRT profiles, no shapes can be collected
678       // anymore.
679       OP_REQUIRES(ctx, cache_res->profiles_.GetNumProfiles() == 0,
680                   errors::Unimplemented("Cannot collect new shapes when "
681                                         "profiles are already created."));
682       // Just collect the input shape info and return. The shapes are used to
683       // generate optimization profiles during engine creation.
684       cache_res->profiles_.AddShape(input_concrete_shapes);
685       VLOG(1) << "Native segment is used during collecting shapes for profiles";
686       ExecuteNativeSegment(ctx, helper);
687       return;
688     } else if (cache_res->profiles_.GetNumProfiles() == 0) {
689       // Add current shape if we did not collect any shapes so far.
690       if (!cache_res->profiles_.HasShape()) {
691         cache_res->profiles_.AddShape(input_concrete_shapes);
692       }
693       // Create profiles out of collected shapes during profile generation.
694       cache_res->profiles_.InitProfiles(input_partial_shapes_);
695     }
696   }
697   StatusOr<std::pair<EngineContext*, int>> status =
698       GetEngine(input_concrete_shapes, ctx, cache_res);
699   OP_REQUIRES_OK_ASYNC(ctx, status.status(), *helper);
700 
701   EngineContext* engine_context = status.ValueOrDie().first;
702   int trt_context_idx = status.ValueOrDie().second;
703   auto may_execute_native_segment = [&] {
704     if (!AllowEngineNativeSegmentExecution()) {
705       ctx->CtxFailure(
706           errors::Aborted("User disallowed engine native segment execution"));
707       return false;
708     }
709     return true;
710   };
711   if (!engine_context->cuda_engine) {
712     LOG_FIRST_FEW_WARNING_WITH_PREFIX
713         << "Engine retrieval for input shapes: "
714         << TensorShapeUtils::ShapeListString(input_concrete_shapes)
715         << " failed. Running native segment for " << name();
716     if (may_execute_native_segment()) {
717       ExecuteNativeSegment(ctx, helper);
718     }
719     return;
720   }
721   Status stat = ExecuteTrtEngine(ctx, engine_context, trt_context_idx);
722   if (!stat.ok()) {
723     LOG_FIRST_FEW_WARNING_WITH_PREFIX << "Failed to execute engine: " << stat
724                                       << " Retrying with native segment for "
725                                       << name();
726     if (!may_execute_native_segment()) {
727       return;
728     }
729     // Release any outputs that are allocated, ExecuteNativeSegment will
730     // re-allocate them and fail if they are currently allocated.
731     // The Tensor pointer in the returned TensorValue must be explicitly
732     // deleted.
733     for (int i = 0; i < ctx->num_outputs(); i++) {
734       delete ctx->release_output(i).tensor;
735     }
736     ExecuteNativeSegment(ctx, helper);
737     return;
738   }
739 }
740 
ExecuteTrtEngine(OpKernelContext * ctx,EngineContext * engine_context,int trt_context_idx)741 Status TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
742                                      EngineContext* engine_context,
743                                      int trt_context_idx) {
744   tensorflow::profiler::TraceMe activity(
745       "TRTEngineOp::ExecuteTrtEngine",
746       tensorflow::profiler::TraceMeLevel::kInfo);
747   VLOG(1) << "Executing TRT engine: " << name();
748   auto& cuda_engine = engine_context->cuda_engine;
749 
750   if (VLOG_IS_ON(2)) {
751 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
752     VLOG(2) << "  Network name: " << cuda_engine->getName();
753 #endif  // #if IS_TRT_VERSION_GE(6, 0, 0, 0)
754     VLOG(2) << "  Activation size: " << cuda_engine->getDeviceMemorySize()
755             << " bytes";
756     VLOG(2) << "  Workspace size: " << cuda_engine->getWorkspaceSize()
757             << " bytes";
758     VLOG(2) << "  Datatype of " << cuda_engine->getNbBindings()
759             << " inputs/outputs";
760     string binding_types = "";
761     for (int i = 0; i < cuda_engine->getNbBindings(); i++) {
762       binding_types += "    " + string(cuda_engine->getBindingName(i)) + ": " +
763                        DebugString(cuda_engine->getBindingDataType(i)) + "\n";
764     }
765     VLOG(2) << binding_types;
766   }
767 
768   const int num_binding = cuda_engine->getNbBindings();
769   std::vector<void*> buffers(num_binding);
770 
771   // nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex
772   // for it.
773   mutex_lock lock(engine_context->mu);
774   nvinfer1::IExecutionContext* execution_context;
775   TF_RETURN_IF_ERROR(
776       engine_context->GetExecutionContext(trt_context_idx, &execution_context));
777 
778   if (VLOG_IS_ON(2)) {
779     VLOG(2) << "Selected execution context: " << trt_context_idx;
780   }
781   const int num_batch =
782       use_implicit_batch_ ? ctx->input(0).shape().dim_size(0) : 0;
783 
784   TF_RETURN_IF_ERROR(SetTrtEngineInputs(cuda_engine.get(), execution_context,
785                                         trt_context_idx, buffers,
786                                         use_implicit_batch_, num_batch, ctx));
787 
788   TF_RETURN_IF_ERROR(SetTrtEngineOutputs(cuda_engine.get(), execution_context,
789                                          trt_context_idx, buffers,
790                                          use_implicit_batch_, num_batch, ctx));
791 
792   // Copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
793   const cudaStream_t* stream = CHECK_NOTNULL(
794       reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
795                                                 ->stream()
796                                                 ->implementation()
797                                                 ->GpuStreamMemberHack()));
798 
799   TF_RETURN_IF_ERROR(TrtEnqueue(execution_context, buffers, *stream,
800                                 use_implicit_batch_, num_batch));
801   return Status::OK();
802 }
803 
GetEngineCacheResource(OpKernelContext * ctx,TRTEngineCacheResource ** cache_res)804 Status TRTEngineOp::GetEngineCacheResource(OpKernelContext* ctx,
805                                            TRTEngineCacheResource** cache_res) {
806   // Canonicalize the op name by removing the scopes if any. This is mainly
807   // because in TFv2, the function graph can be instantiated in various ways and
808   // it'll insert scope names to the name of the TRTEngineOps, which will result
809   // in many different engine caches if we use the instantiated op name
810   // directly, but we still want all of them share the same cache (if they were
811   // representing the same subgraph).
812   absl::string_view resource_name = name();
813   size_t last_slash = resource_name.find_last_of('/');
814   if (last_slash != absl::string_view::npos) {
815     resource_name.remove_prefix(last_slash + 1);
816   }
817 
818   // Get engine cache.
819   return ctx->resource_manager()->LookupOrCreate(
820       std::string(kTfTrtContainerName), std::string(resource_name), cache_res,
821       {[this, ctx](TRTEngineCacheResource** cr) -> Status {
822         *cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_);
823         return Status::OK();
824       }});
825 }
826 
BuildEngine(const std::vector<TensorShape> & input_concrete_shapes,int batch_size,bool use_calibration,TRTInt8Calibrator * calibrator,TRTEngineCacheResource * cache_resource)827 StatusOr<TrtUniquePtrType<nvinfer1::ICudaEngine>> TRTEngineOp::BuildEngine(
828     const std::vector<TensorShape>& input_concrete_shapes, int batch_size,
829     bool use_calibration, TRTInt8Calibrator* calibrator,
830     TRTEngineCacheResource* cache_resource) {
831   VLOG(1) << "Building a new TensorRT engine for " << name()
832           << " with input shapes: "
833           << TensorShapeUtils::ShapeListString(input_concrete_shapes);
834 
835   // Use concrete shapes for implicit batch mode and partial shapes for
836   // explicit batch mode.
837   const std::vector<PartialTensorShape>& conversion_input_shapes =
838       use_implicit_batch_
839           ? std::vector<PartialTensorShape>(input_concrete_shapes.begin(),
840                                             input_concrete_shapes.end())
841           : input_partial_shapes_;
842   TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
843   auto status = convert::ConvertGraphDefToEngine(
844       segment_graph_def_, precision_mode_, batch_size, workspace_size_,
845       conversion_input_shapes, &logger, cache_resource->allocator_.get(),
846       calibrator, &engine, use_calibration, use_implicit_batch_, nullptr,
847       &cache_resource->profiles_, name());
848   if (!status.ok()) {
849     LOG_FIRST_FEW_WARNING_WITH_PREFIX
850         << "Engine creation for " << name() << " failed. "
851         << "The native segment will be used instead. "
852         << "Reason: " << status;
853     // Store an empty engine in the cache for these input shapes so we don't try
854     // to build the same failing engine again.
855     cache_resource->cache_.emplace(input_concrete_shapes,
856                                    absl::make_unique<EngineContext>());
857     return status;
858   }
859   return engine;
860 }
861 
GetEngine(const std::vector<TensorShape> & input_concrete_shapes,OpKernelContext * ctx,TRTEngineCacheResource * cache_res)862 StatusOr<std::pair<EngineContext*, int>> TRTEngineOp::GetEngine(
863     const std::vector<TensorShape>& input_concrete_shapes, OpKernelContext* ctx,
864     TRTEngineCacheResource* cache_res) {
865   static EngineContext empty_context;
866 
867   mutex_lock lock(engine_mutex_);
868   // Using first input to get batch size is reliable - VerifyInputShapes()
869   // guarantees that the first input is not a scalar. As such we can always use
870   // the first input to get the batch size for implicit batch mode. For explicit
871   // batch mode, this value is not used.
872   const int batch_size = input_concrete_shapes[0].dim_size(0);
873   // TODO(Tamas): remove the need for batch_size in explicit_batch mode
874   auto& cache = cache_res->cache_;
875   auto allocator = cache_res->allocator_.get();
876   if (allocator == nullptr) {
877     return std::pair<EngineContext*, int>(&empty_context, 0);
878   }
879 
880   // Handle the static engine case. For static engines, the cache will have a
881   // single element containing the only engine.
882   if (static_engine_) {
883     if (cache.size()) {
884       // TODO(laigd): need a better shape compatibility check for the case where
885       // implicit batch is disabled.
886       if (!use_implicit_batch_ ||
887           AreShapesCompatible(input_concrete_shapes, cache.begin()->first)) {
888         return std::pair<EngineContext*, int>(cache.begin()->second.get(), 0);
889       }
890       return std::pair<EngineContext*, int>(&empty_context, 0);
891     }
892 
893     TrtUniquePtrType<IRuntime> infer(nvinfer1::createInferRuntime(logger));
894     infer->setGpuAllocator(allocator);
895     // Need to initialize plugins in order to deserialize engines that contain
896     // plugins.
897     MaybeInitializeTrtPlugins(&logger);
898     TrtUniquePtrType<nvinfer1::ICudaEngine> static_engine(
899         infer->deserializeCudaEngine(serialized_segment_.c_str(),
900                                      serialized_segment_.size(), nullptr));
901     if (!static_engine) {
902       if (!allow_build_at_runtime_) {
903         // Store an empty engine in the cache so we don't try to load the same
904         // failing engine again.
905         cache.emplace(input_concrete_shapes,
906                       absl::make_unique<EngineContext>());
907         return std::pair<EngineContext*, int>(&empty_context, 0);
908       }
909       if (segment_graph_def_.node().empty()) {
910         Status status = ImportSegmentGraphDef(ctx->function_library(),
911                                               ctx->device()->name());
912         if (!status.ok()) {
913           LOG_FIRST_FEW_WARNING_WITH_PREFIX << "Getting segment graph for "
914                                             << name() << " failed. "
915                                             << "Reason: " << status;
916         }
917       }
918       auto result = BuildEngine(input_concrete_shapes, batch_size,
919                                 /*use_calibration=*/false,
920                                 /*calibrator=*/nullptr, cache_res);
921       if (!result.ok()) {
922         return std::pair<EngineContext*, int>(&empty_context, 0);
923       }
924       static_engine = std::move(result.ValueOrDie());
925     }
926     auto raw_static_engine = static_engine.get();
927     const auto max_batch_size = raw_static_engine->getMaxBatchSize();
928     // Static engine will have max_batch_size for batch size so that all inputs
929     // will map to this single engine.
930     std::vector<TensorShape> engine_input_shapes(input_concrete_shapes);
931     for (int i = 0; i < engine_input_shapes.size(); i++) {
932       engine_input_shapes[i].set_dim(0, max_batch_size);
933     }
934     auto exec_context_status =
935         ExecutionContext::Create(raw_static_engine, allocator);
936     if (!exec_context_status.ok()) {
937       return std::pair<EngineContext*, int>(&empty_context, 0);
938     }
939 
940     // TODO(laigd): here we assume engine_input_shapes matches the actual input
941     // shapes of the engine, we should verify that.
942     cache.emplace(engine_input_shapes,
943                   absl::make_unique<EngineContext>(
944                       std::move(static_engine),
945                       std::move(exec_context_status.ValueOrDie())));
946     // Runtime is safe to delete after engine creation
947     VLOG(1) << "Size of serialized TRT engine: "
948             << serialized_segment_.capacity();
949     string tmp;
950     // Swap with temporary empty string to deallocate the CPU memory.
951     serialized_segment_.swap(tmp);
952     if (use_implicit_batch_ && (max_batch_size < batch_size)) {
953       return std::pair<EngineContext*, int>(&empty_context, 0);
954     }
955     return std::pair<EngineContext*, int>(cache.at(engine_input_shapes).get(),
956                                           0);
957   }  // static_engine_
958 
959   int profile_id = -1;
960   if (!use_implicit_batch_) {
961     profile_id = cache_res->profiles_.GetProfileNumber(input_concrete_shapes);
962     // Since all profiles are already created at this point, finding no
963     // compatible profiles results in falling back to native TF.
964     if (profile_id == -1) {
965       return std::pair<EngineContext*, int>(&empty_context, 0);
966     }
967   }
968 
969   EngineContext* engine_contexts;
970   if (use_implicit_batch_) {
971     engine_contexts = cache_res->GetEngineContext(input_concrete_shapes);
972   } else {
973     engine_contexts = cache_res->GetEngineContext(profile_id);
974   }
975 
976   // If cache does not have a compatible engine then create a new engine.
977   if (engine_contexts == nullptr) {
978     if (!allow_build_at_runtime_) {
979       LOG_FIRST_FEW_WARNING_WITH_PREFIX
980           << "Found no engine in cache matching input shapes. "
981           << "Not building a new engine because "
982           << "allow_build_at_runtime=False. "
983           << "The native segment will be used instead.";
984       // Store an empty engine in the cache for these input shapes so we don't
985       // try to build the same failing engine again.
986       cache.emplace(input_concrete_shapes, absl::make_unique<EngineContext>());
987       return std::pair<EngineContext*, int>(&empty_context, 0);
988     }
989 
990     // Up to this point, calibrator_ can never be empty, since otherwise it
991     // means calibration_mode_ is true and this path won't get executed.
992     auto result = BuildEngine(input_concrete_shapes, batch_size,
993                               use_calibration_, calibrator_.get(), cache_res);
994     if (!result.ok()) {
995       return std::pair<EngineContext*, int>(&empty_context, 0);
996     }
997     TrtUniquePtrType<nvinfer1::ICudaEngine> engine =
998         std::move(result.ValueOrDie());
999     std::vector<ExecutionContext> exec_contexts;
1000     TF_RETURN_IF_ERROR(cache_res->profiles_.CreateExecutionContexts(
1001         engine.get(), exec_contexts, allocator));
1002     cache.emplace(input_concrete_shapes,
1003                   absl::make_unique<EngineContext>(std::move(engine),
1004                                                    std::move(exec_contexts)));
1005     VLOG(1) << "Added new engine to cache of " << name()
1006             << ". Cache size: " << cache.size();
1007     engine_contexts = cache.at(input_concrete_shapes).get();
1008     // Query which profile of the new engine matches the actual input.
1009     profile_id = cache_res->profiles_.GetProfileNumber(input_concrete_shapes);
1010   }
1011   return std::pair<EngineContext*, int>(engine_contexts,
1012                                         use_implicit_batch_ ? 0 : profile_id);
1013 }
1014 
1015 // TODO(hinsu): Move this allocation to CalibrationContext constructor, if
1016 // possible.
AllocateCalibrationResources(OpKernelContext * ctx,TRTEngineCacheResource * cache_res)1017 Status TRTEngineOp::AllocateCalibrationResources(
1018     OpKernelContext* ctx, TRTEngineCacheResource* cache_res) {
1019   cache_res->calib_ctx_ = absl::make_unique<CalibrationContext>();
1020   auto* cres = cache_res->calib_ctx_.get();
1021 
1022   // Get the input shapes.
1023   const int batch_size = ctx->input(0).dim_size(0);
1024   const int num_inputs = ctx->num_inputs();
1025   std::vector<TensorShape> shapes;
1026   cres->device_tensors_.resize(num_inputs);
1027   VLOG(1) << "Constructing calibrator";
1028   for (int i = 0; i < num_inputs; i++) {
1029     // allocate workspace on device for inputs
1030     const Tensor& t = ctx->input(i);
1031     shapes.emplace_back(t.shape());
1032     Tensor* device_tensor;
1033     TF_RETURN_IF_ERROR(ctx->allocate_persistent(
1034         t.dtype(), t.shape(), &cres->device_tensors_.at(i), &device_tensor));
1035     CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
1036     void* device_address = GetTensorAddress(device_tensor);
1037     if (device_address == nullptr) {
1038       return errors::InvalidArgument(
1039           "Unsupported data type encountered in input ", i);
1040     }
1041     cres->device_buffers_.emplace(
1042         StrCat(IONamePrefixes::kInputPHName, i),
1043         std::pair<void*, size_t>(device_address, device_tensor->TotalBytes()));
1044   }
1045   cres->calibrator_.reset(
1046       new TRTInt8Calibrator(cres->device_buffers_, batch_size, name()));
1047   const int platform_gpu_id =
1048       ctx->device()->tensorflow_gpu_device_info()->gpu_id;
1049   if (platform_gpu_id < 0) {
1050     LOG(ERROR) << "Can't get gpu_device_info from context->device()";
1051     return errors::InvalidArgument(
1052         "Context->device doesn't contain device info!");
1053   }
1054 
1055   cache_res->Ref();
1056   cres->thr_.reset(new std::thread([this, cres, shapes, platform_gpu_id,
1057                                     cache_res]() {
1058     core::ScopedUnref sc(cache_res);
1059 
1060     VLOG(1) << "Starting calibration thread on device " << platform_gpu_id
1061             << ", Calibration Resource @ " << cres;
1062     auto err = cudaSetDevice(platform_gpu_id);
1063     if (err != cudaSuccess) {
1064       // TODO(aaroey): should return error here.
1065       LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id
1066                  << " in calibration thread";
1067     }
1068     std::vector<PartialTensorShape> partial_shapes(shapes.begin(),
1069                                                    shapes.end());
1070     // ConvertGraphDefToEngine() will try to build the engine. This thread
1071     // will loop inside buildCudaEngine() consuming the calibration data
1072     // that is set by the TF op, and drive the builder until calibrator
1073     // returns false. Engine is discarded after calibration table is
1074     // generated
1075     //
1076     // TODO(aaroey): maybe setting the max batch size using the python
1077     // calibration wrapper class.
1078     auto s = convert::ConvertGraphDefToEngine(
1079         this->segment_graph_def_, TrtPrecisionMode::INT8,
1080         cres->calibrator_->getBatchSize(), this->workspace_size_,
1081         partial_shapes, &cache_res->GetLogger(), cache_res->allocator_.get(),
1082         cres->calibrator_.get(), &cres->engine_, /*use_calibration=*/true,
1083         this->use_implicit_batch_, /*convert_successfully=*/nullptr,
1084         /*profiles=*/nullptr, name());
1085     if (!s.ok()) {
1086       LOG(ERROR) << "Calibration failed: " << s;
1087       cres->calibrator_->setDone();  // Ignore further pushes
1088     } else {
1089       // Transfer the ownership of the engine to the engine cache, so we can
1090       // dump it out during conversion for TF 2.0.
1091       mutex_lock lock(this->engine_mutex_);
1092       this->calibrator_ = std::move(cres->calibrator_);
1093       auto exec_context_status = ExecutionContext::Create(
1094           cres->engine_.get(), cache_res->allocator_.get());
1095       if (!exec_context_status.ok()) {
1096         LOG(ERROR) << "Calibration failed: " << s;
1097         cres->calibrator_->setDone();  // Ignore further pushes
1098       } else {
1099         cache_res->cache_.emplace(
1100             shapes, absl::make_unique<EngineContext>(
1101                         std::move(cres->engine_),
1102                         std::move(exec_context_status.ValueOrDie())));
1103       }
1104     }
1105 
1106     VLOG(1) << "Calibration loop terminated " << this->name();
1107   }));
1108   VLOG(1) << "initialized calibrator resource";
1109   return Status::OK();
1110 }
1111 
1112 REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
1113 
1114 }  // namespace tensorrt
1115 }  // namespace tensorflow
1116 
1117 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
1118