• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #if GOOGLE_CUDA && GOOGLE_TENSORRT
16 #include "absl/strings/str_cat.h"
17 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
18 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
19 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter.h"
20 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
21 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
22 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/stream_executor.h"
26 #include "third_party/tensorrt/NvInfer.h"
27 #include "third_party/tensorrt/NvInferRuntimeCommon.h"
28 
29 namespace tensorflow {
30 namespace tensorrt {
31 namespace convert {
32 
33 struct VarAttributes {
34   TensorShapeProto shape_proto;
35   TensorShape shape;
36   string name;
37   DataType dtype;
38   string shared_name;
39   string container;
40 };
41 
42 template <typename T, bool is_resource>
ReadVariableHelper(const OpConverterParams * params,const VarAttributes & attrs,TRT_ShapedWeights * weights)43 Status ReadVariableHelper(const OpConverterParams* params,
44                           const VarAttributes& attrs,
45                           TRT_ShapedWeights* weights) {
46   Tensor tensor(attrs.dtype, attrs.shape);
47   auto ctx = params->converter->context();
48   TRT_ENSURE(ctx != nullptr);
49   auto tensor_flat = tensor.flat<T>();
50 
51   // Clone function library runtime in order to get a mutable library
52   // definition to add and run a function with the variable operation.
53   auto lib = ctx->function_library();
54   std::unique_ptr<FunctionLibraryDefinition> lib_def;
55   std::unique_ptr<ProcessFunctionLibraryRuntime> lib_pflr;
56   FunctionLibraryRuntime* lib_clone;  // Not owned.
57   TF_RETURN_IF_ERROR(lib->Clone(&lib_def, &lib_pflr, &lib_clone));
58 
59   // Create function definition.
60   FunctionDef fdef;
61   std::vector<Tensor> args;
62   string func_name = attrs.name + "/func";
63   if (is_resource) {
64     // Create input tensor with the resource handle.
65     const auto& inputs = params->inputs;
66     const TRT_TensorOrWeights& handle = inputs.at(0);
67     args.emplace_back(handle.resource());
68 
69     fdef = FunctionDefHelper::Define(
70         func_name,                                             // Name
71         {"in: resource"},                                      // Args
72         {absl::StrCat("out: ", DataTypeString(attrs.dtype))},  // Returns
73         {},                                                    // Attr def
74         // Nodes
75         {{{attrs.name},
76           "ReadVariableOp",
77           {"in"},  // Name of the Placeholder or VarHandleOp
78           {{"dtype", attrs.dtype}}},
79          {{"out"}, "Identity", {attrs.name}, {{"T", attrs.dtype}}}});
80   } else {
81     fdef = FunctionDefHelper::Define(
82         func_name,                                             // Name
83         {},                                                    // Args
84         {absl::StrCat("out: ", DataTypeString(attrs.dtype))},  // Returns
85         {},                                                    // Attr def
86         // Nodes
87         {{{attrs.name},
88           "VariableV2",
89           {},
90           {{"dtype", attrs.dtype},
91            {"shape", attrs.shape_proto},
92            {"container", attrs.container},
93            {"shared_name", attrs.shared_name}}},
94          {{"out"}, "Identity", {attrs.name}, {{"T", attrs.dtype}}}});
95   }
96 
97   // Add function definition to the library.
98   TF_RETURN_IF_ERROR(lib_def->AddFunctionDef(fdef));
99 
100   // Instantiate function.
101   FunctionLibraryRuntime::Handle func_handle;
102   FunctionLibraryRuntime::InstantiateOptions inst_ops;
103   inst_ops.state_handle = "";
104   inst_ops.target = ctx->device()->name();
105   AttrValueMap attr_list;
106   TF_RETURN_IF_ERROR(lib_clone->Instantiate(func_name, AttrSlice(&attr_list),
107                                             inst_ops, &func_handle));
108 
109   FunctionLibraryRuntime::Options opts;
110   opts.rendezvous = ctx->rendezvous();
111   opts.cancellation_manager = ctx->cancellation_manager();
112   opts.runner = ctx->runner();
113 
114   std::vector<Tensor>* rets = new std::vector<Tensor>();
115   std::unique_ptr<std::vector<Tensor>> outputs_wrapper(rets);
116 
117   // Run the new function synchronously.
118   TF_RETURN_IF_ERROR(lib_clone->RunSync(opts, func_handle, args, rets));
119 
120   TRT_ENSURE(ctx->op_device_context() != nullptr);
121   TRT_ENSURE(ctx->op_device_context()->stream() != nullptr);
122 
123   // Copy tensor.
124   const cudaStream_t* stream = CHECK_NOTNULL(
125       reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
126                                                 ->stream()
127                                                 ->implementation()
128                                                 ->GpuStreamMemberHack()));
129 
130   TRT_ENSURE(stream != nullptr);
131 
132   auto ret = cudaMemcpyAsync(tensor_flat.data(), rets->at(0).flat<T>().data(),
133                              rets->at(0).NumElements() * sizeof(T),
134                              cudaMemcpyDeviceToHost, *stream);
135   if (ret != 0) {
136     return errors::Internal("Could not copy the variable ", attrs.name);
137   }
138   cudaStreamSynchronize(*stream);
139 
140   TF_RETURN_IF_ERROR(
141       TfTensorToTrtWeights(tensor, params->weight_store, weights));
142 
143   return Status::OK();
144 }
145 
146 class ConvertVariableV2 : public OpConverterBase<ConvertVariableV2> {
147  public:
ConvertVariableV2(OpConverterParams * params)148   ConvertVariableV2(OpConverterParams* params)
149       : OpConverterBase<ConvertVariableV2>(params) {}
150 
InputSpec()151   static constexpr std::array<InputArgSpec, 0> InputSpec() { return {}; }
152 
AllowedDataTypes()153   static constexpr std::array<DataType, 2> AllowedDataTypes() {
154     return {DataType::DT_FLOAT, DataType::DT_HALF};
155   }
156 
NodeDefDataTypeAttributeName()157   static constexpr const char* NodeDefDataTypeAttributeName() {
158     return "dtype";
159   }
160 
161   template <typename T>
ValidateImpl()162   Status ValidateImpl() {
163     const auto& node_def = params_->node_def;
164 
165     // Verify and consume node attributes.
166     StatusOr<TensorShapeProto> shape_proto =
167         GetAttrValue<TensorShapeProto>("shape");
168     StatusOr<string> shared_name = GetAttrValue<string>("shared_name");
169     StatusOr<string> container = GetAttrValue<string>("container");
170     TRT_ENSURE_OK(shape_proto);
171     TRT_ENSURE_OK(shared_name);
172     TRT_ENSURE_OK(container);
173 
174     attrs_.shape_proto = *shape_proto;
175     attrs_.shape = TensorShape(*shape_proto);
176     attrs_.name = node_def.name();
177     attrs_.shared_name = *shared_name;
178     attrs_.container = *container;
179 
180     Tensor tensor(attrs_.dtype, attrs_.shape);
181     auto tensor_flat = tensor.flat<T>();
182     for (int64_t i = 0; i < tensor_flat.size(); i++) {
183       tensor_flat(i) = T(0.0f);
184     }
185 
186     TRT_ShapedWeights weights;
187     TF_RETURN_IF_ERROR(
188         TfTensorToTrtWeights(tensor, params_->weight_store, &weights));
189 
190     // Only push outputs during validation and when outputs are expected.
191     if (params_->validation_only && params_->outputs != nullptr) {
192       AddOutput(TRT_TensorOrWeights(weights));
193     }
194     return Status::OK();
195   }
196 
Validate()197   Status Validate() {
198     const auto& node_def = params_->node_def;
199     StatusOr<DataType> dtype = GetAttrValue<DataType>("dtype");
200     TRT_ENSURE_OK(dtype);
201     attrs_.dtype = *dtype;
202 
203     switch (attrs_.dtype) {
204       case DT_FLOAT:
205         return ValidateImpl<float>();
206       case DT_HALF:
207         return ValidateImpl<Eigen::half>();
208       default:
209         // Note: this should have been caught by ValidateNodeDefDataType, but
210         // the compiler expects that all paths be handled in switch.
211         return errors::Unimplemented("Data type ", DataTypeString(attrs_.dtype),
212                                      " is not supported for ", node_def.op(),
213                                      ", at ", node_def.name());
214     }
215   }
216 
217   template <typename T>
ConvertImpl()218   Status ConvertImpl() {
219     TRT_ShapedWeights weights;
220     TF_RETURN_IF_ERROR(ReadVariableHelper<T, false>(params_, attrs_, &weights));
221     AddOutput(TRT_TensorOrWeights(weights));
222     return Status::OK();
223   }
224 
Convert()225   Status Convert() {
226     const auto& node_def = params_->node_def;
227 
228     switch (attrs_.dtype) {
229       case DT_FLOAT:
230         return ConvertImpl<float>();
231       case DT_HALF:
232         return ConvertImpl<Eigen::half>();
233       default:
234         // Note: this should have been caught by ValidateNodeDefDataType, but
235         // the compiler expects that all paths be handled in switch.
236         return errors::Unimplemented("Data type ", DataTypeString(attrs_.dtype),
237                                      " is not supported for ", node_def.op(),
238                                      ", at ", node_def.name());
239     }
240   }
241 
242  private:
243   VarAttributes attrs_{};
244 };
245 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertVariableV2>(),
246                                   {"VariableV2"});
247 
248 class ConvertReadVariableOp : public OpConverterBase<ConvertReadVariableOp> {
249  public:
ConvertReadVariableOp(OpConverterParams * params)250   ConvertReadVariableOp(OpConverterParams* params)
251       : OpConverterBase<ConvertReadVariableOp>(params) {}
252 
InputSpec()253   static constexpr std::array<InputArgSpec, 1> InputSpec() {
254     return {InputArgSpec::Create("resource", TrtInputArg::kResource)};
255   }
256 
AllowedDataTypes()257   static constexpr std::array<DataType, 2> AllowedDataTypes() {
258     return {DataType::DT_FLOAT, DataType::DT_HALF};
259   }
260 
NodeDefDataTypeAttributeName()261   static constexpr const char* NodeDefDataTypeAttributeName() {
262     return "dtype";
263   }
264 
265   template <typename T>
ValidateImpl()266   Status ValidateImpl() {
267     const auto& node_def = params_->node_def;
268 
269     // Verify and consume node attributes.
270     StatusOr<TensorShapeProto> shape_proto =
271         GetAttrValue<TensorShapeProto>("_shape");
272     TRT_ENSURE_OK(shape_proto);
273 
274     attrs_.shape_proto = *shape_proto;
275     attrs_.shape = TensorShape(*shape_proto);
276     attrs_.name = node_def.name();
277 
278     Tensor tensor(attrs_.dtype, attrs_.shape);
279     auto tensor_flat = tensor.flat<T>();
280     for (int64_t i = 0; i < tensor_flat.size(); i++) {
281       tensor_flat(i) = T(0.0f);
282     }
283 
284     TRT_ShapedWeights weights;
285     TF_RETURN_IF_ERROR(
286         TfTensorToTrtWeights(tensor, params_->weight_store, &weights));
287 
288     // Only push outputs during validation and when outputs are expected.
289     if (params_->validation_only && params_->outputs != nullptr) {
290       AddOutput(TRT_TensorOrWeights(weights));
291     }
292     return Status::OK();
293   }
294 
Validate()295   Status Validate() {
296     const auto& node_def = params_->node_def;
297     if (params_->use_implicit_batch) {
298       return errors::Unimplemented("Implicit batch mode not supported, at ",
299                                    node_def.name());
300     }
301 
302     StatusOr<DataType> dtype = GetAttrValue<DataType>("dtype");
303     TRT_ENSURE_OK(dtype);
304     attrs_.dtype = *dtype;
305 
306     switch (attrs_.dtype) {
307       case DT_FLOAT:
308         return ValidateImpl<float>();
309       case DT_HALF:
310         return ValidateImpl<Eigen::half>();
311       default:
312         // Note: this should have been caught by ValidateNodeDefDataType, but
313         // the compiler expects that all paths be handled in switch.
314         return errors::Unimplemented("Data type ", DataTypeString(attrs_.dtype),
315                                      " is not supported for ", node_def.op(),
316                                      ", at ", node_def.name());
317     }
318   }
319 
320   template <typename T>
ConvertImpl()321   Status ConvertImpl() {
322     TRT_ShapedWeights weights;
323     TF_RETURN_IF_ERROR(ReadVariableHelper<T, true>(params_, attrs_, &weights));
324     AddOutput(TRT_TensorOrWeights(weights));
325     return Status::OK();
326   }
327 
Convert()328   Status Convert() {
329     const auto& node_def = params_->node_def;
330 
331     switch (attrs_.dtype) {
332       case DT_FLOAT:
333         return ConvertImpl<float>();
334       case DT_HALF:
335         return ConvertImpl<Eigen::half>();
336       default:
337         // Note: this should have been caught by ValidateNodeDefDataType, but
338         // the compiler expects that all paths be handled in switch.
339         return errors::Unimplemented("Data type ", DataTypeString(attrs_.dtype),
340                                      " is not supported for ", node_def.op(),
341                                      ", at ", node_def.name());
342     }
343   }
344 
345  private:
346   VarAttributes attrs_{};
347 };
348 REGISTER_DEFAULT_TRT_OP_CONVERTER(
349     MakeConverterFunction<ConvertReadVariableOp>(), {"ReadVariableOp"});
350 
351 }  // namespace convert
352 }  // namespace tensorrt
353 }  // namespace tensorflow
354 
355 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
356