1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_ 17 #define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_ 18 19 #include <list> 20 #include <string> 21 #include <unordered_set> 22 #include <vector> 23 24 #include "tensorflow/compiler/tf2tensorrt/common/datavec.h" 25 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" 26 #include "tensorflow/compiler/tf2tensorrt/utils/trt_execution_context.h" 27 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/tensor_shape.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/lib/core/status.h" 32 #include "tensorflow/core/lib/strings/str_util.h" 33 #include "tensorflow/core/lib/strings/strcat.h" 34 35 #if GOOGLE_CUDA && GOOGLE_TENSORRT 36 37 #include "third_party/tensorrt/NvInfer.h" 38 39 namespace tensorflow { 40 namespace tensorrt { 41 42 // Stores optimization profile parameters (min/opt/max of each input shape). 43 // 44 // A TensorRT optimization profile describes the possible min/max values of 45 // each dynamic input shape along with an optimum value. These values are used 46 // by the TensorRT builder to select the best kernel for the optimum value among 47 // those kernels that are valid for all input tensors in the [min, max] range. 48 struct OptimizationProfileConfig { 49 // Length of vector == 2*num_inputs to engine. min[0:num_inputs-1] are the min 50 // input dimensions for execution tensors. If engine has shape input tensors, 51 // then min[num_inputs + i] store the shape value for input i. For inputs that 52 // are not shape tensors min = opt = max = {0, {}}. 53 std::vector<nvinfer1::Dims> min; 54 std::vector<nvinfer1::Dims> opt; 55 std::vector<nvinfer1::Dims> max; 56 DebugStringOptimizationProfileConfig57 string DebugString() const { 58 using absl::StrCat; 59 return StrCat("[min: ", tensorflow::tensorrt::DebugString(min), 60 ", opt: : ", tensorflow::tensorrt::DebugString(opt), 61 ", max: ", tensorflow::tensorrt::DebugString(max), "]"); 62 } 63 64 // Sets the min/opt/max dimensions for profile. 65 // 66 // The given min/opt/max dimensions should satisfy the condition 67 // min <= opt <= max. Additionally TRT requires that the min/opt/max values 68 // are compatible with the network input. Compatibility is defined the 69 // following way: let dim be the shape of an input binding and min/opt/max the 70 // corresponding profile dims. TRT requires that dim.d[k] must be -1 if 71 // (min.d[k] != dim.d[k] || opt.d[k] != dim.d[k] || max.d[k] != dim.d[k]). 72 // 73 // Parameters: 74 // network - TensorRT network, used to enumerate all the input tensors 75 // profile - on exit the profile information will be set for each input tensor SetDimensionsOptimizationProfileConfig76 Status SetDimensions(const nvinfer1::INetworkDefinition* network, 77 nvinfer1::IOptimizationProfile* profile) const { 78 int n_inputs = network->getNbInputs(); 79 if (min.size() != 2 * n_inputs || opt.size() != 2 * n_inputs || 80 max.size() != 2 * n_inputs) { 81 return errors::Internal("Incorrect number of profile config parameters"); 82 } 83 for (int i = 0; i < n_inputs; i++) { 84 const ITensorProxyPtr input = network->getInput(i); 85 const char* name = input->getName(); 86 if (input->isShapeTensor()) { 87 int idx = i + n_inputs; 88 VLOG(2) << "Setting shape values for " << name << ", " 89 << ::tensorflow::tensorrt::DebugString(opt[idx]); 90 profile->setShapeValues(name, nvinfer1::OptProfileSelector::kMIN, 91 min[idx].d, min[idx].nbDims); 92 profile->setShapeValues(name, nvinfer1::OptProfileSelector::kOPT, 93 opt[idx].d, opt[idx].nbDims); 94 profile->setShapeValues(name, nvinfer1::OptProfileSelector::kMAX, 95 max[idx].d, max[idx].nbDims); 96 } 97 if (input->isExecutionTensor()) { 98 VLOG(2) << "Setting input dimensions for " << name << ", " 99 << ::tensorflow::tensorrt::DebugString(opt[i]); 100 profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN, 101 min[i]); 102 profile->setDimensions(name, nvinfer1::OptProfileSelector::kOPT, 103 opt[i]); 104 profile->setDimensions(name, nvinfer1::OptProfileSelector::kMAX, 105 max[i]); 106 } 107 } 108 return Status::OK(); 109 } 110 111 // Returns true if profile range completely includes the given shapes. IncludesShapesOptimizationProfileConfig112 bool IncludesShapes(const std::vector<TensorShape>& shapes, 113 bool has_shape_tensor, 114 const std::vector<nvinfer1::Dims>& shape_values) const { 115 // min, max, and opt must have the same size which is already verified in 116 // SetDimensions. 117 if (min.size() != shapes.size() * 2 || 118 (has_shape_tensor && min.size() != shape_values.size() * 2)) { 119 VLOG(2) << "Profile size mismatch min size " << min.size() 120 << " vs input shapes size " << shapes.size() << " " 121 << shape_values.size(); 122 return false; 123 } 124 for (int i = 0; i < shapes.size(); i++) { 125 auto current_shape = shapes[i]; 126 // min, max, and opt must have the same nbDims, which is already verified 127 // in SetDimensions. 128 if (min[i].nbDims != current_shape.dims()) { 129 return false; 130 } 131 // Check if range [min, max] includes current_shape. 132 for (int dim = 0; dim < current_shape.dims(); dim++) { 133 if ((min[i].d[dim] > current_shape.dim_size(dim)) || 134 (max[i].d[dim] < current_shape.dim_size(dim))) { 135 return false; 136 } 137 } 138 } 139 // Check shape values. 140 if (has_shape_tensor) { 141 int offset = shapes.size(); 142 for (int i = 0; i < shape_values.size(); i++) { 143 auto shape_val = shape_values[i]; 144 // min, max, and opt must have the same nbDims, which is already 145 // verified in SetDimensions. 146 if (min[i + offset].nbDims != shape_val.nbDims) { 147 return false; 148 } 149 // Check if range [min, max] includes shape_val. 150 for (int dim = 0; dim < shape_val.nbDims; dim++) { 151 if (min[i + offset].d[dim] > shape_val.d[dim] || 152 max[i + offset].d[dim] < shape_val.d[dim]) { 153 return false; 154 } 155 } 156 } 157 } 158 return true; 159 } 160 }; 161 162 // Manages Optimization profiles during TRT Engine construction. 163 // 164 // An optimization profile describes a range of dimensions for each TRT network 165 // input, and the optimal dimensions that the auto-tuner should use for 166 // optimization. 167 // 168 // This class stores the list of input shapes that were seen during the 169 // build/profile_generation_mode phase, and using them it creates a set of 170 // OptimizationProfileConfigs. These configs will be added to IBuilderConfig 171 // before the engine is created. 172 class TrtShapeOptimizationProfile { 173 public: TrtShapeOptimizationProfile()174 TrtShapeOptimizationProfile() {} 175 176 // Stores input shape information during profile_generation_mode. AddShape(const std::vector<TensorShape> & shapes)177 void AddShape(const std::vector<TensorShape>& shapes) { 178 input_shapes_.push_back(shapes); 179 input_shape_values_.push_back(actual_shape_values_); 180 VLOG(1) << "Collected shape(s) " << DebugString(shapes) << " for profiles."; 181 } 182 183 // Collects ShapeTensorCompatible tensor values. This is needed both during 184 // profile_generation_mode and during normal inference calls. 185 Status CollectShapeValues(OpKernelContext* ctx); 186 187 // Collects ShapeTensorCompatible tensor values, used only for unit tests. 188 Status CollectShapeValues(const DataVec& input); 189 clear()190 void clear() { profiles_.clear(); } 191 192 // Returns the profile number that should be used to execute the network with 193 // the given input shapes. Returns -1 if none of cached profiles are 194 // compatible with the given input shapes. 195 int GetProfileNumber(const std::vector<TensorShape>& shapes); 196 197 // Creates optimization profiles and add them to the builder config. 198 Status ConfigureBuilder(nvinfer1::IBuilder* builder, 199 nvinfer1::IBuilderConfig* config, 200 const nvinfer1::INetworkDefinition* network); 201 202 // Creates execution contexts for each optimization profile. 203 Status CreateExecutionContexts(nvinfer1::ICudaEngine* engine, 204 std::vector<ExecutionContext>* exec_contexts); 205 206 Status SetInputShapeBinding(int input_index, int binding_index, 207 nvinfer1::ICudaEngine* cuda_engine, 208 nvinfer1::IExecutionContext* exec_context) const; 209 210 // Creates optimization profiles profiles_ for the set of concrete input 211 // shapes collected in input_shapes_. The input_partial_shapes of the network 212 // is used to ensure that the created optimization profiles are compatible 213 // with the network. 214 void InitProfiles(const std::vector<PartialTensorShape>& input_partial_shapes, 215 ProfileStrategy strategy); 216 217 // Returns number of created profiles. 218 int GetNumProfiles() const; 219 HasShape()220 bool HasShape() const { return !input_shapes_.empty(); } NeedProfiles()221 bool NeedProfiles() const { return need_profiles_; } 222 223 // Restores profiles from the engine (used after deserialization). 224 Status RestoreProfiles(const nvinfer1::ICudaEngine* engine); 225 226 // Whether the network has any shape tensors. HasShapeTensor()227 bool HasShapeTensor() const { return has_shape_tensor_; } 228 229 void SetShapeTensorMask(const nvinfer1::INetworkDefinition* network); 230 231 // Whether the optimization profiles describe input that can be handled with 232 // a static engine (only 1 profile with min=max). IsStaticCompatible()233 bool IsStaticCompatible() { 234 return strategy_ == ProfileStrategy::kOptimal && profiles_.size() == 1 && 235 !HasShapeTensor(); 236 // TODO(tfeher): remove !HasShapeTensor() condition once the 237 // FixShapeValueProfile workaround is turned off. 238 } 239 240 private: 241 // Set of input shape vetors that we collect during profile_generation_mode. 242 std::vector<std::vector<TensorShape>> input_shapes_; 243 244 // Input shape values that we collect during profile_generation_mode. If the 245 // tensor is not compatible with a TRT shape tensor then an empty shape is 246 // stored. 247 std::vector<std::vector<nvinfer1::Dims>> input_shape_values_; 248 249 // Shape values present in the current inference call. 250 std::vector<nvinfer1::Dims> actual_shape_values_; 251 252 // The optimization profiles generated from input_shapes_. 253 std::vector<OptimizationProfileConfig> profiles_; 254 255 // Whether the network has any shape tensors. Initially we assume that the 256 // network might have a shape value input. This will be updated when the 257 // network is created / engine is deserialized. 258 bool has_shape_tensor_ = true; 259 260 // Whether the network/engine requires optimization profiles. 261 bool need_profiles_ = false; 262 263 // Whether an input tensor is a shape tensor. 264 std::vector<bool> is_shape_tensor_; 265 266 // Optimization profile generation strategy. 267 ProfileStrategy strategy_; 268 269 // Adds optimization profiles to the builder config. 270 Status AddProfiles(nvinfer1::IBuilder* builder, 271 nvinfer1::IBuilderConfig* config, 272 const nvinfer1::INetworkDefinition* network); 273 274 void SetShapeTensorMask(const nvinfer1::ICudaEngine* engine, int n_inputs); 275 void SetShapeTensorMask( 276 const std::vector<PartialTensorShape>& input_partial_shapes); 277 278 void ImplicitBatchModeCompatibleStrategy( 279 const std::vector<std::vector<nvinfer1::Dims>>& collected_shapes); 280 void OptimalStrategy( 281 const std::vector<std::vector<nvinfer1::Dims>>& collected_shapes); 282 Status RangeStrategy( 283 const std::vector<std::vector<nvinfer1::Dims>>& collected_shapes); 284 }; 285 286 } // namespace tensorrt 287 } // namespace tensorflow 288 289 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT 290 #endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_ 291