• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_WEIGHTS_H_
16 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_WEIGHTS_H_
17 
18 #if GOOGLE_CUDA && GOOGLE_TENSORRT
19 
20 #include <vector>
21 
22 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
23 #include "tensorflow/compiler/tf2tensorrt/utils/trt_tensor_proxy.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/types.h"
27 #include "third_party/tensorrt/NvInfer.h"
28 
29 namespace tensorflow {
30 namespace tensorrt {
31 namespace convert {
32 
33 // Class to convert TF compile-time constants (e.g. Const nodes) to TRT weight.
34 class TRT_ShapedWeights {
35  public:
36   explicit TRT_ShapedWeights(
37       nvinfer1::DataType type = nvinfer1::DataType::kFLOAT);
38 
39   // Constructs a weights from another weights.
40   //
41   // NOTE: this does not copy the underlying buffer but only increase its
42   // reference count.
43   TRT_ShapedWeights(const TRT_ShapedWeights& rhs) = default;
44 
45   nvinfer1::Weights GetTrtWeights() const;
46 
GetTensor()47   const Tensor& GetTensor() const { return tensor_; }
48 
49   // Returns a pointer of type const T to the underlying buffer of the tensor.
50   template <typename T>
GetPointer()51   const T* GetPointer() const {
52     int64 num_elem =
53         (tensor_.NumElements() * DataTypeSize(tensor_.dtype())) / sizeof(T);
54     return tensor_.bit_casted_shaped<T, 1>({num_elem}).data();
55   }
56 
57   // Returns a pointer of type T to the underlying buffer of the tensor.
58   template <typename T>
GetPointer()59   T* GetPointer() {
60     int64 num_elem =
61         (tensor_.NumElements() * DataTypeSize(tensor_.dtype())) / sizeof(T);
62     return tensor_.bit_casted_shaped<T, 1>({num_elem}).data();
63   }
64 
65   // Fills all the weight values with value.
66   template <typename T>
SetValues(T value)67   Status SetValues(T value) {
68     switch (type_) {
69       case nvinfer1::DataType::kFLOAT: {
70         float* ptr = tensor_.flat<float>().data();
71         std::fill(ptr, ptr + volume_, value);
72         break;
73       }
74       case nvinfer1::DataType::kHALF: {
75         Eigen::half* ptr = tensor_.flat<Eigen::half>().data();
76         std::fill(ptr, ptr + volume_, Eigen::half(value));
77         break;
78       }
79       case nvinfer1::DataType::kINT32: {
80         int32* ptr = tensor_.flat<int32>().data();
81         std::fill(ptr, ptr + volume_, value);
82         break;
83       }
84       default:
85         return errors::InvalidArgument(
86             "Unsupported data type ", tensorflow::tensorrt::DebugString(type_));
87     }
88     return Status::OK();
89   }
90 
91   Status SetShape(DimsAdapter dims);
SetShapeUnsafe(DimsAdapter dims)92   void SetShapeUnsafe(DimsAdapter dims) { shape_ = std::move(dims); }
93 
94   // Returns total number of elements. Returning 0 means either some dim is 0
95   // or the number of dims is 0. Note that a TF scalar constant is marked as
96   // Dims{0, {1}}, and has a count() == 1.
count()97   int64_t count() const { return volume_; }
98 
99   size_t size_bytes() const;
100 
101   string DebugString() const;
102 
103   template <typename T>
GetSpan()104   absl::Span<const T> GetSpan() const {
105     return absl::Span<const T>(tensor_.flat<T>().data(), volume_);
106   }
107 
108   template <typename T>
ToVector()109   std::vector<T> ToVector() const {
110     auto span = GetSpan<T>();
111     return std::vector<T>(span.data(), span.data() + span.size());
112   }
113 
TrtDType()114   nvinfer1::DataType TrtDType() const { return type_; }
115 
Shape()116   const DimsAdapter& Shape() const { return shape_; }
Shape()117   DimsAdapter& Shape() { return shape_; }
118 
119  private:
120   // The shape of the weights. Defaults to the empty shape.
121   DimsAdapter shape_;
122 
123   // This creation method is only used by TrtWeightStore, which creates the
124   // underlying buffer.
125   static StatusOr<TRT_ShapedWeights> CreateWithTensor(nvinfer1::DataType type,
126                                                       DimsAdapter dims,
127                                                       Tensor tensor);
128 
129   nvinfer1::DataType type_;
130 
131   // All weights should be stored inside TrtWeightStore to make sure lifetime of
132   // all the underlying tensors are available until the engine is built. For
133   // this reason, tensor_ should never be reassigned to a different value that
134   // is not already present in the TrtWeightStore.
135   Tensor tensor_;
136   // Contains the volume of the weight's shape.
137   int64_t volume_;
138 
139   friend class TrtWeightStore;
140 };
141 
142 // Container for TRT_ShapedWeights. We need this container because TRT does not
143 // manage the lifetime of the weights buffer, it only keeps a pointer to it and
144 // requires that the data referenced by the pointer be available until the
145 // building of engine is complete. For more information see
146 // https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/c_api/classnvinfer1_1_1_weights.html
147 //
148 // TODO(laigd): consider adding garbage collection to the unused weights.
149 class TrtWeightStore {
150  public:
151   // Gets a TRT_ShapedWeights with 'type' and 'dims'.
152   StatusOr<TRT_ShapedWeights> GetTempWeights(nvinfer1::DataType trt_type,
153                                              const DimsAdapter& dims);
154 
155   // Gets a TRT_ShapedWeights with the same data type and dimensions as
156   // 'weights'.
GetTempWeights(const TRT_ShapedWeights & weights)157   StatusOr<TRT_ShapedWeights> GetTempWeights(const TRT_ShapedWeights& weights) {
158     return GetTempWeights(weights.TrtDType(), weights.Shape());
159   }
160 
161  private:
162   // The backend storage of the TRT_ShapedWeights.
163   std::vector<Tensor> store_;
164 };
165 
166 // Enumerates the possible types of arguments of a converter. This determines
167 // what object is contained in TRT_TensorOrWeights, and converters can require
168 // a specific type for each of their arguments.
169 enum class TRT_ArgumentType {
170   TENSOR = 0,
171   WEIGHTS = 1,
172   RESOURCE = 2,
173 };
174 
175 // Represents a TRT-style input to a TF node, it can be either a
176 // ITensorProxyPtr (representing nvinfer1::ITensor* or SimpleITensor),
177 // or TRT_ShapedWeights which is compile-time constant.
178 //
179 // TODO(laigd): maybe rename it to TrtArgument, or mimic XlaCompiler::Argument.
180 class TRT_TensorOrWeights {
181  public:
TRT_TensorOrWeights()182   TRT_TensorOrWeights() {}
183   TRT_TensorOrWeights(ITensorProxyPtr);
184   TRT_TensorOrWeights(ITensorProxyPtr tensor, int batch_size);
185 
186   // Constructs a wrapper for the given ITensor.
187   // This is used by Converter when building the TRT network, where the ITensor
188   // is owned by the TRT network being built. See comment for 'trt_tensor_'
189   // in trt_proxy_tensor.h.
190   explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor, int batch_size = -1);
191 
192   // Creates a SimpleITensor for trt_dtype and trt_dims and takes ownership of
193   // the object. Constructs a wrapper for the SimpleITensor. This is used by
194   // TrtNodeValidator to encapsulate the type and shape information for
195   // validation of graph nodes, and the created ITensor is fake and temporary,
196   // and should not be used to build any TRT network. See comment for
197   // 'simple_tensor_' in trt_proxy_tensor.h.
198   explicit TRT_TensorOrWeights(nvinfer1::DataType trt_dtype,
199                                const nvinfer1::Dims& trt_dims, int batch_size);
200 
201   // Constructs a wrapper for the given weights.
202   explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights);
203 
204   // Constructs a wrapper for the given resource handle.
205   explicit TRT_TensorOrWeights(const ResourceHandle& resource);
206 
207   TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs);
208 
209   void operator=(const TRT_TensorOrWeights& rhs);
210 
is_tensor()211   bool is_tensor() const {
212     return initialized_ && arg_type_ == TRT_ArgumentType::TENSOR;
213   }
is_weights()214   bool is_weights() const {
215     return initialized_ && arg_type_ == TRT_ArgumentType::WEIGHTS;
216   }
is_resource()217   bool is_resource() const {
218     return initialized_ && arg_type_ == TRT_ArgumentType::RESOURCE;
219   }
220 
221   ITensorProxyPtr tensor() const;
222 
223   ResourceHandle resource() const;
224 
weights()225   TRT_ShapedWeights& weights() {
226     DCHECK(is_weights());
227     return weights_;
228   }
229 
weights()230   const TRT_ShapedWeights& weights() const {
231     DCHECK(is_weights());
232     return weights_;
233   }
234 
235   nvinfer1::Dims GetTrtDims() const;
236 
237   Status GetTfType(DataType* tf_type) const;
238 
batch_size()239   int batch_size() const { return batch_size_; }
240 
241   string DebugString() const;
242 
TrtDType()243   nvinfer1::DataType TrtDType() const {
244     if (arg_type_ == TRT_ArgumentType::RESOURCE) {
245       VLOG(0) << "Calling TrtDType() with a RESOURCE argument is undefined "
246                  "behavior.";
247     }
248     return arg_type_ == TRT_ArgumentType::TENSOR ? tensor_proxy_ptr_->getType()
249                                                  : weights_.TrtDType();
250   }
251 
252  private:
set_batch_size(int batch_size)253   void set_batch_size(int batch_size) { batch_size_ = batch_size; }
254 
255   // First dimension of the TF tensor (NOT tensor_) that is represented by
256   // tensor_ is treated as the "batch dimension" by TRT, and tensor_'s
257   // dimensions (obtained via tensor_->getDimensions()) do not contain the batch
258   // dimension. For example, when a TF tensor with shape (A,B,C) is represented
259   // in TRT, tensor_->getDimensions() will be (B,C) and batch_size_ will be A.
260   //
261   // This requires that all tensors in the subgraph that is converted to a TRT
262   // engine have the same batch size are represented by the first dimension of
263   // their shape, and Converter will verify this during conversion. The drawback
264   // is that currently it cannot convert a graph that doesn't have the batch
265   // size represented in the shapes or the batch sizes are different. See
266   // b/118387490 for more details.
267   //
268   // If use_implicit_batch is false, batch_size_ is unused and
269   // tensor_->getDimensions() will contain the entire shape (A,B,C).
270   //
271   // tensor_proxy_ptr_ is used when arg_type_ == TENSOR.
272   ITensorProxyPtr tensor_proxy_ptr_ = nullptr;
273   int batch_size_ = -1;
274 
275   // For DT_RESOURCE arguments (there is no corresponding type in TRT).
276   // resource_ is used when arg_type_ == RESOURCE.
277   ResourceHandle resource_;
278 
279   // weights_ is used when arg_type_ == WEIGHTS.
280   TRT_ShapedWeights weights_;
281   bool initialized_ = false;
282   TRT_ArgumentType arg_type_ = TRT_ArgumentType::WEIGHTS;
283 
284   friend class Converter;
285 };
286 }  // namespace convert
287 }  // namespace tensorrt
288 }  // namespace tensorflow
289 
290 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
291 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_WEIGHTS_H_
292