1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
17
18 #include <algorithm>
19 #include <functional>
20
21 #include "absl/algorithm/container.h"
22 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
23
24 #if GOOGLE_CUDA && GOOGLE_TENSORRT
25 namespace tensorflow {
26 namespace tensorrt {
27
28 // Returns a vector of nvinfer1::Dims for a vector of TensorShapes.
29 template <typename TensorShapeType>
GetDimVec(std::vector<TensorShapeType> shape_vec)30 std::vector<nvinfer1::Dims> GetDimVec(std::vector<TensorShapeType> shape_vec) {
31 std::vector<nvinfer1::Dims> dimvec(shape_vec.size());
32 absl::c_transform(shape_vec, dimvec.begin(), [](TensorShapeType shape) {
33 return TensorShapeToTrtDims(shape, false);
34 });
35 return dimvec;
36 }
37
38 // In dynamic shape mode the optimization profile dims are only allowed to
39 // differ from the network input dims where the network input dims have -1
40 // values. We enforce this condition by changing prof_dims if necessary.
EnforceCompatibility(nvinfer1::Dims * prof_dims,const PartialTensorShape & input_shape)41 void EnforceCompatibility(nvinfer1::Dims* prof_dims,
42 const PartialTensorShape& input_shape) {
43 for (int i = 0; i < input_shape.dims(); i++) {
44 if (input_shape.dim_size(i) != -1) {
45 prof_dims->d[i] = input_shape.dim_size(i);
46 }
47 }
48 }
49
SetImplicitBatchModeCompatibleProfile(const std::vector<nvinfer1::Dims> & dimvec,std::vector<nvinfer1::Dims> * min,std::vector<nvinfer1::Dims> * opt,std::vector<nvinfer1::Dims> * max)50 void SetImplicitBatchModeCompatibleProfile(
51 const std::vector<nvinfer1::Dims>& dimvec, std::vector<nvinfer1::Dims>* min,
52 std::vector<nvinfer1::Dims>* opt, std::vector<nvinfer1::Dims>* max) {
53 *min = dimvec;
54 for (auto& dim : *min) {
55 dim.d[0] = 1; // Set min batch size to 1.
56 }
57 *opt = dimvec;
58 *max = dimvec;
59 }
60
ImplicitBatchModeCompatibleStrategy()61 void TrtShapeOptimizationProfile::ImplicitBatchModeCompatibleStrategy() {
62 for (auto& shape_vec : input_shapes_) {
63 if (!shape_vec.empty()) {
64 std::vector<nvinfer1::Dims> dimvec = GetDimVec(shape_vec);
65 std::vector<nvinfer1::Dims> min, opt, max;
66 SetImplicitBatchModeCompatibleProfile(dimvec, &min, &opt, &max);
67 OptimizationProfileConfig profConfig{min, opt, max};
68 profiles_.push_back(std::move(profConfig));
69 }
70 }
71 }
72
OptimalStrategy()73 void TrtShapeOptimizationProfile::OptimalStrategy() {
74 for (auto& shape_vec : input_shapes_) {
75 if (!shape_vec.empty()) {
76 std::vector<nvinfer1::Dims> min = GetDimVec(shape_vec);
77 std::vector<nvinfer1::Dims> opt = min;
78 std::vector<nvinfer1::Dims> max = min;
79 OptimizationProfileConfig profConfig{min, opt, max};
80 profiles_.push_back(std::move(profConfig));
81 }
82 }
83 }
84
85 // Adjust shape value profile to prevent TRT from removing shape value input
86 // bindings whose value is redundant (only a single value matches the profile).
87 // This should be removed once the NVIDIA bug 3153064 is fixed.
FixShapeValueProfile(OptimizationProfileConfig * prof,const std::vector<bool> & is_shape_tensor)88 void FixShapeValueProfile(OptimizationProfileConfig* prof,
89 const std::vector<bool>& is_shape_tensor) {
90 for (int i = 0; i < prof->min.size(); i++) {
91 if (is_shape_tensor[i] &&
92 std::equal(prof->min[i].d, prof->min[i].d + prof->min[i].nbDims,
93 prof->max[i].d)) {
94 VLOG(2) << "Adjust profile for shape value tensor " << i;
95 prof->max[i].d[0]++;
96 }
97 }
98 }
99
InitProfiles(const std::vector<PartialTensorShape> & input_partial_shapes)100 void TrtShapeOptimizationProfile::InitProfiles(
101 const std::vector<PartialTensorShape>& input_partial_shapes) {
102 if (input_shapes_.size() == 0) {
103 VLOG(1) << "Not creating profiles without input_shapes. "
104 "You have to enable profile generation mode first (build).";
105 return;
106 }
107 switch (strategy_) {
108 case ProfileStrategy::kImplicitBatchModeCompatible:
109 VLOG(1) << "Creating profiles with ImplicitBatchModeCompatible strategy";
110 ImplicitBatchModeCompatibleStrategy();
111 break;
112 case ProfileStrategy::kOptimal:
113 VLOG(1) << "Creating profiles with Optimal strategy";
114 OptimalStrategy();
115 break;
116 }
117 // Define a mask that describe which input could be a shape tensor. Note that
118 // here we can have false positives. The shape tensor mask will be updated
119 // once the network is constructed.
120 SetShapeTensorMask(input_partial_shapes);
121 if (input_partial_shapes.size() > 0) {
122 for (OptimizationProfileConfig& prof : profiles_) {
123 // TODO: Remove this when the bug is fixed.
124 FixShapeValueProfile(&prof, is_shape_tensor_);
125 for (int i = 0; i < input_partial_shapes.size(); i++) {
126 auto network_input = input_partial_shapes[i];
127 EnforceCompatibility(&prof.min[i], network_input);
128 EnforceCompatibility(&prof.opt[i], network_input);
129 EnforceCompatibility(&prof.max[i], network_input);
130 }
131 }
132 }
133 }
134
135 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
AddProfiles(nvinfer1::IBuilder * builder,nvinfer1::IBuilderConfig * config,const nvinfer1::INetworkDefinition * network)136 Status TrtShapeOptimizationProfile::AddProfiles(
137 nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
138 const nvinfer1::INetworkDefinition* network) {
139 // Create a vector of optimization profiles.
140 for (int i = 0; i < profiles_.size(); i++) {
141 auto* optProfile = builder->createOptimizationProfile();
142 Status status = profiles_[i].SetDimensions(network, optProfile);
143 if (!status.ok()) {
144 return status;
145 }
146 int idx = -1;
147 if (optProfile->isValid()) {
148 idx = config->addOptimizationProfile(optProfile);
149 }
150 if (idx >= 0) {
151 if (i != idx) {
152 return errors::Internal(
153 "Profile index of engine config is different from resource profile "
154 "index: ",
155 i, " != ", idx);
156 }
157 VLOG(1) << "Added optimization profile " << profiles_[i].DebugString()
158 << " to builder config.";
159 } else {
160 LOG(ERROR) << "Failed to add optimization profile "
161 << profiles_[i].DebugString()
162 << ". This usually happens when profile is invalid.";
163 }
164 }
165 if (!profiles_.empty() && config->getNbOptimizationProfiles() == 0) {
166 return errors::Internal("Failure in adding an optimization profile.");
167 }
168 need_profiles_ = config->getNbOptimizationProfiles() > 0;
169 // Update the the mask that flag shape tensors. The network is known now,
170 // the mask will be correct.
171 SetShapeTensorMask(network);
172 // if TRT_VERSION < 6, then we do not need to add.
173 return Status::OK();
174 }
175 #endif
176
177 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
ConfigureBuilder(nvinfer1::IBuilder * builder,nvinfer1::IBuilderConfig * config,const nvinfer1::INetworkDefinition * network)178 Status TrtShapeOptimizationProfile::ConfigureBuilder(
179 nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
180 const nvinfer1::INetworkDefinition* network) {
181 TF_RETURN_IF_ERROR(AddProfiles(builder, config, network));
182 return Status::OK();
183 }
184 #endif
185
186 // Sets the shape tensor mask using the network definition.
SetShapeTensorMask(const nvinfer1::INetworkDefinition * network)187 void TrtShapeOptimizationProfile::SetShapeTensorMask(
188 const nvinfer1::INetworkDefinition* network) {
189 int n_inputs = network->getNbInputs();
190 is_shape_tensor_.resize(n_inputs, false);
191 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
192 for (int i = 0; i < n_inputs; i++) {
193 const nvinfer1::ITensor* input = network->getInput(i);
194 is_shape_tensor_[i] = input->isShapeTensor();
195 if (is_shape_tensor_[i]) {
196 VLOG(2) << "Found shape tensor " << input->getName() << ' at ' << i;
197 }
198 }
199 #endif
200 has_shape_tensor_ =
201 absl::c_any_of(is_shape_tensor_, [](bool b) { return b; });
202 }
203
204 // Sets the shape tensor mask using the input partial shapes. This only tells
205 // whether the tensors are shape value compatible, only the final network
206 // definition or the engine would give concrete answers.
SetShapeTensorMask(const std::vector<PartialTensorShape> & input_partial_shapes)207 void TrtShapeOptimizationProfile::SetShapeTensorMask(
208 const std::vector<PartialTensorShape>& input_partial_shapes) {
209 is_shape_tensor_.resize(input_partial_shapes.size(), false);
210 for (int i = 0; i < input_partial_shapes.size(); i++) {
211 is_shape_tensor_[i] = IsTrtShapeTensorCompatible(input_partial_shapes[i]);
212 if (is_shape_tensor_[i]) {
213 VLOG(2) << "Found shape compatible tensor at " << i;
214 }
215 }
216 has_shape_tensor_ =
217 absl::c_any_of(is_shape_tensor_, [](bool b) { return b; });
218 }
219
GetProfileNumber(const std::vector<TensorShape> & shapes)220 int TrtShapeOptimizationProfile::GetProfileNumber(
221 const std::vector<TensorShape>& shapes) {
222 if (!need_profiles_) return 0;
223 for (int i = 0; i < profiles_.size(); i++) {
224 if (profiles_[i].IncludesShapes(shapes)) {
225 return i;
226 }
227 }
228 VLOG(1) << "Profile not found for input shapes " << DebugString(shapes)
229 << ".";
230 return -1;
231 }
232
CreateExecutionContexts(nvinfer1::ICudaEngine * engine,std::vector<ExecutionContext> & exec_context,TRTBaseAllocator * memory_allocator)233 Status TrtShapeOptimizationProfile::CreateExecutionContexts(
234 nvinfer1::ICudaEngine* engine, std::vector<ExecutionContext>& exec_context,
235 TRTBaseAllocator* memory_allocator) {
236 int i = 0;
237 // The following loop runs once if we have static shapes, to create a single
238 // execution context without profiles. In dynamic mode we create one context
239 // for each profile and set the corresponding optimization profile.
240 do {
241 VLOG(1) << "Creating execution context " << i;
242 auto exec_context_status =
243 ExecutionContext::Create(engine, memory_allocator);
244 if (!exec_context_status.ok()) {
245 return errors::Internal("Failed to create execution context");
246 }
247 if (i > 0) {
248 // This condition is needed for two reasons:
249 // - using static shapes we do not have any profiles so we cannot call
250 // set optimizationprofiles.
251 // - The 0th profile is set implicitly for the first execution context
252 // therefore we do not need to set.
253 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
254 bool stat = exec_context_status.ValueOrDie()
255 .GetIExecutionContext()
256 ->setOptimizationProfile(i);
257 if (!stat) {
258 return errors::Internal("Could not set TRT optimization profile.");
259 }
260 #endif
261 }
262 exec_context.push_back(std::move(exec_context_status.ValueOrDie()));
263 i++;
264 } while (i < profiles_.size());
265
266 return Status::OK();
267 }
268
RestoreProfiles(const nvinfer1::ICudaEngine * engine)269 Status TrtShapeOptimizationProfile::RestoreProfiles(
270 const nvinfer1::ICudaEngine* engine) {
271 need_profiles_ = false;
272 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
273 if (!engine) {
274 // We do not need to restore profiles for an empty engine.
275 return Status::OK();
276 }
277 #if IS_TRT_VERSION_GE(7, 0, 0, 0)
278 if (engine->hasImplicitBatchDimension()) {
279 // Nothing to do, we cannot have profiles in implicit batch mode.
280 return Status::OK();
281 }
282 #endif
283 int n_profiles = engine->getNbOptimizationProfiles();
284 need_profiles_ = n_profiles > 0;
285 int n_inputs = GetNumberOfEngineInputs(engine);
286 VLOG(2) << "Attempting to restore " << n_profiles << " profiles, each with "
287 << n_inputs << " inputs";
288 for (int prof_idx = 0; prof_idx < n_profiles; prof_idx++) {
289 OptimizationProfileConfig cfg;
290 for (int j = 0; j < n_inputs; j++) {
291 nvinfer1::Dims min = engine->getProfileDimensions(
292 j, prof_idx, nvinfer1::OptProfileSelector::kMIN);
293 nvinfer1::Dims max = engine->getProfileDimensions(
294 j, prof_idx, nvinfer1::OptProfileSelector::kMAX);
295 nvinfer1::Dims opt = engine->getProfileDimensions(
296 j, prof_idx, nvinfer1::OptProfileSelector::kOPT);
297 cfg.min.push_back(min);
298 cfg.max.push_back(max);
299 cfg.opt.push_back(opt);
300 }
301 VLOG(2) << "Restored profile " << cfg.DebugString();
302 profiles_.push_back(std::move(cfg));
303 }
304 #endif
305 return Status::OK();
306 }
307
GetNumProfiles() const308 int TrtShapeOptimizationProfile::GetNumProfiles() const {
309 return profiles_.size();
310 }
311
312 } // namespace tensorrt
313 } // namespace tensorflow
314 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
315