1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_ 17 #define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_ 18 19 #include <list> 20 #include <string> 21 #include <unordered_set> 22 #include <vector> 23 24 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" 25 #include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h" 26 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/lib/strings/str_util.h" 31 #include "tensorflow/core/lib/strings/strcat.h" 32 33 #if GOOGLE_CUDA && GOOGLE_TENSORRT 34 35 #include "third_party/tensorrt/NvInfer.h" 36 37 namespace tensorflow { 38 namespace tensorrt { 39 40 // Stores optimization profile parameters (min/opt/max of each input shape). 41 // 42 // A TensorRT optimization profile describes the possible min/max values of 43 // each dynamic input shape along with an optimum value. These values are used 44 // by the TensorRT builder to select the best kernel for the optimum value among 45 // those kernels that are valid for all input tensors in the [min, max] range. 46 struct OptimizationProfileConfig { 47 // Length of vector == num_inputs to engine. 48 std::vector<nvinfer1::Dims> min; 49 std::vector<nvinfer1::Dims> opt; 50 std::vector<nvinfer1::Dims> max; 51 DebugStringOptimizationProfileConfig52 string DebugString() const { 53 using absl::StrCat; 54 return StrCat("[min: ", tensorflow::tensorrt::DebugString(min), 55 ", opt: : ", tensorflow::tensorrt::DebugString(opt), 56 ", max: ", tensorflow::tensorrt::DebugString(max), "]"); 57 } 58 59 #if IS_TRT_VERSION_GE(6, 0, 0, 0) 60 // Sets the min/opt/max dimensions for profile. 61 // 62 // The given min/opt/max dimensions should satisfy the condition 63 // min <= opt <= max. Additionally TRT requires that the min/opt/max values 64 // are compatible with the network input. Compatibility is defined the 65 // following way: let dim be the shape of an input binding and min/opt/max the 66 // corresponding profile dims. TRT requires that dim.d[k] must be -1 if 67 // (min.d[k] != dim.d[k] || opt.d[k] != dim.d[k] || max.d[k] != dim.d[k]). 68 // 69 // Parameters: 70 // network - TensorRT network, used to enumerate all the input tensors 71 // profile - on exit the profile information will be set for each input tensor SetDimensionsOptimizationProfileConfig72 Status SetDimensions(const nvinfer1::INetworkDefinition* network, 73 nvinfer1::IOptimizationProfile* profile) const { 74 int n_inputs = network->getNbInputs(); 75 if (min.size() != n_inputs || opt.size() != n_inputs || 76 max.size() != n_inputs) { 77 return errors::Internal("Incorrect number of profile config parameters"); 78 } 79 for (int i = 0; i < n_inputs; i++) { 80 const nvinfer1::ITensor* input = network->getInput(i); 81 const char* name = input->getName(); 82 VLOG(2) << "Setting input dimensions for " << name << ", " 83 << ::tensorflow::tensorrt::DebugString(opt[i]); 84 profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN, min[i]); 85 profile->setDimensions(name, nvinfer1::OptProfileSelector::kOPT, opt[i]); 86 profile->setDimensions(name, nvinfer1::OptProfileSelector::kMAX, max[i]); 87 } 88 return Status::OK(); 89 } 90 #endif 91 92 // Returns true if profile range completely includes the given shapes. IncludesShapesOptimizationProfileConfig93 bool IncludesShapes(const std::vector<TensorShape>& shapes) const { 94 // min, max, and opt must have the same size which is already verified in 95 // SetDimensions. 96 if (min.size() != shapes.size()) { 97 return false; 98 } 99 for (int i = 0; i < shapes.size(); i++) { 100 auto current_shape = shapes[i]; 101 // min, max, and opt must have the same nbDims, which is already verified 102 // in SetDimensions. 103 if (min[i].nbDims != current_shape.dims()) { 104 return false; 105 } 106 // Check if range [min, max] includes current_shape. 107 for (int dim = 0; dim < current_shape.dims(); dim++) { 108 if ((min[i].d[dim] > current_shape.dim_size(dim)) || 109 (max[i].d[dim] < current_shape.dim_size(dim))) { 110 return false; 111 } 112 } 113 } 114 return true; 115 } 116 }; 117 118 // Optimization profile generation strategies. 119 enum class ProfileStrategy { 120 kImplicitBatchModeCompatible, 121 kOptimal, 122 }; 123 124 // Manages Optimization profiles during TRT Engine construction. 125 // 126 // An optimization profile describes a range of dimensions for each TRT network 127 // input, and the optimal dimensions that the auto-tuner should use for 128 // optimization. 129 // 130 // This class stores the list of input shapes that were seen during the 131 // build/profile_generation_mode phase, and using them it creates a set of 132 // OptimizationProfileConfigs. These configs will be added to IBuilderConfig 133 // before the engine is created. 134 class TrtShapeOptimizationProfile { 135 public: 136 TrtShapeOptimizationProfile( 137 ProfileStrategy strategy = ProfileStrategy::kImplicitBatchModeCompatible) strategy_(strategy)138 : strategy_(strategy) {} 139 140 // Stores input shape information during profile_generation_mode. AddShape(const std::vector<TensorShape> & shapes)141 void AddShape(const std::vector<TensorShape>& shapes) { 142 input_shapes_.insert(shapes); 143 VLOG(1) << "Collected shape(s) " << DebugString(shapes) << " for profiles."; 144 } 145 clear()146 void clear() { profiles_.clear(); } 147 148 // Returns the profile number that should be used to execute the network with 149 // the given input shapes. Returns -1 if none of cached profiles are 150 // compatible with the given input shapes. 151 int GetProfileNumber(const std::vector<TensorShape>& shapes); 152 153 #if IS_TRT_VERSION_GE(6, 0, 0, 0) 154 // Creates optimization profiles and add them to the builder config. 155 Status ConfigureBuilder(nvinfer1::IBuilder* builder, 156 nvinfer1::IBuilderConfig* config, 157 const nvinfer1::INetworkDefinition* network); 158 #endif 159 160 // Creates execution contexts for each optimization profile. 161 Status CreateExecutionContexts(nvinfer1::ICudaEngine* engine, 162 std::vector<ExecutionContext>& exec_context, 163 TRTBaseAllocator* memory_allocator); 164 165 // Creates optimization profiles profiles_ for the set of concrete input 166 // shapes collected in input_shapes_. The input_partial_shapes of the network 167 // is used to ensure that the created optimization profiles are compatible 168 // with the network. 169 void InitProfiles( 170 const std::vector<PartialTensorShape>& input_partial_shapes); 171 172 // Returns number of created profiles. 173 int GetNumProfiles() const; 174 HasShape()175 bool HasShape() const { return !input_shapes_.empty(); } NeedProfiles()176 bool NeedProfiles() const { return need_profiles_; } 177 178 // Restores profiles from the engine (used after deserialization). 179 Status RestoreProfiles(const nvinfer1::ICudaEngine* engine); 180 181 private: 182 // Set of input shape vetors that we collect during profile_generation_mode. 183 std::unordered_set<std::vector<TensorShape>, VectorTensorShapeHasher> 184 input_shapes_; 185 186 // The optimization profiles generated from input_shapes_. 187 std::vector<OptimizationProfileConfig> profiles_; 188 189 // Whether the network has any shape tensors. 190 bool has_shape_tensor_; 191 192 // Whether the network/engine requires optimization profiles. 193 bool need_profiles_ = false; 194 195 // Whether an input tensor is a shape tensor. 196 std::vector<bool> is_shape_tensor_; 197 198 // Optimization profile generation strategy. 199 ProfileStrategy strategy_; 200 201 #if IS_TRT_VERSION_GE(6, 0, 0, 0) 202 // Adds optimization profiles to the builder config. 203 Status AddProfiles(nvinfer1::IBuilder* builder, 204 nvinfer1::IBuilderConfig* config, 205 const nvinfer1::INetworkDefinition* network); 206 #endif 207 208 void SetShapeTensorMask(const nvinfer1::INetworkDefinition* network); 209 void SetShapeTensorMask( 210 const std::vector<PartialTensorShape>& input_partial_shapes); 211 212 void ImplicitBatchModeCompatibleStrategy(); 213 void OptimalStrategy(); 214 }; 215 216 } // namespace tensorrt 217 } // namespace tensorflow 218 219 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT 220 #endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_ 221