1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <vector>
18
19 #include "absl/memory/memory.h"
20 #include "absl/strings/ascii.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
24 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
25 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
26 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
27 #include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h"
28 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
29 #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
30 #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
31 #include "tensorflow/core/common_runtime/function.h"
32 #include "tensorflow/core/common_runtime/graph_constructor.h"
33 #include "tensorflow/core/common_runtime/graph_optimizer.h"
34 #include "tensorflow/core/framework/function.h"
35 #include "tensorflow/core/framework/graph_to_functiondef.h"
36 #include "tensorflow/core/framework/node_def_builder.h"
37 #include "tensorflow/core/framework/op.h"
38 #include "tensorflow/core/framework/op_kernel.h"
39 #include "tensorflow/core/graph/algorithm.h"
40 #include "tensorflow/core/lib/core/refcount.h"
41 #include "tensorflow/core/lib/strings/str_util.h"
42 #include "tensorflow/core/lib/strings/strcat.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/platform/mutex.h"
45 #include "tensorflow/core/platform/stream_executor.h"
46 #include "tensorflow/core/platform/thread_annotations.h"
47 #include "tensorflow/core/platform/types.h"
48 #include "tensorflow/core/profiler/lib/traceme.h"
49 #include "tensorflow/core/util/env_var.h"
50
51 #if GOOGLE_CUDA && GOOGLE_TENSORRT
52 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
53 #include "third_party/tensorrt/NvInfer.h"
54
55 namespace tensorflow {
56 namespace tensorrt {
57 static Logger logger;
58 using absl::StrAppend;
59 using absl::StrCat;
60 using ::nvinfer1::IRuntime;
61
62 #define LOG_FIRST_FEW_WARNING_WITH_PREFIX \
63 LOG_FIRST_N(WARNING, 5) << "TF-TRT Warning: "
64
65 // A helper class to call done() when destructed for asynchronous execution.
66 // Helps simultaneous execution of native and TRT engines.
67
68 class AsyncHelper : public core::RefCounted {
69 public:
AsyncHelper(AsyncOpKernel::DoneCallback done)70 AsyncHelper(AsyncOpKernel::DoneCallback done) : done_(done) {}
71
~AsyncHelper()72 ~AsyncHelper() override { this->operator()(); }
73
operator ()()74 void operator()() {
75 if (!called_) {
76 done_();
77 called_ = true;
78 }
79 }
80
81 private:
82 AsyncOpKernel::DoneCallback done_;
83 bool called_ = false; // Has `done_` been called?
84 };
85
86 // This OP can construct TRTEngine on the fly and if construction of engine
87 // fails, executes equivalent subgraph as a TensorFlow function.
88 class TRTEngineOp : public AsyncOpKernel {
89 public:
90 explicit TRTEngineOp(OpKernelConstruction* context);
91
92 void ComputeAsync(OpKernelContext* context,
93 AsyncOpKernel::DoneCallback done) override;
94
95 private:
96 using CacheType =
97 LRUCache<std::vector<TensorShape>, std::unique_ptr<EngineContext>,
98 VectorTensorShapeHasher>;
99
100 // Executes calibration.
101 void ExecuteCalibration(OpKernelContext* ctx,
102 TRTEngineCacheResource* cache_res,
103 AsyncHelper* helper);
104
105 // Constructs a function handle for the segment of the TRTEngineOp.
106 StatusOr<FunctionLibraryRuntime::Handle> ConstructFunctionHandle(
107 FunctionLibraryRuntime* lib, const string& device_name,
108 bool allow_soft_placement = false, size_t num_inputs = 0,
109 size_t num_outputs = 0);
110
111 // Imports the GraphDef for the segment of the TRTEngineOp to
112 // segment_graph_def_.
113 Status ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
114 const string& device_name);
115
116 // Executes replaced native segment as function Op.
117 void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
118
119 // Executes the tensorrt engine. Returns whether we need to retry by running
120 // the native segment.
121 Status ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context,
122 int trt_context_idx);
123
124 // Allocates necessary resources for calibration.
125 Status AllocateCalibrationResources(OpKernelContext* ctx,
126 TRTEngineCacheResource* cache_res);
127
128 Status GetEngineCacheResource(OpKernelContext* ctx,
129 TRTEngineCacheResource** cache_res);
130
131 // Returns a pair of 1) An EngineContext object that is compatible with the
132 // input and 2) The index of the IExecutionContext compatible with the input.
133 // If a cuda engine for the given input shapes can't be found, returns
134 // (nullptr, 0) to allow native engine execution. Returns an error code for
135 // any problem that would prevent both TensorRT engine exceution and native
136 // segment execution.
137 StatusOr<std::pair<EngineContext*, int>> GetEngine(
138 const std::vector<TensorShape>& input_concrete_shapes,
139 OpKernelContext* ctx, TRTEngineCacheResource* cache_resource);
140
141 // Builds and returns a cuda engine for the input shapes. If building the
142 // engine fails, enters a dummy entry into the cache_resource cache so we
143 // don't continually try to build the same failing engine.
144 StatusOr<TrtUniquePtrType<nvinfer1::ICudaEngine>> BuildEngine(
145 const std::vector<TensorShape>& input_concrete_shapes, int batch_size,
146 bool use_calibration, TRTInt8Calibrator* calibrator,
147 TRTEngineCacheResource* cache_resource);
148
149 // Verify that the input shapes are consistent and can be handled by this op.
150 Status VerifyInputShapes(const std::vector<TensorShape>& shapes);
151
152 std::vector<string> input_nodes_;
153 std::vector<string> output_nodes_;
154
155 // serialized protobuf segment or trt engine depending on static_engine_ flag.
156 string serialized_segment_;
157
158 // The function for TF native execution of the segment.
159 NameAttrList func_;
160
161 // GraphDef representation of the segment.
162 GraphDef segment_graph_def_;
163
164 // Engine Precision mode.
165 TrtPrecisionMode precision_mode_;
166
167 // Whether engine is constructed during the conversion or needs to be
168 // constructed from protobuf segment.
169 bool static_engine_;
170
171 // Whether to calibrate INT8 engine.
172 bool calibration_mode_;
173
174 // Whether to use implicit batch dimension for TensorRT.
175 bool use_implicit_batch_;
176
177 // Whether to collect optimization profiles for TensorRT, only used when
178 // use_implicit_batch_=false.
179 bool profile_generation_mode_;
180
181 // Whether the TRTEngineOp has any input with unknown dimensions.
182 bool has_dynamic_shape_input_;
183
184 // Whether to build TensorRT engines at runtime.
185 bool allow_build_at_runtime_;
186
187 // Whether to allow soft placement when the graph is executed with native
188 // TensorFlow.
189 bool allow_soft_placement_;
190
191 // Maximum number of cached engines.
192 int max_cached_engines_;
193
194 int64 workspace_size_;
195 mutex engine_mutex_;
196 FunctionLibraryRuntime::Handle native_execution_func_handle_;
197
198 // The finalized calibrator for inference.
199 std::unique_ptr<TRTInt8Calibrator> calibrator_;
200
201 // If true, create calibration graph for INT8 mode. Otherwise, we are using
202 // user-provided quantization ranges.
203 bool use_calibration_;
204
205 // Array of all input shapes, collected from the input_shapes attribute when
206 // constructing the TRTEngineOp. The input_shapes attribute is set during
207 // graph conversion time. This data is used to retrieve which input dimensions
208 // could be unknown. During inference time this information is not available
209 // otherwise (all shapes are known (concrete) shapes when we run inference).
210 std::vector<PartialTensorShape> input_partial_shapes_;
211 };
212
213 #define TYPECASE(dt, X, Y) \
214 case dt: { \
215 return (void*)X->flat<EnumToDataType<dt>::Type>().data(); \
216 }
217
GetTensorAddress(const Tensor * tensor_ptr)218 void* GetTensorAddress(const Tensor* tensor_ptr) {
219 auto tensor_type = tensor_ptr->dtype();
220 switch (tensor_type) {
221 TYPECASE(DT_FLOAT, tensor_ptr, dest_ptr);
222 TYPECASE(DT_HALF, tensor_ptr, dest_ptr);
223 TYPECASE(DT_INT8, tensor_ptr, dest_ptr);
224 TYPECASE(DT_INT32, tensor_ptr, dest_ptr);
225 default: {
226 LOG(ERROR) << "Unsupported Data type " << DataTypeString(tensor_type);
227 return nullptr;
228 }
229 }
230 }
231
FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,FunctionLibraryRuntime * flib_runtime,GraphDef * graph_def)232 static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,
233 FunctionLibraryRuntime* flib_runtime,
234 GraphDef* graph_def) {
235 const FunctionLibraryDefinition* flib_def =
236 flib_runtime->GetFunctionLibraryDefinition();
237 const FunctionBody* fbody;
238 fbody = flib_runtime->GetFunctionBody(handle);
239 if (!fbody) {
240 return errors::Internal(
241 "Function body is null when converting from FuncDef to GraphDef.");
242 }
243 std::unique_ptr<Graph> graph(new Graph(flib_def));
244 CopyGraph(*fbody->graph, graph.get());
245
246 auto replace_name = [](const char* const prefix, string* name) {
247 if (absl::StartsWith(*name, absl::AsciiStrToLower(prefix))) {
248 name->replace(0, strlen(prefix), prefix);
249 return true;
250 }
251 return false;
252 };
253 graph->ToGraphDef(graph_def);
254 // GraphToFunctionDef() will convert all the node names to lowercase.
255 for (auto& node : *graph_def->mutable_node()) {
256 if (!replace_name(IONamePrefixes::kInputPHName, node.mutable_name())) {
257 if (replace_name(IONamePrefixes::kOutputPHName, node.mutable_name())) {
258 // Instantiation of the function will append _RetVal to the node name,
259 // need to remove it for backward compatibility.
260 const char* const suffix_to_remove = "_RetVal";
261 if (absl::EndsWith(node.name(), suffix_to_remove)) {
262 node.mutable_name()->erase(node.name().size() -
263 strlen(suffix_to_remove));
264 }
265 }
266 }
267 for (auto& input : *node.mutable_input()) {
268 if (!replace_name(IONamePrefixes::kInputPHName, &input)) {
269 replace_name(IONamePrefixes::kOutputPHName, &input);
270 }
271 }
272 }
273 return Status::OK();
274 }
275
ConstructFunctionHandle(FunctionLibraryRuntime * lib,const string & device_name,bool allow_soft_placement,size_t num_inputs,size_t num_outputs)276 StatusOr<FunctionLibraryRuntime::Handle> TRTEngineOp::ConstructFunctionHandle(
277 FunctionLibraryRuntime* lib, const string& device_name,
278 bool allow_soft_placement, size_t num_inputs, size_t num_outputs) {
279 VLOG(1) << "Constructing function handle";
280 if (lib == nullptr) {
281 return errors::Internal("Context function library is null");
282 }
283 FunctionLibraryRuntime::InstantiateOptions inst_ops;
284 inst_ops.state_handle = "";
285 inst_ops.target = device_name;
286 if (allow_soft_placement) {
287 const FunctionDef* fdef =
288 lib->GetFunctionLibraryDefinition()->Find(func_.name());
289 if (!fdef) {
290 return errors::Internal(
291 StrCat("Cann't find FunctionDef for", func_.name()));
292 }
293 bool ints_on_device =
294 fdef->attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
295 fdef->attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
296 // kIntsOnDeviceAttr is not compatible with is_multi_device_function which
297 // is needed to support allow_soft_placement.
298 if (ints_on_device) {
299 LOG_FIRST_FEW_WARNING_WITH_PREFIX
300 << "Function " << name()
301 << " has attribute kIntsOnDeviceAttr=true "
302 "and will be executed natively with allow_soft_placement=false. "
303 "If this is a problem, please re-generate your SavedModel with "
304 "the TF-TRT runtime you are using.";
305 } else {
306 inst_ops.is_multi_device_function = true;
307 inst_ops.input_devices.resize(num_inputs, device_name);
308 inst_ops.output_devices.resize(num_outputs, device_name);
309 inst_ops.config_proto.set_allow_soft_placement(true);
310 }
311 }
312 FunctionLibraryRuntime::Handle func_handle;
313 Status status = lib->Instantiate(func_.name(), AttrSlice(&func_.attr()),
314 inst_ops, &func_handle);
315 if (status.ok()) {
316 return func_handle;
317 }
318 return status;
319 }
320
ImportSegmentGraphDef(FunctionLibraryRuntime * lib,const string & device_name)321 Status TRTEngineOp::ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
322 const string& device_name) {
323 TF_ASSIGN_OR_RETURN(FunctionLibraryRuntime::Handle func_handle,
324 ConstructFunctionHandle(lib, device_name));
325 return FunctionDefToGraphDef(func_handle, lib, &segment_graph_def_);
326 }
327
TRTEngineOp(OpKernelConstruction * context)328 TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
329 : AsyncOpKernel(context) {
330 // read serialized_engine
331 OP_REQUIRES_OK(context,
332 context->GetAttr("serialized_segment", &serialized_segment_));
333 OP_REQUIRES_OK(context,
334 context->GetAttr("workspace_size_bytes", &workspace_size_));
335 OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
336
337 VLOG(1) << "Constructing " << name();
338 string precision_string;
339 OP_REQUIRES_OK(context,
340 context->GetAttr("precision_mode", &precision_string));
341 string calibration_data;
342 OP_REQUIRES_OK(context,
343 context->GetAttr("calibration_data", &calibration_data));
344 OP_REQUIRES_OK(context, context->GetAttr("segment_func", &func_));
345 OP_REQUIRES(context, !func_.name().empty(),
346 errors::InvalidArgument(
347 "The TF function for the TRT segment could not be empty"));
348 OP_REQUIRES_OK(context,
349 TrtPrecisionModeFromName(precision_string, &precision_mode_));
350 OP_REQUIRES_OK(context,
351 context->GetAttr("use_calibration", &use_calibration_));
352 OP_REQUIRES_OK(context,
353 context->GetAttr("input_shapes", &input_partial_shapes_));
354 auto status =
355 context->GetAttr("_allow_build_at_runtime", &allow_build_at_runtime_);
356 if (status.code() == tensorflow::error::NOT_FOUND) {
357 VLOG(2) << "Not found _allow_build_at_runtime in "
358 << context->device()->name()
359 << ", thus setting _allow_build_at_runtime=true";
360 allow_build_at_runtime_ = true;
361 } else {
362 OP_REQUIRES_OK(context, status);
363 }
364
365 status = context->GetAttr("_allow_soft_placement", &allow_soft_placement_);
366 if (status.code() == tensorflow::error::NOT_FOUND) {
367 allow_soft_placement_ = true;
368 } else {
369 OP_REQUIRES_OK(context, status);
370 }
371
372 native_execution_func_handle_ = kInvalidHandle;
373 if (!static_engine_) {
374 OP_REQUIRES_OK(context, ImportSegmentGraphDef(context->function_library(),
375 context->device()->name()));
376 }
377 // TODO(laigd): calibration_data is used in TF v1.x and we keep it only for
378 // backward compatibility reasons. Remove it once all known users switch to
379 // 2.0.
380 calibration_mode_ =
381 (use_calibration_ && precision_mode_ == TrtPrecisionMode::INT8 &&
382 calibration_data.empty());
383 if (!calibration_data.empty()) {
384 calibrator_.reset(new TRTInt8Calibrator(calibration_data));
385 calibration_data.resize(0);
386 }
387 OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
388 &max_cached_engines_));
389
390 status = context->GetAttr("_use_implicit_batch", &use_implicit_batch_);
391 if (status.code() == tensorflow::error::NOT_FOUND) {
392 VLOG(2) << "Not found _use_implicit_batch in " << context->device()->name()
393 << ", thus setting _use_implicit_batch=true";
394 use_implicit_batch_ = true;
395 }
396 #if !IS_TRT_VERSION_GE(6, 0, 0, 0)
397 if (!use_implicit_batch_) {
398 VLOG(2) << "Need at least TensorRT 6.0 for explicit batch mode. Setting "
399 << "_use_implicit_batch=true";
400 use_implicit_batch_ = true;
401 }
402 #endif
403 status =
404 context->GetAttr("_profile_generation_mode", &profile_generation_mode_);
405 if (status.code() == tensorflow::error::NOT_FOUND) {
406 VLOG(2) << "Not found _profile_generation_mode in "
407 << context->device()->name()
408 << ", thus setting _profile_generation_mode=false";
409 profile_generation_mode_ = false;
410 }
411 if (use_implicit_batch_) {
412 OP_REQUIRES(context, !profile_generation_mode_,
413 errors::InvalidArgument(
414 "profile_generation_mode_=true is only supported if "
415 "use_implicit_batch=false"));
416 if (input_partial_shapes_.empty()) {
417 VLOG(1) << "Attribute input_shapes is not set. This happens probably "
418 << "because you are using a model that is already converted "
419 << "to TensorRT with a previous version of TF-TRT (i.e. includes "
420 << "TRTEngineOp in graph). This is not an error. If you convert "
421 << "the original model again to TensorRT, the attributes "
422 << "input_shapes will be set automatically.";
423 }
424 } else {
425 OP_REQUIRES(
426 context, !input_partial_shapes_.empty(),
427 errors::InvalidArgument(
428 "Explicit batch mode requires attribute input_shapes to be set."
429 "If you are using a model that was converted to TensorRT by a "
430 "previous version of TF-TRT, (i.e. includes TRTEngineOp in graph "
431 "without the input_shapes attribute), then you need to convert the "
432 "original model again to TensorRT in order to set the attribute "
433 "input_shapes."));
434 OP_REQUIRES(context, !calibration_mode_,
435 errors::InvalidArgument(
436 "Explicit batch mode does not support calibration"));
437 }
438 has_dynamic_shape_input_ = absl::c_any_of(
439 input_partial_shapes_,
440 [](PartialTensorShape shape) { return !shape.IsFullyDefined(); });
441 VLOG(2) << "TRTEngineOp has_dynamic_shape_input_: "
442 << has_dynamic_shape_input_;
443 }
444
ExecuteNativeSegment(OpKernelContext * ctx,AsyncHelper * helper)445 void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
446 AsyncHelper* helper) {
447 tensorflow::profiler::TraceMe activity(
448 "TRTEngineOp::ExecuteNativeSegment",
449 tensorflow::profiler::TraceMeLevel::kInfo);
450 std::vector<Tensor> inputs;
451 std::vector<Tensor>* outputs = new std::vector<Tensor>();
452 if (native_execution_func_handle_ == kInvalidHandle) {
453 StatusOr<FunctionLibraryRuntime::Handle> status_or_handle =
454 ConstructFunctionHandle(ctx->function_library(), ctx->device()->name(),
455 allow_soft_placement_, ctx->num_inputs(),
456 ctx->num_outputs());
457 OP_REQUIRES_OK_ASYNC(ctx, status_or_handle.status(), *helper);
458 native_execution_func_handle_ = status_or_handle.ValueOrDie();
459 }
460 auto lib = ctx->function_library();
461 FunctionLibraryRuntime::Options opts;
462 opts.rendezvous = ctx->rendezvous();
463 opts.cancellation_manager = ctx->cancellation_manager();
464 opts.runner = ctx->runner();
465 inputs.reserve(ctx->num_inputs());
466 for (int i = 0; i < ctx->num_inputs(); i++) {
467 inputs.push_back(ctx->input(i));
468 }
469 helper->Ref(); // Increment count for calculating native graph
470 VLOG(1) << "Executing native segment: " << name();
471 lib->Run(opts, native_execution_func_handle_, inputs, outputs,
472 [this, ctx, outputs, helper](const Status& s) {
473 core::ScopedUnref sc(helper);
474 std::unique_ptr<std::vector<Tensor>> outputs_wrapper(outputs);
475 OP_REQUIRES_OK_ASYNC(ctx, s, *helper);
476 VLOG(1) << "Native Segment completed";
477 for (size_t t = 0; t < outputs->size(); ++t) {
478 ctx->set_output(t, outputs->at(t));
479 }
480 });
481 }
482
ExecuteCalibration(OpKernelContext * ctx,TRTEngineCacheResource * cache_res,AsyncHelper * helper)483 void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
484 TRTEngineCacheResource* cache_res,
485 AsyncHelper* helper) {
486 tensorflow::profiler::TraceMe activity(
487 "TRTEngineOp::ExecuteCalibration",
488 tensorflow::profiler::TraceMeLevel::kInfo);
489 VLOG(1) << "Executing TRT calibration: " << name();
490 helper->Ref();
491 core::ScopedUnref sc(helper);
492
493 CalibrationContext* calib_ctx = cache_res->calib_ctx_.get();
494 const int num_inputs = ctx->num_inputs();
495 // TODO(laigd): need to check that input shape matches.
496 // Pass input data to calibrator
497 std::unordered_map<string, void*> input_data;
498 for (int i = 0; i < num_inputs; i++) {
499 const Tensor& t = ctx->input(i);
500 void* data_address = GetTensorAddress(&t);
501 OP_REQUIRES_ASYNC(ctx, data_address,
502 errors::InvalidArgument(
503 "Unsupported data type encountered in input ", i),
504 *helper);
505 // Check the allocated buffer is sufficient for input
506 const auto device_tensor =
507 calib_ctx->device_tensors_.at(i).AccessTensor(ctx);
508 CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
509 input_data.emplace(StrCat(IONamePrefixes::kInputPHName, i), data_address);
510 }
511 VLOG(2) << "Filled map for sending";
512 // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
513 const cudaStream_t* stream = CHECK_NOTNULL(
514 reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
515 ->stream()
516 ->implementation()
517 ->GpuStreamMemberHack()));
518 // If calibrator is terminated before, it means an error has occurred.
519 //
520 // Note: setBatch() will wait until TRTInt8Calibrator::getBatch() is called
521 // the first time before proceeding, so if buildCudaEngine() returns an error,
522 // it means getBatch() is never called, and the setBatch() here will hang
523 // until setDone() is called later by the calibration thread in
524 // AllocateCalibrationResources(). In that case, this setBatch() will always
525 // be able to detect the error and return false.
526 OP_REQUIRES_ASYNC(ctx, calib_ctx->calibrator_->setBatch(input_data, *stream),
527 errors::Internal("Failed to feed calibration data"),
528 *helper);
529 VLOG(2) << "Passed calibration data";
530 ExecuteNativeSegment(ctx, helper);
531 }
532
VerifyInputShapes(const std::vector<TensorShape> & input_concrete_shapes)533 Status TRTEngineOp::VerifyInputShapes(
534 const std::vector<TensorShape>& input_concrete_shapes) {
535 if (input_concrete_shapes.empty()) {
536 return errors::InvalidArgument("Input shapes are empty, for ", name());
537 }
538
539 if (input_partial_shapes_.empty()) {
540 if (!use_implicit_batch_) {
541 return errors::InvalidArgument(
542 "Explicit batch mode requires input_partial_shapes_ ",
543 "to contain the dynamic input shapes to TRTEngineOp");
544 }
545 // If the graph was converted with an earlier version of TF-TRT, it can
546 // happen that the input_partial_shapes_ vector is not set (see
547 // input_shapes attribute handling in the TRTEngineOp constructor).
548 // In implicit batch mode it is allowed to have empty input_partial_shapes_,
549 // since it is only required in explicit batch mode (see the input_shapes
550 // attribute of ConvertGraphDefToEngine in TRTEngineOp::GetEngine.
551 } else {
552 // Additional consistency checks if input_partial_shapes_ is present.
553 const string error_msg = StrCat(
554 "Input shapes do not match input partial shapes stored in graph, for ",
555 name(), ": ", DebugString(input_concrete_shapes),
556 " != ", DebugString(input_partial_shapes_));
557 if (input_concrete_shapes.size() != input_partial_shapes_.size()) {
558 return errors::InvalidArgument(error_msg);
559 }
560 for (int i = 0; i < input_concrete_shapes.size(); i++) {
561 if (input_concrete_shapes[i].dims() != input_partial_shapes_[i].dims()) {
562 return errors::InvalidArgument(error_msg);
563 }
564 }
565 for (int i = 0; i < input_concrete_shapes.size(); i++) {
566 for (int d = 0; d < input_concrete_shapes[i].dims(); d++) {
567 if (input_partial_shapes_[i].dim_size(d) != -1) {
568 if (input_concrete_shapes[i].dim_size(d) !=
569 input_partial_shapes_[i].dim_size(d)) {
570 return errors::InvalidArgument(error_msg);
571 }
572 }
573 }
574 }
575 }
576
577 if (use_implicit_batch_) {
578 if (input_concrete_shapes[0].dims() < 1) {
579 return errors::InvalidArgument(
580 "Input shapes contain scalar, for ", name(), ": ",
581 TensorShapeUtils::ShapeListString(input_concrete_shapes));
582 }
583 const int batch_size = input_concrete_shapes[0].dim_size(0);
584 if (batch_size < 1) {
585 return errors::InvalidArgument(
586 "Incorrect batch dimension, for ", name(), ": ",
587 TensorShapeUtils::ShapeListString(input_concrete_shapes));
588 }
589 for (const TensorShape& shape : input_concrete_shapes) {
590 if (batch_size != shape.dim_size(0)) {
591 return errors::InvalidArgument(
592 "Input shapes are inconsistent on the batch dimension, for ",
593 name(), ": ",
594 TensorShapeUtils::ShapeListString(input_concrete_shapes));
595 }
596 }
597 }
598 return Status::OK();
599 }
600
AllowEngineNativeSegmentExecution()601 static bool AllowEngineNativeSegmentExecution() {
602 bool value;
603 Status status =
604 ReadBoolFromEnvVar("TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION",
605 /*default_value=*/true, &value);
606 if (!status.ok()) {
607 LOG(ERROR) << status;
608 }
609 return value;
610 }
611
ComputeAsync(OpKernelContext * ctx,AsyncOpKernel::DoneCallback done)612 void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
613 AsyncOpKernel::DoneCallback done) {
614 tensorflow::profiler::TraceMe activity(
615 "TRTEngineOp::ComputeAsync", tensorflow::profiler::TraceMeLevel::kInfo);
616 auto helper = new AsyncHelper(done);
617 core::ScopedUnref sc(helper);
618
619 // Get TRT resource.
620 TRTEngineCacheResource* cache_res = nullptr;
621 OP_REQUIRES_OK_ASYNC(ctx, GetEngineCacheResource(ctx, &cache_res), *helper);
622 core::ScopedUnref unref_cache_res(cache_res);
623
624 // Run calibration if in int8+calibration mode.
625 // * Logic in TF 1.x:
626 // - During conversion: calibration_mode_ is true and cache size is 0, so it
627 // will run calibration.
628 // - During inference: calibration_data will be set, so calibration_mode_ is
629 // false and it won't trigger calibration.
630 // * Logic in TF 2.0:
631 // - During conversion: similar to 1.x.
632 // - During inference: calibration_data will still be empty, but cache will
633 // contain the the calibrated engine, so it won't trigger calibration.
634 //
635 // TODO(laigd): consider the following alternatives:
636 // 1. Serialize the state (calibration or inference) using
637 // TRTEngineInstance proto (or a new proto), so we know which mode we're
638 // in and don't run calibration during inference (which is invalid).
639 // 2. Reuse the calibration_data attribute or use a new attribute in the
640 // NodeDef to indicate whether it's in calibration mode.
641 if (calibration_mode_ && cache_res->cache_.size() == 0) {
642 if (!cache_res->calib_ctx_) {
643 // TODO(laigd): better encapsulation.
644 mutex_lock lock(engine_mutex_);
645 if (!cache_res->calib_ctx_) {
646 OP_REQUIRES_OK_ASYNC(ctx, AllocateCalibrationResources(ctx, cache_res),
647 *helper);
648 }
649 }
650 // TODO(laigd): check that the input shapes match the shapes of the
651 // persistent tensor in the calibration resource.
652 ExecuteCalibration(ctx, cache_res, helper);
653 return;
654 }
655
656 // Get shapes of inputs to engine.
657 std::vector<TensorShape> input_concrete_shapes;
658 input_concrete_shapes.reserve(ctx->num_inputs());
659 for (int i = 0; i < ctx->num_inputs(); ++i) {
660 input_concrete_shapes.push_back(ctx->input(i).shape());
661 }
662
663 Status verify_input_shape_status = VerifyInputShapes(input_concrete_shapes);
664 // TODO(bixia): Fix the segmentation.
665 if (!verify_input_shape_status.ok()) {
666 LOG_FIRST_FEW_WARNING_WITH_PREFIX
667 << "Running native segment for" << name()
668 << " due to failure in verifying input shapes: "
669 << verify_input_shape_status.error_message();
670 ExecuteNativeSegment(ctx, helper);
671 return;
672 }
673
674 if (!use_implicit_batch_ && has_dynamic_shape_input_) {
675 if (profile_generation_mode_) {
676 // Collecting new shapes for profiles can be only done once. After the
677 // shapes are converted to TRT profiles, no shapes can be collected
678 // anymore.
679 OP_REQUIRES(ctx, cache_res->profiles_.GetNumProfiles() == 0,
680 errors::Unimplemented("Cannot collect new shapes when "
681 "profiles are already created."));
682 // Just collect the input shape info and return. The shapes are used to
683 // generate optimization profiles during engine creation.
684 cache_res->profiles_.AddShape(input_concrete_shapes);
685 VLOG(1) << "Native segment is used during collecting shapes for profiles";
686 ExecuteNativeSegment(ctx, helper);
687 return;
688 } else if (cache_res->profiles_.GetNumProfiles() == 0) {
689 // Add current shape if we did not collect any shapes so far.
690 if (!cache_res->profiles_.HasShape()) {
691 cache_res->profiles_.AddShape(input_concrete_shapes);
692 }
693 // Create profiles out of collected shapes during profile generation.
694 cache_res->profiles_.InitProfiles(input_partial_shapes_);
695 }
696 }
697 StatusOr<std::pair<EngineContext*, int>> status =
698 GetEngine(input_concrete_shapes, ctx, cache_res);
699 OP_REQUIRES_OK_ASYNC(ctx, status.status(), *helper);
700
701 EngineContext* engine_context = status.ValueOrDie().first;
702 int trt_context_idx = status.ValueOrDie().second;
703 auto may_execute_native_segment = [&] {
704 if (!AllowEngineNativeSegmentExecution()) {
705 ctx->CtxFailure(
706 errors::Aborted("User disallowed engine native segment execution"));
707 return false;
708 }
709 return true;
710 };
711 if (!engine_context->cuda_engine) {
712 LOG_FIRST_FEW_WARNING_WITH_PREFIX
713 << "Engine retrieval for input shapes: "
714 << TensorShapeUtils::ShapeListString(input_concrete_shapes)
715 << " failed. Running native segment for " << name();
716 if (may_execute_native_segment()) {
717 ExecuteNativeSegment(ctx, helper);
718 }
719 return;
720 }
721 Status stat = ExecuteTrtEngine(ctx, engine_context, trt_context_idx);
722 if (!stat.ok()) {
723 LOG_FIRST_FEW_WARNING_WITH_PREFIX << "Failed to execute engine: " << stat
724 << " Retrying with native segment for "
725 << name();
726 if (!may_execute_native_segment()) {
727 return;
728 }
729 // Release any outputs that are allocated, ExecuteNativeSegment will
730 // re-allocate them and fail if they are currently allocated.
731 // The Tensor pointer in the returned TensorValue must be explicitly
732 // deleted.
733 for (int i = 0; i < ctx->num_outputs(); i++) {
734 delete ctx->release_output(i).tensor;
735 }
736 ExecuteNativeSegment(ctx, helper);
737 return;
738 }
739 }
740
ExecuteTrtEngine(OpKernelContext * ctx,EngineContext * engine_context,int trt_context_idx)741 Status TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
742 EngineContext* engine_context,
743 int trt_context_idx) {
744 tensorflow::profiler::TraceMe activity(
745 "TRTEngineOp::ExecuteTrtEngine",
746 tensorflow::profiler::TraceMeLevel::kInfo);
747 VLOG(1) << "Executing TRT engine: " << name();
748 auto& cuda_engine = engine_context->cuda_engine;
749
750 if (VLOG_IS_ON(2)) {
751 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
752 VLOG(2) << " Network name: " << cuda_engine->getName();
753 #endif // #if IS_TRT_VERSION_GE(6, 0, 0, 0)
754 VLOG(2) << " Activation size: " << cuda_engine->getDeviceMemorySize()
755 << " bytes";
756 VLOG(2) << " Workspace size: " << cuda_engine->getWorkspaceSize()
757 << " bytes";
758 VLOG(2) << " Datatype of " << cuda_engine->getNbBindings()
759 << " inputs/outputs";
760 string binding_types = "";
761 for (int i = 0; i < cuda_engine->getNbBindings(); i++) {
762 binding_types += " " + string(cuda_engine->getBindingName(i)) + ": " +
763 DebugString(cuda_engine->getBindingDataType(i)) + "\n";
764 }
765 VLOG(2) << binding_types;
766 }
767
768 const int num_binding = cuda_engine->getNbBindings();
769 std::vector<void*> buffers(num_binding);
770
771 // nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex
772 // for it.
773 mutex_lock lock(engine_context->mu);
774 nvinfer1::IExecutionContext* execution_context;
775 TF_RETURN_IF_ERROR(
776 engine_context->GetExecutionContext(trt_context_idx, &execution_context));
777
778 if (VLOG_IS_ON(2)) {
779 VLOG(2) << "Selected execution context: " << trt_context_idx;
780 }
781 const int num_batch =
782 use_implicit_batch_ ? ctx->input(0).shape().dim_size(0) : 0;
783
784 TF_RETURN_IF_ERROR(SetTrtEngineInputs(cuda_engine.get(), execution_context,
785 trt_context_idx, buffers,
786 use_implicit_batch_, num_batch, ctx));
787
788 TF_RETURN_IF_ERROR(SetTrtEngineOutputs(cuda_engine.get(), execution_context,
789 trt_context_idx, buffers,
790 use_implicit_batch_, num_batch, ctx));
791
792 // Copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
793 const cudaStream_t* stream = CHECK_NOTNULL(
794 reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
795 ->stream()
796 ->implementation()
797 ->GpuStreamMemberHack()));
798
799 TF_RETURN_IF_ERROR(TrtEnqueue(execution_context, buffers, *stream,
800 use_implicit_batch_, num_batch));
801 return Status::OK();
802 }
803
GetEngineCacheResource(OpKernelContext * ctx,TRTEngineCacheResource ** cache_res)804 Status TRTEngineOp::GetEngineCacheResource(OpKernelContext* ctx,
805 TRTEngineCacheResource** cache_res) {
806 // Canonicalize the op name by removing the scopes if any. This is mainly
807 // because in TFv2, the function graph can be instantiated in various ways and
808 // it'll insert scope names to the name of the TRTEngineOps, which will result
809 // in many different engine caches if we use the instantiated op name
810 // directly, but we still want all of them share the same cache (if they were
811 // representing the same subgraph).
812 absl::string_view resource_name = name();
813 size_t last_slash = resource_name.find_last_of('/');
814 if (last_slash != absl::string_view::npos) {
815 resource_name.remove_prefix(last_slash + 1);
816 }
817
818 // Get engine cache.
819 return ctx->resource_manager()->LookupOrCreate(
820 std::string(kTfTrtContainerName), std::string(resource_name), cache_res,
821 {[this, ctx](TRTEngineCacheResource** cr) -> Status {
822 *cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_);
823 return Status::OK();
824 }});
825 }
826
BuildEngine(const std::vector<TensorShape> & input_concrete_shapes,int batch_size,bool use_calibration,TRTInt8Calibrator * calibrator,TRTEngineCacheResource * cache_resource)827 StatusOr<TrtUniquePtrType<nvinfer1::ICudaEngine>> TRTEngineOp::BuildEngine(
828 const std::vector<TensorShape>& input_concrete_shapes, int batch_size,
829 bool use_calibration, TRTInt8Calibrator* calibrator,
830 TRTEngineCacheResource* cache_resource) {
831 VLOG(1) << "Building a new TensorRT engine for " << name()
832 << " with input shapes: "
833 << TensorShapeUtils::ShapeListString(input_concrete_shapes);
834
835 // Use concrete shapes for implicit batch mode and partial shapes for
836 // explicit batch mode.
837 const std::vector<PartialTensorShape>& conversion_input_shapes =
838 use_implicit_batch_
839 ? std::vector<PartialTensorShape>(input_concrete_shapes.begin(),
840 input_concrete_shapes.end())
841 : input_partial_shapes_;
842 TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
843 auto status = convert::ConvertGraphDefToEngine(
844 segment_graph_def_, precision_mode_, batch_size, workspace_size_,
845 conversion_input_shapes, &logger, cache_resource->allocator_.get(),
846 calibrator, &engine, use_calibration, use_implicit_batch_, nullptr,
847 &cache_resource->profiles_, name());
848 if (!status.ok()) {
849 LOG_FIRST_FEW_WARNING_WITH_PREFIX
850 << "Engine creation for " << name() << " failed. "
851 << "The native segment will be used instead. "
852 << "Reason: " << status;
853 // Store an empty engine in the cache for these input shapes so we don't try
854 // to build the same failing engine again.
855 cache_resource->cache_.emplace(input_concrete_shapes,
856 absl::make_unique<EngineContext>());
857 return status;
858 }
859 return engine;
860 }
861
GetEngine(const std::vector<TensorShape> & input_concrete_shapes,OpKernelContext * ctx,TRTEngineCacheResource * cache_res)862 StatusOr<std::pair<EngineContext*, int>> TRTEngineOp::GetEngine(
863 const std::vector<TensorShape>& input_concrete_shapes, OpKernelContext* ctx,
864 TRTEngineCacheResource* cache_res) {
865 static EngineContext empty_context;
866
867 mutex_lock lock(engine_mutex_);
868 // Using first input to get batch size is reliable - VerifyInputShapes()
869 // guarantees that the first input is not a scalar. As such we can always use
870 // the first input to get the batch size for implicit batch mode. For explicit
871 // batch mode, this value is not used.
872 const int batch_size = input_concrete_shapes[0].dim_size(0);
873 // TODO(Tamas): remove the need for batch_size in explicit_batch mode
874 auto& cache = cache_res->cache_;
875 auto allocator = cache_res->allocator_.get();
876 if (allocator == nullptr) {
877 return std::pair<EngineContext*, int>(&empty_context, 0);
878 }
879
880 // Handle the static engine case. For static engines, the cache will have a
881 // single element containing the only engine.
882 if (static_engine_) {
883 if (cache.size()) {
884 // TODO(laigd): need a better shape compatibility check for the case where
885 // implicit batch is disabled.
886 if (!use_implicit_batch_ ||
887 AreShapesCompatible(input_concrete_shapes, cache.begin()->first)) {
888 return std::pair<EngineContext*, int>(cache.begin()->second.get(), 0);
889 }
890 return std::pair<EngineContext*, int>(&empty_context, 0);
891 }
892
893 TrtUniquePtrType<IRuntime> infer(nvinfer1::createInferRuntime(logger));
894 infer->setGpuAllocator(allocator);
895 // Need to initialize plugins in order to deserialize engines that contain
896 // plugins.
897 MaybeInitializeTrtPlugins(&logger);
898 TrtUniquePtrType<nvinfer1::ICudaEngine> static_engine(
899 infer->deserializeCudaEngine(serialized_segment_.c_str(),
900 serialized_segment_.size(), nullptr));
901 if (!static_engine) {
902 if (!allow_build_at_runtime_) {
903 // Store an empty engine in the cache so we don't try to load the same
904 // failing engine again.
905 cache.emplace(input_concrete_shapes,
906 absl::make_unique<EngineContext>());
907 return std::pair<EngineContext*, int>(&empty_context, 0);
908 }
909 if (segment_graph_def_.node().empty()) {
910 Status status = ImportSegmentGraphDef(ctx->function_library(),
911 ctx->device()->name());
912 if (!status.ok()) {
913 LOG_FIRST_FEW_WARNING_WITH_PREFIX << "Getting segment graph for "
914 << name() << " failed. "
915 << "Reason: " << status;
916 }
917 }
918 auto result = BuildEngine(input_concrete_shapes, batch_size,
919 /*use_calibration=*/false,
920 /*calibrator=*/nullptr, cache_res);
921 if (!result.ok()) {
922 return std::pair<EngineContext*, int>(&empty_context, 0);
923 }
924 static_engine = std::move(result.ValueOrDie());
925 }
926 auto raw_static_engine = static_engine.get();
927 const auto max_batch_size = raw_static_engine->getMaxBatchSize();
928 // Static engine will have max_batch_size for batch size so that all inputs
929 // will map to this single engine.
930 std::vector<TensorShape> engine_input_shapes(input_concrete_shapes);
931 for (int i = 0; i < engine_input_shapes.size(); i++) {
932 engine_input_shapes[i].set_dim(0, max_batch_size);
933 }
934 auto exec_context_status =
935 ExecutionContext::Create(raw_static_engine, allocator);
936 if (!exec_context_status.ok()) {
937 return std::pair<EngineContext*, int>(&empty_context, 0);
938 }
939
940 // TODO(laigd): here we assume engine_input_shapes matches the actual input
941 // shapes of the engine, we should verify that.
942 cache.emplace(engine_input_shapes,
943 absl::make_unique<EngineContext>(
944 std::move(static_engine),
945 std::move(exec_context_status.ValueOrDie())));
946 // Runtime is safe to delete after engine creation
947 VLOG(1) << "Size of serialized TRT engine: "
948 << serialized_segment_.capacity();
949 string tmp;
950 // Swap with temporary empty string to deallocate the CPU memory.
951 serialized_segment_.swap(tmp);
952 if (use_implicit_batch_ && (max_batch_size < batch_size)) {
953 return std::pair<EngineContext*, int>(&empty_context, 0);
954 }
955 return std::pair<EngineContext*, int>(cache.at(engine_input_shapes).get(),
956 0);
957 } // static_engine_
958
959 int profile_id = -1;
960 if (!use_implicit_batch_) {
961 profile_id = cache_res->profiles_.GetProfileNumber(input_concrete_shapes);
962 // Since all profiles are already created at this point, finding no
963 // compatible profiles results in falling back to native TF.
964 if (profile_id == -1) {
965 return std::pair<EngineContext*, int>(&empty_context, 0);
966 }
967 }
968
969 EngineContext* engine_contexts;
970 if (use_implicit_batch_) {
971 engine_contexts = cache_res->GetEngineContext(input_concrete_shapes);
972 } else {
973 engine_contexts = cache_res->GetEngineContext(profile_id);
974 }
975
976 // If cache does not have a compatible engine then create a new engine.
977 if (engine_contexts == nullptr) {
978 if (!allow_build_at_runtime_) {
979 LOG_FIRST_FEW_WARNING_WITH_PREFIX
980 << "Found no engine in cache matching input shapes. "
981 << "Not building a new engine because "
982 << "allow_build_at_runtime=False. "
983 << "The native segment will be used instead.";
984 // Store an empty engine in the cache for these input shapes so we don't
985 // try to build the same failing engine again.
986 cache.emplace(input_concrete_shapes, absl::make_unique<EngineContext>());
987 return std::pair<EngineContext*, int>(&empty_context, 0);
988 }
989
990 // Up to this point, calibrator_ can never be empty, since otherwise it
991 // means calibration_mode_ is true and this path won't get executed.
992 auto result = BuildEngine(input_concrete_shapes, batch_size,
993 use_calibration_, calibrator_.get(), cache_res);
994 if (!result.ok()) {
995 return std::pair<EngineContext*, int>(&empty_context, 0);
996 }
997 TrtUniquePtrType<nvinfer1::ICudaEngine> engine =
998 std::move(result.ValueOrDie());
999 std::vector<ExecutionContext> exec_contexts;
1000 TF_RETURN_IF_ERROR(cache_res->profiles_.CreateExecutionContexts(
1001 engine.get(), exec_contexts, allocator));
1002 cache.emplace(input_concrete_shapes,
1003 absl::make_unique<EngineContext>(std::move(engine),
1004 std::move(exec_contexts)));
1005 VLOG(1) << "Added new engine to cache of " << name()
1006 << ". Cache size: " << cache.size();
1007 engine_contexts = cache.at(input_concrete_shapes).get();
1008 // Query which profile of the new engine matches the actual input.
1009 profile_id = cache_res->profiles_.GetProfileNumber(input_concrete_shapes);
1010 }
1011 return std::pair<EngineContext*, int>(engine_contexts,
1012 use_implicit_batch_ ? 0 : profile_id);
1013 }
1014
1015 // TODO(hinsu): Move this allocation to CalibrationContext constructor, if
1016 // possible.
AllocateCalibrationResources(OpKernelContext * ctx,TRTEngineCacheResource * cache_res)1017 Status TRTEngineOp::AllocateCalibrationResources(
1018 OpKernelContext* ctx, TRTEngineCacheResource* cache_res) {
1019 cache_res->calib_ctx_ = absl::make_unique<CalibrationContext>();
1020 auto* cres = cache_res->calib_ctx_.get();
1021
1022 // Get the input shapes.
1023 const int batch_size = ctx->input(0).dim_size(0);
1024 const int num_inputs = ctx->num_inputs();
1025 std::vector<TensorShape> shapes;
1026 cres->device_tensors_.resize(num_inputs);
1027 VLOG(1) << "Constructing calibrator";
1028 for (int i = 0; i < num_inputs; i++) {
1029 // allocate workspace on device for inputs
1030 const Tensor& t = ctx->input(i);
1031 shapes.emplace_back(t.shape());
1032 Tensor* device_tensor;
1033 TF_RETURN_IF_ERROR(ctx->allocate_persistent(
1034 t.dtype(), t.shape(), &cres->device_tensors_.at(i), &device_tensor));
1035 CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
1036 void* device_address = GetTensorAddress(device_tensor);
1037 if (device_address == nullptr) {
1038 return errors::InvalidArgument(
1039 "Unsupported data type encountered in input ", i);
1040 }
1041 cres->device_buffers_.emplace(
1042 StrCat(IONamePrefixes::kInputPHName, i),
1043 std::pair<void*, size_t>(device_address, device_tensor->TotalBytes()));
1044 }
1045 cres->calibrator_.reset(
1046 new TRTInt8Calibrator(cres->device_buffers_, batch_size, name()));
1047 const int platform_gpu_id =
1048 ctx->device()->tensorflow_gpu_device_info()->gpu_id;
1049 if (platform_gpu_id < 0) {
1050 LOG(ERROR) << "Can't get gpu_device_info from context->device()";
1051 return errors::InvalidArgument(
1052 "Context->device doesn't contain device info!");
1053 }
1054
1055 cache_res->Ref();
1056 cres->thr_.reset(new std::thread([this, cres, shapes, platform_gpu_id,
1057 cache_res]() {
1058 core::ScopedUnref sc(cache_res);
1059
1060 VLOG(1) << "Starting calibration thread on device " << platform_gpu_id
1061 << ", Calibration Resource @ " << cres;
1062 auto err = cudaSetDevice(platform_gpu_id);
1063 if (err != cudaSuccess) {
1064 // TODO(aaroey): should return error here.
1065 LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id
1066 << " in calibration thread";
1067 }
1068 std::vector<PartialTensorShape> partial_shapes(shapes.begin(),
1069 shapes.end());
1070 // ConvertGraphDefToEngine() will try to build the engine. This thread
1071 // will loop inside buildCudaEngine() consuming the calibration data
1072 // that is set by the TF op, and drive the builder until calibrator
1073 // returns false. Engine is discarded after calibration table is
1074 // generated
1075 //
1076 // TODO(aaroey): maybe setting the max batch size using the python
1077 // calibration wrapper class.
1078 auto s = convert::ConvertGraphDefToEngine(
1079 this->segment_graph_def_, TrtPrecisionMode::INT8,
1080 cres->calibrator_->getBatchSize(), this->workspace_size_,
1081 partial_shapes, &cache_res->GetLogger(), cache_res->allocator_.get(),
1082 cres->calibrator_.get(), &cres->engine_, /*use_calibration=*/true,
1083 this->use_implicit_batch_, /*convert_successfully=*/nullptr,
1084 /*profiles=*/nullptr, name());
1085 if (!s.ok()) {
1086 LOG(ERROR) << "Calibration failed: " << s;
1087 cres->calibrator_->setDone(); // Ignore further pushes
1088 } else {
1089 // Transfer the ownership of the engine to the engine cache, so we can
1090 // dump it out during conversion for TF 2.0.
1091 mutex_lock lock(this->engine_mutex_);
1092 this->calibrator_ = std::move(cres->calibrator_);
1093 auto exec_context_status = ExecutionContext::Create(
1094 cres->engine_.get(), cache_res->allocator_.get());
1095 if (!exec_context_status.ok()) {
1096 LOG(ERROR) << "Calibration failed: " << s;
1097 cres->calibrator_->setDone(); // Ignore further pushes
1098 } else {
1099 cache_res->cache_.emplace(
1100 shapes, absl::make_unique<EngineContext>(
1101 std::move(cres->engine_),
1102 std::move(exec_context_status.ValueOrDie())));
1103 }
1104 }
1105
1106 VLOG(1) << "Calibration loop terminated " << this->name();
1107 }));
1108 VLOG(1) << "initialized calibrator resource";
1109 return Status::OK();
1110 }
1111
1112 REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
1113
1114 } // namespace tensorrt
1115 } // namespace tensorflow
1116
1117 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
1118