1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #if GOOGLE_CUDA && GOOGLE_TENSORRT
16 #include "absl/strings/str_cat.h"
17 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
18 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
19 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter.h"
20 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
21 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
22 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/stream_executor.h"
26 #include "third_party/tensorrt/NvInfer.h"
27 #include "third_party/tensorrt/NvInferRuntimeCommon.h"
28
29 namespace tensorflow {
30 namespace tensorrt {
31 namespace convert {
32
33 struct VarAttributes {
34 TensorShapeProto shape_proto;
35 TensorShape shape;
36 string name;
37 DataType dtype;
38 string shared_name;
39 string container;
40 };
41
42 template <typename T, bool is_resource>
ReadVariableHelper(const OpConverterParams * params,const VarAttributes & attrs,TRT_ShapedWeights * weights)43 Status ReadVariableHelper(const OpConverterParams* params,
44 const VarAttributes& attrs,
45 TRT_ShapedWeights* weights) {
46 Tensor tensor(attrs.dtype, attrs.shape);
47 auto ctx = params->converter->context();
48 TRT_ENSURE(ctx != nullptr);
49 auto tensor_flat = tensor.flat<T>();
50
51 // Clone function library runtime in order to get a mutable library
52 // definition to add and run a function with the variable operation.
53 auto lib = ctx->function_library();
54 std::unique_ptr<FunctionLibraryDefinition> lib_def;
55 std::unique_ptr<ProcessFunctionLibraryRuntime> lib_pflr;
56 FunctionLibraryRuntime* lib_clone; // Not owned.
57 TF_RETURN_IF_ERROR(lib->Clone(&lib_def, &lib_pflr, &lib_clone));
58
59 // Create function definition.
60 FunctionDef fdef;
61 std::vector<Tensor> args;
62 string func_name = attrs.name + "/func";
63 if (is_resource) {
64 // Create input tensor with the resource handle.
65 const auto& inputs = params->inputs;
66 const TRT_TensorOrWeights& handle = inputs.at(0);
67 args.emplace_back(handle.resource());
68
69 fdef = FunctionDefHelper::Define(
70 func_name, // Name
71 {"in: resource"}, // Args
72 {absl::StrCat("out: ", DataTypeString(attrs.dtype))}, // Returns
73 {}, // Attr def
74 // Nodes
75 {{{attrs.name},
76 "ReadVariableOp",
77 {"in"}, // Name of the Placeholder or VarHandleOp
78 {{"dtype", attrs.dtype}}},
79 {{"out"}, "Identity", {attrs.name}, {{"T", attrs.dtype}}}});
80 } else {
81 fdef = FunctionDefHelper::Define(
82 func_name, // Name
83 {}, // Args
84 {absl::StrCat("out: ", DataTypeString(attrs.dtype))}, // Returns
85 {}, // Attr def
86 // Nodes
87 {{{attrs.name},
88 "VariableV2",
89 {},
90 {{"dtype", attrs.dtype},
91 {"shape", attrs.shape_proto},
92 {"container", attrs.container},
93 {"shared_name", attrs.shared_name}}},
94 {{"out"}, "Identity", {attrs.name}, {{"T", attrs.dtype}}}});
95 }
96
97 // Add function definition to the library.
98 TF_RETURN_IF_ERROR(lib_def->AddFunctionDef(fdef));
99
100 // Instantiate function.
101 FunctionLibraryRuntime::Handle func_handle;
102 FunctionLibraryRuntime::InstantiateOptions inst_ops;
103 inst_ops.state_handle = "";
104 inst_ops.target = ctx->device()->name();
105 AttrValueMap attr_list;
106 TF_RETURN_IF_ERROR(lib_clone->Instantiate(func_name, AttrSlice(&attr_list),
107 inst_ops, &func_handle));
108
109 FunctionLibraryRuntime::Options opts;
110 opts.rendezvous = ctx->rendezvous();
111 opts.cancellation_manager = ctx->cancellation_manager();
112 opts.runner = ctx->runner();
113
114 std::vector<Tensor>* rets = new std::vector<Tensor>();
115 std::unique_ptr<std::vector<Tensor>> outputs_wrapper(rets);
116
117 // Run the new function synchronously.
118 TF_RETURN_IF_ERROR(lib_clone->RunSync(opts, func_handle, args, rets));
119
120 TRT_ENSURE(ctx->op_device_context() != nullptr);
121 TRT_ENSURE(ctx->op_device_context()->stream() != nullptr);
122
123 // Copy tensor.
124 const cudaStream_t* stream = CHECK_NOTNULL(
125 reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
126 ->stream()
127 ->implementation()
128 ->GpuStreamMemberHack()));
129
130 TRT_ENSURE(stream != nullptr);
131
132 auto ret = cudaMemcpyAsync(tensor_flat.data(), rets->at(0).flat<T>().data(),
133 rets->at(0).NumElements() * sizeof(T),
134 cudaMemcpyDeviceToHost, *stream);
135 if (ret != 0) {
136 return errors::Internal("Could not copy the variable ", attrs.name);
137 }
138 cudaStreamSynchronize(*stream);
139
140 TF_RETURN_IF_ERROR(
141 TfTensorToTrtWeights(tensor, params->weight_store, weights));
142
143 return Status::OK();
144 }
145
146 class ConvertVariableV2 : public OpConverterBase<ConvertVariableV2> {
147 public:
ConvertVariableV2(OpConverterParams * params)148 ConvertVariableV2(OpConverterParams* params)
149 : OpConverterBase<ConvertVariableV2>(params) {}
150
InputSpec()151 static constexpr std::array<InputArgSpec, 0> InputSpec() { return {}; }
152
AllowedDataTypes()153 static constexpr std::array<DataType, 2> AllowedDataTypes() {
154 return {DataType::DT_FLOAT, DataType::DT_HALF};
155 }
156
NodeDefDataTypeAttributeName()157 static constexpr const char* NodeDefDataTypeAttributeName() {
158 return "dtype";
159 }
160
161 template <typename T>
ValidateImpl()162 Status ValidateImpl() {
163 const auto& node_def = params_->node_def;
164
165 // Verify and consume node attributes.
166 StatusOr<TensorShapeProto> shape_proto =
167 GetAttrValue<TensorShapeProto>("shape");
168 StatusOr<string> shared_name = GetAttrValue<string>("shared_name");
169 StatusOr<string> container = GetAttrValue<string>("container");
170 TRT_ENSURE_OK(shape_proto);
171 TRT_ENSURE_OK(shared_name);
172 TRT_ENSURE_OK(container);
173
174 attrs_.shape_proto = *shape_proto;
175 attrs_.shape = TensorShape(*shape_proto);
176 attrs_.name = node_def.name();
177 attrs_.shared_name = *shared_name;
178 attrs_.container = *container;
179
180 Tensor tensor(attrs_.dtype, attrs_.shape);
181 auto tensor_flat = tensor.flat<T>();
182 for (int64_t i = 0; i < tensor_flat.size(); i++) {
183 tensor_flat(i) = T(0.0f);
184 }
185
186 TRT_ShapedWeights weights;
187 TF_RETURN_IF_ERROR(
188 TfTensorToTrtWeights(tensor, params_->weight_store, &weights));
189
190 // Only push outputs during validation and when outputs are expected.
191 if (params_->validation_only && params_->outputs != nullptr) {
192 AddOutput(TRT_TensorOrWeights(weights));
193 }
194 return Status::OK();
195 }
196
Validate()197 Status Validate() {
198 const auto& node_def = params_->node_def;
199 StatusOr<DataType> dtype = GetAttrValue<DataType>("dtype");
200 TRT_ENSURE_OK(dtype);
201 attrs_.dtype = *dtype;
202
203 switch (attrs_.dtype) {
204 case DT_FLOAT:
205 return ValidateImpl<float>();
206 case DT_HALF:
207 return ValidateImpl<Eigen::half>();
208 default:
209 // Note: this should have been caught by ValidateNodeDefDataType, but
210 // the compiler expects that all paths be handled in switch.
211 return errors::Unimplemented("Data type ", DataTypeString(attrs_.dtype),
212 " is not supported for ", node_def.op(),
213 ", at ", node_def.name());
214 }
215 }
216
217 template <typename T>
ConvertImpl()218 Status ConvertImpl() {
219 TRT_ShapedWeights weights;
220 TF_RETURN_IF_ERROR(ReadVariableHelper<T, false>(params_, attrs_, &weights));
221 AddOutput(TRT_TensorOrWeights(weights));
222 return Status::OK();
223 }
224
Convert()225 Status Convert() {
226 const auto& node_def = params_->node_def;
227
228 switch (attrs_.dtype) {
229 case DT_FLOAT:
230 return ConvertImpl<float>();
231 case DT_HALF:
232 return ConvertImpl<Eigen::half>();
233 default:
234 // Note: this should have been caught by ValidateNodeDefDataType, but
235 // the compiler expects that all paths be handled in switch.
236 return errors::Unimplemented("Data type ", DataTypeString(attrs_.dtype),
237 " is not supported for ", node_def.op(),
238 ", at ", node_def.name());
239 }
240 }
241
242 private:
243 VarAttributes attrs_{};
244 };
245 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertVariableV2>(),
246 {"VariableV2"});
247
248 class ConvertReadVariableOp : public OpConverterBase<ConvertReadVariableOp> {
249 public:
ConvertReadVariableOp(OpConverterParams * params)250 ConvertReadVariableOp(OpConverterParams* params)
251 : OpConverterBase<ConvertReadVariableOp>(params) {}
252
InputSpec()253 static constexpr std::array<InputArgSpec, 1> InputSpec() {
254 return {InputArgSpec::Create("resource", TrtInputArg::kResource)};
255 }
256
AllowedDataTypes()257 static constexpr std::array<DataType, 2> AllowedDataTypes() {
258 return {DataType::DT_FLOAT, DataType::DT_HALF};
259 }
260
NodeDefDataTypeAttributeName()261 static constexpr const char* NodeDefDataTypeAttributeName() {
262 return "dtype";
263 }
264
265 template <typename T>
ValidateImpl()266 Status ValidateImpl() {
267 const auto& node_def = params_->node_def;
268
269 // Verify and consume node attributes.
270 StatusOr<TensorShapeProto> shape_proto =
271 GetAttrValue<TensorShapeProto>("_shape");
272 TRT_ENSURE_OK(shape_proto);
273
274 attrs_.shape_proto = *shape_proto;
275 attrs_.shape = TensorShape(*shape_proto);
276 attrs_.name = node_def.name();
277
278 Tensor tensor(attrs_.dtype, attrs_.shape);
279 auto tensor_flat = tensor.flat<T>();
280 for (int64_t i = 0; i < tensor_flat.size(); i++) {
281 tensor_flat(i) = T(0.0f);
282 }
283
284 TRT_ShapedWeights weights;
285 TF_RETURN_IF_ERROR(
286 TfTensorToTrtWeights(tensor, params_->weight_store, &weights));
287
288 // Only push outputs during validation and when outputs are expected.
289 if (params_->validation_only && params_->outputs != nullptr) {
290 AddOutput(TRT_TensorOrWeights(weights));
291 }
292 return Status::OK();
293 }
294
Validate()295 Status Validate() {
296 const auto& node_def = params_->node_def;
297 if (params_->use_implicit_batch) {
298 return errors::Unimplemented("Implicit batch mode not supported, at ",
299 node_def.name());
300 }
301
302 StatusOr<DataType> dtype = GetAttrValue<DataType>("dtype");
303 TRT_ENSURE_OK(dtype);
304 attrs_.dtype = *dtype;
305
306 switch (attrs_.dtype) {
307 case DT_FLOAT:
308 return ValidateImpl<float>();
309 case DT_HALF:
310 return ValidateImpl<Eigen::half>();
311 default:
312 // Note: this should have been caught by ValidateNodeDefDataType, but
313 // the compiler expects that all paths be handled in switch.
314 return errors::Unimplemented("Data type ", DataTypeString(attrs_.dtype),
315 " is not supported for ", node_def.op(),
316 ", at ", node_def.name());
317 }
318 }
319
320 template <typename T>
ConvertImpl()321 Status ConvertImpl() {
322 TRT_ShapedWeights weights;
323 TF_RETURN_IF_ERROR(ReadVariableHelper<T, true>(params_, attrs_, &weights));
324 AddOutput(TRT_TensorOrWeights(weights));
325 return Status::OK();
326 }
327
Convert()328 Status Convert() {
329 const auto& node_def = params_->node_def;
330
331 switch (attrs_.dtype) {
332 case DT_FLOAT:
333 return ConvertImpl<float>();
334 case DT_HALF:
335 return ConvertImpl<Eigen::half>();
336 default:
337 // Note: this should have been caught by ValidateNodeDefDataType, but
338 // the compiler expects that all paths be handled in switch.
339 return errors::Unimplemented("Data type ", DataTypeString(attrs_.dtype),
340 " is not supported for ", node_def.op(),
341 ", at ", node_def.name());
342 }
343 }
344
345 private:
346 VarAttributes attrs_{};
347 };
348 REGISTER_DEFAULT_TRT_OP_CONVERTER(
349 MakeConverterFunction<ConvertReadVariableOp>(), {"ReadVariableOp"});
350
351 } // namespace convert
352 } // namespace tensorrt
353 } // namespace tensorflow
354
355 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
356