1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
17
18 #include <algorithm>
19 #include <functional>
20
21 #include "absl/algorithm/container.h"
22 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
23 #include "tensorflow/core/platform/stream_executor.h"
24
25 #if GOOGLE_CUDA && GOOGLE_TENSORRT
26
27 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
28
29 namespace tensorflow {
30 namespace tensorrt {
31
32 // Returns a vector of nvinfer1::Dims for a vector of TensorShapes.
33 template <typename TensorShapeType>
GetDimVec(std::vector<TensorShapeType> shape_vec)34 std::vector<nvinfer1::Dims> GetDimVec(std::vector<TensorShapeType> shape_vec) {
35 std::vector<nvinfer1::Dims> dimvec(shape_vec.size());
36 absl::c_transform(shape_vec, dimvec.begin(), [](TensorShapeType shape) {
37 nvinfer1::Dims dims;
38 TF_CHECK_OK(TensorShapeToTrtDims(shape, false, &dims));
39 return dims;
40 });
41 return dimvec;
42 }
43
44 // In dynamic shape mode the optimization profile dims are only allowed to
45 // differ from the network input dims where the network input dims have -1
46 // values. We enforce this condition by changing prof_dims if necessary.
EnforceCompatibility(nvinfer1::Dims * prof_dims,const PartialTensorShape & input_shape)47 void EnforceCompatibility(nvinfer1::Dims* prof_dims,
48 const PartialTensorShape& input_shape) {
49 for (int i = 0; i < input_shape.dims(); i++) {
50 if (input_shape.dim_size(i) != -1) {
51 prof_dims->d[i] = input_shape.dim_size(i);
52 }
53 }
54 }
55
SetImplicitBatchModeCompatibleProfile(const std::vector<nvinfer1::Dims> & dimvec,std::vector<nvinfer1::Dims> * min,std::vector<nvinfer1::Dims> * opt,std::vector<nvinfer1::Dims> * max)56 void SetImplicitBatchModeCompatibleProfile(
57 const std::vector<nvinfer1::Dims>& dimvec, std::vector<nvinfer1::Dims>* min,
58 std::vector<nvinfer1::Dims>* opt, std::vector<nvinfer1::Dims>* max) {
59 *min = dimvec;
60 for (auto& dim : *min) {
61 // Shape value tensors can have -1 value as a wildcard. We do not change
62 // in that case.
63 if (dim.d[0] != -1) dim.d[0] = 1; // Set min batch size to 1.
64 }
65 *opt = dimvec;
66 *max = dimvec;
67 }
68
ImplicitBatchModeCompatibleStrategy(const std::vector<std::vector<nvinfer1::Dims>> & collected_shapes)69 void TrtShapeOptimizationProfile::ImplicitBatchModeCompatibleStrategy(
70 const std::vector<std::vector<nvinfer1::Dims>>& collected_shapes) {
71 for (auto& shape_vec : collected_shapes) {
72 std::vector<nvinfer1::Dims> min, opt, max;
73 SetImplicitBatchModeCompatibleProfile(shape_vec, &min, &opt, &max);
74 VLOG(2) << "Initializing optimization profile config with min="
75 << DebugString(min) << ", opt=max=" << DebugString(max);
76 OptimizationProfileConfig profConfig{min, opt, max};
77 profiles_.push_back(std::move(profConfig));
78 }
79 }
80
81 // Applies a binary operation for each dimension of the input shapes.
82 // x[i].d[k] = op(x[i].d[k], y[i].d[k]), where i enumerates the input tensors,
83 // and k enumerates the dimensions of the tensors. The BinaryOperation may be
84 // std::min, std::max etc.
85 template <typename BinaryOperation>
ShapeProfileBinaryOp(std::vector<nvinfer1::Dims> * x,const std::vector<nvinfer1::Dims> & y,BinaryOperation op)86 Status ShapeProfileBinaryOp(std::vector<nvinfer1::Dims>* x,
87 const std::vector<nvinfer1::Dims>& y,
88 BinaryOperation op) {
89 if (x->size() != y.size())
90 return errors::InvalidArgument(
91 "Number of input tensors differ during profile creation");
92 for (int i = 0; i < x->size(); i++) {
93 if (x->at(i).nbDims != y[i].nbDims)
94 return errors::InvalidArgument(
95 "Number of input dimensions differ during profile creation");
96 for (int j = 0; j < x->at(i).nbDims; j++) {
97 x->at(i).d[j] = op(x->at(i).d[j], y[i].d[j]);
98 }
99 }
100 return Status::OK();
101 }
102
RangeStrategy(const std::vector<std::vector<nvinfer1::Dims>> & collected_shapes)103 Status TrtShapeOptimizationProfile::RangeStrategy(
104 const std::vector<std::vector<nvinfer1::Dims>>& collected_shapes) {
105 if (collected_shapes.empty()) return Status::OK();
106
107 std::vector<nvinfer1::Dims> min = collected_shapes[0];
108 std::vector<nvinfer1::Dims> max = min;
109
110 for (int i = 1; i < collected_shapes.size(); i++) {
111 TF_RETURN_IF_ERROR(
112 ShapeProfileBinaryOp(&min, collected_shapes[i],
113 [](int a, int b) { return std::min(a, b); }));
114 TF_RETURN_IF_ERROR(
115 ShapeProfileBinaryOp(&max, collected_shapes[i],
116 [](int a, int b) { return std::max(a, b); }));
117 }
118 VLOG(2) << "Initializing optimization profile config with min="
119 << DebugString(min) << ", opt=max=" << DebugString(max);
120 OptimizationProfileConfig profConfig{min, max, max};
121 profiles_.push_back(std::move(profConfig));
122 return Status::OK();
123 }
124
OptimalStrategy(const std::vector<std::vector<nvinfer1::Dims>> & collected_shapes)125 void TrtShapeOptimizationProfile::OptimalStrategy(
126 const std::vector<std::vector<nvinfer1::Dims>>& collected_shapes) {
127 for (auto& shape_vec : collected_shapes) {
128 std::vector<nvinfer1::Dims> min = shape_vec;
129 std::vector<nvinfer1::Dims> opt = min;
130 std::vector<nvinfer1::Dims> max = min;
131 VLOG(2) << "Initializing optimization profile config with min=opt=max="
132 << DebugString(min);
133 OptimizationProfileConfig profConfig{min, opt, max};
134 profiles_.push_back(std::move(profConfig));
135 }
136 }
137
138 // Collects the values of tensors that are ShapeTensorCompatible to. The values
139 // are stored in the actual_shape_values_ member variable.
CollectShapeValues(OpKernelContext * ctx)140 Status TrtShapeOptimizationProfile::CollectShapeValues(OpKernelContext* ctx) {
141 const cudaStream_t* stream = CHECK_NOTNULL(
142 reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
143 ->stream()
144 ->implementation()
145 ->GpuStreamMemberHack()));
146 actual_shape_values_.resize(ctx->num_inputs());
147 if (is_shape_tensor_.empty()) {
148 is_shape_tensor_.resize(ctx->num_inputs());
149 for (int i = 0; i < ctx->num_inputs(); i++) {
150 is_shape_tensor_[i] = IsTrtShapeTensorCompatible(ctx->input(i));
151 }
152 }
153 int n_shape_val = 0;
154 // First copy all the shape value candidates into actual_shape_values_ vector.
155 for (int i = 0; i < ctx->num_inputs(); i++) {
156 if (is_shape_tensor_[i]) {
157 // We have to copy the shape values to the host, because TRT's
158 // ExecutionContext::setInputShapeBinding expects a host pointer.
159 n_shape_val++;
160 const Tensor& input = ctx->input(i);
161 actual_shape_values_[i].nbDims = input.NumElements();
162 auto ret = cudaMemcpyAsync(
163 actual_shape_values_[i].d, input.flat<int32>().data(),
164 input.NumElements() * sizeof(int32), cudaMemcpyDeviceToHost, *stream);
165 if (ret != 0) {
166 return errors::Internal("Could not copy shape tensor values");
167 }
168 VLOG(2) << "Input " << i << " is (probably) a shape tensor, n_values="
169 << input.NumElements();
170 } else {
171 actual_shape_values_[i] = {0, {}};
172 }
173 }
174 if (n_shape_val > 0) {
175 // If we have any shape values candidates, then wait until data is copied
176 // to host.
177 cudaStreamSynchronize(*stream);
178 }
179 return Status::OK();
180 }
181
182 // Collects the values of tensors that are ShapeTensorCompatible to. To be used
183 // for unit tests.
CollectShapeValues(const DataVec & input)184 Status TrtShapeOptimizationProfile::CollectShapeValues(const DataVec& input) {
185 actual_shape_values_.resize(input.size());
186 for (int i = 0; i < input.size(); i++) {
187 if (is_shape_tensor_[i]) {
188 if (!IsTrtShapeTensorCompatible(input[i].tensor)) {
189 return errors::Internal("Inconsistent shape tensor ", input[i].name,
190 ", ", i);
191 }
192 int n_elements = input[i].tensor.NumElements();
193 actual_shape_values_[i].nbDims = n_elements;
194 // During unit tests, the data is in unified memory
195 std::copy(input[i].tensor.flat<int32>().data(),
196 input[i].tensor.flat<int32>().data() + n_elements,
197 actual_shape_values_[i].d);
198 VLOG(2) << "Collected tensor shape values "
199 << DebugString(actual_shape_values_[i]);
200 } else {
201 actual_shape_values_[i] = {0, {}};
202 }
203 }
204 return Status::OK();
205 }
206
207 // Adjusts shape value profile to prevent TRT from removing shape value input
208 // bindings whose value is redundant (only a single value matches the profile).
209 // This should be removed once the NVIDIA bug 3153064 is fixed.
FixShapeValueProfile(OptimizationProfileConfig * prof,const std::vector<bool> & is_shape_tensor)210 void FixShapeValueProfile(OptimizationProfileConfig* prof,
211 const std::vector<bool>& is_shape_tensor) {
212 int shape_value_offset = is_shape_tensor.size();
213 for (int i = 0; i < is_shape_tensor.size(); i++) {
214 if (is_shape_tensor[i] &&
215 std::equal(prof->min[shape_value_offset + i].d,
216 prof->min[shape_value_offset + i].d +
217 prof->min[shape_value_offset + i].nbDims,
218 prof->max[shape_value_offset + i].d)) {
219 prof->max[shape_value_offset + i].d[0]++;
220 VLOG(2) << "Adjusted profile for shape value tensor " << i << " "
221 << DebugString(prof->max[shape_value_offset + i]);
222 } else {
223 VLOG(2) << i << " is not a shape tensor." << is_shape_tensor[i];
224 }
225 }
226 }
227
228 // Checks whether rhs is already contained in values.
AlreadyCollected(const std::vector<std::vector<nvinfer1::Dims>> & values,const std::vector<nvinfer1::Dims> & rhs)229 bool AlreadyCollected(const std::vector<std::vector<nvinfer1::Dims>>& values,
230 const std::vector<nvinfer1::Dims>& rhs) {
231 for (auto& lhs : values) {
232 bool ret = lhs.size() == rhs.size();
233 for (int i = 0; ret && i < lhs.size(); i++) {
234 ret &= lhs[i].nbDims == rhs[i].nbDims;
235 for (int j = 0; ret && j < lhs[i].nbDims; j++) {
236 ret &= (lhs[i].d[j] == rhs[i].d[j]);
237 }
238 }
239 if (ret) return true;
240 }
241 return false;
242 }
243
InitProfiles(const std::vector<PartialTensorShape> & input_partial_shapes,ProfileStrategy strategy)244 void TrtShapeOptimizationProfile::InitProfiles(
245 const std::vector<PartialTensorShape>& input_partial_shapes,
246 ProfileStrategy strategy) {
247 strategy_ = strategy;
248 if (input_shapes_.size() == 0) {
249 VLOG(1) << "Not creating profiles without input_shapes. "
250 "You have to enable profile generation mode first (build).";
251 return;
252 }
253 // Preprocess the vector of input shapes and shape values:
254 // - Converts TensorShape -> nvinfer::Dims.
255 // - Concatenates the shape values after the input shapes:
256 // dimvec = [dim0, dim1,..., shapeval0, shapval1, ...]
257 // - Ensures that the list is unique.
258 std::vector<std::vector<nvinfer1::Dims>> collected_shapes;
259 for (int i = 0; i < input_shapes_.size(); i++) {
260 auto shape_vec = input_shapes_[i];
261 VLOG(2) << "Initprofiles, processing shape " << i;
262 if (!shape_vec.empty()) {
263 std::vector<nvinfer1::Dims> dimvec = GetDimVec(shape_vec);
264 dimvec.insert(dimvec.end(), input_shape_values_[i].begin(),
265 input_shape_values_[i].end());
266 // TODO(tfeher): This condition should not apply for explicit profile. In
267 // that case consicutive elements in collected_shapes contain the user
268 // defined values of min, opt and max, and it is valid the have min = opt
269 // and opt = max.
270 if (!AlreadyCollected(collected_shapes, dimvec)) {
271 collected_shapes.push_back(dimvec);
272 }
273 }
274 }
275 switch (strategy_) {
276 case ProfileStrategy::kImplicitBatchModeCompatible:
277 VLOG(1) << "Creating profiles with ImplicitBatchModeCompatible strategy";
278 ImplicitBatchModeCompatibleStrategy(collected_shapes);
279 break;
280 // Treat all other strategies the same as kOptimal for now. Implementing
281 // those is outlined in the dynamic shape support implementation plan.
282 case ProfileStrategy::kRange:
283 VLOG(1) << "Creating profiles with Range strategy";
284 TF_CHECK_OK(RangeStrategy(collected_shapes));
285 break;
286 case ProfileStrategy::kRangeOptimal:
287 VLOG(1) << "Creating profiles with RangeOptimal strategy";
288 OptimalStrategy(collected_shapes);
289 TF_CHECK_OK(RangeStrategy(collected_shapes));
290 break;
291 case ProfileStrategy::kOptimal:
292 VLOG(1) << "Creating profiles with Optimal strategy";
293 OptimalStrategy(collected_shapes);
294 break;
295 }
296 // Define a mask that describe which input could be a shape tensor. Note
297 // that here we can have false positives. The shape tensor mask will be
298 // updated once the network is constructed.
299 SetShapeTensorMask(input_partial_shapes);
300 if (input_partial_shapes.size() > 0) {
301 for (OptimizationProfileConfig& prof : profiles_) {
302 // TODO: Remove this when the bug is fixed.
303 FixShapeValueProfile(&prof, is_shape_tensor_);
304 for (int i = 0; i < input_partial_shapes.size(); i++) {
305 auto network_input = input_partial_shapes[i];
306 EnforceCompatibility(&prof.min[i], network_input);
307 EnforceCompatibility(&prof.opt[i], network_input);
308 EnforceCompatibility(&prof.max[i], network_input);
309 }
310 }
311 }
312 }
313
AddProfiles(nvinfer1::IBuilder * builder,nvinfer1::IBuilderConfig * config,const nvinfer1::INetworkDefinition * network)314 Status TrtShapeOptimizationProfile::AddProfiles(
315 nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
316 const nvinfer1::INetworkDefinition* network) {
317 // Create a vector of optimization profiles.
318 for (int i = 0; i < profiles_.size(); i++) {
319 auto* optProfile = builder->createOptimizationProfile();
320 Status status = profiles_[i].SetDimensions(network, optProfile);
321 if (!status.ok()) {
322 return status;
323 }
324 int idx = -1;
325 if (optProfile->isValid()) {
326 idx = config->addOptimizationProfile(optProfile);
327 }
328 if (idx >= 0) {
329 if (i != idx) {
330 return errors::Internal(
331 "Profile index of engine config is different from source profile "
332 "index: ",
333 i, " != ", idx);
334 }
335 VLOG(1) << "Added optimization profile " << profiles_[i].DebugString()
336 << " with idx " << idx << " to builder config.";
337 } else {
338 LOG(ERROR) << "Failed to add optimization profile "
339 << profiles_[i].DebugString()
340 << ". This usually happens when profile is invalid.";
341 }
342 }
343 if (!profiles_.empty() && config->getNbOptimizationProfiles() == 0) {
344 return errors::Internal("Failure in adding an optimization profile.");
345 }
346 need_profiles_ = config->getNbOptimizationProfiles() > 0;
347 // Update the the mask that flag shape tensors. The network is known now,
348 // the mask will be correct.
349 SetShapeTensorMask(network);
350 // if TRT_VERSION < 6, then we do not need to add.
351 return Status::OK();
352 }
353
ConfigureBuilder(nvinfer1::IBuilder * builder,nvinfer1::IBuilderConfig * config,const nvinfer1::INetworkDefinition * network)354 Status TrtShapeOptimizationProfile::ConfigureBuilder(
355 nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
356 const nvinfer1::INetworkDefinition* network) {
357 TF_RETURN_IF_ERROR(AddProfiles(builder, config, network));
358 return Status::OK();
359 }
360
361 // Sets the shape tensor mask from the TRT engine definition.
SetShapeTensorMask(const nvinfer1::ICudaEngine * engine,int n_inputs)362 void TrtShapeOptimizationProfile::SetShapeTensorMask(
363 const nvinfer1::ICudaEngine* engine, int n_inputs) {
364 is_shape_tensor_.resize(n_inputs, false);
365 for (int i = 0; i < n_inputs; i++) {
366 is_shape_tensor_[i] = engine->isShapeBinding(i);
367 if (is_shape_tensor_[i]) {
368 VLOG(2) << "Found shape tensor at " << i;
369 }
370 }
371 has_shape_tensor_ =
372 absl::c_any_of(is_shape_tensor_, [](bool b) { return b; });
373 }
374
375 // Sets the shape tensor mask using the network definition.
SetShapeTensorMask(const nvinfer1::INetworkDefinition * network)376 void TrtShapeOptimizationProfile::SetShapeTensorMask(
377 const nvinfer1::INetworkDefinition* network) {
378 int n_inputs = network->getNbInputs();
379 is_shape_tensor_.resize(n_inputs, false);
380 for (int i = 0; i < n_inputs; i++) {
381 const ITensorProxyPtr input = network->getInput(i);
382 is_shape_tensor_[i] = input->isShapeTensor();
383 if (is_shape_tensor_[i]) {
384 VLOG(2) << "Found shape tensor " << input->getName() << ' at ' << i;
385 }
386 }
387 has_shape_tensor_ =
388 absl::c_any_of(is_shape_tensor_, [](bool b) { return b; });
389 }
390
391 // Sets the shape tensor mask using the input partial shapes. This only tells
392 // whether the tensors are shape value compatible, only the final network
393 // definition or the engine would give concrete answers.
SetShapeTensorMask(const std::vector<PartialTensorShape> & input_partial_shapes)394 void TrtShapeOptimizationProfile::SetShapeTensorMask(
395 const std::vector<PartialTensorShape>& input_partial_shapes) {
396 is_shape_tensor_.resize(input_partial_shapes.size(), false);
397 for (int i = 0; i < input_partial_shapes.size(); i++) {
398 is_shape_tensor_[i] = IsTrtShapeTensorCompatible(input_partial_shapes[i]);
399 if (is_shape_tensor_[i]) {
400 VLOG(2) << "Found shape compatible tensor at " << i;
401 }
402 }
403 has_shape_tensor_ =
404 absl::c_any_of(is_shape_tensor_, [](bool b) { return b; });
405 }
406
GetProfileNumber(const std::vector<TensorShape> & shapes)407 int TrtShapeOptimizationProfile::GetProfileNumber(
408 const std::vector<TensorShape>& shapes) {
409 if (!need_profiles_) return 0;
410 // TODO(tfeher): Return the best profile not just the first compatible.
411 for (int i = 0; i < profiles_.size(); i++) {
412 if (profiles_[i].IncludesShapes(shapes, HasShapeTensor(),
413 actual_shape_values_)) {
414 return i;
415 }
416 }
417 VLOG(1) << "Profile not found for input shapes " << DebugString(shapes)
418 << ".";
419 return -1;
420 }
421
CreateExecutionContexts(nvinfer1::ICudaEngine * engine,std::vector<ExecutionContext> * exec_contexts)422 Status TrtShapeOptimizationProfile::CreateExecutionContexts(
423 nvinfer1::ICudaEngine* engine,
424 std::vector<ExecutionContext>* exec_contexts) {
425 int i = 0;
426 // The following loop runs once if we have static shapes, to create a single
427 // execution context without profiles. In dynamic mode we create one context
428 // for each profile and set the corresponding optimization profile.
429 do {
430 VLOG(1) << "Creating execution context " << i;
431 ExecutionContext context = ExecutionContext::Create(engine);
432 if (i > 0) {
433 // This condition is needed for two reasons:
434 // - using static shapes we do not have any profiles so we cannot call
435 // set optimizationprofiles.
436 // - The 0th profile is set implicitly for the first execution context
437 // therefore we do not need to set.
438 if (!context->setOptimizationProfile(i)) {
439 return errors::Internal("Could not set TRT optimization profile.");
440 }
441 }
442 exec_contexts->push_back(std::move(context));
443 i++;
444 } while (i < profiles_.size());
445
446 return Status::OK();
447 }
448
SetInputShapeBinding(int input_index,int binding_index,nvinfer1::ICudaEngine * cuda_engine,nvinfer1::IExecutionContext * exec_context) const449 Status TrtShapeOptimizationProfile::SetInputShapeBinding(
450 int input_index, int binding_index, nvinfer1::ICudaEngine* cuda_engine,
451 nvinfer1::IExecutionContext* exec_context) const {
452 if (cuda_engine->isShapeBinding(binding_index)) {
453 // Input shape binding data has to be in host memory. That is the reason
454 // we can't use input_tensor.flat().data(). which contains the same
455 // values in device memory. Instead, we use data that was copied to host
456 // by CollectShapeValues.
457 VLOG(2) << "Setting input shape binding for idx " << binding_index
458 << ", with values "
459 << DebugString(actual_shape_values_.at(input_index));
460 bool ret = exec_context->setInputShapeBinding(
461 binding_index, actual_shape_values_.at(input_index).d);
462 if (!ret) {
463 return errors::Internal("Could not set input shape binding for idx ",
464 binding_index);
465 }
466 }
467 return Status::OK();
468 }
469
470 // If binding_idx is a shape tensor, then returns the associated min/max/opt
471 // shape values from prof_idx.
GetDimsFromShapeVal(int prof_idx,int binding_idx,nvinfer1::OptProfileSelector selector,const nvinfer1::ICudaEngine * engine)472 nvinfer1::Dims GetDimsFromShapeVal(int prof_idx, int binding_idx,
473 nvinfer1::OptProfileSelector selector,
474 const nvinfer1::ICudaEngine* engine) {
475 if (engine->isShapeBinding(binding_idx)) {
476 const int32* shape_val_ptr =
477 engine->getProfileShapeValues(binding_idx, prof_idx, selector);
478 if (shape_val_ptr) {
479 VLOG(2) << "Found shape value in prof " << prof_idx << ", binding "
480 << binding_idx;
481 nvinfer1::Dims dims = engine->getBindingDimensions(binding_idx);
482 // nbDims == 0 represent scalar, -1 represents invalid dim
483 int n_values = (dims.nbDims == 0) ? 1 : dims.d[0];
484 if (n_values > 0) {
485 dims.nbDims = n_values;
486 std::copy(shape_val_ptr, shape_val_ptr + n_values, dims.d);
487 }
488 return dims;
489 }
490 }
491 return {0, {0}};
492 }
493
RestoreProfiles(const nvinfer1::ICudaEngine * engine)494 Status TrtShapeOptimizationProfile::RestoreProfiles(
495 const nvinfer1::ICudaEngine* engine) {
496 need_profiles_ = false;
497 if (!engine) {
498 // We do not need to restore profiles for an empty engine.
499 return Status::OK();
500 }
501 if (engine->hasImplicitBatchDimension()) {
502 // Nothing to do, we cannot have profiles in implicit batch mode.
503 return Status::OK();
504 }
505 int n_profiles = engine->getNbOptimizationProfiles();
506 need_profiles_ = n_profiles > 0;
507 #if IS_TRT_VERSION_GE(7, 1, 3, 0)
508 int n_bindings = engine->getNbBindings();
509 int K = n_bindings / n_profiles;
510 #endif
511 int n_inputs = GetNumberOfEngineInputs(engine);
512 VLOG(2) << "Attempting to restore " << n_profiles << " profiles, each with "
513 << n_inputs << " inputs";
514 SetShapeTensorMask(engine, n_inputs);
515 for (int prof_idx = 0; prof_idx < n_profiles; prof_idx++) {
516 OptimizationProfileConfig cfg;
517
518 cfg.min.resize(n_inputs * 2);
519 cfg.max.resize(n_inputs * 2);
520 cfg.opt.resize(n_inputs * 2);
521 // restore shape values
522 for (int j = 0; j < n_inputs; j++) {
523 #if IS_TRT_VERSION_GE(7, 1, 3, 0)
524 // TODO(tfeher): consider getting the binding idx from
525 // GetTrtBindingIndex. To make that work we need to construct the input
526 // name similarily as it is done in SetTrtEngineInputs.
527 int binding_idx = prof_idx * K + j;
528 #else
529 int binding_idx = j;
530 #endif
531 nvinfer1::Dims min = engine->getProfileDimensions(
532 binding_idx, prof_idx, nvinfer1::OptProfileSelector::kMIN);
533 nvinfer1::Dims max = engine->getProfileDimensions(
534 binding_idx, prof_idx, nvinfer1::OptProfileSelector::kMAX);
535 nvinfer1::Dims opt = engine->getProfileDimensions(
536 binding_idx, prof_idx, nvinfer1::OptProfileSelector::kOPT);
537 cfg.min[j] = min;
538 cfg.max[j] = max;
539 cfg.opt[j] = opt;
540
541 cfg.min[j + n_inputs] = GetDimsFromShapeVal(
542 prof_idx, binding_idx, nvinfer1::OptProfileSelector::kMIN, engine);
543 cfg.max[j + n_inputs] = GetDimsFromShapeVal(
544 prof_idx, binding_idx, nvinfer1::OptProfileSelector::kMAX, engine);
545 cfg.opt[j + n_inputs] = GetDimsFromShapeVal(
546 prof_idx, binding_idx, nvinfer1::OptProfileSelector::kOPT, engine);
547 }
548 VLOG(2) << "Restored profile " << cfg.DebugString();
549 profiles_.push_back(std::move(cfg));
550 }
551 return Status::OK();
552 }
553
GetNumProfiles() const554 int TrtShapeOptimizationProfile::GetNumProfiles() const {
555 return profiles_.size();
556 }
557
558 } // namespace tensorrt
559 } // namespace tensorflow
560 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
561