• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_TENSOR_PROXY_H
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_TENSOR_PROXY_H
18 
19 #include <cassert>
20 #include <cstddef>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
26 
27 #if GOOGLE_CUDA && GOOGLE_TENSORRT
28 #include "third_party/tensorrt/NvInfer.h"
29 
30 namespace tensorflow {
31 
32 namespace tensorrt {
33 
34 // SimpleITensor implements part of the ITensor interfaces to support the TF-TRT
35 // validator, as well as some TF-TRT tests. The former use case only utilizes
36 // the interfaces related to shape and type information.
37 class SimpleITensor {
38  public:
SimpleITensor(nvinfer1::DataType trt_dtype,const nvinfer1::Dims & trt_dims)39   SimpleITensor(nvinfer1::DataType trt_dtype, const nvinfer1::Dims& trt_dims)
40       : trt_dtype_(trt_dtype), trt_dims_(trt_dims) {}
41 
SimpleITensor()42   SimpleITensor() : dynamic_range_min_(0.0f), dynamic_range_max_(0.0f) {}
SimpleITensor(const nvinfer1::Dims & dims)43   SimpleITensor(const nvinfer1::Dims& dims)
44       : trt_dims_(dims), dynamic_range_min_(0.0f), dynamic_range_max_(0.0f) {}
45 
SimpleITensor(const std::vector<int> & dims)46   SimpleITensor(const std::vector<int>& dims) {
47     trt_dims_.nbDims = dims.size();
48     for (int i = 0; i < dims.size(); ++i) {
49       trt_dims_.d[i] = dims[i];
50     }
51     dynamic_range_min_ = 0.0f;
52     dynamic_range_max_ = 0.0f;
53   }
54 
setName(const char * name)55   void setName(const char* name) {}
56 
getName()57   const char* getName() const { return ""; }
58 
setDimensions(nvinfer1::Dims dimensions)59   void setDimensions(nvinfer1::Dims dimensions) { trt_dims_ = dimensions; }
60 
getDimensions()61   nvinfer1::Dims getDimensions() const { return trt_dims_; }
62 
setType(nvinfer1::DataType trt_dtype)63   void setType(nvinfer1::DataType trt_dtype) { trt_dtype_ = trt_dtype; }
64 
getType()65   nvinfer1::DataType getType() const { return trt_dtype_; }
66 
isNetworkInput()67   bool isNetworkInput() const { return false; }
68 
isNetworkOutput()69   bool isNetworkOutput() const { return false; }
70 
setBroadcastAcrossBatch(bool broadcastAcrossBatch)71   void setBroadcastAcrossBatch(bool broadcastAcrossBatch) {}
72 
getBroadcastAcrossBatch()73   bool getBroadcastAcrossBatch() const { return false; }
74 
getLocation()75   nvinfer1::TensorLocation getLocation() const { return location_; }
76 
setLocation(nvinfer1::TensorLocation location)77   void setLocation(nvinfer1::TensorLocation location) { location_ = location; }
setDynamicRange(float min,float max)78   bool setDynamicRange(float min, float max) {
79     dynamic_range_max_ = max;
80     dynamic_range_min_ = min;
81     return true;
82   }
83 
getDynamicRange()84   float getDynamicRange() const {
85     return (std::abs(dynamic_range_min_) + dynamic_range_max_) / 2.f;
86   }
dynamicRangeIsSet()87   bool dynamicRangeIsSet() const { return true; }
88 
resetDynamicRange()89   void resetDynamicRange() {
90     dynamic_range_min_ = 0.f;
91     dynamic_range_max_ = 0.f;
92   }
getDynamicRangeMin()93   float getDynamicRangeMin() const { return dynamic_range_min_; }
94 
getDynamicRangeMax()95   float getDynamicRangeMax() const { return dynamic_range_max_; }
96 
setAllowedFormats(nvinfer1::TensorFormats formats)97   void setAllowedFormats(nvinfer1::TensorFormats formats) {}
98 
getAllowedFormats()99   nvinfer1::TensorFormats getAllowedFormats() const { return 1; }
100 
isShapeTensor()101   bool isShapeTensor() const { return false; }
isExecutionTensor()102   bool isExecutionTensor() const { return true; }
103 
104  private:
105   nvinfer1::DataType trt_dtype_;
106   nvinfer1::Dims trt_dims_;
107   std::string name_;
108   nvinfer1::TensorLocation location_;
109   float dynamic_range_min_;
110   float dynamic_range_max_;
111 };
112 
113 enum class TensorType : int { kTRT, kSIMPLE };
114 
115 class ITensorProxy {
116  public:
117   //! ITensor not owned
ITensorProxy(nvinfer1::ITensor * trt_tensor)118   ITensorProxy(nvinfer1::ITensor* trt_tensor)
119       : trt_tensor_(trt_tensor), ttype_(TensorType::kTRT) {}
120 
121   //! SimpleITensor owned
ITensorProxy(SimpleITensor * simple_itensor)122   ITensorProxy(SimpleITensor* simple_itensor)
123       : simple_tensor_(simple_itensor), ttype_(TensorType::kSIMPLE) {}
124 
125   //! SimpleITensor owned
ITensorProxy(nvinfer1::DataType trt_dtype,const nvinfer1::Dims & trt_dims)126   explicit ITensorProxy(nvinfer1::DataType trt_dtype,
127                         const nvinfer1::Dims& trt_dims)
128       : simple_tensor_(std::unique_ptr<SimpleITensor>(
129             new SimpleITensor(trt_dtype, trt_dims))),
130         ttype_(TensorType::kSIMPLE) {}
131 
132   //! Variants for testing purposes
ITensorProxy()133   ITensorProxy()
134       : simple_tensor_(std::unique_ptr<SimpleITensor>(new SimpleITensor())),
135         ttype_(TensorType::kSIMPLE) {}
136 
ITensorProxy(const nvinfer1::Dims & dims)137   explicit ITensorProxy(const nvinfer1::Dims& dims)
138       : simple_tensor_(std::unique_ptr<SimpleITensor>(new SimpleITensor(dims))),
139         ttype_(TensorType::kSIMPLE) {}
140 
ITensorProxy(const std::vector<int> & dims)141   explicit ITensorProxy(const std::vector<int>& dims)
142       : simple_tensor_(std::unique_ptr<SimpleITensor>(new SimpleITensor(dims))),
143         ttype_(TensorType::kSIMPLE) {}
144 
is_trt_tensor()145   bool is_trt_tensor() const {
146     assert(validate());
147     assert(ttype_ == TensorType::kTRT);
148     return trt_tensor_ != nullptr;
149   }
150 
is_simple_tensor()151   bool is_simple_tensor() const {
152     assert(validate());
153     assert(ttype_ == TensorType::kSIMPLE);
154     return simple_tensor_ != nullptr;
155   }
156 
ttype()157   TensorType ttype() const { return ttype_; }
158 
trt_tensor()159   nvinfer1::ITensor* trt_tensor() const {
160     assert(trt_tensor_ != nullptr);
161     assert(ttype_ == TensorType::kTRT);
162     return trt_tensor_;
163   }
164 
simple_tensor()165   SimpleITensor* simple_tensor() const {
166     assert(simple_tensor_ != nullptr);
167     assert(ttype_ == TensorType::kSIMPLE);
168     return simple_tensor_.get();
169   }
170 
setName(const char * name)171   void setName(const char* name) {
172     switch (ttype_) {
173       case TensorType::kTRT:
174         return trt_tensor_->setName(name);
175       case TensorType::kSIMPLE:
176         return simple_tensor_->setName(name);
177     }
178     assert(0 && "Unsupported itensor_ type");
179   }
180 
getName()181   const char* getName() const {
182     switch (ttype_) {
183       case TensorType::kTRT:
184         return trt_tensor_->getName();
185       case TensorType::kSIMPLE:
186         return simple_tensor_->getName();
187     }
188     assert(0 && "Unsupported itensor_ type");
189   }
190 
setDimensions(nvinfer1::Dims dimensions)191   void setDimensions(nvinfer1::Dims dimensions) {
192     switch (ttype_) {
193       case TensorType::kTRT:
194         return trt_tensor_->setDimensions(dimensions);
195       case TensorType::kSIMPLE:
196         return simple_tensor_->setDimensions(dimensions);
197     }
198     assert(0 && "Unsupported itensor_ type");
199   }
200 
getDimensions()201   nvinfer1::Dims getDimensions() const {
202     switch (ttype_) {
203       case TensorType::kTRT:
204         return trt_tensor_->getDimensions();
205       case TensorType::kSIMPLE:
206         return simple_tensor_->getDimensions();
207     }
208     assert(0 && "Unsupported itensor_ type");
209   }
210 
setType(nvinfer1::DataType type)211   void setType(nvinfer1::DataType type) {
212     switch (ttype_) {
213       case TensorType::kTRT:
214         return trt_tensor_->setType(type);
215       case TensorType::kSIMPLE:
216         return simple_tensor_->setType(type);
217     }
218     assert(0 && "Unsupported itensor_ type");
219   }
220 
getType()221   nvinfer1::DataType getType() const {
222     switch (ttype_) {
223       case TensorType::kTRT:
224         return trt_tensor_->getType();
225       case TensorType::kSIMPLE:
226         return simple_tensor_->getType();
227     }
228     assert(0 && "Unsupported itensor_ type");
229   }
230 
isNetworkInput()231   bool isNetworkInput() const {
232     switch (ttype_) {
233       case TensorType::kTRT:
234         return trt_tensor_->isNetworkInput();
235       case TensorType::kSIMPLE:
236         return simple_tensor_->isNetworkInput();
237     }
238     assert(0 && "Unsupported itensor_ type");
239   }
240 
isNetworkOutput()241   bool isNetworkOutput() const {
242     switch (ttype_) {
243       case TensorType::kTRT:
244         return trt_tensor_->isNetworkOutput();
245       case TensorType::kSIMPLE:
246         return simple_tensor_->isNetworkOutput();
247     }
248     assert(0 && "Unsupported itensor_ type");
249   }
250 
setBroadcastAcrossBatch(bool broadcastAcrossBatch)251   void setBroadcastAcrossBatch(bool broadcastAcrossBatch) {
252     switch (ttype_) {
253       case TensorType::kTRT:
254         return trt_tensor_->setBroadcastAcrossBatch(broadcastAcrossBatch);
255       case TensorType::kSIMPLE:
256         return simple_tensor_->setBroadcastAcrossBatch(broadcastAcrossBatch);
257     }
258     assert(0 && "Unsupported itensor_ type");
259   }
260 
getBroadcastAcrossBatch()261   bool getBroadcastAcrossBatch() const {
262     switch (ttype_) {
263       case TensorType::kTRT:
264         return trt_tensor_->getBroadcastAcrossBatch();
265       case TensorType::kSIMPLE:
266         return simple_tensor_->getBroadcastAcrossBatch();
267     }
268     assert(0 && "Unsupported itensor_ type");
269   }
270 
getLocation()271   nvinfer1::TensorLocation getLocation() const {
272     switch (ttype_) {
273       case TensorType::kTRT:
274         return trt_tensor_->getLocation();
275       case TensorType::kSIMPLE:
276         return simple_tensor_->getLocation();
277     }
278     assert(0 && "Unsupported itensor_ type");
279   }
280 
setLocation(nvinfer1::TensorLocation location)281   void setLocation(nvinfer1::TensorLocation location) {
282     switch (ttype_) {
283       case TensorType::kTRT:
284         return trt_tensor_->setLocation(location);
285       case TensorType::kSIMPLE:
286         return simple_tensor_->setLocation(location);
287     }
288     assert(0 && "Unsupported itensor_ type");
289   }
290 
setDynamicRange(float min,float max)291   bool setDynamicRange(float min, float max) {
292     switch (ttype_) {
293       case TensorType::kTRT:
294         return trt_tensor_->setDynamicRange(min, max);
295       case TensorType::kSIMPLE:
296         return simple_tensor_->setDynamicRange(min, max);
297     }
298     assert(0 && "Unsupported itensor_ type");
299   }
300 
dynamicRangeIsSet()301   bool dynamicRangeIsSet() const {
302     switch (ttype_) {
303       case TensorType::kTRT:
304         return trt_tensor_->dynamicRangeIsSet();
305       case TensorType::kSIMPLE:
306         return simple_tensor_->dynamicRangeIsSet();
307     }
308     assert(0 && "Unsupported itensor_ type");
309   }
310 
resetDynamicRange()311   void resetDynamicRange() {
312     switch (ttype_) {
313       case TensorType::kTRT:
314         return trt_tensor_->resetDynamicRange();
315       case TensorType::kSIMPLE:
316         return simple_tensor_->resetDynamicRange();
317     }
318     assert(0 && "Unsupported itensor_ type");
319   }
getDynamicRangeMin()320   float getDynamicRangeMin() const {
321     switch (ttype_) {
322       case TensorType::kTRT:
323         return trt_tensor_->getDynamicRangeMin();
324       case TensorType::kSIMPLE:
325         return simple_tensor_->getDynamicRangeMin();
326     }
327     assert(0 && "Unsupported itensor_ type");
328   }
329 
getDynamicRangeMax()330   float getDynamicRangeMax() const {
331     switch (ttype_) {
332       case TensorType::kTRT:
333         return trt_tensor_->getDynamicRangeMax();
334       case TensorType::kSIMPLE:
335         return simple_tensor_->getDynamicRangeMax();
336     }
337     assert(0 && "Unsupported itensor_ type");
338   }
339 #if !IS_TRT_VERSION_GE(8, 0, 0, 0)
getDynamicRange()340   float getDynamicRange() const {
341     switch (ttype_) {
342       case TensorType::kTRT:
343         return trt_tensor_->getDynamicRange();
344       case TensorType::kSIMPLE:
345         return simple_tensor_->getDynamicRange();
346     }
347     assert(0 && "Unsupported itensor_ type");
348   }
349 #endif
setAllowedFormats(nvinfer1::TensorFormats formats)350   void setAllowedFormats(nvinfer1::TensorFormats formats) {
351     switch (ttype_) {
352       case TensorType::kTRT:
353         return trt_tensor_->setAllowedFormats(formats);
354       case TensorType::kSIMPLE:
355         return simple_tensor_->setAllowedFormats(formats);
356     }
357     assert(0 && "Unsupported itensor_ type");
358   }
359 
getAllowedFormats()360   nvinfer1::TensorFormats getAllowedFormats() const {
361     switch (ttype_) {
362       case TensorType::kTRT:
363         return trt_tensor_->getAllowedFormats();
364       case TensorType::kSIMPLE:
365         return simple_tensor_->getAllowedFormats();
366     }
367     assert(0 && "Unsupported itensor_ type");
368   }
369 
isShapeTensor()370   bool isShapeTensor() const {
371     switch (ttype_) {
372       case TensorType::kTRT:
373         return trt_tensor_->isShapeTensor();
374       case TensorType::kSIMPLE:
375         return simple_tensor_->isShapeTensor();
376     }
377     assert(0 && "Unsupported itensor_ type");
378   }
379 
isExecutionTensor()380   bool isExecutionTensor() const {
381     switch (ttype_) {
382       case TensorType::kTRT:
383         return trt_tensor_->isExecutionTensor();
384       case TensorType::kSIMPLE:
385         return simple_tensor_->isExecutionTensor();
386     }
387     assert(0 && "Unsupported itensor_ type");
388   }
389 
390  private:
validate()391   bool validate() const {
392     return (trt_tensor_ && !simple_tensor_) || (!trt_tensor_ && simple_tensor_);
393   }
394 
395   // When ITensorProxy represents an ITensor, the ITensor can be either passed
396   // by the caller via the constructor that takes an ITensor* as parameter, or
397   // be created as a SimpleITensor.
398   //
399   // In the first case, the ITensor pointer is stored in 'tensor_' below, and
400   // the ITensor itself is not owned by this class. This method is used by
401   // Converter (e.g. AddInputTensor) and op converters during TRT network
402   // construction, where the TRT network owns the ITensor.
403   //
404   nvinfer1::ITensor* trt_tensor_ = nullptr;  // Not owned.
405   // In the second case, the created SimpleITensor is stored in
406   // 'simple_itensor_' below and is owned by this class. SimpleITensor is a fake
407   // implementation of ITensor and is used for testing and by TrtNodeValidator
408   //  to validate the graph nodes.
409   std::shared_ptr<SimpleITensor> simple_tensor_ = nullptr;
410 
411   TensorType ttype_;
412 };
413 
414 class ITensorProxyPtr {
415  public:
ITensorProxyPtr(std::nullptr_t)416   ITensorProxyPtr(std::nullptr_t) : p_(nullptr) {}
ITensorProxyPtr(ITensorProxy * p)417   ITensorProxyPtr(ITensorProxy* p) : p_(p) {}
ITensorProxyPtr(nvinfer1::ITensor * p)418   ITensorProxyPtr(nvinfer1::ITensor* p) : p_(new ITensorProxy(p)) {}
ITensorProxyPtr(SimpleITensor * p)419   ITensorProxyPtr(SimpleITensor* p) : p_(new ITensorProxy(p)) {}
420 
ITensorProxyPtr()421   ITensorProxyPtr() : p_(new ITensorProxy()) {}
ITensorProxyPtr(const nvinfer1::Dims & dims)422   ITensorProxyPtr(const nvinfer1::Dims& dims) : p_(new ITensorProxy(dims)) {}
ITensorProxyPtr(const std::vector<int> & dims)423   ITensorProxyPtr(const std::vector<int>& dims) : p_(new ITensorProxy(dims)) {}
424 
425   std::shared_ptr<ITensorProxy> p_{nullptr};
426   ITensorProxy* operator->() { return p_.get(); }
427   ITensorProxy* operator->() const { return p_.get(); }
428   ITensorProxy* operator*() { return p_.get(); }
429   ITensorProxy* operator*() const { return p_.get(); }
430 };
431 
432 inline bool operator==(const ITensorProxyPtr& p1, const ITensorProxyPtr& p2) {
433   if (p1.p_ == nullptr) {
434     return p2.p_ == nullptr;
435   }
436   if (p2.p_ == nullptr) {
437     return p1.p_ == nullptr;
438   }
439   return (p1->ttype() == p2->ttype()) &&
440          ((p1->ttype() == TensorType::kTRT &&
441            p1->trt_tensor() == p2->trt_tensor()) ||
442           (p1->ttype() == TensorType::kSIMPLE &&
443            p1->simple_tensor() == p2->simple_tensor()));
444 }
445 
446 struct ITensorProxyHash {
operatorITensorProxyHash447   size_t operator()(const ITensorProxyPtr& tensor) const {
448     return reinterpret_cast<std::uintptr_t>(tensor.p_.get());
449   }
450 };
451 
452 }  // namespace tensorrt
453 }  // namespace tensorflow
454 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
455 
456 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_TENSOR_PROXY_H
457