1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/tf2tensorrt/convert/weights.h"
16
17 #include <functional>
18 #include <numeric>
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
22
23 #if GOOGLE_CUDA && GOOGLE_TENSORRT
24
25 namespace tensorflow {
26 namespace tensorrt {
27
28 namespace convert {
29
TRT_ShapedWeights(nvinfer1::DataType type)30 TRT_ShapedWeights::TRT_ShapedWeights(nvinfer1::DataType type)
31 : shape_(0, DimsAdapter::StorageType{}), type_(type), volume_(0) {}
32
CreateWithTensor(nvinfer1::DataType type,DimsAdapter dims,Tensor tensor)33 StatusOr<TRT_ShapedWeights> TRT_ShapedWeights::CreateWithTensor(
34 nvinfer1::DataType type, DimsAdapter dims, Tensor tensor) {
35 TRT_ShapedWeights weights(type);
36 weights.shape_ = dims;
37 weights.tensor_ = std::forward<Tensor>(tensor);
38 weights.volume_ = weights.shape_.Volume();
39 if (weights.shape_.NumDims() == 0) {
40 DCHECK(weights.shape_.IsEmpty() || weights.shape_.IsScalar());
41 }
42 return weights;
43 }
44
GetTrtWeights() const45 nvinfer1::Weights TRT_ShapedWeights::GetTrtWeights() const {
46 return nvinfer1::Weights{type_, GetPointer<int8>(), volume_};
47 }
48
SetShape(DimsAdapter dims)49 Status TRT_ShapedWeights::SetShape(DimsAdapter dims) {
50 if (volume_ != dims.Volume()) {
51 VLOG(2) << "Changing shape from " << shape_.DebugString() << ", to "
52 << dims.DebugString();
53 return errors::Internal("SetShape would change number of elements");
54 }
55 shape_ = std::move(dims);
56 return Status::OK();
57 }
58
size_bytes() const59 size_t TRT_ShapedWeights::size_bytes() const {
60 size_t data_type_size = -1;
61 switch (type_) {
62 case nvinfer1::DataType::kFLOAT:
63 case nvinfer1::DataType::kINT32:
64 data_type_size = 4;
65 break;
66 case nvinfer1::DataType::kHALF:
67 data_type_size = 2;
68 break;
69 case nvinfer1::DataType::kINT8:
70 case nvinfer1::DataType::kBOOL:
71 data_type_size = 1;
72 break;
73 }
74 return volume_ * data_type_size;
75 }
76
DebugString() const77 string TRT_ShapedWeights::DebugString() const {
78 return absl::StrCat(
79 "TRT_ShapedWeights(shape=", shape_.DebugString(),
80 ", type=", tensorflow::tensorrt::DebugString(type_),
81 ", values=", reinterpret_cast<uintptr_t>(GetPointer<int8>()), ")");
82 }
83
TRT_TensorOrWeights(ITensorProxyPtr tensor)84 TRT_TensorOrWeights::TRT_TensorOrWeights(ITensorProxyPtr tensor)
85 : tensor_proxy_ptr_(tensor),
86 initialized_(true),
87 arg_type_(TRT_ArgumentType::TENSOR) {}
88
TRT_TensorOrWeights(ITensorProxyPtr tensor,int batch_size)89 TRT_TensorOrWeights::TRT_TensorOrWeights(ITensorProxyPtr tensor, int batch_size)
90 : tensor_proxy_ptr_(tensor),
91 batch_size_(batch_size),
92 initialized_(true),
93 arg_type_(TRT_ArgumentType::TENSOR) {}
94
TRT_TensorOrWeights(nvinfer1::ITensor * tensor,int batch_size)95 TRT_TensorOrWeights::TRT_TensorOrWeights(nvinfer1::ITensor* tensor,
96 int batch_size)
97 : tensor_proxy_ptr_(tensor),
98 batch_size_(batch_size),
99 initialized_(true),
100 arg_type_(TRT_ArgumentType::TENSOR) {}
101
TRT_TensorOrWeights(nvinfer1::DataType trt_dtype,const nvinfer1::Dims & trt_dims,int batch_size)102 TRT_TensorOrWeights::TRT_TensorOrWeights(nvinfer1::DataType trt_dtype,
103 const nvinfer1::Dims& trt_dims,
104 int batch_size)
105 : tensor_proxy_ptr_(new SimpleITensor(trt_dtype, trt_dims)),
106 batch_size_(batch_size),
107 initialized_(true),
108 arg_type_(TRT_ArgumentType::TENSOR) {}
109
TRT_TensorOrWeights(const TRT_ShapedWeights & weights)110 TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_ShapedWeights& weights)
111 : weights_(weights),
112 initialized_(true),
113 arg_type_(TRT_ArgumentType::WEIGHTS) {}
114
TRT_TensorOrWeights(const ResourceHandle & resource)115 TRT_TensorOrWeights::TRT_TensorOrWeights(const ResourceHandle& resource)
116 : resource_(resource),
117 initialized_(true),
118 arg_type_(TRT_ArgumentType::RESOURCE) {}
119
TRT_TensorOrWeights(const TRT_TensorOrWeights & rhs)120 TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs)
121 : tensor_proxy_ptr_(rhs.tensor_proxy_ptr_),
122 batch_size_(rhs.batch_size_),
123 resource_(rhs.resource_),
124 weights_(rhs.weights_),
125 initialized_(rhs.initialized_),
126 arg_type_(rhs.arg_type_) {}
127
operator =(const TRT_TensorOrWeights & rhs)128 void TRT_TensorOrWeights::operator=(const TRT_TensorOrWeights& rhs) {
129 tensor_proxy_ptr_ = rhs.tensor_proxy_ptr_;
130 batch_size_ = rhs.batch_size_;
131 weights_ = rhs.weights_;
132 resource_ = rhs.resource_;
133 initialized_ = rhs.initialized_;
134 arg_type_ = rhs.arg_type_;
135 }
136
tensor() const137 ITensorProxyPtr TRT_TensorOrWeights::tensor() const {
138 DCHECK(is_tensor());
139 return tensor_proxy_ptr_;
140 }
141
resource() const142 ResourceHandle TRT_TensorOrWeights::resource() const {
143 DCHECK(is_resource());
144 return resource_;
145 }
146
GetTrtDims() const147 nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const {
148 switch (arg_type_) {
149 case TRT_ArgumentType::TENSOR:
150 return tensor()->getDimensions();
151 case TRT_ArgumentType::WEIGHTS:
152 return weights().Shape().AsTrtDims();
153 case TRT_ArgumentType::RESOURCE:
154 return {0, {}}; // Scalar.
155 }
156 }
157
GetTfType(DataType * tf_type) const158 Status TRT_TensorOrWeights::GetTfType(DataType* tf_type) const {
159 if (!initialized_) {
160 return errors::Internal("The object is not initialized");
161 }
162 switch (arg_type_) {
163 case TRT_ArgumentType::TENSOR: {
164 nvinfer1::DataType trt_type = tensor()->getType();
165 return TrtTypeToTfType(trt_type, tf_type);
166 }
167 case TRT_ArgumentType::WEIGHTS:
168 *tf_type = weights().GetTensor().dtype();
169 return Status::OK();
170 case TRT_ArgumentType::RESOURCE:
171 *tf_type = DataType::DT_RESOURCE;
172 return Status::OK();
173 }
174 }
175
DebugString() const176 string TRT_TensorOrWeights::DebugString() const {
177 string output = "TRT_TensorOrWeights(type=";
178 if (is_tensor()) {
179 absl::StrAppend(&output,
180 "tensor=", tensorflow::tensorrt::DebugString(tensor()),
181 ", batch_size=", batch_size_);
182 } else {
183 absl::StrAppend(&output, "weights=", weights_.DebugString());
184 }
185 absl::StrAppend(&output, ")");
186 return output;
187 }
188
GetTempWeights(nvinfer1::DataType trt_dtype,const DimsAdapter & dims)189 StatusOr<TRT_ShapedWeights> TrtWeightStore::GetTempWeights(
190 nvinfer1::DataType trt_dtype, const DimsAdapter& dims) {
191 DataType tf_dtype;
192 TF_RETURN_IF_ERROR(TrtTypeToTfType(trt_dtype, &tf_dtype));
193 TensorShape shape;
194 TF_RETURN_IF_ERROR(dims.TensorShape(&shape));
195 // TODO(jie): check weights size_bytes. 0 means type error
196 Tensor tensor(tf_dtype, shape);
197 StatusOr<TRT_ShapedWeights> weights =
198 TRT_ShapedWeights::CreateWithTensor(trt_dtype, dims, tensor);
199 TRT_ENSURE_OK(weights);
200 store_.emplace_back(std::move(tensor));
201 return weights;
202 }
203
204 } // namespace convert
205 } // namespace tensorrt
206 } // namespace tensorflow
207
208 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
209