• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
17 
18 #include <algorithm>
19 #include <functional>
20 
21 #include "absl/algorithm/container.h"
22 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
23 #include "tensorflow/core/platform/stream_executor.h"
24 
25 #if GOOGLE_CUDA && GOOGLE_TENSORRT
26 
27 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
28 
29 namespace tensorflow {
30 namespace tensorrt {
31 
32 // Returns a vector of nvinfer1::Dims for a vector of TensorShapes.
33 template <typename TensorShapeType>
GetDimVec(std::vector<TensorShapeType> shape_vec)34 std::vector<nvinfer1::Dims> GetDimVec(std::vector<TensorShapeType> shape_vec) {
35   std::vector<nvinfer1::Dims> dimvec(shape_vec.size());
36   absl::c_transform(shape_vec, dimvec.begin(), [](TensorShapeType shape) {
37     nvinfer1::Dims dims;
38     TF_CHECK_OK(TensorShapeToTrtDims(shape, false, &dims));
39     return dims;
40   });
41   return dimvec;
42 }
43 
44 // In dynamic shape mode the optimization profile dims are only allowed to
45 // differ from the network input dims where the network input dims have -1
46 // values. We enforce this condition by changing prof_dims if necessary.
EnforceCompatibility(nvinfer1::Dims * prof_dims,const PartialTensorShape & input_shape)47 void EnforceCompatibility(nvinfer1::Dims* prof_dims,
48                           const PartialTensorShape& input_shape) {
49   for (int i = 0; i < input_shape.dims(); i++) {
50     if (input_shape.dim_size(i) != -1) {
51       prof_dims->d[i] = input_shape.dim_size(i);
52     }
53   }
54 }
55 
SetImplicitBatchModeCompatibleProfile(const std::vector<nvinfer1::Dims> & dimvec,std::vector<nvinfer1::Dims> * min,std::vector<nvinfer1::Dims> * opt,std::vector<nvinfer1::Dims> * max)56 void SetImplicitBatchModeCompatibleProfile(
57     const std::vector<nvinfer1::Dims>& dimvec, std::vector<nvinfer1::Dims>* min,
58     std::vector<nvinfer1::Dims>* opt, std::vector<nvinfer1::Dims>* max) {
59   *min = dimvec;
60   for (auto& dim : *min) {
61     // Shape value tensors can have -1 value as a wildcard. We do not change
62     // in that case.
63     if (dim.d[0] != -1) dim.d[0] = 1;  // Set min batch size to 1.
64   }
65   *opt = dimvec;
66   *max = dimvec;
67 }
68 
ImplicitBatchModeCompatibleStrategy(const std::vector<std::vector<nvinfer1::Dims>> & collected_shapes)69 void TrtShapeOptimizationProfile::ImplicitBatchModeCompatibleStrategy(
70     const std::vector<std::vector<nvinfer1::Dims>>& collected_shapes) {
71   for (auto& shape_vec : collected_shapes) {
72     std::vector<nvinfer1::Dims> min, opt, max;
73     SetImplicitBatchModeCompatibleProfile(shape_vec, &min, &opt, &max);
74     VLOG(2) << "Initializing optimization profile config with min="
75             << DebugString(min) << ", opt=max=" << DebugString(max);
76     OptimizationProfileConfig profConfig{min, opt, max};
77     profiles_.push_back(std::move(profConfig));
78   }
79 }
80 
81 // Applies a binary operation for each dimension of the input shapes.
82 // x[i].d[k] = op(x[i].d[k], y[i].d[k]), where i enumerates the input tensors,
83 // and k enumerates the dimensions of the tensors. The BinaryOperation may be
84 // std::min, std::max etc.
85 template <typename BinaryOperation>
ShapeProfileBinaryOp(std::vector<nvinfer1::Dims> * x,const std::vector<nvinfer1::Dims> & y,BinaryOperation op)86 Status ShapeProfileBinaryOp(std::vector<nvinfer1::Dims>* x,
87                             const std::vector<nvinfer1::Dims>& y,
88                             BinaryOperation op) {
89   if (x->size() != y.size())
90     return errors::InvalidArgument(
91         "Number of input tensors differ during profile creation");
92   for (int i = 0; i < x->size(); i++) {
93     if (x->at(i).nbDims != y[i].nbDims)
94       return errors::InvalidArgument(
95           "Number of input dimensions differ during profile creation");
96     for (int j = 0; j < x->at(i).nbDims; j++) {
97       x->at(i).d[j] = op(x->at(i).d[j], y[i].d[j]);
98     }
99   }
100   return Status::OK();
101 }
102 
RangeStrategy(const std::vector<std::vector<nvinfer1::Dims>> & collected_shapes)103 Status TrtShapeOptimizationProfile::RangeStrategy(
104     const std::vector<std::vector<nvinfer1::Dims>>& collected_shapes) {
105   if (collected_shapes.empty()) return Status::OK();
106 
107   std::vector<nvinfer1::Dims> min = collected_shapes[0];
108   std::vector<nvinfer1::Dims> max = min;
109 
110   for (int i = 1; i < collected_shapes.size(); i++) {
111     TF_RETURN_IF_ERROR(
112         ShapeProfileBinaryOp(&min, collected_shapes[i],
113                              [](int a, int b) { return std::min(a, b); }));
114     TF_RETURN_IF_ERROR(
115         ShapeProfileBinaryOp(&max, collected_shapes[i],
116                              [](int a, int b) { return std::max(a, b); }));
117   }
118   VLOG(2) << "Initializing optimization profile config with min="
119           << DebugString(min) << ", opt=max=" << DebugString(max);
120   OptimizationProfileConfig profConfig{min, max, max};
121   profiles_.push_back(std::move(profConfig));
122   return Status::OK();
123 }
124 
OptimalStrategy(const std::vector<std::vector<nvinfer1::Dims>> & collected_shapes)125 void TrtShapeOptimizationProfile::OptimalStrategy(
126     const std::vector<std::vector<nvinfer1::Dims>>& collected_shapes) {
127   for (auto& shape_vec : collected_shapes) {
128     std::vector<nvinfer1::Dims> min = shape_vec;
129     std::vector<nvinfer1::Dims> opt = min;
130     std::vector<nvinfer1::Dims> max = min;
131     VLOG(2) << "Initializing optimization profile config with min=opt=max="
132             << DebugString(min);
133     OptimizationProfileConfig profConfig{min, opt, max};
134     profiles_.push_back(std::move(profConfig));
135   }
136 }
137 
138 // Collects the values of tensors that are ShapeTensorCompatible to. The values
139 // are stored in the actual_shape_values_ member variable.
CollectShapeValues(OpKernelContext * ctx)140 Status TrtShapeOptimizationProfile::CollectShapeValues(OpKernelContext* ctx) {
141   const cudaStream_t* stream = CHECK_NOTNULL(
142       reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
143                                                 ->stream()
144                                                 ->implementation()
145                                                 ->GpuStreamMemberHack()));
146   actual_shape_values_.resize(ctx->num_inputs());
147   if (is_shape_tensor_.empty()) {
148     is_shape_tensor_.resize(ctx->num_inputs());
149     for (int i = 0; i < ctx->num_inputs(); i++) {
150       is_shape_tensor_[i] = IsTrtShapeTensorCompatible(ctx->input(i));
151     }
152   }
153   int n_shape_val = 0;
154   // First copy all the shape value candidates into actual_shape_values_ vector.
155   for (int i = 0; i < ctx->num_inputs(); i++) {
156     if (is_shape_tensor_[i]) {
157       // We have to copy the shape values to the host, because TRT's
158       // ExecutionContext::setInputShapeBinding expects a host pointer.
159       n_shape_val++;
160       const Tensor& input = ctx->input(i);
161       actual_shape_values_[i].nbDims = input.NumElements();
162       auto ret = cudaMemcpyAsync(
163           actual_shape_values_[i].d, input.flat<int32>().data(),
164           input.NumElements() * sizeof(int32), cudaMemcpyDeviceToHost, *stream);
165       if (ret != 0) {
166         return errors::Internal("Could not copy shape tensor values");
167       }
168       VLOG(2) << "Input " << i << " is (probably) a shape tensor, n_values="
169               << input.NumElements();
170     } else {
171       actual_shape_values_[i] = {0, {}};
172     }
173   }
174   if (n_shape_val > 0) {
175     // If we have any shape values candidates, then wait until data is copied
176     // to host.
177     cudaStreamSynchronize(*stream);
178   }
179   return Status::OK();
180 }
181 
182 // Collects the values of tensors that are ShapeTensorCompatible to. To be used
183 // for unit tests.
CollectShapeValues(const DataVec & input)184 Status TrtShapeOptimizationProfile::CollectShapeValues(const DataVec& input) {
185   actual_shape_values_.resize(input.size());
186   for (int i = 0; i < input.size(); i++) {
187     if (is_shape_tensor_[i]) {
188       if (!IsTrtShapeTensorCompatible(input[i].tensor)) {
189         return errors::Internal("Inconsistent shape tensor ", input[i].name,
190                                 ", ", i);
191       }
192       int n_elements = input[i].tensor.NumElements();
193       actual_shape_values_[i].nbDims = n_elements;
194       // During unit tests, the data is in unified memory
195       std::copy(input[i].tensor.flat<int32>().data(),
196                 input[i].tensor.flat<int32>().data() + n_elements,
197                 actual_shape_values_[i].d);
198       VLOG(2) << "Collected tensor shape values "
199               << DebugString(actual_shape_values_[i]);
200     } else {
201       actual_shape_values_[i] = {0, {}};
202     }
203   }
204   return Status::OK();
205 }
206 
207 // Adjusts shape value profile to prevent TRT from removing shape value input
208 // bindings whose value is redundant (only a single value matches the profile).
209 // This should be removed once the NVIDIA bug 3153064 is fixed.
FixShapeValueProfile(OptimizationProfileConfig * prof,const std::vector<bool> & is_shape_tensor)210 void FixShapeValueProfile(OptimizationProfileConfig* prof,
211                           const std::vector<bool>& is_shape_tensor) {
212   int shape_value_offset = is_shape_tensor.size();
213   for (int i = 0; i < is_shape_tensor.size(); i++) {
214     if (is_shape_tensor[i] &&
215         std::equal(prof->min[shape_value_offset + i].d,
216                    prof->min[shape_value_offset + i].d +
217                        prof->min[shape_value_offset + i].nbDims,
218                    prof->max[shape_value_offset + i].d)) {
219       prof->max[shape_value_offset + i].d[0]++;
220       VLOG(2) << "Adjusted profile for shape value tensor " << i << " "
221               << DebugString(prof->max[shape_value_offset + i]);
222     } else {
223       VLOG(2) << i << " is not a shape tensor." << is_shape_tensor[i];
224     }
225   }
226 }
227 
228 // Checks whether rhs is already contained in values.
AlreadyCollected(const std::vector<std::vector<nvinfer1::Dims>> & values,const std::vector<nvinfer1::Dims> & rhs)229 bool AlreadyCollected(const std::vector<std::vector<nvinfer1::Dims>>& values,
230                       const std::vector<nvinfer1::Dims>& rhs) {
231   for (auto& lhs : values) {
232     bool ret = lhs.size() == rhs.size();
233     for (int i = 0; ret && i < lhs.size(); i++) {
234       ret &= lhs[i].nbDims == rhs[i].nbDims;
235       for (int j = 0; ret && j < lhs[i].nbDims; j++) {
236         ret &= (lhs[i].d[j] == rhs[i].d[j]);
237       }
238     }
239     if (ret) return true;
240   }
241   return false;
242 }
243 
InitProfiles(const std::vector<PartialTensorShape> & input_partial_shapes,ProfileStrategy strategy)244 void TrtShapeOptimizationProfile::InitProfiles(
245     const std::vector<PartialTensorShape>& input_partial_shapes,
246     ProfileStrategy strategy) {
247   strategy_ = strategy;
248   if (input_shapes_.size() == 0) {
249     VLOG(1) << "Not creating profiles without input_shapes. "
250                "You have to enable profile generation mode first (build).";
251     return;
252   }
253   // Preprocess the vector of input shapes and shape values:
254   // - Converts TensorShape -> nvinfer::Dims.
255   // - Concatenates the shape values after the input shapes:
256   //   dimvec = [dim0, dim1,..., shapeval0, shapval1, ...]
257   // - Ensures that the list is unique.
258   std::vector<std::vector<nvinfer1::Dims>> collected_shapes;
259   for (int i = 0; i < input_shapes_.size(); i++) {
260     auto shape_vec = input_shapes_[i];
261     VLOG(2) << "Initprofiles, processing shape " << i;
262     if (!shape_vec.empty()) {
263       std::vector<nvinfer1::Dims> dimvec = GetDimVec(shape_vec);
264       dimvec.insert(dimvec.end(), input_shape_values_[i].begin(),
265                     input_shape_values_[i].end());
266       // TODO(tfeher): This condition should not apply for explicit profile. In
267       // that case consicutive elements in collected_shapes contain the user
268       // defined values of min, opt and max, and it is valid the have min = opt
269       // and opt = max.
270       if (!AlreadyCollected(collected_shapes, dimvec)) {
271         collected_shapes.push_back(dimvec);
272       }
273     }
274   }
275   switch (strategy_) {
276     case ProfileStrategy::kImplicitBatchModeCompatible:
277       VLOG(1) << "Creating profiles with ImplicitBatchModeCompatible strategy";
278       ImplicitBatchModeCompatibleStrategy(collected_shapes);
279       break;
280     // Treat all other strategies the same as kOptimal for now. Implementing
281     // those is outlined in the dynamic shape support implementation plan.
282     case ProfileStrategy::kRange:
283       VLOG(1) << "Creating profiles with Range strategy";
284       TF_CHECK_OK(RangeStrategy(collected_shapes));
285       break;
286     case ProfileStrategy::kRangeOptimal:
287       VLOG(1) << "Creating profiles with RangeOptimal strategy";
288       OptimalStrategy(collected_shapes);
289       TF_CHECK_OK(RangeStrategy(collected_shapes));
290       break;
291     case ProfileStrategy::kOptimal:
292       VLOG(1) << "Creating profiles with Optimal strategy";
293       OptimalStrategy(collected_shapes);
294       break;
295   }
296   // Define a mask that describe which input could be a shape tensor. Note
297   // that here we can have false positives. The shape tensor mask will be
298   // updated once the network is constructed.
299   SetShapeTensorMask(input_partial_shapes);
300   if (input_partial_shapes.size() > 0) {
301     for (OptimizationProfileConfig& prof : profiles_) {
302       // TODO: Remove this when the bug is fixed.
303       FixShapeValueProfile(&prof, is_shape_tensor_);
304       for (int i = 0; i < input_partial_shapes.size(); i++) {
305         auto network_input = input_partial_shapes[i];
306         EnforceCompatibility(&prof.min[i], network_input);
307         EnforceCompatibility(&prof.opt[i], network_input);
308         EnforceCompatibility(&prof.max[i], network_input);
309       }
310     }
311   }
312 }
313 
AddProfiles(nvinfer1::IBuilder * builder,nvinfer1::IBuilderConfig * config,const nvinfer1::INetworkDefinition * network)314 Status TrtShapeOptimizationProfile::AddProfiles(
315     nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
316     const nvinfer1::INetworkDefinition* network) {
317   // Create a vector of optimization profiles.
318   for (int i = 0; i < profiles_.size(); i++) {
319     auto* optProfile = builder->createOptimizationProfile();
320     Status status = profiles_[i].SetDimensions(network, optProfile);
321     if (!status.ok()) {
322       return status;
323     }
324     int idx = -1;
325     if (optProfile->isValid()) {
326       idx = config->addOptimizationProfile(optProfile);
327     }
328     if (idx >= 0) {
329       if (i != idx) {
330         return errors::Internal(
331             "Profile index of engine config is different from source profile "
332             "index: ",
333             i, " != ", idx);
334       }
335       VLOG(1) << "Added optimization profile " << profiles_[i].DebugString()
336               << " with idx " << idx << " to builder config.";
337     } else {
338       LOG(ERROR) << "Failed to add optimization profile "
339                  << profiles_[i].DebugString()
340                  << ". This usually happens when profile is invalid.";
341     }
342   }
343   if (!profiles_.empty() && config->getNbOptimizationProfiles() == 0) {
344     return errors::Internal("Failure in adding an optimization profile.");
345   }
346   need_profiles_ = config->getNbOptimizationProfiles() > 0;
347   // Update the the mask that flag shape tensors. The network is known now,
348   // the mask will be correct.
349   SetShapeTensorMask(network);
350   // if TRT_VERSION < 6, then we do not need to add.
351   return Status::OK();
352 }
353 
ConfigureBuilder(nvinfer1::IBuilder * builder,nvinfer1::IBuilderConfig * config,const nvinfer1::INetworkDefinition * network)354 Status TrtShapeOptimizationProfile::ConfigureBuilder(
355     nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
356     const nvinfer1::INetworkDefinition* network) {
357   TF_RETURN_IF_ERROR(AddProfiles(builder, config, network));
358   return Status::OK();
359 }
360 
361 // Sets the shape tensor mask from the TRT engine definition.
SetShapeTensorMask(const nvinfer1::ICudaEngine * engine,int n_inputs)362 void TrtShapeOptimizationProfile::SetShapeTensorMask(
363     const nvinfer1::ICudaEngine* engine, int n_inputs) {
364   is_shape_tensor_.resize(n_inputs, false);
365   for (int i = 0; i < n_inputs; i++) {
366     is_shape_tensor_[i] = engine->isShapeBinding(i);
367     if (is_shape_tensor_[i]) {
368       VLOG(2) << "Found shape tensor at " << i;
369     }
370   }
371   has_shape_tensor_ =
372       absl::c_any_of(is_shape_tensor_, [](bool b) { return b; });
373 }
374 
375 // Sets the shape tensor mask using the network definition.
SetShapeTensorMask(const nvinfer1::INetworkDefinition * network)376 void TrtShapeOptimizationProfile::SetShapeTensorMask(
377     const nvinfer1::INetworkDefinition* network) {
378   int n_inputs = network->getNbInputs();
379   is_shape_tensor_.resize(n_inputs, false);
380   for (int i = 0; i < n_inputs; i++) {
381     const ITensorProxyPtr input = network->getInput(i);
382     is_shape_tensor_[i] = input->isShapeTensor();
383     if (is_shape_tensor_[i]) {
384       VLOG(2) << "Found shape tensor " << input->getName() << ' at ' << i;
385     }
386   }
387   has_shape_tensor_ =
388       absl::c_any_of(is_shape_tensor_, [](bool b) { return b; });
389 }
390 
391 // Sets the shape tensor mask using the input partial shapes. This only tells
392 // whether the tensors are shape value compatible, only the final network
393 // definition or the engine would give concrete answers.
SetShapeTensorMask(const std::vector<PartialTensorShape> & input_partial_shapes)394 void TrtShapeOptimizationProfile::SetShapeTensorMask(
395     const std::vector<PartialTensorShape>& input_partial_shapes) {
396   is_shape_tensor_.resize(input_partial_shapes.size(), false);
397   for (int i = 0; i < input_partial_shapes.size(); i++) {
398     is_shape_tensor_[i] = IsTrtShapeTensorCompatible(input_partial_shapes[i]);
399     if (is_shape_tensor_[i]) {
400       VLOG(2) << "Found shape compatible tensor at " << i;
401     }
402   }
403   has_shape_tensor_ =
404       absl::c_any_of(is_shape_tensor_, [](bool b) { return b; });
405 }
406 
GetProfileNumber(const std::vector<TensorShape> & shapes)407 int TrtShapeOptimizationProfile::GetProfileNumber(
408     const std::vector<TensorShape>& shapes) {
409   if (!need_profiles_) return 0;
410   // TODO(tfeher): Return the best profile not just the first compatible.
411   for (int i = 0; i < profiles_.size(); i++) {
412     if (profiles_[i].IncludesShapes(shapes, HasShapeTensor(),
413                                     actual_shape_values_)) {
414       return i;
415     }
416   }
417   VLOG(1) << "Profile not found for input shapes " << DebugString(shapes)
418           << ".";
419   return -1;
420 }
421 
CreateExecutionContexts(nvinfer1::ICudaEngine * engine,std::vector<ExecutionContext> * exec_contexts)422 Status TrtShapeOptimizationProfile::CreateExecutionContexts(
423     nvinfer1::ICudaEngine* engine,
424     std::vector<ExecutionContext>* exec_contexts) {
425   int i = 0;
426   // The following loop runs once if we have static shapes, to create a single
427   // execution context without profiles. In dynamic mode we create one context
428   // for each profile and set the corresponding optimization profile.
429   do {
430     VLOG(1) << "Creating execution context " << i;
431     ExecutionContext context = ExecutionContext::Create(engine);
432     if (i > 0) {
433       // This condition is needed for two reasons:
434       // - using static shapes we do not have any profiles so we cannot call
435       //   set optimizationprofiles.
436       // - The 0th profile is set implicitly for the first execution context
437       //   therefore we do not need to set.
438       if (!context->setOptimizationProfile(i)) {
439         return errors::Internal("Could not set TRT optimization profile.");
440       }
441     }
442     exec_contexts->push_back(std::move(context));
443     i++;
444   } while (i < profiles_.size());
445 
446   return Status::OK();
447 }
448 
SetInputShapeBinding(int input_index,int binding_index,nvinfer1::ICudaEngine * cuda_engine,nvinfer1::IExecutionContext * exec_context) const449 Status TrtShapeOptimizationProfile::SetInputShapeBinding(
450     int input_index, int binding_index, nvinfer1::ICudaEngine* cuda_engine,
451     nvinfer1::IExecutionContext* exec_context) const {
452   if (cuda_engine->isShapeBinding(binding_index)) {
453     // Input shape binding data has to be in host memory. That is the reason
454     // we can't use input_tensor.flat().data(). which contains the same
455     // values in device memory. Instead, we use data that was copied to host
456     // by CollectShapeValues.
457     VLOG(2) << "Setting input shape binding for idx " << binding_index
458             << ", with values "
459             << DebugString(actual_shape_values_.at(input_index));
460     bool ret = exec_context->setInputShapeBinding(
461         binding_index, actual_shape_values_.at(input_index).d);
462     if (!ret) {
463       return errors::Internal("Could not set input shape binding for idx ",
464                               binding_index);
465     }
466   }
467   return Status::OK();
468 }
469 
470 // If binding_idx is a shape tensor, then returns the associated min/max/opt
471 // shape values from prof_idx.
GetDimsFromShapeVal(int prof_idx,int binding_idx,nvinfer1::OptProfileSelector selector,const nvinfer1::ICudaEngine * engine)472 nvinfer1::Dims GetDimsFromShapeVal(int prof_idx, int binding_idx,
473                                    nvinfer1::OptProfileSelector selector,
474                                    const nvinfer1::ICudaEngine* engine) {
475   if (engine->isShapeBinding(binding_idx)) {
476     const int32* shape_val_ptr =
477         engine->getProfileShapeValues(binding_idx, prof_idx, selector);
478     if (shape_val_ptr) {
479       VLOG(2) << "Found shape value in prof " << prof_idx << ", binding "
480               << binding_idx;
481       nvinfer1::Dims dims = engine->getBindingDimensions(binding_idx);
482       // nbDims == 0 represent scalar, -1 represents invalid dim
483       int n_values = (dims.nbDims == 0) ? 1 : dims.d[0];
484       if (n_values > 0) {
485         dims.nbDims = n_values;
486         std::copy(shape_val_ptr, shape_val_ptr + n_values, dims.d);
487       }
488       return dims;
489     }
490   }
491   return {0, {0}};
492 }
493 
RestoreProfiles(const nvinfer1::ICudaEngine * engine)494 Status TrtShapeOptimizationProfile::RestoreProfiles(
495     const nvinfer1::ICudaEngine* engine) {
496   need_profiles_ = false;
497   if (!engine) {
498     // We do not need to restore profiles for an empty engine.
499     return Status::OK();
500   }
501   if (engine->hasImplicitBatchDimension()) {
502     // Nothing to do, we cannot have profiles in implicit batch mode.
503     return Status::OK();
504   }
505   int n_profiles = engine->getNbOptimizationProfiles();
506   need_profiles_ = n_profiles > 0;
507 #if IS_TRT_VERSION_GE(7, 1, 3, 0)
508   int n_bindings = engine->getNbBindings();
509   int K = n_bindings / n_profiles;
510 #endif
511   int n_inputs = GetNumberOfEngineInputs(engine);
512   VLOG(2) << "Attempting to restore " << n_profiles << " profiles, each with "
513           << n_inputs << " inputs";
514   SetShapeTensorMask(engine, n_inputs);
515   for (int prof_idx = 0; prof_idx < n_profiles; prof_idx++) {
516     OptimizationProfileConfig cfg;
517 
518     cfg.min.resize(n_inputs * 2);
519     cfg.max.resize(n_inputs * 2);
520     cfg.opt.resize(n_inputs * 2);
521     // restore shape values
522     for (int j = 0; j < n_inputs; j++) {
523 #if IS_TRT_VERSION_GE(7, 1, 3, 0)
524       // TODO(tfeher): consider getting the binding idx from
525       // GetTrtBindingIndex. To make that work we need to construct the input
526       // name similarily as it is done in SetTrtEngineInputs.
527       int binding_idx = prof_idx * K + j;
528 #else
529       int binding_idx = j;
530 #endif
531       nvinfer1::Dims min = engine->getProfileDimensions(
532           binding_idx, prof_idx, nvinfer1::OptProfileSelector::kMIN);
533       nvinfer1::Dims max = engine->getProfileDimensions(
534           binding_idx, prof_idx, nvinfer1::OptProfileSelector::kMAX);
535       nvinfer1::Dims opt = engine->getProfileDimensions(
536           binding_idx, prof_idx, nvinfer1::OptProfileSelector::kOPT);
537       cfg.min[j] = min;
538       cfg.max[j] = max;
539       cfg.opt[j] = opt;
540 
541       cfg.min[j + n_inputs] = GetDimsFromShapeVal(
542           prof_idx, binding_idx, nvinfer1::OptProfileSelector::kMIN, engine);
543       cfg.max[j + n_inputs] = GetDimsFromShapeVal(
544           prof_idx, binding_idx, nvinfer1::OptProfileSelector::kMAX, engine);
545       cfg.opt[j + n_inputs] = GetDimsFromShapeVal(
546           prof_idx, binding_idx, nvinfer1::OptProfileSelector::kOPT, engine);
547     }
548     VLOG(2) << "Restored profile " << cfg.DebugString();
549     profiles_.push_back(std::move(cfg));
550   }
551   return Status::OK();
552 }
553 
GetNumProfiles() const554 int TrtShapeOptimizationProfile::GetNumProfiles() const {
555   return profiles_.size();
556 }
557 
558 }  // namespace tensorrt
559 }  // namespace tensorflow
560 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
561