• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/tf2tensorrt/convert/weights.h"
16 
17 #include <functional>
18 #include <numeric>
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
22 
23 #if GOOGLE_CUDA && GOOGLE_TENSORRT
24 
25 namespace tensorflow {
26 namespace tensorrt {
27 
28 namespace convert {
29 
TRT_ShapedWeights(nvinfer1::DataType type)30 TRT_ShapedWeights::TRT_ShapedWeights(nvinfer1::DataType type)
31     : shape_(0, DimsAdapter::StorageType{}), type_(type), volume_(0) {}
32 
CreateWithTensor(nvinfer1::DataType type,DimsAdapter dims,Tensor tensor)33 StatusOr<TRT_ShapedWeights> TRT_ShapedWeights::CreateWithTensor(
34     nvinfer1::DataType type, DimsAdapter dims, Tensor tensor) {
35   TRT_ShapedWeights weights(type);
36   weights.shape_ = dims;
37   weights.tensor_ = std::forward<Tensor>(tensor);
38   weights.volume_ = weights.shape_.Volume();
39   if (weights.shape_.NumDims() == 0) {
40     DCHECK(weights.shape_.IsEmpty() || weights.shape_.IsScalar());
41   }
42   return weights;
43 }
44 
GetTrtWeights() const45 nvinfer1::Weights TRT_ShapedWeights::GetTrtWeights() const {
46   return nvinfer1::Weights{type_, GetPointer<int8>(), volume_};
47 }
48 
SetShape(DimsAdapter dims)49 Status TRT_ShapedWeights::SetShape(DimsAdapter dims) {
50   if (volume_ != dims.Volume()) {
51     VLOG(2) << "Changing shape from " << shape_.DebugString() << ", to "
52             << dims.DebugString();
53     return errors::Internal("SetShape would change number of elements");
54   }
55   shape_ = std::move(dims);
56   return Status::OK();
57 }
58 
size_bytes() const59 size_t TRT_ShapedWeights::size_bytes() const {
60   size_t data_type_size = -1;
61   switch (type_) {
62     case nvinfer1::DataType::kFLOAT:
63     case nvinfer1::DataType::kINT32:
64       data_type_size = 4;
65       break;
66     case nvinfer1::DataType::kHALF:
67       data_type_size = 2;
68       break;
69     case nvinfer1::DataType::kINT8:
70     case nvinfer1::DataType::kBOOL:
71       data_type_size = 1;
72       break;
73   }
74   return volume_ * data_type_size;
75 }
76 
DebugString() const77 string TRT_ShapedWeights::DebugString() const {
78   return absl::StrCat(
79       "TRT_ShapedWeights(shape=", shape_.DebugString(),
80       ", type=", tensorflow::tensorrt::DebugString(type_),
81       ", values=", reinterpret_cast<uintptr_t>(GetPointer<int8>()), ")");
82 }
83 
TRT_TensorOrWeights(ITensorProxyPtr tensor)84 TRT_TensorOrWeights::TRT_TensorOrWeights(ITensorProxyPtr tensor)
85     : tensor_proxy_ptr_(tensor),
86       initialized_(true),
87       arg_type_(TRT_ArgumentType::TENSOR) {}
88 
TRT_TensorOrWeights(ITensorProxyPtr tensor,int batch_size)89 TRT_TensorOrWeights::TRT_TensorOrWeights(ITensorProxyPtr tensor, int batch_size)
90     : tensor_proxy_ptr_(tensor),
91       batch_size_(batch_size),
92       initialized_(true),
93       arg_type_(TRT_ArgumentType::TENSOR) {}
94 
TRT_TensorOrWeights(nvinfer1::ITensor * tensor,int batch_size)95 TRT_TensorOrWeights::TRT_TensorOrWeights(nvinfer1::ITensor* tensor,
96                                          int batch_size)
97     : tensor_proxy_ptr_(tensor),
98       batch_size_(batch_size),
99       initialized_(true),
100       arg_type_(TRT_ArgumentType::TENSOR) {}
101 
TRT_TensorOrWeights(nvinfer1::DataType trt_dtype,const nvinfer1::Dims & trt_dims,int batch_size)102 TRT_TensorOrWeights::TRT_TensorOrWeights(nvinfer1::DataType trt_dtype,
103                                          const nvinfer1::Dims& trt_dims,
104                                          int batch_size)
105     : tensor_proxy_ptr_(new SimpleITensor(trt_dtype, trt_dims)),
106       batch_size_(batch_size),
107       initialized_(true),
108       arg_type_(TRT_ArgumentType::TENSOR) {}
109 
TRT_TensorOrWeights(const TRT_ShapedWeights & weights)110 TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_ShapedWeights& weights)
111     : weights_(weights),
112       initialized_(true),
113       arg_type_(TRT_ArgumentType::WEIGHTS) {}
114 
TRT_TensorOrWeights(const ResourceHandle & resource)115 TRT_TensorOrWeights::TRT_TensorOrWeights(const ResourceHandle& resource)
116     : resource_(resource),
117       initialized_(true),
118       arg_type_(TRT_ArgumentType::RESOURCE) {}
119 
TRT_TensorOrWeights(const TRT_TensorOrWeights & rhs)120 TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs)
121     : tensor_proxy_ptr_(rhs.tensor_proxy_ptr_),
122       batch_size_(rhs.batch_size_),
123       resource_(rhs.resource_),
124       weights_(rhs.weights_),
125       initialized_(rhs.initialized_),
126       arg_type_(rhs.arg_type_) {}
127 
operator =(const TRT_TensorOrWeights & rhs)128 void TRT_TensorOrWeights::operator=(const TRT_TensorOrWeights& rhs) {
129   tensor_proxy_ptr_ = rhs.tensor_proxy_ptr_;
130   batch_size_ = rhs.batch_size_;
131   weights_ = rhs.weights_;
132   resource_ = rhs.resource_;
133   initialized_ = rhs.initialized_;
134   arg_type_ = rhs.arg_type_;
135 }
136 
tensor() const137 ITensorProxyPtr TRT_TensorOrWeights::tensor() const {
138   DCHECK(is_tensor());
139   return tensor_proxy_ptr_;
140 }
141 
resource() const142 ResourceHandle TRT_TensorOrWeights::resource() const {
143   DCHECK(is_resource());
144   return resource_;
145 }
146 
GetTrtDims() const147 nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const {
148   switch (arg_type_) {
149     case TRT_ArgumentType::TENSOR:
150       return tensor()->getDimensions();
151     case TRT_ArgumentType::WEIGHTS:
152       return weights().Shape().AsTrtDims();
153     case TRT_ArgumentType::RESOURCE:
154       return {0, {}};  // Scalar.
155   }
156 }
157 
GetTfType(DataType * tf_type) const158 Status TRT_TensorOrWeights::GetTfType(DataType* tf_type) const {
159   if (!initialized_) {
160     return errors::Internal("The object is not initialized");
161   }
162   switch (arg_type_) {
163     case TRT_ArgumentType::TENSOR: {
164       nvinfer1::DataType trt_type = tensor()->getType();
165       return TrtTypeToTfType(trt_type, tf_type);
166     }
167     case TRT_ArgumentType::WEIGHTS:
168       *tf_type = weights().GetTensor().dtype();
169       return Status::OK();
170     case TRT_ArgumentType::RESOURCE:
171       *tf_type = DataType::DT_RESOURCE;
172       return Status::OK();
173   }
174 }
175 
DebugString() const176 string TRT_TensorOrWeights::DebugString() const {
177   string output = "TRT_TensorOrWeights(type=";
178   if (is_tensor()) {
179     absl::StrAppend(&output,
180                     "tensor=", tensorflow::tensorrt::DebugString(tensor()),
181                     ", batch_size=", batch_size_);
182   } else {
183     absl::StrAppend(&output, "weights=", weights_.DebugString());
184   }
185   absl::StrAppend(&output, ")");
186   return output;
187 }
188 
GetTempWeights(nvinfer1::DataType trt_dtype,const DimsAdapter & dims)189 StatusOr<TRT_ShapedWeights> TrtWeightStore::GetTempWeights(
190     nvinfer1::DataType trt_dtype, const DimsAdapter& dims) {
191   DataType tf_dtype;
192   TF_RETURN_IF_ERROR(TrtTypeToTfType(trt_dtype, &tf_dtype));
193   TensorShape shape;
194   TF_RETURN_IF_ERROR(dims.TensorShape(&shape));
195   // TODO(jie): check weights size_bytes. 0 means type error
196   Tensor tensor(tf_dtype, shape);
197   StatusOr<TRT_ShapedWeights> weights =
198       TRT_ShapedWeights::CreateWithTensor(trt_dtype, dims, tensor);
199   TRT_ENSURE_OK(weights);
200   store_.emplace_back(std::move(tensor));
201   return weights;
202 }
203 
204 }  // namespace convert
205 }  // namespace tensorrt
206 }  // namespace tensorflow
207 
208 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
209