1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_TENSOR_PROXY_H 17 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_TENSOR_PROXY_H 18 19 #include <cassert> 20 #include <cstddef> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 25 #include "tensorflow/compiler/tf2tensorrt/common/utils.h" 26 27 #if GOOGLE_CUDA && GOOGLE_TENSORRT 28 #include "third_party/tensorrt/NvInfer.h" 29 30 namespace tensorflow { 31 32 namespace tensorrt { 33 34 // SimpleITensor implements part of the ITensor interfaces to support the TF-TRT 35 // validator, as well as some TF-TRT tests. The former use case only utilizes 36 // the interfaces related to shape and type information. 37 class SimpleITensor { 38 public: SimpleITensor(nvinfer1::DataType trt_dtype,const nvinfer1::Dims & trt_dims)39 SimpleITensor(nvinfer1::DataType trt_dtype, const nvinfer1::Dims& trt_dims) 40 : trt_dtype_(trt_dtype), trt_dims_(trt_dims) {} 41 SimpleITensor()42 SimpleITensor() : dynamic_range_min_(0.0f), dynamic_range_max_(0.0f) {} SimpleITensor(const nvinfer1::Dims & dims)43 SimpleITensor(const nvinfer1::Dims& dims) 44 : trt_dims_(dims), dynamic_range_min_(0.0f), dynamic_range_max_(0.0f) {} 45 SimpleITensor(const std::vector<int> & dims)46 SimpleITensor(const std::vector<int>& dims) { 47 trt_dims_.nbDims = dims.size(); 48 for (int i = 0; i < dims.size(); ++i) { 49 trt_dims_.d[i] = dims[i]; 50 } 51 dynamic_range_min_ = 0.0f; 52 dynamic_range_max_ = 0.0f; 53 } 54 setName(const char * name)55 void setName(const char* name) {} 56 getName()57 const char* getName() const { return ""; } 58 setDimensions(nvinfer1::Dims dimensions)59 void setDimensions(nvinfer1::Dims dimensions) { trt_dims_ = dimensions; } 60 getDimensions()61 nvinfer1::Dims getDimensions() const { return trt_dims_; } 62 setType(nvinfer1::DataType trt_dtype)63 void setType(nvinfer1::DataType trt_dtype) { trt_dtype_ = trt_dtype; } 64 getType()65 nvinfer1::DataType getType() const { return trt_dtype_; } 66 isNetworkInput()67 bool isNetworkInput() const { return false; } 68 isNetworkOutput()69 bool isNetworkOutput() const { return false; } 70 setBroadcastAcrossBatch(bool broadcastAcrossBatch)71 void setBroadcastAcrossBatch(bool broadcastAcrossBatch) {} 72 getBroadcastAcrossBatch()73 bool getBroadcastAcrossBatch() const { return false; } 74 getLocation()75 nvinfer1::TensorLocation getLocation() const { return location_; } 76 setLocation(nvinfer1::TensorLocation location)77 void setLocation(nvinfer1::TensorLocation location) { location_ = location; } setDynamicRange(float min,float max)78 bool setDynamicRange(float min, float max) { 79 dynamic_range_max_ = max; 80 dynamic_range_min_ = min; 81 return true; 82 } 83 getDynamicRange()84 float getDynamicRange() const { 85 return (std::abs(dynamic_range_min_) + dynamic_range_max_) / 2.f; 86 } dynamicRangeIsSet()87 bool dynamicRangeIsSet() const { return true; } 88 resetDynamicRange()89 void resetDynamicRange() { 90 dynamic_range_min_ = 0.f; 91 dynamic_range_max_ = 0.f; 92 } getDynamicRangeMin()93 float getDynamicRangeMin() const { return dynamic_range_min_; } 94 getDynamicRangeMax()95 float getDynamicRangeMax() const { return dynamic_range_max_; } 96 setAllowedFormats(nvinfer1::TensorFormats formats)97 void setAllowedFormats(nvinfer1::TensorFormats formats) {} 98 getAllowedFormats()99 nvinfer1::TensorFormats getAllowedFormats() const { return 1; } 100 isShapeTensor()101 bool isShapeTensor() const { return false; } isExecutionTensor()102 bool isExecutionTensor() const { return true; } 103 104 private: 105 nvinfer1::DataType trt_dtype_; 106 nvinfer1::Dims trt_dims_; 107 std::string name_; 108 nvinfer1::TensorLocation location_; 109 float dynamic_range_min_; 110 float dynamic_range_max_; 111 }; 112 113 enum class TensorType : int { kTRT, kSIMPLE }; 114 115 class ITensorProxy { 116 public: 117 //! ITensor not owned ITensorProxy(nvinfer1::ITensor * trt_tensor)118 ITensorProxy(nvinfer1::ITensor* trt_tensor) 119 : trt_tensor_(trt_tensor), ttype_(TensorType::kTRT) {} 120 121 //! SimpleITensor owned ITensorProxy(SimpleITensor * simple_itensor)122 ITensorProxy(SimpleITensor* simple_itensor) 123 : simple_tensor_(simple_itensor), ttype_(TensorType::kSIMPLE) {} 124 125 //! SimpleITensor owned ITensorProxy(nvinfer1::DataType trt_dtype,const nvinfer1::Dims & trt_dims)126 explicit ITensorProxy(nvinfer1::DataType trt_dtype, 127 const nvinfer1::Dims& trt_dims) 128 : simple_tensor_(std::unique_ptr<SimpleITensor>( 129 new SimpleITensor(trt_dtype, trt_dims))), 130 ttype_(TensorType::kSIMPLE) {} 131 132 //! Variants for testing purposes ITensorProxy()133 ITensorProxy() 134 : simple_tensor_(std::unique_ptr<SimpleITensor>(new SimpleITensor())), 135 ttype_(TensorType::kSIMPLE) {} 136 ITensorProxy(const nvinfer1::Dims & dims)137 explicit ITensorProxy(const nvinfer1::Dims& dims) 138 : simple_tensor_(std::unique_ptr<SimpleITensor>(new SimpleITensor(dims))), 139 ttype_(TensorType::kSIMPLE) {} 140 ITensorProxy(const std::vector<int> & dims)141 explicit ITensorProxy(const std::vector<int>& dims) 142 : simple_tensor_(std::unique_ptr<SimpleITensor>(new SimpleITensor(dims))), 143 ttype_(TensorType::kSIMPLE) {} 144 is_trt_tensor()145 bool is_trt_tensor() const { 146 assert(validate()); 147 assert(ttype_ == TensorType::kTRT); 148 return trt_tensor_ != nullptr; 149 } 150 is_simple_tensor()151 bool is_simple_tensor() const { 152 assert(validate()); 153 assert(ttype_ == TensorType::kSIMPLE); 154 return simple_tensor_ != nullptr; 155 } 156 ttype()157 TensorType ttype() const { return ttype_; } 158 trt_tensor()159 nvinfer1::ITensor* trt_tensor() const { 160 assert(trt_tensor_ != nullptr); 161 assert(ttype_ == TensorType::kTRT); 162 return trt_tensor_; 163 } 164 simple_tensor()165 SimpleITensor* simple_tensor() const { 166 assert(simple_tensor_ != nullptr); 167 assert(ttype_ == TensorType::kSIMPLE); 168 return simple_tensor_.get(); 169 } 170 setName(const char * name)171 void setName(const char* name) { 172 switch (ttype_) { 173 case TensorType::kTRT: 174 return trt_tensor_->setName(name); 175 case TensorType::kSIMPLE: 176 return simple_tensor_->setName(name); 177 } 178 assert(0 && "Unsupported itensor_ type"); 179 } 180 getName()181 const char* getName() const { 182 switch (ttype_) { 183 case TensorType::kTRT: 184 return trt_tensor_->getName(); 185 case TensorType::kSIMPLE: 186 return simple_tensor_->getName(); 187 } 188 assert(0 && "Unsupported itensor_ type"); 189 } 190 setDimensions(nvinfer1::Dims dimensions)191 void setDimensions(nvinfer1::Dims dimensions) { 192 switch (ttype_) { 193 case TensorType::kTRT: 194 return trt_tensor_->setDimensions(dimensions); 195 case TensorType::kSIMPLE: 196 return simple_tensor_->setDimensions(dimensions); 197 } 198 assert(0 && "Unsupported itensor_ type"); 199 } 200 getDimensions()201 nvinfer1::Dims getDimensions() const { 202 switch (ttype_) { 203 case TensorType::kTRT: 204 return trt_tensor_->getDimensions(); 205 case TensorType::kSIMPLE: 206 return simple_tensor_->getDimensions(); 207 } 208 assert(0 && "Unsupported itensor_ type"); 209 } 210 setType(nvinfer1::DataType type)211 void setType(nvinfer1::DataType type) { 212 switch (ttype_) { 213 case TensorType::kTRT: 214 return trt_tensor_->setType(type); 215 case TensorType::kSIMPLE: 216 return simple_tensor_->setType(type); 217 } 218 assert(0 && "Unsupported itensor_ type"); 219 } 220 getType()221 nvinfer1::DataType getType() const { 222 switch (ttype_) { 223 case TensorType::kTRT: 224 return trt_tensor_->getType(); 225 case TensorType::kSIMPLE: 226 return simple_tensor_->getType(); 227 } 228 assert(0 && "Unsupported itensor_ type"); 229 } 230 isNetworkInput()231 bool isNetworkInput() const { 232 switch (ttype_) { 233 case TensorType::kTRT: 234 return trt_tensor_->isNetworkInput(); 235 case TensorType::kSIMPLE: 236 return simple_tensor_->isNetworkInput(); 237 } 238 assert(0 && "Unsupported itensor_ type"); 239 } 240 isNetworkOutput()241 bool isNetworkOutput() const { 242 switch (ttype_) { 243 case TensorType::kTRT: 244 return trt_tensor_->isNetworkOutput(); 245 case TensorType::kSIMPLE: 246 return simple_tensor_->isNetworkOutput(); 247 } 248 assert(0 && "Unsupported itensor_ type"); 249 } 250 setBroadcastAcrossBatch(bool broadcastAcrossBatch)251 void setBroadcastAcrossBatch(bool broadcastAcrossBatch) { 252 switch (ttype_) { 253 case TensorType::kTRT: 254 return trt_tensor_->setBroadcastAcrossBatch(broadcastAcrossBatch); 255 case TensorType::kSIMPLE: 256 return simple_tensor_->setBroadcastAcrossBatch(broadcastAcrossBatch); 257 } 258 assert(0 && "Unsupported itensor_ type"); 259 } 260 getBroadcastAcrossBatch()261 bool getBroadcastAcrossBatch() const { 262 switch (ttype_) { 263 case TensorType::kTRT: 264 return trt_tensor_->getBroadcastAcrossBatch(); 265 case TensorType::kSIMPLE: 266 return simple_tensor_->getBroadcastAcrossBatch(); 267 } 268 assert(0 && "Unsupported itensor_ type"); 269 } 270 getLocation()271 nvinfer1::TensorLocation getLocation() const { 272 switch (ttype_) { 273 case TensorType::kTRT: 274 return trt_tensor_->getLocation(); 275 case TensorType::kSIMPLE: 276 return simple_tensor_->getLocation(); 277 } 278 assert(0 && "Unsupported itensor_ type"); 279 } 280 setLocation(nvinfer1::TensorLocation location)281 void setLocation(nvinfer1::TensorLocation location) { 282 switch (ttype_) { 283 case TensorType::kTRT: 284 return trt_tensor_->setLocation(location); 285 case TensorType::kSIMPLE: 286 return simple_tensor_->setLocation(location); 287 } 288 assert(0 && "Unsupported itensor_ type"); 289 } 290 setDynamicRange(float min,float max)291 bool setDynamicRange(float min, float max) { 292 switch (ttype_) { 293 case TensorType::kTRT: 294 return trt_tensor_->setDynamicRange(min, max); 295 case TensorType::kSIMPLE: 296 return simple_tensor_->setDynamicRange(min, max); 297 } 298 assert(0 && "Unsupported itensor_ type"); 299 } 300 dynamicRangeIsSet()301 bool dynamicRangeIsSet() const { 302 switch (ttype_) { 303 case TensorType::kTRT: 304 return trt_tensor_->dynamicRangeIsSet(); 305 case TensorType::kSIMPLE: 306 return simple_tensor_->dynamicRangeIsSet(); 307 } 308 assert(0 && "Unsupported itensor_ type"); 309 } 310 resetDynamicRange()311 void resetDynamicRange() { 312 switch (ttype_) { 313 case TensorType::kTRT: 314 return trt_tensor_->resetDynamicRange(); 315 case TensorType::kSIMPLE: 316 return simple_tensor_->resetDynamicRange(); 317 } 318 assert(0 && "Unsupported itensor_ type"); 319 } getDynamicRangeMin()320 float getDynamicRangeMin() const { 321 switch (ttype_) { 322 case TensorType::kTRT: 323 return trt_tensor_->getDynamicRangeMin(); 324 case TensorType::kSIMPLE: 325 return simple_tensor_->getDynamicRangeMin(); 326 } 327 assert(0 && "Unsupported itensor_ type"); 328 } 329 getDynamicRangeMax()330 float getDynamicRangeMax() const { 331 switch (ttype_) { 332 case TensorType::kTRT: 333 return trt_tensor_->getDynamicRangeMax(); 334 case TensorType::kSIMPLE: 335 return simple_tensor_->getDynamicRangeMax(); 336 } 337 assert(0 && "Unsupported itensor_ type"); 338 } 339 #if !IS_TRT_VERSION_GE(8, 0, 0, 0) getDynamicRange()340 float getDynamicRange() const { 341 switch (ttype_) { 342 case TensorType::kTRT: 343 return trt_tensor_->getDynamicRange(); 344 case TensorType::kSIMPLE: 345 return simple_tensor_->getDynamicRange(); 346 } 347 assert(0 && "Unsupported itensor_ type"); 348 } 349 #endif setAllowedFormats(nvinfer1::TensorFormats formats)350 void setAllowedFormats(nvinfer1::TensorFormats formats) { 351 switch (ttype_) { 352 case TensorType::kTRT: 353 return trt_tensor_->setAllowedFormats(formats); 354 case TensorType::kSIMPLE: 355 return simple_tensor_->setAllowedFormats(formats); 356 } 357 assert(0 && "Unsupported itensor_ type"); 358 } 359 getAllowedFormats()360 nvinfer1::TensorFormats getAllowedFormats() const { 361 switch (ttype_) { 362 case TensorType::kTRT: 363 return trt_tensor_->getAllowedFormats(); 364 case TensorType::kSIMPLE: 365 return simple_tensor_->getAllowedFormats(); 366 } 367 assert(0 && "Unsupported itensor_ type"); 368 } 369 isShapeTensor()370 bool isShapeTensor() const { 371 switch (ttype_) { 372 case TensorType::kTRT: 373 return trt_tensor_->isShapeTensor(); 374 case TensorType::kSIMPLE: 375 return simple_tensor_->isShapeTensor(); 376 } 377 assert(0 && "Unsupported itensor_ type"); 378 } 379 isExecutionTensor()380 bool isExecutionTensor() const { 381 switch (ttype_) { 382 case TensorType::kTRT: 383 return trt_tensor_->isExecutionTensor(); 384 case TensorType::kSIMPLE: 385 return simple_tensor_->isExecutionTensor(); 386 } 387 assert(0 && "Unsupported itensor_ type"); 388 } 389 390 private: validate()391 bool validate() const { 392 return (trt_tensor_ && !simple_tensor_) || (!trt_tensor_ && simple_tensor_); 393 } 394 395 // When ITensorProxy represents an ITensor, the ITensor can be either passed 396 // by the caller via the constructor that takes an ITensor* as parameter, or 397 // be created as a SimpleITensor. 398 // 399 // In the first case, the ITensor pointer is stored in 'tensor_' below, and 400 // the ITensor itself is not owned by this class. This method is used by 401 // Converter (e.g. AddInputTensor) and op converters during TRT network 402 // construction, where the TRT network owns the ITensor. 403 // 404 nvinfer1::ITensor* trt_tensor_ = nullptr; // Not owned. 405 // In the second case, the created SimpleITensor is stored in 406 // 'simple_itensor_' below and is owned by this class. SimpleITensor is a fake 407 // implementation of ITensor and is used for testing and by TrtNodeValidator 408 // to validate the graph nodes. 409 std::shared_ptr<SimpleITensor> simple_tensor_ = nullptr; 410 411 TensorType ttype_; 412 }; 413 414 class ITensorProxyPtr { 415 public: ITensorProxyPtr(std::nullptr_t)416 ITensorProxyPtr(std::nullptr_t) : p_(nullptr) {} ITensorProxyPtr(ITensorProxy * p)417 ITensorProxyPtr(ITensorProxy* p) : p_(p) {} ITensorProxyPtr(nvinfer1::ITensor * p)418 ITensorProxyPtr(nvinfer1::ITensor* p) : p_(new ITensorProxy(p)) {} ITensorProxyPtr(SimpleITensor * p)419 ITensorProxyPtr(SimpleITensor* p) : p_(new ITensorProxy(p)) {} 420 ITensorProxyPtr()421 ITensorProxyPtr() : p_(new ITensorProxy()) {} ITensorProxyPtr(const nvinfer1::Dims & dims)422 ITensorProxyPtr(const nvinfer1::Dims& dims) : p_(new ITensorProxy(dims)) {} ITensorProxyPtr(const std::vector<int> & dims)423 ITensorProxyPtr(const std::vector<int>& dims) : p_(new ITensorProxy(dims)) {} 424 425 std::shared_ptr<ITensorProxy> p_{nullptr}; 426 ITensorProxy* operator->() { return p_.get(); } 427 ITensorProxy* operator->() const { return p_.get(); } 428 ITensorProxy* operator*() { return p_.get(); } 429 ITensorProxy* operator*() const { return p_.get(); } 430 }; 431 432 inline bool operator==(const ITensorProxyPtr& p1, const ITensorProxyPtr& p2) { 433 if (p1.p_ == nullptr) { 434 return p2.p_ == nullptr; 435 } 436 if (p2.p_ == nullptr) { 437 return p1.p_ == nullptr; 438 } 439 return (p1->ttype() == p2->ttype()) && 440 ((p1->ttype() == TensorType::kTRT && 441 p1->trt_tensor() == p2->trt_tensor()) || 442 (p1->ttype() == TensorType::kSIMPLE && 443 p1->simple_tensor() == p2->simple_tensor())); 444 } 445 446 struct ITensorProxyHash { operatorITensorProxyHash447 size_t operator()(const ITensorProxyPtr& tensor) const { 448 return reinterpret_cast<std::uintptr_t>(tensor.p_.get()); 449 } 450 }; 451 452 } // namespace tensorrt 453 } // namespace tensorflow 454 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT 455 456 #endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_TENSOR_PROXY_H 457