1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_WEIGHTS_H_ 16 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_WEIGHTS_H_ 17 18 #if GOOGLE_CUDA && GOOGLE_TENSORRT 19 20 #include <vector> 21 22 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" 23 #include "tensorflow/compiler/tf2tensorrt/utils/trt_tensor_proxy.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/lib/core/status.h" 26 #include "tensorflow/core/platform/types.h" 27 #include "third_party/tensorrt/NvInfer.h" 28 29 namespace tensorflow { 30 namespace tensorrt { 31 namespace convert { 32 33 // Class to convert TF compile-time constants (e.g. Const nodes) to TRT weight. 34 class TRT_ShapedWeights { 35 public: 36 explicit TRT_ShapedWeights( 37 nvinfer1::DataType type = nvinfer1::DataType::kFLOAT); 38 39 // Constructs a weights from another weights. 40 // 41 // NOTE: this does not copy the underlying buffer but only increase its 42 // reference count. 43 TRT_ShapedWeights(const TRT_ShapedWeights& rhs) = default; 44 45 nvinfer1::Weights GetTrtWeights() const; 46 GetTensor()47 const Tensor& GetTensor() const { return tensor_; } 48 49 // Returns a pointer of type const T to the underlying buffer of the tensor. 50 template <typename T> GetPointer()51 const T* GetPointer() const { 52 int64 num_elem = 53 (tensor_.NumElements() * DataTypeSize(tensor_.dtype())) / sizeof(T); 54 return tensor_.bit_casted_shaped<T, 1>({num_elem}).data(); 55 } 56 57 // Returns a pointer of type T to the underlying buffer of the tensor. 58 template <typename T> GetPointer()59 T* GetPointer() { 60 int64 num_elem = 61 (tensor_.NumElements() * DataTypeSize(tensor_.dtype())) / sizeof(T); 62 return tensor_.bit_casted_shaped<T, 1>({num_elem}).data(); 63 } 64 65 // Fills all the weight values with value. 66 template <typename T> SetValues(T value)67 Status SetValues(T value) { 68 switch (type_) { 69 case nvinfer1::DataType::kFLOAT: { 70 float* ptr = tensor_.flat<float>().data(); 71 std::fill(ptr, ptr + volume_, value); 72 break; 73 } 74 case nvinfer1::DataType::kHALF: { 75 Eigen::half* ptr = tensor_.flat<Eigen::half>().data(); 76 std::fill(ptr, ptr + volume_, Eigen::half(value)); 77 break; 78 } 79 case nvinfer1::DataType::kINT32: { 80 int32* ptr = tensor_.flat<int32>().data(); 81 std::fill(ptr, ptr + volume_, value); 82 break; 83 } 84 default: 85 return errors::InvalidArgument( 86 "Unsupported data type ", tensorflow::tensorrt::DebugString(type_)); 87 } 88 return Status::OK(); 89 } 90 91 Status SetShape(DimsAdapter dims); SetShapeUnsafe(DimsAdapter dims)92 void SetShapeUnsafe(DimsAdapter dims) { shape_ = std::move(dims); } 93 94 // Returns total number of elements. Returning 0 means either some dim is 0 95 // or the number of dims is 0. Note that a TF scalar constant is marked as 96 // Dims{0, {1}}, and has a count() == 1. count()97 int64_t count() const { return volume_; } 98 99 size_t size_bytes() const; 100 101 string DebugString() const; 102 103 template <typename T> GetSpan()104 absl::Span<const T> GetSpan() const { 105 return absl::Span<const T>(tensor_.flat<T>().data(), volume_); 106 } 107 108 template <typename T> ToVector()109 std::vector<T> ToVector() const { 110 auto span = GetSpan<T>(); 111 return std::vector<T>(span.data(), span.data() + span.size()); 112 } 113 TrtDType()114 nvinfer1::DataType TrtDType() const { return type_; } 115 Shape()116 const DimsAdapter& Shape() const { return shape_; } Shape()117 DimsAdapter& Shape() { return shape_; } 118 119 private: 120 // The shape of the weights. Defaults to the empty shape. 121 DimsAdapter shape_; 122 123 // This creation method is only used by TrtWeightStore, which creates the 124 // underlying buffer. 125 static StatusOr<TRT_ShapedWeights> CreateWithTensor(nvinfer1::DataType type, 126 DimsAdapter dims, 127 Tensor tensor); 128 129 nvinfer1::DataType type_; 130 131 // All weights should be stored inside TrtWeightStore to make sure lifetime of 132 // all the underlying tensors are available until the engine is built. For 133 // this reason, tensor_ should never be reassigned to a different value that 134 // is not already present in the TrtWeightStore. 135 Tensor tensor_; 136 // Contains the volume of the weight's shape. 137 int64_t volume_; 138 139 friend class TrtWeightStore; 140 }; 141 142 // Container for TRT_ShapedWeights. We need this container because TRT does not 143 // manage the lifetime of the weights buffer, it only keeps a pointer to it and 144 // requires that the data referenced by the pointer be available until the 145 // building of engine is complete. For more information see 146 // https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/c_api/classnvinfer1_1_1_weights.html 147 // 148 // TODO(laigd): consider adding garbage collection to the unused weights. 149 class TrtWeightStore { 150 public: 151 // Gets a TRT_ShapedWeights with 'type' and 'dims'. 152 StatusOr<TRT_ShapedWeights> GetTempWeights(nvinfer1::DataType trt_type, 153 const DimsAdapter& dims); 154 155 // Gets a TRT_ShapedWeights with the same data type and dimensions as 156 // 'weights'. GetTempWeights(const TRT_ShapedWeights & weights)157 StatusOr<TRT_ShapedWeights> GetTempWeights(const TRT_ShapedWeights& weights) { 158 return GetTempWeights(weights.TrtDType(), weights.Shape()); 159 } 160 161 private: 162 // The backend storage of the TRT_ShapedWeights. 163 std::vector<Tensor> store_; 164 }; 165 166 // Enumerates the possible types of arguments of a converter. This determines 167 // what object is contained in TRT_TensorOrWeights, and converters can require 168 // a specific type for each of their arguments. 169 enum class TRT_ArgumentType { 170 TENSOR = 0, 171 WEIGHTS = 1, 172 RESOURCE = 2, 173 }; 174 175 // Represents a TRT-style input to a TF node, it can be either a 176 // ITensorProxyPtr (representing nvinfer1::ITensor* or SimpleITensor), 177 // or TRT_ShapedWeights which is compile-time constant. 178 // 179 // TODO(laigd): maybe rename it to TrtArgument, or mimic XlaCompiler::Argument. 180 class TRT_TensorOrWeights { 181 public: TRT_TensorOrWeights()182 TRT_TensorOrWeights() {} 183 TRT_TensorOrWeights(ITensorProxyPtr); 184 TRT_TensorOrWeights(ITensorProxyPtr tensor, int batch_size); 185 186 // Constructs a wrapper for the given ITensor. 187 // This is used by Converter when building the TRT network, where the ITensor 188 // is owned by the TRT network being built. See comment for 'trt_tensor_' 189 // in trt_proxy_tensor.h. 190 explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor, int batch_size = -1); 191 192 // Creates a SimpleITensor for trt_dtype and trt_dims and takes ownership of 193 // the object. Constructs a wrapper for the SimpleITensor. This is used by 194 // TrtNodeValidator to encapsulate the type and shape information for 195 // validation of graph nodes, and the created ITensor is fake and temporary, 196 // and should not be used to build any TRT network. See comment for 197 // 'simple_tensor_' in trt_proxy_tensor.h. 198 explicit TRT_TensorOrWeights(nvinfer1::DataType trt_dtype, 199 const nvinfer1::Dims& trt_dims, int batch_size); 200 201 // Constructs a wrapper for the given weights. 202 explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights); 203 204 // Constructs a wrapper for the given resource handle. 205 explicit TRT_TensorOrWeights(const ResourceHandle& resource); 206 207 TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs); 208 209 void operator=(const TRT_TensorOrWeights& rhs); 210 is_tensor()211 bool is_tensor() const { 212 return initialized_ && arg_type_ == TRT_ArgumentType::TENSOR; 213 } is_weights()214 bool is_weights() const { 215 return initialized_ && arg_type_ == TRT_ArgumentType::WEIGHTS; 216 } is_resource()217 bool is_resource() const { 218 return initialized_ && arg_type_ == TRT_ArgumentType::RESOURCE; 219 } 220 221 ITensorProxyPtr tensor() const; 222 223 ResourceHandle resource() const; 224 weights()225 TRT_ShapedWeights& weights() { 226 DCHECK(is_weights()); 227 return weights_; 228 } 229 weights()230 const TRT_ShapedWeights& weights() const { 231 DCHECK(is_weights()); 232 return weights_; 233 } 234 235 nvinfer1::Dims GetTrtDims() const; 236 237 Status GetTfType(DataType* tf_type) const; 238 batch_size()239 int batch_size() const { return batch_size_; } 240 241 string DebugString() const; 242 TrtDType()243 nvinfer1::DataType TrtDType() const { 244 if (arg_type_ == TRT_ArgumentType::RESOURCE) { 245 VLOG(0) << "Calling TrtDType() with a RESOURCE argument is undefined " 246 "behavior."; 247 } 248 return arg_type_ == TRT_ArgumentType::TENSOR ? tensor_proxy_ptr_->getType() 249 : weights_.TrtDType(); 250 } 251 252 private: set_batch_size(int batch_size)253 void set_batch_size(int batch_size) { batch_size_ = batch_size; } 254 255 // First dimension of the TF tensor (NOT tensor_) that is represented by 256 // tensor_ is treated as the "batch dimension" by TRT, and tensor_'s 257 // dimensions (obtained via tensor_->getDimensions()) do not contain the batch 258 // dimension. For example, when a TF tensor with shape (A,B,C) is represented 259 // in TRT, tensor_->getDimensions() will be (B,C) and batch_size_ will be A. 260 // 261 // This requires that all tensors in the subgraph that is converted to a TRT 262 // engine have the same batch size are represented by the first dimension of 263 // their shape, and Converter will verify this during conversion. The drawback 264 // is that currently it cannot convert a graph that doesn't have the batch 265 // size represented in the shapes or the batch sizes are different. See 266 // b/118387490 for more details. 267 // 268 // If use_implicit_batch is false, batch_size_ is unused and 269 // tensor_->getDimensions() will contain the entire shape (A,B,C). 270 // 271 // tensor_proxy_ptr_ is used when arg_type_ == TENSOR. 272 ITensorProxyPtr tensor_proxy_ptr_ = nullptr; 273 int batch_size_ = -1; 274 275 // For DT_RESOURCE arguments (there is no corresponding type in TRT). 276 // resource_ is used when arg_type_ == RESOURCE. 277 ResourceHandle resource_; 278 279 // weights_ is used when arg_type_ == WEIGHTS. 280 TRT_ShapedWeights weights_; 281 bool initialized_ = false; 282 TRT_ArgumentType arg_type_ = TRT_ArgumentType::WEIGHTS; 283 284 friend class Converter; 285 }; 286 } // namespace convert 287 } // namespace tensorrt 288 } // namespace tensorflow 289 290 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT 291 #endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_WEIGHTS_H_ 292