• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
17 
18 #include <algorithm>
19 #include <cstring>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/strings/match.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/string_view.h"
30 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
31 #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
32 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
33 #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
34 #include "tensorflow/core/framework/node_def.pb.h"  // NOLINT
35 #include "tensorflow/core/framework/node_def_builder.h"
36 #include "tensorflow/core/framework/tensor.pb.h"        // NOLINT
37 #include "tensorflow/core/framework/tensor_shape.pb.h"  // NOLINT
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/graph/algorithm.h"
40 #include "tensorflow/core/graph/graph.h"
41 #include "tensorflow/core/graph/graph_constructor.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/strings/numbers.h"
45 #include "tensorflow/core/lib/strings/str_util.h"
46 #include "tensorflow/core/lib/strings/strcat.h"
47 #include "tensorflow/core/platform/logging.h"
48 #include "tensorflow/core/platform/protobuf.h"
49 #include "tensorflow/core/platform/tensor_coding.h"
50 #include "tensorflow/core/platform/types.h"
51 
52 #if GOOGLE_CUDA
53 #if GOOGLE_TENSORRT
54 #include "tensorrt/include/NvInfer.h"
55 
56 // Check if the types are equal. Cast to int first so that failure log message
57 // would work!
58 #define TFTRT_CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2)
59 
60 #define TFTRT_INTERNAL_ERROR_AT_NODE(node)                           \
61   do {                                                               \
62     return errors::Internal("TFTRT::", __FUNCTION__,                 \
63                             " failed to add TRT layer, at: ", node); \
64   } while (0)
65 
66 #define TFTRT_RETURN_ERROR_IF_FALSE(status, node) \
67   do {                                            \
68     if (status == false) {                        \
69       TFTRT_INTERNAL_ERROR_AT_NODE(node);         \
70     }                                             \
71   } while (0)
72 
73 #define TFTRT_RETURN_ERROR_IF_NULLPTR(ptr, node) \
74   do {                                           \
75     if (ptr == nullptr) {                        \
76       TFTRT_INTERNAL_ERROR_AT_NODE(node);        \
77     }                                            \
78   } while (0)
79 
80 namespace tensorflow {
81 namespace tensorrt {
82 // TODO(aaroey): put these constants into some class.
83 const char* const kInputPHName = "TensorRTInputPH_";
84 const char* const kOutputPHName = "TensorRTOutputPH_";
85 
IsEngineInput(absl::string_view name)86 bool IsEngineInput(absl::string_view name) {
87   return absl::StartsWith(name, kInputPHName);
88 }
IsEngineOutput(absl::string_view name)89 bool IsEngineOutput(absl::string_view name) {
90   return absl::StartsWith(name, kOutputPHName);
91 }
92 
93 namespace convert {
94 using absl::StrAppend;
95 using absl::StrCat;
96 
ConvertDType(DataType tf_dtype,nvinfer1::DataType * trt_dtype)97 inline Status ConvertDType(DataType tf_dtype, nvinfer1::DataType* trt_dtype) {
98   switch (tf_dtype) {
99     case DataType::DT_FLOAT:
100       *trt_dtype = nvinfer1::DataType::kFLOAT;
101       break;
102     // TODO(aaroey): this should be DT_QINT8 which is not a well supported type.
103     case DataType::DT_INT8:
104       *trt_dtype = nvinfer1::DataType::kINT8;
105       break;
106     case DataType::DT_HALF:
107       *trt_dtype = nvinfer1::DataType::kHALF;
108       break;
109     case DataType::DT_INT32:
110       *trt_dtype = nvinfer1::DataType::kINT32;
111       break;
112     default:
113       return errors::InvalidArgument("Unsupported data type ",
114                                      DataTypeString(tf_dtype));
115   }
116   return Status::OK();
117 }
118 
119 class TFAttrs {
120  public:
TFAttrs(const NodeDef & tf_node)121   explicit TFAttrs(const NodeDef& tf_node) {
122     for (const auto& attr : tf_node.attr()) {
123       attrs_.insert({attr.first, &attr.second});
124     }
125   }
126 
count(const string & key) const127   bool count(const string& key) const { return attrs_.count(key); }
128 
at(const string & key) const129   AttrValue const* at(const string& key) const {
130     if (!attrs_.count(key)) {
131       LOG(FATAL) << "Attribute not found: " << key;
132     }
133     return attrs_.at(key);
134   }
135 
136   template <typename T>
137   T get(const string& key) const;
138 
139   template <typename T>
get(const string & key,const T & default_value) const140   T get(const string& key, const T& default_value) const {
141     return attrs_.count(key) ? this->get<T>(key) : default_value;
142   }
143 
GetAllAttrKeys() const144   std::vector<string> GetAllAttrKeys() const {
145     std::vector<string> attr_list;
146     for (const auto& attr_item : attrs_) {
147       attr_list.emplace_back(attr_item.first);
148     }
149     return attr_list;
150   }
151 
152  private:
153   typedef std::map<string, AttrValue const*> AttrMap;
154   AttrMap attrs_;
155 };
156 
157 template <>
get(const string & key) const158 string TFAttrs::get<string>(const string& key) const {
159   return this->at(key)->s();
160 }
161 
162 template <>
get(const string & key) const163 std::vector<int64> TFAttrs::get<std::vector<int64>>(const string& key) const {
164   auto attr = this->at(key)->list().i();
165   return std::vector<int64>(attr.begin(), attr.end());
166 }
167 
168 template <>
get(const string & key) const169 std::vector<float> TFAttrs::get<std::vector<float>>(const string& key) const {
170   auto attr = this->at(key)->list().f();
171   return std::vector<float>(attr.begin(), attr.end());
172 }
173 
174 template <>
get(const string & key) const175 nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const {
176   nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
177   TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype));
178   return trt_dtype;
179 }
180 
181 template <>
get(const string & key) const182 DataType TFAttrs::get<DataType>(const string& key) const {
183   return this->at(key)->type();
184 }
185 
186 template <>
get(const string & key) const187 float TFAttrs::get<float>(const string& key) const {
188   return this->at(key)->f();
189 }
190 
191 template <>
get(const string & key) const192 bool TFAttrs::get<bool>(const string& key) const {
193   return this->at(key)->b();
194 }
195 
196 template <>
get(const string & key) const197 int64 TFAttrs::get<int64>(const string& key) const {
198   return this->at(key)->i();
199 }
200 
201 template <typename TensorShapeType>
TensorShapeToTrtDims(const TensorShapeType & shape,bool ignore_first_dim)202 inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
203                                            bool ignore_first_dim) {
204   nvinfer1::Dims trt_dims;
205   const int offset = (ignore_first_dim ? 1 : 0);
206   for (int i = offset; i < shape.dims(); i++) {
207     trt_dims.d[i - offset] = shape.dim_size(i);
208   }
209   trt_dims.nbDims = shape.dims() - offset;
210   return trt_dims;
211 }
212 
TensorShapeArrayToTrtDims(const std::vector<int> & shape,nvinfer1::Dims * out,bool ignore_first_dim=false)213 Status TensorShapeArrayToTrtDims(const std::vector<int>& shape,
214                                  nvinfer1::Dims* out,
215                                  bool ignore_first_dim = false) {
216   PartialTensorShape tensor_shape;
217   TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(shape, &tensor_shape));
218   *out = TensorShapeToTrtDims(tensor_shape, ignore_first_dim);
219   return Status::OK();
220 }
221 
222 // TODO(laigd): use this utility function in more places.
RemoveBatchDimension(nvinfer1::Dims * dims)223 Status RemoveBatchDimension(nvinfer1::Dims* dims) {
224   if (dims->nbDims < 2) {
225     return errors::InvalidArgument(
226         "Dropping batch dimension requires dims with rank>=2.");
227   }
228   std::copy(dims->d + 1, dims->d + dims->nbDims, dims->d);
229   dims->nbDims--;
230   return Status::OK();
231 }
232 
GetOutputProperties(const grappler::GraphProperties & graph_properties,const Node * node,const int out_port,PartialTensorShape * shape,DataType * dtype)233 void GetOutputProperties(const grappler::GraphProperties& graph_properties,
234                          const Node* node, const int out_port,
235                          PartialTensorShape* shape, DataType* dtype) {
236   if (graph_properties.HasOutputProperties(node->name())) {
237     auto output_params = graph_properties.GetOutputProperties(node->name());
238     auto out_shape = output_params.at(out_port);
239     *dtype = out_shape.dtype();
240     *shape = out_shape.shape();
241   } else {
242     LOG(INFO) << "Unknown output shape" << node->name();
243     *dtype = node->output_type(out_port);
244   }
245 }
246 
GetInputProperties(const grappler::GraphProperties & graph_properties,const Node * node,const int in_port,PartialTensorShape * shape,DataType * dtype)247 void GetInputProperties(const grappler::GraphProperties& graph_properties,
248                         const Node* node, const int in_port,
249                         PartialTensorShape* shape, DataType* dtype) {
250   if (graph_properties.HasInputProperties(node->name())) {
251     auto input_params = graph_properties.GetInputProperties(node->name());
252     auto in_shape = input_params.at(in_port);
253     *dtype = in_shape.dtype();
254     *shape = in_shape.shape();
255   } else {
256     *dtype = node->input_type(in_port);
257   }
258 }
259 
ValidateTensorProperties(const string & producer_node_type,const DataType dtype,const PartialTensorShape & shape,bool validation_only,nvinfer1::DataType * trt_dtype,nvinfer1::Dims * trt_dims,int * batch_size)260 Status ValidateTensorProperties(const string& producer_node_type,
261                                 const DataType dtype,
262                                 const PartialTensorShape& shape,
263                                 bool validation_only,
264                                 nvinfer1::DataType* trt_dtype,
265                                 nvinfer1::Dims* trt_dims, int* batch_size) {
266   // Convert data type.
267   TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype));
268 
269   // Convert shape.
270   if (shape.dims() < 0) {
271     return errors::InvalidArgument("Input tensor rank is unknown.");
272   }
273   if (shape.dims() > nvinfer1::Dims::MAX_DIMS + 1) {  // +1 for batch dim
274     return errors::OutOfRange("Input tensor rank is greater than ",
275                               nvinfer1::Dims::MAX_DIMS + 1);
276   }
277   if (producer_node_type != "Const" && shape.dims() < 2) {
278     return errors::InvalidArgument(
279         "Input tensor with rank<2 is not supported since the first dimension "
280         "is treated as batch dimension by TRT");
281   }
282   *trt_dims = TensorShapeToTrtDims(shape, /*ignore_first_dim=*/true);
283   *batch_size = shape.dim_size(0);
284 
285   // Don't convert empty tensors (dim value of 0).
286   for (int d = 1; d < shape.dims(); ++d) {
287     if (shape.dim_size(d) == 0) {
288       return errors::Unimplemented(
289           "Input tensor with shape ", shape.DebugString(),
290           " is an empty tensor, which is not supported by TRT");
291     }
292   }
293 
294   if (validation_only) return Status::OK();
295   // Following are validations at runtime.
296 
297   for (int d = 1; d < shape.dims(); ++d) {
298     if (shape.dim_size(d) < 0) {
299       return errors::InvalidArgument(
300           "Input tensor with shape ", shape.DebugString(),
301           " has an unknown non-batch dimension at dim ", d);
302     }
303   }
304   return Status::OK();
305 }
306 
DebugString(const nvinfer1::DimensionType type)307 string DebugString(const nvinfer1::DimensionType type) {
308   switch (type) {
309     case nvinfer1::DimensionType::kSPATIAL:
310       return "kSPATIAL";
311     case nvinfer1::DimensionType::kCHANNEL:
312       return "kCHANNEL";
313     case nvinfer1::DimensionType::kINDEX:
314       return "kINDEX";
315     case nvinfer1::DimensionType::kSEQUENCE:
316       return "kSEQUENCE";
317     default:
318       return StrCat(static_cast<int>(type), "=unknown");
319   }
320 }
321 
DebugString(const nvinfer1::DataType trt_dtype)322 string DebugString(const nvinfer1::DataType trt_dtype) {
323   switch (trt_dtype) {
324     case nvinfer1::DataType::kFLOAT:
325       return "kFLOAT";
326     case nvinfer1::DataType::kHALF:
327       return "kHALF";
328     case nvinfer1::DataType::kINT8:
329       return "kINT8";
330     case nvinfer1::DataType::kINT32:
331       return "kINT32";
332     default:
333       return "Invalid TRT data type";
334   }
335 }
336 
DebugString(const nvinfer1::Dims & dims)337 string DebugString(const nvinfer1::Dims& dims) {
338   string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
339   for (int i = 0; i < dims.nbDims; ++i) {
340     StrAppend(&out, dims.d[i], "[", DebugString(dims.type[i]), "],");
341   }
342   StrAppend(&out, ")");
343   return out;
344 }
345 
DebugString(const nvinfer1::Permutation & permutation,int len)346 string DebugString(const nvinfer1::Permutation& permutation, int len) {
347   string out = "nvinfer1::Permutation(";
348   for (int i = 0; i < len; ++i) {
349     StrAppend(&out, permutation.order[i], ",");
350   }
351   StrAppend(&out, ")");
352   return out;
353 }
354 
DebugString(const nvinfer1::ITensor & tensor)355 string DebugString(const nvinfer1::ITensor& tensor) {
356   return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
357                 ", name=", tensor.getName(),
358                 ", dtype=", DebugString(tensor.getType()),
359                 ", dims=", DebugString(tensor.getDimensions()), ")");
360 }
361 
GetTrtBroadcastShape(const TRT_TensorOrWeights & operand_l,const TRT_TensorOrWeights & operand_r,nvinfer1::Dims * operand_l_new_dims,nvinfer1::Dims * operand_r_new_dims) const362 Status Converter::GetTrtBroadcastShape(
363     const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r,
364     nvinfer1::Dims* operand_l_new_dims,
365     nvinfer1::Dims* operand_r_new_dims) const {
366   // ***************************************************************************
367   // TensorRT Elementwise op supports broadcast but requires both tensor to be
368   // of Identical rank
369   //
370   // We consider case of:
371   //   1. operand_l to be a Tensor & operand_r to be a Const;
372   //   2. operand_l to be a Tensor & operand_r to be a Tensor;
373   // note: const op const (constant folding) should fallback to TensorFlow
374   //
375   // broadcast scheme:
376   //       T:  1 3 5    (tensor would not have batch dimension)
377   //       W:  1 1 3 1  (weight would have all explicit dimensions)
378   // i. fill in explicit dimensions
379   //    -> T: -1 1 3 5  (we put a -1 for batch dimension)
380   //    -> W:  1 1 3 1
381   // ii. compare broadcast feasibility
382   //
383   // We cannot support the following since TensorRT does not allow manipulation
384   // on batch dimension, we cannot generate output with proper shape
385   //    T: 3 5 1
386   //    W: 1 1 1  1 3 5 1
387   // -> T: 1 1 1 -1 3 5 1
388   // -> W: 1 1 1  1 3 5 1
389   // ***************************************************************************
390   if (!operand_l.is_tensor() && !operand_r.is_tensor()) {
391     return errors::InvalidArgument(
392         "Broadcasting requires at least one of the operands be tensors");
393   }
394 
395   const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1;
396   auto compute_output_dims = [](const TRT_TensorOrWeights& input,
397                                 int broadcast_num_dims, int* output_dims_array,
398                                 nvinfer1::Dims* output_dims) {
399     const nvinfer1::Dims input_dims = input.GetTrtDims();
400     std::fill(output_dims_array, output_dims_array + max_nb_dims, 1);
401     std::copy(input_dims.d, input_dims.d + input_dims.nbDims,
402               output_dims_array + broadcast_num_dims - input_dims.nbDims);
403     if (input.is_tensor()) {
404       const int true_input_dims = input_dims.nbDims + 1;
405       if (true_input_dims < broadcast_num_dims) {
406         return errors::InvalidArgument(
407             "Broadcasting beyond batch dimension is not supported ",
408             "(tensor #dims ", true_input_dims, " vs broadcast #dims ",
409             broadcast_num_dims, ")");
410       }
411       // Set the batch dimension to -1, since batch size is not supposed to
412       // be broadcasted.
413       output_dims_array[0] = -1;
414     }
415     // Copy to output dimensions (stripping the batch dimension).
416     output_dims->nbDims = broadcast_num_dims - 1;
417     std::copy(output_dims_array + 1, output_dims_array + broadcast_num_dims,
418               output_dims->d);
419     return Status::OK();
420   };
421 
422   // Compute the output dimensions.
423   const int broadcast_num_dims =
424       std::max(operand_l.GetTrtDims().nbDims + (operand_l.is_tensor() ? 1 : 0),
425                operand_r.GetTrtDims().nbDims + (operand_r.is_tensor() ? 1 : 0));
426   int output_l[max_nb_dims], output_r[max_nb_dims];
427   TF_RETURN_IF_ERROR(compute_output_dims(operand_l, broadcast_num_dims,
428                                          output_l, operand_l_new_dims));
429   TF_RETURN_IF_ERROR(compute_output_dims(operand_r, broadcast_num_dims,
430                                          output_r, operand_r_new_dims));
431 
432   // Compare broadcast feasibility
433   for (int i = 0; i < broadcast_num_dims; ++i) {
434     if ((output_l[i] != output_r[i]) && (output_l[i] != 1) &&
435         (output_r[i] != 1)) {
436       return errors::InvalidArgument(
437           "Infeasible broadcast scheme (", "batch_dim: ", output_l[0], ", ",
438           DebugString(*operand_l_new_dims), " vs ", "batch_dim: ", output_r[0],
439           ", ", DebugString(*operand_r_new_dims), ")");
440     }
441   }
442   return Status::OK();
443 }
444 
CreateConstantLayer(const TRT_ShapedWeights & weights,const nvinfer1::Dims & dims)445 nvinfer1::ITensor* Converter::CreateConstantLayer(
446     const TRT_ShapedWeights& weights, const nvinfer1::Dims& dims) {
447   nvinfer1::Weights trt_weights = weights.GetTrtWeights();
448   nvinfer1::IConstantLayer* layer = network()->addConstant(dims, trt_weights);
449   if (!layer) return nullptr;
450   const nvinfer1::DataType trt_dtype = trt_weights.type;
451   nvinfer1::ITensor* trt_tensor = layer->getOutput(0);
452 #if !IS_TRT_VERSION_GE(5, 1, 3)
453   // TODO(laigd): there is a bug in TensorRT 5.0 library that, if we don't set
454   // the data type below, it will always be kFLOAT regardless what the data type
455   // of the weights is. Once NVIDIA fixes this bug, we should remove the data
456   // type setting logic below and test should still pass.
457   trt_tensor->setType(trt_dtype);
458 #endif
459   return trt_tensor;
460 }
461 
CreateBroadcastableScalarConstant(OpConverterParams * params,float value,const nvinfer1::Dims & dims,const nvinfer1::ITensor ** tensor,const char * dtype_attr_name="T")462 Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
463                                          const nvinfer1::Dims& dims,
464                                          const nvinfer1::ITensor** tensor,
465                                          const char* dtype_attr_name = "T") {
466   TFAttrs attrs(params->node_def);
467   DataType dtype;
468   if (attrs.count(dtype_attr_name)) {
469     dtype = attrs.get<DataType>(dtype_attr_name);
470   } else {
471     dtype = DT_FLOAT;  // Default to FP32.
472   }
473 
474   // In order to be broadcastable, the number of dims has to match.
475   nvinfer1::Dims broadcastable_dims(dims);
476   for (int i = 0; i < broadcastable_dims.nbDims; i++) {
477     broadcastable_dims.d[i] = 1;
478   }
479   TRT_ShapedWeights weights =
480       params->weight_store->GetTempWeights(dtype, broadcastable_dims);
481   void* raw_ptr = const_cast<void*>(weights.GetValues());
482   switch (dtype) {
483     case DataType::DT_FLOAT:
484       static_cast<float*>(raw_ptr)[0] = value;
485       break;
486     case DataType::DT_HALF:
487       static_cast<Eigen::half*>(raw_ptr)[0] = Eigen::half(value);
488       break;
489     default:
490       return errors::InvalidArgument("Unsupported data type ",
491                                      DataTypeString(dtype));
492   }
493   *tensor = params->converter->CreateConstantLayer(weights, broadcastable_dims);
494   TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name());
495   params->converter->ProvideQuantizationRange(
496       const_cast<nvinfer1::ITensor*>(*tensor), value, value);
497   return Status::OK();
498 }
499 
500 // Convert an axis from TF format to TRT format while validating. TF format
501 // includes the batch dimension, while TRT does not. TF can also use negative
502 // indices.
503 // TODO(tmorris): Use this method in more ops.
ConvertAxis(int tf_axis,int trt_nb_dims,absl::string_view node_name,int * trt_axis)504 Status ConvertAxis(int tf_axis, int trt_nb_dims, absl::string_view node_name,
505                    int* trt_axis) {
506   const int tf_nb_dims = trt_nb_dims + 1;
507   // Check bounds.
508   if (tf_axis < -tf_nb_dims || tf_axis >= tf_nb_dims) {
509     return errors::InvalidArgument(
510         "Axis value of ", tf_axis, " is out of bounds, must be in range [",
511         -tf_nb_dims, ", ", tf_nb_dims, "), at ", node_name);
512   }
513   // Make negative axis positive.
514   if (tf_axis < 0) tf_axis += tf_nb_dims;
515   // Don't allow axis to be the batch dimension.
516   if (tf_axis == 0) {
517     return errors::Unimplemented(
518         "TensorRT does not allow manipulation of the batch dimension, at ",
519         node_name);
520   }
521   // Remove batch dimension.
522   *trt_axis = tf_axis - 1;
523   return Status::OK();
524 }
525 
DimsEqual(const nvinfer1::Dims & dim_l,const nvinfer1::Dims & dim_r)526 inline bool DimsEqual(const nvinfer1::Dims& dim_l,
527                       const nvinfer1::Dims& dim_r) {
528   if (dim_l.nbDims != dim_r.nbDims) {
529     return false;
530   }
531   for (int i = 0; i < dim_l.nbDims; i++) {
532     if (dim_l.d[i] != dim_r.d[i]) {
533       return false;
534     }
535   }
536   return true;
537 }
538 
AllLengthsEqual(const std::vector<std::vector<int>> & inputs)539 bool AllLengthsEqual(const std::vector<std::vector<int>>& inputs) {
540   if (inputs.size() == 0) return true;
541   int length = inputs.at(0).size();
542   for (int i = 1; i < inputs.size(); i++) {
543     if (inputs.at(i).size() != length) return false;
544   }
545   return true;
546 }
547 
GetTrtDimsForTensor(const Tensor & tensor)548 inline nvinfer1::Dims GetTrtDimsForTensor(const Tensor& tensor) {
549   nvinfer1::Dims dims;
550   dims.nbDims = tensor.dims();
551   for (int i = 0; i < dims.nbDims; i++) {
552     dims.d[i] = tensor.dim_size(i);
553   }
554   return dims;
555 }
556 
HasStaticShape(const nvinfer1::Dims & dims)557 inline bool HasStaticShape(const nvinfer1::Dims& dims) {
558   if (dims.nbDims < 0) return false;
559   for (int d = 0; d < dims.nbDims; ++d) {
560     if (dims.d[d] < 0) return false;
561   }
562   return true;
563 }
564 
565 // Returns total number of elements in dims. Returning 0 means either some dim
566 // is 0 or the number of dims is 0.
567 // Note that for TF scalar constant, we always convert to dims [1].
TrtDimsNumElements(const nvinfer1::Dims & dims)568 int64_t TrtDimsNumElements(const nvinfer1::Dims& dims) {
569   if (dims.nbDims == 0) return 0;
570   int64_t count = 1;
571   for (int d = 0; d < dims.nbDims; ++d) {
572     count *= dims.d[d];
573   }
574   return count;
575 }
576 
CreateSamePadding(const nvinfer1::DimsHW & stride,const nvinfer1::DimsHW & kernel,const std::vector<int64_t> & input_dims)577 static std::vector<std::pair<int, int>> CreateSamePadding(
578     const nvinfer1::DimsHW& stride, const nvinfer1::DimsHW& kernel,
579     const std::vector<int64_t>& input_dims) {
580   std::vector<std::pair<int, int>> padding(input_dims.size());
581   CHECK_EQ(stride.nbDims, input_dims.size());  // TODO(jie): N+C? NC+?
582 
583   for (size_t i = 0; i < input_dims.size(); ++i) {
584     // Formula to calculate the padding
585     int p = ((input_dims[i] - 1) / stride.d[i]) * stride.d[i] + kernel.d[i] -
586             input_dims[i];
587     p = (p > 0) ? p : 0;
588 
589     // Right precedence padding, like in TensorFlow
590     int left = p / 2;
591     int right = p - left;
592 
593     VLOG(2) << "PADDING_" << i << " pre: " << left << ", post: " << right
594             << "paras: " << input_dims[i] << ", " << stride.d[i] << ", "
595             << "kernel: " << kernel.d[i];
596     padding[i] = {left, right};
597   }
598   return padding;
599 }
600 
GetCommonNameScope(const string & op_name_a,const string & op_name_b)601 string GetCommonNameScope(const string& op_name_a, const string& op_name_b) {
602   size_t last_scope_separator = 0;
603   const size_t min_size = std::min(op_name_a.size(), op_name_b.size());
604   for (size_t i = 0; i < min_size; ++i) {
605     if (op_name_a[i] != op_name_b[i]) break;
606     if (op_name_a[i] == '/') last_scope_separator = i + 1;
607   }
608   return op_name_a.substr(0, last_scope_separator);
609 }
610 
TRT_ShapedWeights(DataType type)611 TRT_ShapedWeights::TRT_ShapedWeights(DataType type) : type_(type) {
612   shape_.nbDims = 0;
613 }
614 
TRT_ShapedWeights(DataType type,nvinfer1::Dims dims,Tensor tensor)615 TRT_ShapedWeights::TRT_ShapedWeights(DataType type, nvinfer1::Dims dims,
616                                      Tensor tensor)
617     : shape_(dims), type_(type), tensor_(tensor) {}
618 
TRT_ShapedWeights(const TRT_ShapedWeights & rhs)619 TRT_ShapedWeights::TRT_ShapedWeights(const TRT_ShapedWeights& rhs)
620     : shape_(rhs.shape_), type_(rhs.type_), tensor_(rhs.tensor_) {}
621 
count() const622 int64_t TRT_ShapedWeights::count() const { return TrtDimsNumElements(shape_); }
623 
GetTrtWeights() const624 nvinfer1::Weights TRT_ShapedWeights::GetTrtWeights() const {
625   nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT);
626   TF_CHECK_OK(ConvertDType(type_, &trt_type));
627   return nvinfer1::Weights{trt_type, GetValues(), count()};
628 }
629 
size_bytes() const630 size_t TRT_ShapedWeights::size_bytes() const {
631   return this->count() * DataTypeSize(this->type_);
632 }
633 
DebugString() const634 string TRT_ShapedWeights::DebugString() const {
635   return StrCat("TRT_ShapedWeights(shape=", convert::DebugString(shape_),
636                 ", type=", DataTypeString(type_),
637                 ", values=", reinterpret_cast<uintptr_t>(GetValues()), ")");
638 }
639 
640 // A fake ITensor implementation used to check whether the TF-TRT converter can
641 // handle specific node. We only need shape and type information, and the
642 // converter won't (and shouldn't) use this to build the TRT network.
643 class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor {
644  public:
SimpleITensor(nvinfer1::DataType trt_dtype,const nvinfer1::Dims & trt_dims)645   SimpleITensor(nvinfer1::DataType trt_dtype, const nvinfer1::Dims& trt_dims)
646       : trt_dtype_(trt_dtype), trt_dims_(trt_dims) {}
647 
setName(const char * name)648   void setName(const char* name) override {}
649 
getName() const650   const char* getName() const override { return ""; }
651 
setDimensions(nvinfer1::Dims dimensions)652   void setDimensions(nvinfer1::Dims dimensions) override {
653     trt_dims_ = dimensions;
654   }
655 
getDimensions() const656   nvinfer1::Dims getDimensions() const override { return trt_dims_; }
657 
setType(nvinfer1::DataType trt_dtype)658   void setType(nvinfer1::DataType trt_dtype) override {
659     trt_dtype_ = trt_dtype;
660   }
661 
getType() const662   nvinfer1::DataType getType() const override { return trt_dtype_; }
663 
isNetworkInput() const664   bool isNetworkInput() const override { return false; }
665 
isNetworkOutput() const666   bool isNetworkOutput() const override { return false; }
667 
setBroadcastAcrossBatch(bool broadcastAcrossBatch)668   void setBroadcastAcrossBatch(bool broadcastAcrossBatch) override {}
669 
getBroadcastAcrossBatch() const670   bool getBroadcastAcrossBatch() const override { return false; }
671 
getLocation() const672   nvinfer1::TensorLocation getLocation() const override {
673     // This is arbitrary, since we don't use it.
674     return nvinfer1::TensorLocation::kDEVICE;
675   }
676 
setLocation(nvinfer1::TensorLocation location)677   void setLocation(nvinfer1::TensorLocation location) override {}
678 
679 #if IS_TRT_VERSION_GE(5, 0, 0)
setDynamicRange(float min,float max)680   bool setDynamicRange(float min, float max) override { return true; }
681 
getDynamicRange() const682   float getDynamicRange() const override { return 0; }
683 #endif
684 
685 #if IS_TRT_VERSION_GE(5, 1, 0)
dynamicRangeIsSet() const686   bool dynamicRangeIsSet() const override { return true; }
687 
resetDynamicRange()688   void resetDynamicRange() override {}
689 
getDynamicRangeMin() const690   float getDynamicRangeMin() const override { return 0.f; }
691 
getDynamicRangeMax() const692   float getDynamicRangeMax() const override { return 0.f; }
693 #endif
694 
695  private:
696   nvinfer1::DataType trt_dtype_;
697   nvinfer1::Dims trt_dims_;
698 };
699 
TRT_TensorOrWeights(nvinfer1::ITensor * tensor,int batch_size)700 TRT_TensorOrWeights::TRT_TensorOrWeights(nvinfer1::ITensor* tensor,
701                                          int batch_size)
702     : tensor_(tensor),
703       batch_size_(batch_size),
704       initialized_(true),
705       is_tensor_(true) {}
706 
TRT_TensorOrWeights(nvinfer1::DataType trt_dtype,const nvinfer1::Dims & trt_dims,int batch_size)707 TRT_TensorOrWeights::TRT_TensorOrWeights(nvinfer1::DataType trt_dtype,
708                                          const nvinfer1::Dims& trt_dims,
709                                          int batch_size)
710     : simple_itensor_(new SimpleITensor(trt_dtype, trt_dims)),
711       batch_size_(batch_size),
712       initialized_(true),
713       is_tensor_(true) {}
714 
TRT_TensorOrWeights(const TRT_ShapedWeights & weights)715 TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_ShapedWeights& weights)
716     : weights_(weights), initialized_(true), is_tensor_(false) {}
717 
TRT_TensorOrWeights(const TRT_TensorOrWeights & rhs)718 TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs)
719     : tensor_(rhs.tensor_),
720       simple_itensor_(rhs.simple_itensor_),
721       batch_size_(rhs.batch_size_),
722       weights_(rhs.weights_),
723       initialized_(rhs.initialized_),
724       is_tensor_(rhs.is_tensor_) {}
725 
operator =(const TRT_TensorOrWeights & rhs)726 void TRT_TensorOrWeights::operator=(const TRT_TensorOrWeights& rhs) {
727   tensor_ = rhs.tensor_;
728   simple_itensor_ = rhs.simple_itensor_;
729   batch_size_ = rhs.batch_size_;
730   weights_ = rhs.weights_;
731   initialized_ = rhs.initialized_;
732   is_tensor_ = rhs.is_tensor_;
733 }
734 
tensor()735 nvinfer1::ITensor* TRT_TensorOrWeights::tensor() {
736   CHECK(is_tensor());
737   return tensor_ == nullptr ? simple_itensor_.get() : tensor_;
738 }
739 
tensor() const740 const nvinfer1::ITensor* TRT_TensorOrWeights::tensor() const {
741   CHECK(is_tensor());
742   return tensor_ == nullptr ? simple_itensor_.get() : tensor_;
743 }
744 
GetTrtDims() const745 nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const {
746   if (is_tensor()) {
747     return tensor()->getDimensions();
748   } else {
749     return weights().shape_;
750   }
751 }
752 
DebugString() const753 string TRT_TensorOrWeights::DebugString() const {
754   string output = "TRT_TensorOrWeights(type=";
755   if (is_tensor()) {
756     StrAppend(&output, "tensor=", convert::DebugString(*tensor()),
757               ", batch_size=", batch_size_);
758   } else {
759     StrAppend(&output, "weights=", weights_.DebugString());
760   }
761   StrAppend(&output, ")");
762   return output;
763 }
764 
765 // TODO(jie): reorder4 & reorder2 should be merged?
766 // TODO(aaroey): fix the order of parameters.
767 template <typename T>
Reorder4(const nvinfer1::DimsNCHW & shape,const T * idata,const nvinfer1::DimsNCHW & istrides,T * odata,const nvinfer1::DimsNCHW & ostrides)768 void Reorder4(const nvinfer1::DimsNCHW& shape, const T* idata,
769               const nvinfer1::DimsNCHW& istrides, T* odata,
770               const nvinfer1::DimsNCHW& ostrides) {
771   for (int n = 0; n < shape.n(); ++n) {
772     for (int c = 0; c < shape.c(); ++c) {
773       for (int h = 0; h < shape.h(); ++h) {
774         for (int w = 0; w < shape.w(); ++w) {
775           odata[n * ostrides.n() + c * ostrides.c() + h * ostrides.h() +
776                 w * ostrides.w()] = idata[n * istrides.n() + c * istrides.c() +
777                                           h * istrides.h() + w * istrides.w()];
778         }
779       }
780     }
781   }
782 }
783 
784 template <typename T>
Reorder2(const nvinfer1::DimsHW & shape,const T * idata,const nvinfer1::DimsHW & istrides,T * odata,const nvinfer1::DimsHW & ostrides)785 void Reorder2(const nvinfer1::DimsHW& shape, const T* idata,
786               const nvinfer1::DimsHW& istrides, T* odata,
787               const nvinfer1::DimsHW& ostrides) {
788   for (int h = 0; h < shape.h(); ++h) {
789     for (int w = 0; w < shape.w(); ++w) {
790       odata[h * ostrides.h() + w * ostrides.w()] =
791           idata[h * istrides.h() + w * istrides.w()];
792     }
793   }
794 }
795 
796 // TODO(jie): fallback to tensorflow!!
ReorderCKtoKC(const TRT_ShapedWeights & iweights,TRT_ShapedWeights * oweights)797 void ReorderCKtoKC(const TRT_ShapedWeights& iweights,
798                    TRT_ShapedWeights* oweights) {
799   const int c = iweights.shape_.d[0];
800   const int k = iweights.shape_.d[1];
801   oweights->shape_.d[0] = k;
802   oweights->shape_.d[1] = c;
803   const nvinfer1::DimsHW istrides = {1, k};
804   const nvinfer1::DimsHW ostrides = {c, 1};
805   switch (iweights.type_) {
806     case DataType::DT_FLOAT: {
807       Reorder2({k, c}, static_cast<float const*>(iweights.GetValues()),
808                istrides,
809                // TODO(aaroey): get rid of all the const_cast like this.
810                static_cast<float*>(const_cast<void*>(oweights->GetValues())),
811                ostrides);
812       break;
813     }
814     case DataType::DT_HALF: {
815       Reorder2(
816           {k, c}, static_cast<Eigen::half const*>(iweights.GetValues()),
817           istrides,
818           static_cast<Eigen::half*>(const_cast<void*>(oweights->GetValues())),
819           ostrides);
820       break;
821     }
822     default:
823       LOG(FATAL) << "Unsupported type in reorder expected fp32 or fp16 but got "
824                  << DataTypeString(iweights.type_);
825   }
826 }
827 
ReorderRSCKToKCRS(const TRT_ShapedWeights & iweights,TRT_ShapedWeights * oweights,const int num_groups)828 void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
829                        TRT_ShapedWeights* oweights, const int num_groups) {
830   CHECK_EQ(iweights.type_, oweights->type_);
831   CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
832   // K indexes over output channels, C over input channels, and R and S over the
833   // height and width of the convolution
834   const int r = iweights.shape_.d[0];
835   const int s = iweights.shape_.d[1];
836   // TRT requires GKcRS, while TF depthwise has RSCK where c=1, C=G
837   const int c = iweights.shape_.d[2] / num_groups;
838   const int k = iweights.shape_.d[3] * num_groups;
839   VLOG(2) << "num_groups: " << num_groups << "c" << iweights.shape_.d[2]
840           << " then " << c << "k" << iweights.shape_.d[3] << " then " << k
841           << "r" << iweights.shape_.d[0] << " then " << r << "s"
842           << iweights.shape_.d[1] << " then " << s;
843   oweights->shape_.d[0] = k / num_groups;
844   oweights->shape_.d[1] = c * num_groups;
845   oweights->shape_.d[2] = r;
846   oweights->shape_.d[3] = s;
847   const nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
848   const nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1};
849   switch (iweights.type_) {
850     case DataType::DT_FLOAT: {
851       Reorder4({k, c, r, s}, static_cast<float const*>(iweights.GetValues()),
852                istrides,
853                static_cast<float*>(const_cast<void*>(oweights->GetValues())),
854                ostrides);
855       break;
856     }
857     case DataType::DT_HALF: {
858       Reorder4(
859           {k, c, r, s}, static_cast<Eigen::half const*>(iweights.GetValues()),
860           istrides,
861           static_cast<Eigen::half*>(const_cast<void*>(oweights->GetValues())),
862           ostrides);
863       break;
864     }
865 
866     default:
867       LOG(FATAL) << "Unsupported type, expected fp32 or fp16 but got "
868                  << DataTypeString(iweights.type_);
869   }
870 }
871 
GetTempWeights(DataType type,const nvinfer1::Dims & dims)872 TRT_ShapedWeights TrtWeightStore::GetTempWeights(DataType type,
873                                                  const nvinfer1::Dims& dims) {
874   TensorShape shape;
875   // TODO(laigd): make it return a status.
876   TF_CHECK_OK(TensorShapeUtils::MakeShape(dims.d, dims.nbDims, &shape));
877   // TODO(jie): check weights size_bytes. 0 means type error
878   Tensor tensor(type, shape);
879   TRT_ShapedWeights weights(type, dims, tensor);
880   store_.emplace_back(std::move(tensor));
881   return weights;
882 }
883 
884 const std::set<string>* TrtNodeValidator::quantize_ops = new std::set<string>{
885     "QuantizeAndDequantizeV2",
886     "QuantizeAndDequantizeV3",
887     "FakeQuantWithMinMaxVars",
888     "FakeQuantWithMinMaxArgs",
889 };
890 
TrtNodeValidator()891 TrtNodeValidator::TrtNodeValidator() { RegisterOpValidators(); }
892 
ConvertToTensorOrWeights(const NodeDef & node_def,int output_port,const grappler::GraphProperties & graph_properties,TRT_TensorOrWeights * tensor_or_weights)893 Status TrtNodeValidator::ConvertToTensorOrWeights(
894     const NodeDef& node_def, int output_port,
895     const grappler::GraphProperties& graph_properties,
896     TRT_TensorOrWeights* tensor_or_weights) {
897   if (node_def.op() == "Const") {
898     if (output_port != 0) {
899       return errors::InvalidArgument("Const node should only have one output.");
900     }
901     // The output of the conversion will be used as input to other nodes to
902     // determine whether TRT supports those nodes. If it cannot convert the
903     // Const, it's very likely we cannot treat it as a tensor and make it an
904     // input to the TRT network, since TRT removes the first dimension and
905     // treats it as batch size. Also, it's not likely that the converter can
906     // support the op, and performance may suffer even if it can, so we just
907     // simply return error if the conversion fails.
908     std::vector<TRT_TensorOrWeights> inputs;
909     return ConvertConstToWeights(node_def, inputs, tensor_or_weights);
910   }
911   if (!graph_properties.HasOutputProperties(node_def.name())) {
912     return errors::InvalidArgument("Shape and data type are unknown");
913   }
914 
915   // Validate and convert shape and dtype.
916   const auto& output_params =
917       graph_properties.GetOutputProperties(node_def.name());
918   const auto& tensor_properties = output_params.at(output_port);
919   const DataType dtype = tensor_properties.dtype();
920   const PartialTensorShape shape = tensor_properties.shape();
921   nvinfer1::DataType trt_dtype;
922   nvinfer1::Dims trt_dims;
923   int batch_size = -1;
924   TF_RETURN_IF_ERROR(ValidateTensorProperties(
925       node_def.op(), dtype, shape, /*validation_only_=*/true, &trt_dtype,
926       &trt_dims, &batch_size));
927 
928   // Adds a fake ITensor. This is fine since op converter operates in
929   // validation-only mode and it won't (and shouldn't) use the tensor to do
930   // any TRT network operations.
931   *tensor_or_weights = TRT_TensorOrWeights(trt_dtype, trt_dims, batch_size);
932   return Status::OK();
933 }
934 
ValidateNode(const NodeDef & node_def,const std::vector<std::pair<const NodeDef *,int>> & input_node_and_ports,const TrtPrecisionMode precision_mode,const grappler::GraphProperties & graph_properties)935 Status TrtNodeValidator::ValidateNode(
936     const NodeDef& node_def,
937     const std::vector<std::pair<const NodeDef*, int>>& input_node_and_ports,
938     const TrtPrecisionMode precision_mode,
939     const grappler::GraphProperties& graph_properties) {
940   const string& op = node_def.op();
941   // It doesn't support validation of plugins.
942   if (PluginFactoryTensorRT::GetInstance()->IsPlugin(op)) return Status::OK();
943 
944   // In INT8 mode, we will always apply the quantization ranges provided by
945   // these ops to the relevant tensors. This happens regardless of the value of
946   // use_calibration.
947   bool is_supported_op = false;
948   if (quantize_ops->count(op)) {
949     is_supported_op = (precision_mode == TrtPrecisionMode::INT8);
950   } else {
951     is_supported_op = op_validators_.count(node_def.op());
952   }
953   if (!is_supported_op) {
954     return errors::Unimplemented("Op type ", op, " is not supported.");
955   }
956 
957   // Convert input NodeDef and corresponding output ports to
958   // TRT_TensorOrWeights.
959   std::vector<TRT_TensorOrWeights> inputs;
960   for (int i = 0; i < input_node_and_ports.size(); ++i) {
961     const auto& pair = input_node_and_ports[i];
962     TRT_TensorOrWeights tensor_or_weights;
963     Status status = ConvertToTensorOrWeights(
964         *pair.first, pair.second, graph_properties, &tensor_or_weights);
965     if (!status.ok()) {
966       return errors::Internal(
967           "Failed to convert input with index ", i,
968           " to a TRT_TensorOrWeights: ", status.error_message());
969     }
970     inputs.push_back(tensor_or_weights);
971   }
972 
973   OpConverter validator = op_validators_[node_def.op()];
974   OpConverterParams params(
975       /*arg_converter=*/nullptr, node_def, inputs, /*arg_outputs=*/nullptr,
976       /*arg_validation_only=*/true, &weight_store_);
977   return validator(&params);
978 }
979 
ConvertConstToWeights(const NodeDef & const_node_def,const std::vector<TRT_TensorOrWeights> & inputs,TRT_TensorOrWeights * output)980 Status TrtNodeValidator::ConvertConstToWeights(
981     const NodeDef& const_node_def,
982     const std::vector<TRT_TensorOrWeights>& inputs,
983     TRT_TensorOrWeights* output) {
984   std::vector<TRT_TensorOrWeights> outputs;
985   OpConverterParams params(
986       /*arg_converter=*/nullptr, const_node_def, inputs, &outputs,
987       /*arg_validation_only=*/true, &weight_store_);
988   Status status = op_validators_["Const"](&params);
989   if (status.ok() && output) *output = outputs[0];
990   return status;
991 }
992 
Converter(nvinfer1::INetworkDefinition * trt_network,TrtPrecisionMode precision_mode,bool use_calibration)993 Converter::Converter(nvinfer1::INetworkDefinition* trt_network,
994                      TrtPrecisionMode precision_mode, bool use_calibration)
995     : trt_network_(trt_network),
996       precision_mode_(precision_mode),
997       use_calibration_(use_calibration) {
998   this->RegisterOpConverters();
999 }
1000 
ConvertNode(const NodeDef & node_def)1001 Status Converter::ConvertNode(const NodeDef& node_def) {
1002   std::vector<TRT_TensorOrWeights> inputs, outputs;
1003   TF_RETURN_IF_ERROR(this->GetInputs(node_def, &inputs));
1004 
1005   OpConverterParams params(this, node_def, inputs, &outputs,
1006                            /*arg_validation_only=*/false, &weight_store_);
1007   const string& op = node_def.op();
1008   if (PluginFactoryTensorRT::GetInstance()->IsPlugin(op)) {
1009     TF_RETURN_IF_ERROR(plugin_converter_(&params));
1010   } else {
1011     if (!op_registry_.count(op)) {
1012       return errors::Unimplemented("No converter registered for op: ", op);
1013     }
1014     OpConverter op_converter = op_registry_.at(op);
1015     TF_RETURN_IF_ERROR(op_converter(&params));
1016   }
1017 
1018   for (size_t i = 0; i < outputs.size(); ++i) {
1019     TRT_TensorOrWeights& output = outputs[i];
1020     string output_name = node_def.name();
1021     if (i != 0) absl::StrAppend(&output_name, ":", i);
1022     // We need to check the name before setting it. If the input is one of the
1023     // engine input, setting the name here will overwrite engine input
1024     // bindings which will cause runtime error.
1025     // TODO(tmorris): Remove this work-around once we use TRT's IIdentityLayer
1026     // in ConvertIdentity.
1027     if (output.is_tensor()) {
1028       const char* tensor_name = output.tensor()->getName();
1029       if (!IsEngineInput(tensor_name)) {
1030         // TRT initializes tensor names as "(Unnamed ITensor* N)". We rename
1031         // them to match their corresponding TensorFlow name.
1032         // Note: ITensors that we create internally within TF-TRT which are
1033         // not inputs or outputs of a node will not be renamed. This is a
1034         // potential cause of confusion if an error message or warning
1035         // mentions the unnamed tensor.
1036         output.tensor()->setName(output_name.c_str());
1037       }
1038     }
1039     VLOG(2) << "Adding out tensor " << output_name << ": "
1040             << output.DebugString();
1041     Status status = AddTensorOrWeights(output_name, output);
1042     if (!status.ok()) {
1043       return Status(status.code(),
1044                     StrCat("Failed to add output for node ", node_def.name(),
1045                            ": ", status.error_message()));
1046     }
1047   }
1048   return Status::OK();
1049 }
1050 
AddInputTensor(const string & name,nvinfer1::DataType dtype,const nvinfer1::Dims & dims,int batch_size)1051 Status Converter::AddInputTensor(const string& name, nvinfer1::DataType dtype,
1052                                  const nvinfer1::Dims& dims, int batch_size) {
1053   // We verify the batch size only for the input nodes, and rely on individual
1054   // op converter to ensure the batch size of the outputs is not changed.
1055   // TODO(laigd): we need to test this properties.
1056   Status status = MaybeUpdateBatchSize(batch_size);
1057   if (!status.ok()) {
1058     return Status(status.code(), StrCat("Batch size doesn't match for tensor ",
1059                                         name, ": ", status.error_message()));
1060   }
1061   nvinfer1::ITensor* tensor = network()->addInput(name.c_str(), dtype, dims);
1062   if (tensor == nullptr) {
1063     return errors::InvalidArgument("Failed to create Input layer tensor ", name,
1064                                    " rank=", dims.nbDims);
1065   }
1066   status = AddTensorOrWeights(name, TRT_TensorOrWeights(tensor));
1067   if (!status.ok()) {
1068     return Status(status.code(), StrCat("Failed to add input tensor ", name,
1069                                         ": ", status.error_message()));
1070   }
1071   return Status::OK();
1072 }
1073 
RenameAndMarkOutputTensors(const std::vector<Converter::EngineOutputInfo> & output_tensors)1074 Status Converter::RenameAndMarkOutputTensors(
1075     const std::vector<Converter::EngineOutputInfo>& output_tensors) {
1076   for (const auto& output : output_tensors) {
1077     TRT_TensorOrWeights tensor_or_weights;
1078     TF_RETURN_IF_ERROR(
1079         GetTensorOrWeights(output.source_tensor_name, &tensor_or_weights));
1080     if (!tensor_or_weights.is_tensor()) {
1081       return errors::InvalidArgument("Output ", output.source_tensor_name,
1082                                      " is weights not tensor");
1083     }
1084     nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
1085     if (tensor == nullptr) {
1086       return errors::NotFound("Output tensor not found: ",
1087                               output.source_tensor_name);
1088     }
1089     // Check if this tensor has already been marked as an input or output.
1090     //
1091     // ConvertIdentity can cause the same tensor to be repeated in
1092     // output_tensors, which can cause us to overwrite the name of the output
1093     // tensor binding. For example, if we rename OutputPH_0 to OutputPH_1 then
1094     // we won't be able to locate OutputPH_0 during runtime. To fix this,
1095     // duplicate the tensor using no-op shuffle.
1096     //
1097     // TODO(tmorris): Remove this work-around once we use TRT's IIdentityLayer
1098     // in ConvertIdentity.
1099     if (IsEngineInput(tensor->getName()) || IsEngineOutput(tensor->getName())) {
1100       // Using shuffle layer for identity by not setting reshape or transpose.
1101       nvinfer1::IShuffleLayer* layer = network()->addShuffle(*tensor);
1102       TFTRT_RETURN_ERROR_IF_NULLPTR(
1103           layer, StrCat("Output Copy for ", tensor->getName()));
1104       MarkQuantizationRangesAsInferrable(tensor, layer->getOutput(0));
1105       tensor = layer->getOutput(0);
1106     }
1107     tensor->setName(output.dest_node_name.c_str());
1108     network()->markOutput(*tensor);
1109     // Set type after marking as output. TRT only supports setType for engine
1110     // outputs and inputs (type is inferred otherwise).
1111     tensor->setType(output.trt_dtype);
1112     VLOG(1) << "Marking output TRT tensor " << output.source_tensor_name
1113             << ", which feeds TF node " << output.dest_node_name;
1114   }
1115   return Status::OK();
1116 }
1117 
MaybeUpdateBatchSize(int batch_size)1118 Status Converter::MaybeUpdateBatchSize(int batch_size) {
1119   // OK iff either is unknown or they equal to each other.
1120   if (this->batch_size_ < 0 || batch_size < 0 ||
1121       this->batch_size_ == batch_size) {
1122     if (this->batch_size_ < 0 && batch_size >= 0) {
1123       this->batch_size_ = batch_size;
1124     }
1125     return Status::OK();
1126   }
1127   return errors::InvalidArgument(
1128       "Provided batch size does not match converter batch size: ", batch_size,
1129       " vs ", batch_size_);
1130 }
1131 
AddTensorOrWeights(const string & name,TRT_TensorOrWeights input)1132 Status Converter::AddTensorOrWeights(const string& name,
1133                                      TRT_TensorOrWeights input) {
1134   // Set the batch size of the tensor, using batch size collected from the
1135   // input tensors to the TRT subgraph at the beginning of the conversion.
1136   // We rely on the individual op converter to understand the semantics of the
1137   // TF node, and make sure it doesn't change the batch size nor introduce
1138   // intra-element dependency inside the batch.
1139   if (input.is_tensor()) input.set_batch_size(batch_size_);
1140   if (trt_tensors_.insert({name, std::move(input)}).second) return Status::OK();
1141   return errors::AlreadyExists("tensor/weights ", name, " already exist.");
1142 }
1143 
GetTensorOrWeights(const string & name,TRT_TensorOrWeights * output)1144 Status Converter::GetTensorOrWeights(const string& name,
1145                                      TRT_TensorOrWeights* output) {
1146   if (!trt_tensors_.count(name)) {
1147     return errors::NotFound("Tensor or weights with name ", name,
1148                             " could not be found.");
1149   }
1150   *output = trt_tensors_.at(name);
1151   return Status::OK();
1152 }
1153 
TransposeTensor(nvinfer1::ITensor * input_tensor,const std::vector<int> & order_with_batch_dim,const nvinfer1::ITensor ** output_tensor)1154 Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor,
1155                                   const std::vector<int>& order_with_batch_dim,
1156                                   const nvinfer1::ITensor** output_tensor) {
1157   const auto dims = input_tensor->getDimensions();
1158 
1159   if (order_with_batch_dim.size() - 1 != size_t(dims.nbDims)) {
1160     return errors::InvalidArgument(
1161         "Rank of perm for transpose does not match with that of the input.");
1162   }
1163   if (order_with_batch_dim[0] != 0) {
1164     return errors::Unimplemented(
1165         "Transpose at batch dimension is not supported.");
1166   }
1167 
1168   nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor);
1169   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Transpose");
1170   MarkQuantizationRangesAsInferrable(input_tensor, layer->getOutput(0));
1171 
1172   nvinfer1::Permutation permutation;
1173   for (int32_t i = 0; i < dims.nbDims; ++i) {
1174     permutation.order[i] = order_with_batch_dim[i + 1] - 1;
1175   }
1176   VLOG(1) << "TransposeTensor permutation: "
1177           << DebugString(permutation, dims.nbDims);
1178   layer->setFirstTranspose(permutation);
1179 
1180   nvinfer1::Dims reshape_dims;
1181   reshape_dims.nbDims = dims.nbDims;
1182   for (int32_t i = 0; i < reshape_dims.nbDims; ++i) {
1183     reshape_dims.d[i] = 0;
1184     // TODO(aaroey): why not transposing the types as well?
1185     reshape_dims.type[i] = dims.type[i];
1186   }
1187   layer->setReshapeDimensions(reshape_dims);
1188 
1189   *output_tensor = layer->getOutput(0);
1190   return Status::OK();
1191 }
1192 
GetWeightRange(const TRT_ShapedWeights & weights,float * out_min,float * out_max) const1193 Status Converter::GetWeightRange(const TRT_ShapedWeights& weights,
1194                                  float* out_min, float* out_max) const {
1195   switch (weights.type_) {
1196     case DataType::DT_FLOAT: {
1197       auto inp = static_cast<float const*>(weights.GetValues());
1198       auto result = std::minmax_element(inp, inp + weights.count());
1199       *out_min = *result.first;
1200       *out_max = *result.second;
1201       break;
1202     }
1203     case DataType::DT_HALF: {
1204       auto inp = static_cast<Eigen::half const*>(weights.GetValues());
1205       auto result = std::minmax_element(inp, inp + weights.count());
1206       *out_min = Eigen::half_impl::half_to_float(*result.first);
1207       *out_max = Eigen::half_impl::half_to_float(*result.second);
1208       break;
1209     }
1210     case DataType::DT_INT32: {
1211       auto inp = static_cast<int const*>(weights.GetValues());
1212       auto result = std::minmax_element(inp, inp + weights.count());
1213       *out_min = static_cast<float>(*result.first);
1214       *out_max = static_cast<float>(*result.second);
1215       break;
1216     }
1217     default:
1218       return errors::Unimplemented(
1219           "Data type not supported for GetWeightRange: ",
1220           DataTypeString(weights.type_));
1221   }
1222   return Status::OK();
1223 }
1224 
PrepareTensorForShape(const TRT_TensorOrWeights & input,const nvinfer1::Dims & dims,const bool validation_only,const nvinfer1::ITensor ** tensor)1225 Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input,
1226                                         const nvinfer1::Dims& dims,
1227                                         const bool validation_only,
1228                                         const nvinfer1::ITensor** tensor) {
1229   // If -1 is not used for one of the dims, we can check if the shapes are
1230   // compatible.
1231   bool can_check_shapes = true;
1232   for (int i = 0; i < dims.nbDims; i++) {
1233     if (dims.d[i] == -1) {
1234       can_check_shapes = false;
1235       break;
1236     }
1237   }
1238   if (can_check_shapes &&
1239       TrtDimsNumElements(input.GetTrtDims()) != TrtDimsNumElements(dims)) {
1240     return errors::InvalidArgument("Reshape shapes are not compatible (",
1241                                    DebugString(input.GetTrtDims()), " vs ",
1242                                    DebugString(dims), ")");
1243   }
1244   if (validation_only) {
1245     *tensor = nullptr;
1246     return Status::OK();
1247   }
1248 
1249   if (input.is_tensor()) {
1250     if (DimsEqual(input.GetTrtDims(), dims)) {
1251       *tensor = input.tensor();
1252     } else {
1253       nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(
1254           *const_cast<nvinfer1::ITensor*>(input.tensor()));
1255       TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Reshape");
1256       layer->setReshapeDimensions(dims);
1257       MarkQuantizationRangesAsInferrable(
1258           const_cast<nvinfer1::ITensor*>(input.tensor()), layer->getOutput(0));
1259       *tensor = layer->getOutput(0);
1260     }
1261   } else {
1262     *tensor = CreateConstantLayer(input.weights(), dims);
1263     TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, "TF-TRT Internal Reshape");
1264     if (precision_mode() == TrtPrecisionMode::INT8 && !use_calibration()) {
1265       // If we are in int8 mode and not calibrating, we need to explicitly set a
1266       // quantization range for the output tensor of the IConstantLayer. Here we
1267       // set the range to [min(weights), max(weights)].
1268       float min_range = 0.0f;
1269       float max_range = 0.0f;
1270       TF_RETURN_IF_ERROR(
1271           GetWeightRange(input.weights(), &min_range, &max_range));
1272       // Avoid setting range to 0 because TRT will throw an error. If the
1273       // weights are zero then the range doesn't matter: using 127.0f should
1274       // ensure the quantized weight will be exactly zero.
1275       if (min_range == 0.0f && max_range == 0.0f) {
1276         min_range = -127.0f;
1277         max_range = 127.0f;
1278       }
1279       ProvideQuantizationRange(const_cast<nvinfer1::ITensor*>(*tensor),
1280                                min_range, max_range);
1281     }
1282   }
1283   return Status::OK();
1284 }
1285 
MarkQuantizationRangesAsInferrable(nvinfer1::ITensor * input,nvinfer1::ITensor * output)1286 void Converter::MarkQuantizationRangesAsInferrable(nvinfer1::ITensor* input,
1287                                                    nvinfer1::ITensor* output) {
1288   quantization_infer_.push_back({input, output});
1289   quantization_infer_.push_back({output, input});
1290 }
1291 
ProvideQuantizationRange(nvinfer1::ITensor * tensor,float min_range,float max_range)1292 void Converter::ProvideQuantizationRange(nvinfer1::ITensor* tensor,
1293                                          float min_range, float max_range) {
1294   float symmetric_range = std::max(std::abs(min_range), std::abs(max_range));
1295   quantization_ranges_[tensor] = symmetric_range;
1296 }
1297 
MaybeApplyQuantizationRanges()1298 void Converter::MaybeApplyQuantizationRanges() {
1299   if (precision_mode() != TrtPrecisionMode::INT8) return;
1300 
1301   // Infer ranges across marked ops.
1302   PropagateQuantizationRanges();
1303   // Apply ranges.
1304 #if IS_TRT_VERSION_GE(5, 0, 0)
1305   for (auto pair : quantization_ranges_) {
1306     nvinfer1::ITensor* tensor = pair.first;
1307     const float range = pair.second;
1308     VLOG(1) << "Setting range for: " << tensor->getName() << ": " << range;
1309     // TODO(laigd): if 'tensor' already has a range set which doesn't match
1310     // 'range', it should report error.
1311     tensor->setDynamicRange(-range, range);
1312   }
1313 #endif
1314 
1315   // Warn user about tensors that are missing ranges. If TRT fuses some layers
1316   // then these tensors may not actually be required, which is why this is
1317   // just a warning. If we are still missing ranges even after fusion,
1318   // Builder::buildCudaEngine() will return nullptr and we will catch the
1319   // error at that point.
1320   if (!use_calibration()) {
1321     // Get all tensors from network
1322     std::set<nvinfer1::ITensor*> all_tensors;
1323     for (int i = 0; i < this->network()->getNbLayers(); i++) {
1324       nvinfer1::ILayer* layer = this->network()->getLayer(i);
1325       for (int j = 0; j < layer->getNbInputs(); j++) {
1326         all_tensors.insert(layer->getInput(j));
1327       }
1328       for (int j = 0; j < layer->getNbOutputs(); j++) {
1329         all_tensors.insert(layer->getOutput(j));
1330       }
1331     }
1332     // Find tensors with no ranges
1333     for (auto tensor : all_tensors) {
1334       if (!quantization_ranges_.count(tensor)) {
1335         // Note: there may be some warnings for "(Unnamed ITensor* N)". These
1336         // are tensors which are created internally by TF-TRT. The ranges for
1337         // these unnamed ITensors are always inferred from user provided ranges,
1338         // thus there will also be a warning for the range(s) the user missed.
1339         LOG(WARNING) << "Quantization range was not found for "
1340                      << tensor->getName() << ". "
1341                      << "This is okay if TensorRT does not need the range "
1342                      << "(e.g. due to node fusion).";
1343       }
1344     }
1345   }
1346 }
1347 
PropagateQuantizationRanges()1348 void Converter::PropagateQuantizationRanges() {
1349   // Propagate ranges across edges in quantization_infer_ until no new
1350   // information is added.
1351   // Note: this function modifies quantization_infer_, it might be better to
1352   // modify a copy instead if we for some reason need quantization_infer_
1353   // later.
1354   bool information_added = true;
1355   while (information_added) {
1356     information_added = false;
1357     for (auto it = quantization_infer_.begin();
1358          it != quantization_infer_.end();) {
1359       auto input_tensor_range = quantization_ranges_.find(it->first);
1360       auto output_tensor_range = quantization_ranges_.find(it->second);
1361       if (input_tensor_range != quantization_ranges_.end() &&
1362           output_tensor_range == quantization_ranges_.end()) {
1363         // Input has range but output doesn't: copy range
1364         // TODO(laigd): consider reporting error if it a different range is
1365         // already set.
1366         quantization_ranges_[it->second] = input_tensor_range->second;
1367         information_added = true;
1368         VLOG(1) << "Copy quantization range: " << it->first->getName() << " -> "
1369                 << it->second->getName();
1370       }
1371       // We can remove edges when the output range is known
1372       if (quantization_ranges_.find(it->second) != quantization_ranges_.end()) {
1373         it = quantization_infer_.erase(it);
1374       } else {
1375         ++it;
1376       }
1377     }
1378   }
1379 }
1380 
GetInputs(const NodeDef & node_def,std::vector<TRT_TensorOrWeights> * inputs) const1381 Status Converter::GetInputs(const NodeDef& node_def,
1382                             std::vector<TRT_TensorOrWeights>* inputs) const {
1383   for (auto const& input_name : node_def.input()) {
1384     /*************************************************************************
1385      * TODO(jie): handle case 1) here.
1386      * Normalizes the inputs and extracts associated metadata:
1387      * 1) Inputs can contain a colon followed by a suffix of characters.
1388      *    That suffix may be a single number (e.g. inputName:1) or several
1389      *    word characters separated from a number by a colon
1390      *    (e.g. inputName:foo:1). The
1391      *    latter case is used to denote inputs and outputs of functions.
1392      * 2) Control dependency inputs contain caret at the beginning and we
1393      *    remove this and annotate the edge as a control dependency.
1394      ************************************************************************/
1395     // skip control nodes
1396     if (input_name[0] == '^') continue;
1397     string name = input_name;
1398     auto last = name.find_last_of(':');
1399     // TODO(aaroey): use TensorId
1400     if (last != string::npos && last + 2 == name.size() &&
1401         name[last + 1] == '0') {
1402       name.erase(last);
1403     }
1404 
1405     if (trt_tensors_.count(name)) {
1406       TRT_TensorOrWeights input = trt_tensors_.at(name);
1407       inputs->push_back(input);
1408       VLOG(2) << "Retrieved input " << name << ": " << input.DebugString();
1409     } else {
1410       // TODO(aaroey): this should not happen, make it a CHECK.
1411       // TODO(aaroey): use StrCat for pattern like this.
1412       string msg("Node ");
1413       StrAppend(&msg, node_def.name(), " should have an input named '", name,
1414                 "' but it is not available");
1415       LOG(ERROR) << msg;
1416       return errors::InvalidArgument(msg);
1417     }
1418   }
1419   return Status::OK();
1420 }
1421 
1422 // Checks that the number of inputs match, and enforces that the inputs marked
1423 // as true are constant weights. true means that the input must be a weight,
1424 // while false means the input must be a tensor. In the future, false will mean
1425 // the input can be a tensor or weight.
CheckInputsWeights(const OpConverterParams & params,const std::vector<std::pair<string,bool>> & inputs_is_weight)1426 Status CheckInputsWeights(
1427     const OpConverterParams& params,
1428     const std::vector<std::pair<string, bool>>& inputs_is_weight) {
1429   const auto& inputs = params.inputs;
1430   const auto& node_def = params.node_def;
1431   if (inputs.size() != inputs_is_weight.size()) {
1432     return errors::InvalidArgument(
1433         node_def.op(), " got ", inputs.size(), " inputs but expected ",
1434         inputs_is_weight.size(), ", at ", node_def.name());
1435   }
1436   for (int i = 0; i < inputs.size(); i++) {
1437     if (inputs_is_weight[i].second && inputs.at(i).is_tensor()) {
1438       return errors::Unimplemented("The input \"", inputs_is_weight[i].first,
1439                                    "\" for ", node_def.op(),
1440                                    " must be a constant, at ", node_def.name());
1441     }
1442     // TODO(tmorris): Remove this check and provide a method to automatically
1443     // retrive an input as a tensor, converting via CreateConstantLayer if it
1444     // was originally a weight. We will want a caching mechanism to prevent many
1445     // duplicate constants from being created.
1446     if (!inputs_is_weight[i].second && inputs.at(i).is_weights()) {
1447       return errors::Unimplemented("The input \"", inputs_is_weight[i].first,
1448                                    "\" for ", node_def.op(),
1449                                    " must be a tensor, at ", node_def.name());
1450     }
1451   }
1452   return Status::OK();
1453 }
1454 
AllowDataTypes(const OpConverterParams & params,const std::set<DataType> & allowed_dtypes,const char * dtype_attr_name="T")1455 Status AllowDataTypes(const OpConverterParams& params,
1456                       const std::set<DataType>& allowed_dtypes,
1457                       const char* dtype_attr_name = "T") {
1458   const auto& node_def = params.node_def;
1459   TFAttrs attrs(node_def);
1460   if (!attrs.count(dtype_attr_name)) {
1461     return errors::InvalidArgument("Attribute with name ", dtype_attr_name,
1462                                    " not found.");
1463   }
1464   const auto op_dtype = attrs.get<DataType>(dtype_attr_name);
1465   if (!allowed_dtypes.count(op_dtype)) {
1466     // Build string list of allowed types.
1467     std::ostringstream ss;
1468     for (auto it = allowed_dtypes.begin(); it != allowed_dtypes.end(); ++it) {
1469       if (it != allowed_dtypes.begin()) ss << ", ";
1470       ss << DataTypeString(*it);
1471     }
1472     return errors::Unimplemented("Data type ", DataTypeString(op_dtype),
1473                                  " is not supported for ", node_def.op(),
1474                                  ", must be one of [", ss.str(), "], at ",
1475                                  node_def.name());
1476   }
1477   return Status::OK();
1478 }
1479 
ConvertFP32ToFP16(TrtWeightStore * store,const TRT_ShapedWeights & weights_src)1480 TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store,
1481                                     const TRT_ShapedWeights& weights_src) {
1482   auto dtype_new = DataType::DT_HALF;
1483   TRT_ShapedWeights weights =
1484       store->GetTempWeights(dtype_new, weights_src.shape_);
1485   const float* src = static_cast<const float*>(weights_src.GetValues());
1486   Eigen::half* dst = const_cast<Eigen::half*>(
1487       static_cast<Eigen::half const*>(weights.GetValues()));
1488   for (int64_t i = 0; i < weights_src.count(); i++) {
1489     dst[i] = Eigen::half_impl::float_to_half_rtne(src[i]);
1490   }
1491   return weights;
1492 }
1493 
1494 // ****************************************************************************
1495 // Constant folding functions for weights.
1496 // TODO(laigd): we should probably use eigen directly.
1497 // *****************************************************************************
1498 struct LambdaFactory {
1499   enum class OP_CATEGORY : int { RSQRT = 0, NEG, RECIP };
1500   OP_CATEGORY op;
1501 
1502   template <typename T>
unarytensorflow::tensorrt::convert::LambdaFactory1503   std::function<T(T)> unary() {
1504     switch (op) {
1505       case OP_CATEGORY::RSQRT: {
1506         VLOG(2) << "RSQRT GETS DONE";
1507         return [](T t) -> T { return 1.0 / sqrt(t); };
1508       }
1509       case OP_CATEGORY::NEG:
1510         return [](T t) -> T { return -t; };
1511       case OP_CATEGORY::RECIP:
1512         return [](T t) -> T { return 1.0 / t; };
1513       default:
1514         LOG(ERROR) << "Not supported op for unary: " << static_cast<int>(op);
1515         return nullptr;
1516     }
1517   }
1518 };
1519 
1520 template <>
unary()1521 std::function<Eigen::half(Eigen::half)> LambdaFactory::unary<Eigen::half>() {
1522   switch (op) {
1523     case OP_CATEGORY::RSQRT: {
1524       VLOG(2) << "RSQRT GETS DONE";
1525       return [](Eigen::half t) {
1526         return Eigen::half(1.0 / sqrt(static_cast<float>(t)));
1527       };
1528     }
1529     case OP_CATEGORY::NEG:
1530       return [](Eigen::half t) { return -t; };
1531     case OP_CATEGORY::RECIP:
1532       return [](Eigen::half t) {
1533         return Eigen::half(1.0 / static_cast<float>(t));
1534       };
1535     default:
1536       LOG(ERROR) << "Not supported op for unary: " << static_cast<int>(op);
1537       return nullptr;
1538   }
1539 }
1540 
UnaryCompute(const TRT_ShapedWeights & iweights,TRT_ShapedWeights * oweights,LambdaFactory unary_op)1541 Status UnaryCompute(const TRT_ShapedWeights& iweights,
1542                     TRT_ShapedWeights* oweights, LambdaFactory unary_op) {
1543   CHECK_EQ(iweights.type_, oweights->type_);
1544   switch (iweights.type_) {
1545     case DataType::DT_FLOAT: {
1546       auto inp = static_cast<float const*>(iweights.GetValues());
1547       auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues()));
1548       std::transform(inp, inp + iweights.count(), oup, unary_op.unary<float>());
1549       break;
1550     }
1551     case DataType::DT_HALF: {
1552       auto inp = static_cast<Eigen::half const*>(iweights.GetValues());
1553       auto oup =
1554           static_cast<Eigen::half*>(const_cast<void*>(oweights->GetValues()));
1555       std::transform(inp, inp + iweights.count(), oup,
1556                      unary_op.unary<Eigen::half>());
1557       break;
1558     }
1559     default:
1560       return errors::Unimplemented("Data type not supported: " +
1561                                    DataTypeString(iweights.type_));
1562   }
1563   return Status::OK();
1564 }
1565 
1566 // If swapped_inputs is false, 'tensor' is the left operand and 'weights' is the
1567 // right operand. If swapped_inputs is true, those two are swapped.
1568 //
1569 // TODO(jie): broadcast is needed yet not implemented.
1570 // Only implemented channel wise for the time being.
BinaryTensorOpWeight(OpConverterParams * params,const nvinfer1::ITensor * tensor,TRT_ShapedWeights weights,bool swapped_inputs)1571 Status BinaryTensorOpWeight(OpConverterParams* params,
1572                             const nvinfer1::ITensor* tensor,
1573                             TRT_ShapedWeights weights, bool swapped_inputs) {
1574   static const std::unordered_set<string> supported_ops = {"Sub", "Add", "Mul",
1575                                                            "Div", "RealDiv"};
1576   const auto& node_def = params->node_def;
1577   if (!supported_ops.count(node_def.op())) {
1578     return errors::Unimplemented(node_def.op(), " is not supported, at ",
1579                                  node_def.name());
1580   }
1581 
1582   // Check type consistency.
1583   nvinfer1::DataType trt_dtype;
1584   TF_RETURN_IF_ERROR(ConvertDType(weights.type_, &trt_dtype));
1585 
1586   // Check scale mode.
1587   auto dims_w = weights.shape_;
1588   const auto dims_t = tensor->getDimensions();
1589 
1590   // TODO(jie): addScale checks for input tensor dimension
1591   if (dims_t.nbDims != 3) {
1592     return errors::InvalidArgument("addScale requires tensor with rank 3, at ",
1593                                    node_def.name());
1594   }
1595 
1596   // Default to element-wise
1597   auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
1598 
1599   // TODO(jie): maybe use a permutation instead to support more cases;
1600   bool need_to_permute = false;
1601 
1602   if (weights.count() == 1) {
1603     scale_mode = nvinfer1::ScaleMode::kUNIFORM;
1604   } else {
1605     VLOG(2) << "weights dims: " << DebugString(dims_w)
1606             << "; tensor dims: " << DebugString(dims_t);
1607     // Make sure no broadcasting on batch dimension.
1608     if (dims_w.nbDims == dims_t.nbDims + 1) {
1609       if (dims_w.d[0] == 1) {
1610         for (int i = 1; i < dims_w.nbDims; i++) {
1611           dims_w.d[i - 1] = dims_w.d[i];
1612         }
1613         dims_w.nbDims--;
1614       } else {
1615         return errors::InvalidArgument("Binary op cannot operate on batch, at ",
1616                                        node_def.name());
1617       }
1618     }
1619 
1620     if (dims_w.nbDims == dims_t.nbDims && dims_w.d[0] == dims_t.d[0]) {
1621       scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
1622       // Default is element-wise
1623       for (int i = 1; i < dims_w.nbDims; i++) {
1624         if (dims_w.d[i] != dims_t.d[i]) {
1625           // If dimension does not match, switch back to per-channel
1626           scale_mode = nvinfer1::ScaleMode::kCHANNEL;
1627           break;
1628         }
1629       }
1630       // If the mode is per-channel, since channel dimension is assumed to be
1631       // the third to last dimension, we need to make sure all other dimensions
1632       // have size 1.
1633       if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) {
1634         for (int i = 1; i < dims_w.nbDims; i++) {
1635           if (dims_w.d[i] != 1)
1636             return errors::InvalidArgument(
1637                 "Weight dims not compatible for channel-wise broadcast at ",
1638                 node_def.name());
1639         }
1640       }
1641     } else if (dims_w.nbDims == 1 &&
1642                dims_w.d[0] == dims_t.d[dims_t.nbDims - 1]) {
1643       // Channel wise and broadcast required. We compare the last dimension of
1644       // the tensor shape because of tensorflow default broadcasting rules.
1645       need_to_permute = true;
1646       scale_mode = nvinfer1::ScaleMode::kCHANNEL;
1647     } else {
1648       return errors::InvalidArgument("Weight dims not compatible at ",
1649                                      node_def.name());
1650     }
1651   }
1652   // TODO(laigd): we should add validation_only support in TransposeTensor() and
1653   // PrepareTensorForShape().
1654   if (params->validation_only) return Status::OK();
1655 
1656   // Transpose last dimension.
1657   std::vector<int> permutation(dims_t.nbDims + 1);
1658   if (need_to_permute) {
1659     // We swap the last dimension into channel for trt, because of tensorflow
1660     // default broadcasting rules.
1661     for (int i = 0; i < static_cast<int>(permutation.size()); i++) {
1662       permutation[i] = i;
1663     }
1664     permutation[1] = dims_t.nbDims;
1665     permutation[dims_t.nbDims] = 1;
1666     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
1667         const_cast<nvinfer1::ITensor*>(tensor), permutation, &tensor));
1668   }
1669 
1670   if (params->converter->precision_mode() == TrtPrecisionMode::FP16) {
1671     weights = ConvertFP32ToFP16(params->weight_store, weights);
1672   }
1673 
1674   // Prepare weights
1675   TRT_ShapedWeights shift_weights(weights.type_);
1676   TRT_ShapedWeights scale_weights(weights.type_);
1677   TRT_ShapedWeights power_weights(weights.type_);
1678 
1679   if (node_def.op() == "Sub") {
1680     if (swapped_inputs) {
1681       shift_weights = weights;
1682       nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary(
1683           *const_cast<nvinfer1::ITensor*>(tensor),
1684           nvinfer1::UnaryOperation::kNEG);
1685       TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
1686       // Since quantization ranges are symmetric, the same range as the input
1687       // will work for the negation of the input.
1688       params->converter->MarkQuantizationRangesAsInferrable(
1689           const_cast<nvinfer1::ITensor*>(tensor), layer->getOutput(0));
1690       tensor = layer->getOutput(0);
1691     } else {
1692       TRT_ShapedWeights neg_weights =
1693           params->weight_store->GetTempWeights(weights);
1694       LambdaFactory unary_op;
1695       unary_op.op = LambdaFactory::OP_CATEGORY::NEG;
1696       TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op));
1697       shift_weights = neg_weights;
1698     }
1699   } else if (node_def.op() == "Div" || node_def.op() == "RealDiv") {
1700     if (swapped_inputs) {
1701       // We need to infer the quantization range for this intermediate tensor.
1702       //
1703       //   x -> [Recip] -> 1/x -> [Scale] -> s/x
1704       //                    ^
1705       //            need range for this
1706       //
1707       // We have the quantization scales for x and s/x - can we divide the scale
1708       // for s/x by s? Only if it is a scalar.
1709       //
1710       // Because of this issue, fall back to BinaryTensorOpTensor if we are
1711       // doing INT8 with no calibration. There is most likely no performance
1712       // penalty by falling back here.
1713       if (params->converter->precision_mode() == TrtPrecisionMode::INT8 &&
1714           !params->converter->use_calibration()) {
1715         return errors::Unimplemented(
1716             "Intermediate quantization range cannot be determined without"
1717             " calibration. Falling back to BinaryTensorOpTensor for ",
1718             node_def.op(), ", at ", node_def.name());
1719       }
1720       scale_weights = weights;
1721       nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary(
1722           *const_cast<nvinfer1::ITensor*>(tensor),
1723           nvinfer1::UnaryOperation::kRECIP);
1724       TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
1725       tensor = layer->getOutput(0);
1726     } else {
1727       TRT_ShapedWeights recip_weights =
1728           params->weight_store->GetTempWeights(weights);
1729       LambdaFactory unary_op;
1730       unary_op.op = LambdaFactory::OP_CATEGORY::RECIP;
1731       TF_RETURN_IF_ERROR(UnaryCompute(weights, &recip_weights, unary_op));
1732       scale_weights = recip_weights;
1733     }
1734   } else if (node_def.op() == "Mul") {
1735     scale_weights = weights;
1736   } else if (node_def.op() == "Add") {
1737     shift_weights = weights;
1738   } else {
1739     // This should not happen.
1740     return errors::Unimplemented("Binary op not supported at ", node_def.op());
1741   }
1742 
1743   nvinfer1::IScaleLayer* layer = params->converter->network()->addScale(
1744       *const_cast<nvinfer1::ITensor*>(tensor), scale_mode,
1745       shift_weights.GetTrtWeights(), scale_weights.GetTrtWeights(),
1746       power_weights.GetTrtWeights());
1747   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
1748 
1749   const nvinfer1::ITensor* output_tensor = layer->getOutput(0);
1750   // Transpose back dimension
1751   if (need_to_permute) {
1752     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
1753         const_cast<nvinfer1::ITensor*>(output_tensor), permutation,
1754         &output_tensor));
1755   }
1756 
1757   // Pass the output
1758   params->outputs->push_back(
1759       TRT_TensorOrWeights(const_cast<nvinfer1::ITensor*>(output_tensor)));
1760   return Status::OK();
1761 }
1762 
ConvertConv2DHelper(OpConverterParams * params,int group,bool is_conv2d_backprop_input)1763 Status ConvertConv2DHelper(OpConverterParams* params, int group,
1764                            bool is_conv2d_backprop_input) {
1765   const auto& inputs = params->inputs;
1766   const auto& node_def = params->node_def;
1767   TRT_TensorOrWeights backprop_output_size;
1768   const nvinfer1::ITensor* tensor = nullptr;
1769   if (is_conv2d_backprop_input) {
1770     // In the case when Conv2dBackpropInput is used for conv2d_transpose, these
1771     // inputs correspond to: output size, filter, and input.
1772     TF_RETURN_IF_ERROR(CheckInputsWeights(
1773         *params,
1774         {{"input_sizes", true}, {"filter", true}, {"out_backprop", false}}));
1775     backprop_output_size = inputs.at(0);
1776     tensor = inputs.at(2).tensor();
1777   } else {
1778     TF_RETURN_IF_ERROR(
1779         CheckInputsWeights(*params, {{"input", false}, {"filter", true}}));
1780     tensor = inputs.at(0).tensor();
1781   }
1782   TF_RETURN_IF_ERROR(
1783       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
1784   TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
1785   if (weights_rsck.shape_.nbDims != 4) {
1786     return errors::InvalidArgument("Conv2D expects kernel of dimension 4, at " +
1787                                    node_def.name());
1788   }
1789   TFAttrs attrs(node_def);
1790   auto data_format = attrs.get<string>("data_format");
1791   int c_index = (data_format == "NHWC") ? 3 : 1;
1792   int h_index = (data_format == "NHWC") ? 1 : 2;
1793   int w_index = (data_format == "NHWC") ? 2 : 3;
1794   auto tf_dilations = attrs.get<std::vector<int64>>("dilations");
1795   if (tf_dilations.size() != 4) {
1796     return errors::InvalidArgument(
1797         "Convolution dilations field must specify 4 dimensions, at ",
1798         node_def.name());
1799   }
1800   if (tf_dilations[0] != 1 || tf_dilations[c_index] != 1) {
1801     return errors::Unimplemented(
1802         "Dilation rate must be 1 for batch and channel dimensions, at ",
1803         node_def.name());
1804   }
1805   const nvinfer1::DimsHW dilation(tf_dilations[h_index], tf_dilations[w_index]);
1806   if (is_conv2d_backprop_input && (dilation.d[0] != 1 || dilation.d[1] != 1)) {
1807     return errors::Unimplemented(
1808         "Dilation with Conv2DBackpropInput (conv2d_transpose) is not supported",
1809         ", at ", node_def.name());
1810   }
1811 
1812   const auto tf_stride = attrs.get<std::vector<int64>>("strides");
1813   if (tf_stride.size() != 4) {
1814     return errors::InvalidArgument(
1815         "Convolution strides field must specify 4 dimensions, at ",
1816         node_def.name());
1817   }
1818   if (tf_stride[0] != 1 || tf_stride[c_index] != 1) {
1819     return errors::Unimplemented(
1820         "Stride must be 1 for batch and channel dimensions, at ",
1821         node_def.name());
1822   }
1823   const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
1824   if (params->validation_only) return Status::OK();
1825 
1826   // Transpose to NCHW (NCHW is required for IConvLayer).
1827   const bool need_transpose = (data_format == "NHWC");
1828   if (need_transpose) {
1829     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
1830         const_cast<nvinfer1::ITensor*>(tensor), {0, 3, 1, 2}, &tensor));
1831   }
1832   // Dimensions of transposed tensor.
1833   const auto tensor_dim = tensor->getDimensions();
1834 
1835   // group == 0 signifies that this is a depthwise convolution, so set
1836   // num_groups to size of input's channel dim. For a non-depthwise conv,
1837   // num_groups will be 1.
1838   const int num_groups = (group == 0) ? tensor_dim.d[0] : group;
1839 
1840   if (params->converter->precision_mode() == TrtPrecisionMode::FP16) {
1841     weights_rsck = ConvertFP32ToFP16(params->weight_store, weights_rsck);
1842   }
1843   // For conv, TF weights are RSCK, and TRT expects KCRS.
1844   // For backprop, TF weights are RSKC, and TRT expects CKRS.
1845   // Therefore, this reorder will work for both cases.
1846   TRT_ShapedWeights weights =
1847       params->weight_store->GetTempWeights(weights_rsck);
1848   ReorderRSCKToKCRS(weights_rsck, &weights, num_groups);
1849   TRT_ShapedWeights biases(weights.type_);
1850   const int output_axis = is_conv2d_backprop_input ? 1 : 0;
1851   const int noutput = weights.shape_.d[output_axis] * num_groups;
1852   nvinfer1::DimsHW kernel_size;
1853   kernel_size.h() = weights.shape_.d[2];
1854   kernel_size.w() = weights.shape_.d[3];
1855 
1856   // Add padding.
1857   std::vector<std::pair<int, int>> padding;
1858   if (attrs.get<string>("padding") == "SAME") {
1859     nvinfer1::DimsHW effective_kernel_size = kernel_size;
1860     effective_kernel_size.h() += (kernel_size.h() - 1) * (dilation.h() - 1);
1861     effective_kernel_size.w() += (kernel_size.w() - 1) * (dilation.w() - 1);
1862     std::vector<int64_t> input_dims;
1863     if (is_conv2d_backprop_input) {
1864       // For backprop, calculate padding based on "input_sizes" input, which
1865       // actually corresponds to output size. ("input_sizes" makes sense in the
1866       // context of Conv2DBackpropInput).
1867       // We use h_index and w_index instead of 1 and 2 because we havent
1868       // transposed backprop_output_size along with the input.
1869       auto output_size_weights = static_cast<int*>(
1870           const_cast<void*>(backprop_output_size.weights().GetValues()));
1871       input_dims = {output_size_weights[h_index], output_size_weights[w_index]};
1872     } else {
1873       // Use 1 and 2 because tensor_dim has the dimensions of the transposed
1874       // input.
1875       input_dims = {static_cast<int>(tensor_dim.d[1]),
1876                     static_cast<int>(tensor_dim.d[2])};
1877     }
1878     padding = CreateSamePadding(stride, effective_kernel_size, input_dims);
1879   } else {
1880     padding = {{0, 0}, {0, 0}};
1881   }
1882   if (padding[0].first != padding[0].second ||
1883       padding[1].first != padding[1].second) {
1884     // Handle asymmetric padding.
1885     auto pad_layer = params->converter->network()->addPadding(
1886         *const_cast<nvinfer1::ITensor*>(tensor),
1887         nvinfer1::DimsHW(padding[0].first, padding[1].first),
1888         nvinfer1::DimsHW(padding[0].second, padding[1].second));
1889     TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name());
1890     params->converter->MarkQuantizationRangesAsInferrable(
1891         const_cast<nvinfer1::ITensor*>(tensor), pad_layer->getOutput(0));
1892     padding = {{0, 0}, {0, 0}};
1893     tensor = pad_layer->getOutput(0);
1894   }
1895 
1896   // Add convolution.
1897   nvinfer1::ILayer* conv_layer = nullptr;
1898   if (is_conv2d_backprop_input) {
1899     nvinfer1::IDeconvolutionLayer* layer =
1900         params->converter->network()->addDeconvolution(
1901             *const_cast<nvinfer1::ITensor*>(tensor), noutput, kernel_size,
1902             weights.GetTrtWeights(), biases.GetTrtWeights());
1903     TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
1904     layer->setStride(stride);
1905     layer->setPadding({padding[0].first, padding[1].first});
1906     layer->setName(node_def.name().c_str());
1907     layer->setNbGroups(num_groups);
1908     conv_layer = layer;
1909   } else {
1910     nvinfer1::IConvolutionLayer* layer =
1911         params->converter->network()->addConvolution(
1912             *const_cast<nvinfer1::ITensor*>(tensor), noutput, kernel_size,
1913             weights.GetTrtWeights(), biases.GetTrtWeights());
1914     TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
1915     layer->setStride(stride);
1916     layer->setPadding({padding[0].first, padding[1].first});
1917     layer->setName(node_def.name().c_str());
1918     layer->setNbGroups(num_groups);
1919     layer->setDilation(dilation);
1920     conv_layer = layer;
1921   }
1922   const nvinfer1::ITensor* output_tensor = conv_layer->getOutput(0);
1923 
1924   // Restore transpose.
1925   if (need_transpose) {
1926     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
1927         const_cast<nvinfer1::ITensor*>(output_tensor), {0, 2, 3, 1},
1928         &output_tensor));
1929   }
1930   params->outputs->push_back(
1931       TRT_TensorOrWeights(const_cast<nvinfer1::ITensor*>(output_tensor)));
1932   return Status::OK();
1933 }
1934 
BinaryTensorOpTensor(OpConverterParams * params,const TRT_TensorOrWeights & operand_l,const TRT_TensorOrWeights & operand_r)1935 Status BinaryTensorOpTensor(OpConverterParams* params,
1936                             const TRT_TensorOrWeights& operand_l,
1937                             const TRT_TensorOrWeights& operand_r) {
1938   const auto& node_def = params->node_def;
1939   static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
1940       {"Add", nvinfer1::ElementWiseOperation::kSUM},
1941       {"Mul", nvinfer1::ElementWiseOperation::kPROD},
1942       {"Sub", nvinfer1::ElementWiseOperation::kSUB},
1943       {"Div", nvinfer1::ElementWiseOperation::kDIV},
1944       {"RealDiv", nvinfer1::ElementWiseOperation::kDIV},
1945       {"Minimum", nvinfer1::ElementWiseOperation::kMIN},
1946       {"Maximum", nvinfer1::ElementWiseOperation::kMAX},
1947       {"Pow", nvinfer1::ElementWiseOperation::kPOW},
1948   };
1949   auto op_pair = ops.find(node_def.op());
1950   if (op_pair == ops.end()) {
1951     return errors::Unimplemented("Binary op ", node_def.op(),
1952                                  " not supported at: ", node_def.name());
1953   }
1954 
1955   nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
1956   Status status = params->converter->GetTrtBroadcastShape(
1957       operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r);
1958   if (!status.ok()) {
1959     return errors::InvalidArgument(
1960         "Unsupported binary op broadcast scheme for op ", node_def.name(), ": ",
1961         status.error_message());
1962   }
1963   TFAttrs attrs(node_def);
1964   nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
1965   if (dtype == nvinfer1::DataType::kINT32) {
1966     return errors::Unimplemented("Binary op ", node_def.op(),
1967                                  " does not support INT32, at ",
1968                                  node_def.name());
1969   }
1970   if (params->validation_only) return Status::OK();
1971 
1972   const nvinfer1::ITensor* tensor_l = nullptr;
1973   const nvinfer1::ITensor* tensor_r = nullptr;
1974   status = params->converter->PrepareTensorForShape(
1975       operand_l, broadcasted_dims_l, /*validation_only=*/false, &tensor_l);
1976   if (status.ok()) {
1977     status = params->converter->PrepareTensorForShape(
1978         operand_r, broadcasted_dims_r, /*validation_only=*/false, &tensor_r);
1979   }
1980   if (!status.ok()) {
1981     return errors::Internal("Failed to convert binary op ", node_def.name(),
1982                             ": ", status.error_message());
1983   }
1984 
1985   // Check type consistency.
1986   TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype)
1987       << DebugString(tensor_l->getType()) << " vs " << DebugString(dtype);
1988   TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype)
1989       << DebugString(tensor_r->getType()) << " vs " << DebugString(dtype);
1990 
1991   // Add ElementWise layer.
1992   nvinfer1::IElementWiseLayer* layer =
1993       params->converter->network()->addElementWise(
1994           *const_cast<nvinfer1::ITensor*>(tensor_l),
1995           *const_cast<nvinfer1::ITensor*>(tensor_r), op_pair->second);
1996   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
1997   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
1998 
1999   // Pass the output
2000   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
2001   return Status::OK();
2002 }
2003 
ConvertPlugin(OpConverterParams * params)2004 Status ConvertPlugin(OpConverterParams* params) {
2005   const auto& inputs = params->inputs;
2006   const auto& node_def = params->node_def;
2007   // prepare input
2008   std::vector<nvinfer1::ITensor*> all_inputs;
2009   all_inputs.reserve(inputs.size());
2010   for (auto input : inputs) {
2011     all_inputs.emplace_back(const_cast<nvinfer1::ITensor*>(input.tensor()));
2012   }
2013 
2014   // plugin is owned by PluginFactory
2015   // TODO(jie): destroy plugins later (resource management)
2016   PluginTensorRT* plugin =
2017       PluginFactoryTensorRT::GetInstance()->CreatePlugin(node_def.op());
2018 
2019   // passing attributes
2020   // TODO(jie): support more general attribute
2021   TFAttrs attrs(node_def);
2022   auto attr_key_vector = attrs.GetAllAttrKeys();
2023   for (auto attr_key : attr_key_vector) {
2024     // TODO(jie): support only list of float for toy example here.
2025     auto data = attrs.get<std::vector<float>>(attr_key);
2026     size_t size_data = data.size() * sizeof(float);
2027     if (!plugin->SetAttribute(attr_key, static_cast<void*>(data.data()),
2028                               size_data)) {
2029       return errors::InvalidArgument("plugin SetAttribute failed");
2030     }
2031   }
2032 
2033   nvinfer1::IPluginLayer* layer = params->converter->network()->addPlugin(
2034       &all_inputs[0], static_cast<int>(inputs.size()), *plugin);
2035 
2036   for (int i = 0; i < layer->getNbOutputs(); i++) {
2037     nvinfer1::ITensor* output_tensor = layer->getOutput(i);
2038     params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
2039   }
2040   return Status::OK();
2041 }
2042 
ConvertTranspose(OpConverterParams * params)2043 Status ConvertTranspose(OpConverterParams* params) {
2044   const auto& inputs = params->inputs;
2045   TF_RETURN_IF_ERROR(
2046       CheckInputsWeights(*params, {{"x", false}, {"perm", true}}));
2047   TF_RETURN_IF_ERROR(AllowDataTypes(
2048       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
2049   // Get the permutation from weights.
2050   TRT_ShapedWeights weights = inputs.at(1).weights();
2051   const int* weights_ptr =
2052       static_cast<int*>(const_cast<void*>(weights.GetValues()));
2053   std::vector<int> perm(weights_ptr, weights_ptr + weights.count());
2054 
2055   // Verify the permutation.
2056   nvinfer1::ITensor* input_tensor =
2057       const_cast<nvinfer1::ITensor*>(inputs.at(0).tensor());
2058   if (perm.size() - 1 != size_t(input_tensor->getDimensions().nbDims)) {
2059     return errors::InvalidArgument(
2060         "Rank of perm for transpose does not match with that of the input.");
2061   }
2062   if (perm[0] != 0) {
2063     return errors::Unimplemented(
2064         "Transpose at batch dimension is not supported.");
2065   }
2066 
2067   if (params->validation_only) return Status::OK();
2068 
2069   // Start conversion.
2070   const nvinfer1::ITensor* output_tensor = nullptr;
2071   TF_RETURN_IF_ERROR(
2072       params->converter->TransposeTensor(input_tensor, perm, &output_tensor));
2073   params->outputs->push_back(
2074       TRT_TensorOrWeights(const_cast<nvinfer1::ITensor*>(output_tensor)));
2075   return Status::OK();
2076 }
2077 
ConvertReshape(OpConverterParams * params)2078 Status ConvertReshape(OpConverterParams* params) {
2079   const auto& inputs = params->inputs;
2080   const auto& node_def = params->node_def;
2081   TF_RETURN_IF_ERROR(
2082       CheckInputsWeights(*params, {{"tensor", false}, {"shape", true}}));
2083   TF_RETURN_IF_ERROR(AllowDataTypes(
2084       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
2085   TRT_TensorOrWeights input_tensor = inputs.at(0);
2086   TRT_ShapedWeights weights = inputs.at(1).weights();
2087   if (weights.count() == 0) {
2088     return errors::Unimplemented("Reshape to shape=[] is not supported, at ",
2089                                  node_def.name());
2090   }
2091 
2092   const int* weights_ptr =
2093       static_cast<int*>(const_cast<void*>(weights.GetValues()));
2094 
2095   // Check that it doesn't change the batch dimension. This check is
2096   // conservative, for example, when the first dim of the shape is -1 and input
2097   // tensor shape is not fixed, it is still possible that the reshape doesn't
2098   // change the batch dim, but as long as there is a possibility that it could
2099   // change the batch dim, it reject the conversion. The parameters are:
2100   //
2101   // * reshape_batch_dim: the value of the first dim of the input shape constant
2102   // * reshape_dims: all other dims of the input shape constant
2103   // * input_batch_dim: the value of the first dim of the input tensor to
2104   //   reshape
2105   // * input_dims: all other dims of the input tensor to reshape
2106   //
2107   // The validation logic is:
2108   //
2109   // if input_batch_dim is fixed:
2110   //   if reshape_batch_dim == input_batch_dim:
2111   //     ok
2112   //   elif reshape_batch_dim == -1 (meaning reshape_dims are fixed) and
2113   //        input_dims are fixed and
2114   //        prod(input_dims) == prod(reshape_dims)
2115   //     ok
2116   //   else:
2117   //     not ok
2118   // elif input_dims are fixed:
2119   //   if reshape_dims are fixed and
2120   //      prod(input_dims) == prod(reshape_dims):
2121   //     ok
2122   //   else:
2123   //     not ok
2124   // else:
2125   //   not ok
2126 
2127   const int input_batch_dim = input_tensor.batch_size();
2128   const int reshape_batch_dim = weights_ptr[0];
2129   const nvinfer1::Dims input_dims = input_tensor.GetTrtDims();
2130 
2131   nvinfer1::Dims reshape_dims;
2132   reshape_dims.nbDims = weights.count() - 1;
2133   for (int i = 1; i < weights.count(); i++) {
2134     reshape_dims.d[i - 1] = weights_ptr[i];
2135   }
2136 
2137   // Check that it doesn't change the batch dimension according to the logic
2138   // mentioned above.
2139   bool reshape_may_change_batch_dim = false;
2140   if (input_batch_dim > 0) {        // Batch size is fixed.
2141     if (reshape_batch_dim == -1) {  // Other dims of the shape must be fixed.
2142       if (!HasStaticShape(input_dims) ||
2143           TrtDimsNumElements(reshape_dims) != TrtDimsNumElements(input_dims)) {
2144         reshape_may_change_batch_dim = true;
2145       }
2146     } else if (reshape_batch_dim != input_batch_dim) {
2147       reshape_may_change_batch_dim = true;
2148     }
2149   } else if (HasStaticShape(input_dims)) {
2150     if (!HasStaticShape(reshape_dims) ||
2151         TrtDimsNumElements(reshape_dims) != TrtDimsNumElements(input_dims)) {
2152       reshape_may_change_batch_dim = true;
2153     }
2154   } else {
2155     reshape_may_change_batch_dim = true;
2156   }
2157   VLOG(1) << "input_batch_dim=" << input_batch_dim
2158           << ", input_dims=" << DebugString(input_dims)
2159           << "\nreshape_batch_dim=" << reshape_batch_dim
2160           << ", reshape_dims=" << DebugString(reshape_dims);
2161   if (reshape_may_change_batch_dim) {
2162     const string msg = StrCat(
2163         "Reshape on batch dimension is not supported, at ", node_def.name());
2164     return errors::Unimplemented(msg);
2165   }
2166   if (params->validation_only) return Status::OK();
2167 
2168   // Start conversion.
2169   const nvinfer1::ITensor* output_tensor = nullptr;
2170   TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
2171       input_tensor, reshape_dims, /*validation_only=*/false, &output_tensor));
2172   params->outputs->push_back(
2173       TRT_TensorOrWeights(const_cast<nvinfer1::ITensor*>(output_tensor)));
2174   return Status::OK();
2175 }
2176 
ConvertExpandDims(OpConverterParams * params)2177 Status ConvertExpandDims(OpConverterParams* params) {
2178   const auto& inputs = params->inputs;
2179   const auto& node_def = params->node_def;
2180   TF_RETURN_IF_ERROR(
2181       CheckInputsWeights(*params, {{"input", false}, {"axis", true}}));
2182   TF_RETURN_IF_ERROR(AllowDataTypes(
2183       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
2184   // Get input shape as vector.
2185   TRT_TensorOrWeights input_tensor = inputs.at(0);
2186   const nvinfer1::Dims dims = input_tensor.GetTrtDims();
2187   std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
2188   // Add batch dim back.
2189   input_dims.insert(input_dims.begin(), -1);
2190   const int input_rank = input_dims.size();
2191   // Get axis to expand on.
2192   TRT_ShapedWeights weights = inputs.at(1).weights();
2193   if (weights.count() != 1) {
2194     return errors::InvalidArgument("ExpandDims axis must be a scalar, at ",
2195                                    node_def.name());
2196   }
2197   const int* weights_ptr =
2198       static_cast<int*>(const_cast<void*>(weights.GetValues()));
2199   int axis = weights_ptr[0];
2200   // Make sure axis is valid.
2201   if ((axis < (-input_rank - 1)) || (axis > input_rank)) {
2202     return errors::InvalidArgument(
2203         "Axis for ExpandDims is invalid, must be in the range "
2204         "[-rank(input) - 1, rank(input)], at ",
2205         node_def.name());
2206   }
2207   // Convert negative axis to corresponding positive axis.
2208   if (axis < 0) axis += input_rank + 1;
2209   if (axis == 0) {
2210     return errors::Unimplemented(
2211         "Modifying batch dimension is not supported for ExpandDims, at ",
2212         node_def.name());
2213   }
2214   if (params->validation_only) return Status::OK();
2215 
2216   // ExpandDims: Insert new dim of size 1.
2217   input_dims.insert(input_dims.begin() + axis, 1);
2218   // Reshape tensor.
2219   nvinfer1::Dims new_dims;
2220   TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims,
2221                                                /*ignore_first_dim=*/true));
2222   const nvinfer1::ITensor* output_tensor = nullptr;
2223   TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
2224       input_tensor, new_dims, /*validation_only=*/false, &output_tensor));
2225   params->outputs->push_back(
2226       TRT_TensorOrWeights(const_cast<nvinfer1::ITensor*>(output_tensor)));
2227   return Status::OK();
2228 }
2229 
ConvertSqueeze(OpConverterParams * params)2230 Status ConvertSqueeze(OpConverterParams* params) {
2231   const auto& inputs = params->inputs;
2232   const auto& node_def = params->node_def;
2233   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
2234   TF_RETURN_IF_ERROR(AllowDataTypes(
2235       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
2236   // Get input shape.
2237   TRT_TensorOrWeights input_tensor = inputs.at(0);
2238   const nvinfer1::Dims dims = input_tensor.GetTrtDims();
2239   std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
2240   // Add batch dim back.
2241   input_dims.insert(input_dims.begin(), -1);
2242   const int input_rank = input_dims.size();
2243   // Mark axes to remove by setting them to 0.
2244   TFAttrs attrs(node_def);
2245   auto squeeze_dims = attrs.get<std::vector<int64>>("squeeze_dims");
2246   if (squeeze_dims.empty()) {
2247     return errors::Unimplemented(
2248         "Squeeze is only implemented for explicit dims, at ", node_def.name());
2249   }
2250   for (int axis : squeeze_dims) {
2251     // Make sure axis is valid.
2252     if ((axis < -input_rank) || (axis >= input_rank)) {
2253       return errors::InvalidArgument(
2254           "Axis for Squeeze is invalid, must be in the range "
2255           "[-rank(input), rank(input)), at ",
2256           node_def.name());
2257     }
2258     // Convert negative axis to corresponding positive axis.
2259     if (axis < 0) axis += input_rank;
2260     // Don't squeeze batch dim.
2261     if (axis == 0) {
2262       return errors::Unimplemented("Cannot squeeze batch dimension, at ",
2263                                    node_def.name());
2264     }
2265     // Make sure target dimension is size 1.
2266     if (input_dims[axis] != 1) {
2267       return errors::InvalidArgument(
2268           "Cannot squeeze ", axis, "th dimension ", input_dims[axis],
2269           " which isn't size 1, at ", node_def.name());
2270     }
2271     // Mark dim for removal by setting to 0.
2272     input_dims[axis] = 0;
2273   }
2274   if (params->validation_only) return Status::OK();
2275 
2276   // Remove all dims which are equal to 0.
2277   input_dims.erase(std::remove(input_dims.begin(), input_dims.end(), 0),
2278                    input_dims.end());
2279   // Reshape tensor.
2280   nvinfer1::Dims new_dims;
2281   TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims,
2282                                                /*ignore_first_dim=*/true));
2283   const nvinfer1::ITensor* output_tensor = nullptr;
2284   TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
2285       input_tensor, new_dims, /*validation_only=*/false, &output_tensor));
2286   params->outputs->push_back(
2287       TRT_TensorOrWeights(const_cast<nvinfer1::ITensor*>(output_tensor)));
2288   return Status::OK();
2289 }
2290 
ConvertStridedSliceHelper(OpConverterParams * params,const TRT_TensorOrWeights & input,std::vector<int> begin,std::vector<int> size,const std::vector<int> & stride)2291 Status ConvertStridedSliceHelper(OpConverterParams* params,
2292                                  const TRT_TensorOrWeights& input,
2293                                  std::vector<int> begin, std::vector<int> size,
2294                                  const std::vector<int>& stride) {
2295   const auto& node_def = params->node_def;
2296   // Get input dims.
2297   nvinfer1::Dims dims = input.GetTrtDims();
2298   std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
2299   // Temporarily add batch dimension so that indexes line up properly.
2300   input_dims.insert(input_dims.begin(), -1);
2301   // Check bounds.
2302   for (int i = 1; i < input_dims.size(); i++) {
2303     if (begin[i] < 0 || begin[i] > input_dims[i]) {
2304       return errors::InvalidArgument("\"begin\" for dimension ",
2305                                      std::to_string(i), " in ", node_def.op(),
2306                                      " is out of range, at ", node_def.name());
2307     }
2308     const int end = begin[i] + size[i];
2309     if (end < 0 || end > input_dims[i]) {
2310       return errors::InvalidArgument("\"begin\" + \"size\" for dimension ",
2311                                      std::to_string(i), " in ", node_def.op(),
2312                                      " is out of range, at ", node_def.name());
2313     }
2314     if (size[i] <= 0) {
2315       return errors::InvalidArgument("\"size\" cannot be negative or zero for ",
2316                                      node_def.op(), ", at ", node_def.name());
2317     }
2318   }
2319 // TRT 5.1 adds a slice layer. For older versions, we attempt to use the
2320 // padding layer with negative padding.
2321 #if IS_TRT_VERSION_GE(5, 1, 0) && 0
2322   // TODO(laigd): TRT 5.1 RC has a bug when ISliceLayer is used along with
2323   // IConcatenationLayer, so disable ISliceLayer for now until it's fixed.
2324   // Use ISliceLayer.
2325   nvinfer1::Dims begin_dims, size_dims, stride_dims;
2326   TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(begin, &begin_dims,
2327                                                /*ignore_first_dim=*/true));
2328   TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(size, &size_dims,
2329                                                /*ignore_first_dim=*/true));
2330   TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(stride, &stride_dims,
2331                                                /*ignore_first_dim=*/true));
2332   if (params->validation_only) return Status::OK();
2333 
2334   nvinfer1::ISliceLayer* layer = params->converter->network()->addSlice(
2335       *const_cast<nvinfer1::ITensor*>(input.tensor()), begin_dims, size_dims,
2336       stride_dims);
2337   params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
2338   return Status::OK();
2339 #else
2340   // Use IPaddingLayer.
2341   // Strides must be 1 in this case.
2342   for (int x : stride) {
2343     if (x != 1) {
2344       return errors::Unimplemented(
2345           "Strides other than 1 are not supported with this version of TRT, "
2346           "at ",
2347           node_def.name());
2348     }
2349   }
2350   // Rank must be 2, 3 or 4.
2351   if (input_dims.size() > 4) {
2352     return errors::Unimplemented(node_def.op(),
2353                                  " for tensors with rank > 4 is "
2354                                  "not supported in this version of "
2355                                  "TRT, at ",
2356                                  node_def.name());
2357   }
2358   // Reshape if necessary to 4-D, since IPaddingLayer requires a 4-D input.
2359   const bool need_reshape = (input_dims.size() != 4);
2360   int reshape_dims_added = 0;
2361   nvinfer1::Dims reshape_dims;
2362   if (need_reshape) {
2363     // Add new dims after batch dim until tensor is 4D.
2364     while (input_dims.size() < 4) {
2365       input_dims.insert(input_dims.begin() + 1, 1);
2366       begin.insert(begin.begin() + 1, 0);
2367       size.insert(size.begin() + 1, 1);
2368       reshape_dims_added++;
2369     }
2370     TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &reshape_dims,
2371                                                  /*ignore_first_dim=*/true));
2372   }
2373   // Find dimensions which need to be sliced.
2374   std::vector<int> pad_dims;
2375   for (int i = 1; i < input_dims.size(); i++) {
2376     if ((begin[i] != 0) || (begin[i] + size[i] != input_dims[i])) {
2377       pad_dims.push_back(i);
2378     }
2379   }
2380   if (pad_dims.empty()) {
2381     // No dimensions are changed, so this is a no-op. We could just return the
2382     // input without creating a new layer. TRT will crash if an empty engine
2383     // with no layers is attempted to be created, so we add a no-op shuffle to
2384     // prevent our unit tests from breaking.
2385     // TODO(tmorris): Allow empty engines in the unit tests and return the input
2386     // as output here.
2387     if (params->validation_only) return Status::OK();
2388     nvinfer1::IShuffleLayer* layer = params->converter->network()->addShuffle(
2389         *const_cast<nvinfer1::ITensor*>(input.tensor()));
2390     params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
2391     return Status::OK();
2392   } else if (pad_dims.size() == 1) {
2393     // Only one dim is modified but we have to have 2, mark a second dim which
2394     // will have padding of 0. The dim we add is chosen to avoid an unecessary
2395     // transpose.
2396     if (pad_dims[0] != 2) {
2397       pad_dims.push_back(2);
2398     } else {
2399       pad_dims.push_back(3);
2400     }
2401   } else if (pad_dims.size() > 2) {
2402     return errors::Unimplemented(
2403         node_def.op(),
2404         " can only modify up to 2 dimensions in this version of TRT, at ",
2405         node_def.name());
2406   }
2407   std::sort(pad_dims.begin(), pad_dims.end());
2408   // Convert to pre/post padding values. Since TRT does not have a StridedSlice
2409   // or Slice layer prior to 5.1, we instead create an IPaddingLayer with
2410   // negative padding.
2411   nvinfer1::DimsHW pre_padding, post_padding;
2412   for (int i = 0; i < pad_dims.size(); i++) {
2413     const int axis = pad_dims[i];
2414     pre_padding.d[i] = -begin[axis];
2415     post_padding.d[i] = (begin[axis] + size[axis]) - input_dims[axis];
2416   }
2417 
2418   // IPaddingLayer will always apply the padding to dims 2,3 (input format is
2419   // NCHW).
2420   const bool need_transpose = !(pad_dims[0] == 2 && pad_dims[1] == 3);
2421   std::vector<int> transpose_order(input_dims.size());
2422   std::vector<int> inv_transpose_order(input_dims.size());
2423   if (need_transpose) {
2424     if (pad_dims[0] == 1 && pad_dims[1] == 3) {
2425       transpose_order = {0, 2, 1, 3};
2426       inv_transpose_order = {0, 2, 1, 3};
2427     } else if (pad_dims[0] == 1 && pad_dims[1] == 2) {
2428       transpose_order = {0, 3, 1, 2};
2429       inv_transpose_order = {0, 2, 3, 1};
2430     }
2431   }
2432   if (params->validation_only) return Status::OK();
2433 
2434   // Start conversion.
2435   nvinfer1::ITensor* tensor = const_cast<nvinfer1::ITensor*>(input.tensor());
2436   if (need_reshape) {
2437     const nvinfer1::ITensor* output_tensor = nullptr;
2438     TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
2439         input, reshape_dims, /*validation_only=*/false, &output_tensor));
2440     tensor = const_cast<nvinfer1::ITensor*>(output_tensor);
2441   }
2442   if (need_transpose) {
2443     const nvinfer1::ITensor* output_tensor = nullptr;
2444     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
2445         tensor, transpose_order, &output_tensor));
2446     tensor = const_cast<nvinfer1::ITensor*>(output_tensor);
2447   }
2448   // Add padding layer
2449   nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding(
2450       *const_cast<nvinfer1::ITensor*>(tensor), pre_padding, post_padding);
2451   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
2452   params->converter->MarkQuantizationRangesAsInferrable(tensor,
2453                                                         layer->getOutput(0));
2454   tensor = layer->getOutput(0);
2455   // Restore transpose
2456   if (need_transpose) {
2457     const nvinfer1::ITensor* output_tensor = nullptr;
2458     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
2459         tensor, inv_transpose_order, &output_tensor));
2460     tensor = const_cast<nvinfer1::ITensor*>(output_tensor);
2461   }
2462   // Restore reshape
2463   if (need_reshape) {
2464     // Calculate output dimensions
2465     for (int i = 0; i < pad_dims.size(); i++) {
2466       const int axis = pad_dims[i];
2467       input_dims[axis] = size[axis];
2468     }
2469     // Remove added 1 dimensions
2470     for (int i = 0; i < reshape_dims_added; i++) {
2471       int value = input_dims[1];
2472       if (value != 1) {
2473         return errors::Internal("StridedSlice error when reshaping, at ",
2474                                 node_def.name());
2475       }
2476       input_dims.erase(input_dims.begin() + 1);
2477     }
2478 
2479     nvinfer1::Dims new_dims;
2480     TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims,
2481                                                  /*ignore_first_dim=*/true));
2482     const nvinfer1::ITensor* output_tensor = nullptr;
2483     TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
2484         TRT_TensorOrWeights(tensor), new_dims, /*validation_only=*/false,
2485         &output_tensor));
2486     tensor = const_cast<nvinfer1::ITensor*>(output_tensor);
2487   }
2488 
2489   params->outputs->push_back(
2490       TRT_TensorOrWeights(const_cast<nvinfer1::ITensor*>(tensor)));
2491   return Status::OK();
2492 #endif
2493 }
2494 
ConvertSlice(OpConverterParams * params)2495 Status ConvertSlice(OpConverterParams* params) {
2496   const auto& inputs = params->inputs;
2497   const auto& node_def = params->node_def;
2498   TF_RETURN_IF_ERROR(CheckInputsWeights(
2499       *params, {{"input", false}, {"begin", true}, {"size", true}}));
2500   TF_RETURN_IF_ERROR(AllowDataTypes(
2501       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
2502   std::vector<int> begin = inputs.at(1).weights().ToVector<int>();
2503   std::vector<int> size = inputs.at(2).weights().ToVector<int>();
2504   // Get input dims.
2505   nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
2506   std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
2507   // Add batch dimension so that indexes line up properly.
2508   input_dims.insert(input_dims.begin(), inputs.at(0).batch_size());
2509   if (!AllLengthsEqual({input_dims, begin, size})) {
2510     return errors::InvalidArgument(
2511         "Length of begin and size arguments must equal rank of input for "
2512         "Slice, at ",
2513         node_def.name());
2514   }
2515   // Check that batch dimension is unmodified.
2516   const bool begin_is_modified = begin[0] != 0;
2517   // If size[0]s is not -1, we can only know if the batch dimension is
2518   // unmodified when the batch size is defined. When the batch size is
2519   // undefined, we don't convert to be safe.
2520   const bool batch_size_is_defined = input_dims[0] > 0;
2521   const bool size_is_modified =
2522       size[0] != -1 && (!batch_size_is_defined ||
2523                         (batch_size_is_defined && size[0] != input_dims[0]));
2524   if (begin_is_modified || size_is_modified) {
2525     return errors::Unimplemented(
2526         "TensorRT does not allow modifications to the batch dimension, at ",
2527         node_def.name());
2528   }
2529   // Size of -1 signifies to take all remaining elements.
2530   for (int i = 1; i < input_dims.size(); i++) {
2531     if (size[i] == -1) {
2532       size[i] = input_dims[i] - begin[i];
2533     }
2534   }
2535   // Stride is 1 for all dims.
2536   std::vector<int> stride(begin.size(), 1);
2537   return ConvertStridedSliceHelper(params, inputs.at(0), begin, size, stride);
2538 }
2539 
ConvertStridedSlice(OpConverterParams * params)2540 Status ConvertStridedSlice(OpConverterParams* params) {
2541   const auto& inputs = params->inputs;
2542   const auto& node_def = params->node_def;
2543   TF_RETURN_IF_ERROR(CheckInputsWeights(
2544       *params,
2545       {{"input", false}, {"begin", true}, {"end", true}, {"strides", true}}));
2546   TF_RETURN_IF_ERROR(AllowDataTypes(
2547       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
2548   // Get input dims.
2549   nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
2550   std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
2551   // Add batch dimension so that indexes line up properly.
2552   input_dims.insert(input_dims.begin(), inputs.at(0).batch_size());
2553   // Get begin and end bounds per axis.
2554   std::vector<int> begin = inputs.at(1).weights().ToVector<int>();
2555   std::vector<int> end = inputs.at(2).weights().ToVector<int>();
2556   std::vector<int> stride = inputs.at(3).weights().ToVector<int>();
2557   if (!AllLengthsEqual({input_dims, begin, end, stride})) {
2558     return errors::InvalidArgument(
2559         "Length of begin, end, and stride arguments must equal rank of input "
2560         "for StridedSlice, at ",
2561         node_def.name());
2562   }
2563   // Unsupported mask options.
2564   TFAttrs attrs(node_def);
2565   for (const string& attr :
2566        {"ellipsis_mask", "new_axis_mask", "shrink_axis_mask"}) {
2567     int attr_val = attrs.get<int64>(attr);
2568     if (attr_val != 0) {
2569       return errors::Unimplemented(
2570           attr, " is not supported for StridedSlice, at ", node_def.name());
2571     }
2572   }
2573   const int begin_mask = attrs.get<int64>("begin_mask");
2574   const int end_mask = attrs.get<int64>("end_mask");
2575   // Check that batch dimension is unmodified.
2576   const bool begin_is_modified = !(begin_mask & 1) && begin[0] != 0;
2577   const bool stride_is_modified = stride[0] != 1;
2578   // If the batch size is -1 and the end mask is not set, we can only know if
2579   // the batch dimension is unmodified when the batch size is defined. When the
2580   // batch size is undefined, we don't convert to be safe.
2581   const bool batch_size_is_defined = input_dims[0] > 0;
2582   const bool end_is_modified =
2583       !(end_mask & 1) && (!batch_size_is_defined ||
2584                           (batch_size_is_defined && end[0] != input_dims[0]));
2585   if (begin_is_modified || stride_is_modified || end_is_modified) {
2586     return errors::Unimplemented(
2587         "TensorRT does not allow modifications to the batch dimension, at ",
2588         node_def.name());
2589   }
2590   // Standarize begin and end bounds by applying masks, making negative values
2591   // positive, and correcting out of bounds ranges (StridedSlice does this
2592   // silently).
2593   for (int i = 1; i < input_dims.size(); i++) {
2594     // Begin
2595     if ((1 << i) & begin_mask) {
2596       begin[i] = 0;
2597     } else if (begin[i] < 0) {
2598       begin[i] += input_dims[i];
2599     }
2600     begin[i] = std::max(0, std::min(begin[i], input_dims[i]));
2601     // End
2602     if ((1 << i) & end_mask) {
2603       end[i] = input_dims[i];
2604     } else if (end[i] < 0) {
2605       end[i] += input_dims[i];
2606     }
2607     end[i] = std::max(0, std::min(end[i], input_dims[i]));
2608   }
2609   // Negative or zero strides currently not supported.
2610   for (int i = 0; i < input_dims.size(); i++) {
2611     if (stride[i] <= 0) {
2612       return errors::Unimplemented(
2613           "Negative or zero stride values are not supported for StridedSlice, "
2614           "at ",
2615           node_def.name());
2616     }
2617   }
2618   // TRT Slice layer uses (begin, size) instead of (begin, end)
2619   std::vector<int> size(input_dims.size());
2620   for (int i = 0; i < input_dims.size(); i++) {
2621     // Divide by stride (round up)
2622     size[i] = (end[i] - begin[i] + stride[i] - 1) / stride[i];
2623   }
2624   return ConvertStridedSliceHelper(params, inputs.at(0), begin, size, stride);
2625 }
2626 
ConvertConv2D(OpConverterParams * params)2627 Status ConvertConv2D(OpConverterParams* params) {
2628   return ConvertConv2DHelper(params, 1, /*is_conv2d_backprop_input=*/false);
2629 }
2630 
ConvertConv2DDepthwise(OpConverterParams * params)2631 Status ConvertConv2DDepthwise(OpConverterParams* params) {
2632   return ConvertConv2DHelper(params, 0, /*is_conv2d_backprop_input=*/false);
2633 }
2634 
ConvertConv2DBackpropInput(OpConverterParams * params)2635 Status ConvertConv2DBackpropInput(OpConverterParams* params) {
2636   return ConvertConv2DHelper(params, 1, /*is_conv2d_backprop_input=*/true);
2637 }
2638 
ConvertPool(OpConverterParams * params)2639 Status ConvertPool(OpConverterParams* params) {
2640   const auto& inputs = params->inputs;
2641   const auto& node_def = params->node_def;
2642   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
2643   TF_RETURN_IF_ERROR(
2644       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
2645   nvinfer1::PoolingType type;
2646   if (node_def.op() == "MaxPool") {
2647     type = nvinfer1::PoolingType::kMAX;
2648   } else if (node_def.op() == "AvgPool") {
2649     type = nvinfer1::PoolingType::kAVERAGE;
2650   } else {
2651     return errors::Unimplemented("Unsupported pooling type: ", node_def.op(),
2652                                  ", at ", node_def.name());
2653   }
2654   TFAttrs attrs(node_def);
2655   const string padding_type = attrs.get<string>("padding");
2656   if ((padding_type != "SAME") && (padding_type != "VALID")) {
2657     return errors::Unimplemented("Unsupported padding type: ", padding_type,
2658                                  ", at ", node_def.name());
2659   }
2660   if (params->validation_only) return Status::OK();
2661 
2662   const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
2663   int h_index = 2;
2664   int w_index = 3;
2665   const auto data_format = attrs.get<string>("data_format");
2666   if (data_format == "NHWC") {
2667     h_index = 1;
2668     w_index = 2;
2669     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
2670         const_cast<nvinfer1::ITensor*>(tensor), {0, 3, 1, 2}, &tensor));
2671   }
2672 
2673   const auto tf_stride = attrs.get<std::vector<int64>>("strides");
2674   const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
2675 
2676   const auto tf_kernel = attrs.get<std::vector<int64>>("ksize");
2677   const nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
2678 
2679   auto tensor_dim = tensor->getDimensions();
2680   std::vector<std::pair<int, int>> padding;
2681   if (padding_type == "SAME") {
2682     // This is NCHW tensor with no batch dimension.
2683     //  1 -> h
2684     //  2 -> w
2685     padding = CreateSamePadding(
2686         stride, ksize,
2687         {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
2688   } else if (padding_type == "VALID") {
2689     padding = {{0, 0}, {0, 0}};
2690   }
2691 
2692   if (padding[0].first != padding[0].second ||
2693       padding[1].first != padding[1].second) {
2694     VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
2695             << padding[1].first << padding[1].second;
2696     auto pad_layer = params->converter->network()->addPadding(
2697         *const_cast<nvinfer1::ITensor*>(tensor),
2698         nvinfer1::DimsHW(padding[0].first, padding[1].first),
2699         nvinfer1::DimsHW(padding[0].second, padding[1].second));
2700     TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name());
2701     params->converter->MarkQuantizationRangesAsInferrable(
2702         const_cast<nvinfer1::ITensor*>(tensor), pad_layer->getOutput(0));
2703     padding = {{0, 0}, {0, 0}};
2704     tensor = pad_layer->getOutput(0);
2705   }
2706 
2707   nvinfer1::IPoolingLayer* layer = params->converter->network()->addPooling(
2708       *const_cast<nvinfer1::ITensor*>(tensor), type, ksize);
2709   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
2710   // TODO(tmorris): Average pooling may not be entirely safe to infer
2711   // quantization range through (at least forwards - backwards should be fine).
2712   // Max pooling is okay.
2713   params->converter->MarkQuantizationRangesAsInferrable(
2714       const_cast<nvinfer1::ITensor*>(tensor), layer->getOutput(0));
2715 
2716   layer->setStride(stride);
2717   layer->setPadding({padding[0].first, padding[1].first});
2718   layer->setName(node_def.name().c_str());
2719   const nvinfer1::ITensor* output_tensor = layer->getOutput(0);
2720 
2721   if (data_format == "NHWC") {
2722     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
2723         const_cast<nvinfer1::ITensor*>(output_tensor), {0, 2, 3, 1},
2724         &output_tensor));
2725   }
2726   params->outputs->push_back(
2727       TRT_TensorOrWeights(const_cast<nvinfer1::ITensor*>(output_tensor)));
2728   return Status::OK();
2729 }
2730 
2731 // TODO(tmorris): Use ActivationType::kLEAKY_RELU in TRT 5.1+ once perf
2732 // improves.
ConvertLeakyRelu(OpConverterParams * params)2733 Status ConvertLeakyRelu(OpConverterParams* params) {
2734   const auto& inputs = params->inputs;
2735   const auto& node_def = params->node_def;
2736   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
2737   TF_RETURN_IF_ERROR(
2738       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
2739 
2740   TFAttrs attrs(node_def);
2741   const float alpha = attrs.get<float>("alpha");
2742   if (alpha < 0.0f || alpha > 1.0f) {
2743     return errors::Unimplemented(
2744         "Alpha value for LeakyRelu must be between 0 and 1, at ",
2745         node_def.name());
2746   }
2747   if (params->validation_only) return Status::OK();
2748 
2749   // Input Tensor
2750   const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
2751   // Create const for alpha.
2752   const nvinfer1::ITensor* const_alpha_tensor = nullptr;
2753   TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant(
2754       params, alpha, tensor->getDimensions(), &const_alpha_tensor));
2755   // alpha * x
2756   nvinfer1::IElementWiseLayer* mul_layer =
2757       params->converter->network()->addElementWise(
2758           *const_cast<nvinfer1::ITensor*>(tensor),
2759           *const_cast<nvinfer1::ITensor*>(const_alpha_tensor),
2760           nvinfer1::ElementWiseOperation::kPROD);
2761   TFTRT_RETURN_ERROR_IF_NULLPTR(mul_layer, node_def.name());
2762   // max(x, alpha * x)
2763   nvinfer1::IElementWiseLayer* max_layer =
2764       params->converter->network()->addElementWise(
2765           *const_cast<nvinfer1::ITensor*>(tensor),
2766           *const_cast<nvinfer1::ITensor*>(mul_layer->getOutput(0)),
2767           nvinfer1::ElementWiseOperation::kMAX);
2768   TFTRT_RETURN_ERROR_IF_NULLPTR(max_layer, node_def.name());
2769   nvinfer1::ITensor* output_tensor = max_layer->getOutput(0);
2770   params->converter->MarkQuantizationRangesAsInferrable(
2771       output_tensor, const_cast<nvinfer1::ITensor*>(mul_layer->getOutput(0)));
2772 
2773   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
2774   return Status::OK();
2775 }
2776 
ConvertActivation(OpConverterParams * params)2777 Status ConvertActivation(OpConverterParams* params) {
2778   const auto& inputs = params->inputs;
2779   const auto& node_def = params->node_def;
2780   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
2781   TF_RETURN_IF_ERROR(
2782       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
2783   static const std::unordered_map<string, nvinfer1::ActivationType> ops{
2784       {"Relu", nvinfer1::ActivationType::kRELU},
2785       {"Sigmoid", nvinfer1::ActivationType::kSIGMOID},
2786       {"Tanh", nvinfer1::ActivationType::kTANH},
2787   };
2788   auto op_pair = ops.find(node_def.op());
2789   if (op_pair == ops.end()) {
2790     return errors::Unimplemented("Activation op: ", node_def.op(),
2791                                  " not supported at: ", node_def.name());
2792   }
2793   if (params->validation_only) return Status::OK();
2794 
2795   // Start conversion.
2796   const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
2797   nvinfer1::IActivationLayer* layer =
2798       params->converter->network()->addActivation(
2799           *const_cast<nvinfer1::ITensor*>(tensor), op_pair->second);
2800   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
2801   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
2802   // Set quantization range for output of Sigmoid, Tanh.
2803   if (node_def.op() == "Sigmoid") {
2804     params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 1.0f);
2805   } else if (node_def.op() == "Tanh") {
2806     params->converter->ProvideQuantizationRange(output_tensor, -1.0f, 1.0f);
2807   }
2808   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
2809   return Status::OK();
2810 }
2811 
ConvertQuantize(OpConverterParams * params)2812 Status ConvertQuantize(OpConverterParams* params) {
2813   const auto& inputs = params->inputs;
2814   const auto& node_def = params->node_def;
2815   if (node_def.op() == "FakeQuantWithMinMaxArgs") {
2816     TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
2817   } else if (node_def.op() == "FakeQuantWithMinMaxVars") {
2818     TF_RETURN_IF_ERROR(CheckInputsWeights(
2819         *params, {{"input", false}, {"min", true}, {"max", true}}));
2820   } else if (node_def.op() == "QuantizeAndDequantizeV2") {
2821     TF_RETURN_IF_ERROR(CheckInputsWeights(
2822         *params, {{"input", false}, {"input_min", true}, {"input_max", true}}));
2823   } else if (node_def.op() == "QuantizeAndDequantizeV3") {
2824     TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false},
2825                                                     {"input_min", true},
2826                                                     {"input_max", true},
2827                                                     {"num_bits", true}}));
2828   }
2829   float min_range = 0.0f;
2830   float max_range = 0.0f;
2831   if (node_def.op() == "FakeQuantWithMinMaxArgs") {
2832     // Get ranges via node attributes.
2833     TFAttrs attrs(node_def);
2834     if (attrs.count("min") == 0 || attrs.count("max") == 0) {
2835       return errors::InvalidArgument("Min or max attribute not found for ",
2836                                      node_def.op(), " at ", node_def.name());
2837     }
2838     min_range = attrs.get<float>("min");
2839     max_range = attrs.get<float>("max");
2840   } else if (node_def.op() == "FakeQuantWithMinMaxVars" ||
2841              node_def.op() == "QuantizeAndDequantizeV2" ||
2842              node_def.op() == "QuantizeAndDequantizeV3") {
2843     // Get ranges via inputs.
2844     auto get_weights_value = [&inputs](int index) {
2845       auto raw_weights = static_cast<float*>(
2846           const_cast<void*>(inputs.at(index).weights().GetValues()));
2847       return raw_weights[0];
2848     };
2849     min_range = get_weights_value(1);
2850     max_range = get_weights_value(2);
2851   } else {
2852     return errors::InvalidArgument("Unknown quantization op ", node_def.op(),
2853                                    ", at ", node_def.name());
2854   }
2855   if (params->validation_only) return Status::OK();
2856 
2857   // Store ranges for tensor
2858   params->converter->ProvideQuantizationRange(
2859       const_cast<nvinfer1::ITensor*>(inputs.at(0).tensor()), min_range,
2860       max_range);
2861   // Sometimes, TRT may not quantize a tensor, either because it chooses to
2862   // execute a higher precision kernel or because of op fusion. In these cases,
2863   // accuracy will suffer if the model was trained to expect quantization at
2864   // that tensor. We should consider adding a clip(tensor, min_range, max_range)
2865   // operation here to ensure that any arbitrarily placed quantize node will
2866   // execute as expected. However, this will negatively affect performance. If
2867   // users train their models in a way which models inference as close as
2868   // possible (i.e. not quantizing in place where fusion will occur), then there
2869   // is no problem with the current implementation.
2870   params->outputs->push_back(inputs.at(0));
2871   return Status::OK();
2872 }
2873 
2874 // TODO(tmorris): Use ActivationType::kCLIP in TRT 5.1+ once perf improves.
ConvertRelu6(OpConverterParams * params)2875 Status ConvertRelu6(OpConverterParams* params) {
2876   const auto& inputs = params->inputs;
2877   const auto& node_def = params->node_def;
2878   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
2879   TF_RETURN_IF_ERROR(
2880       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
2881   if (params->validation_only) return Status::OK();
2882   // ***************************************************************************
2883   // TensorRT does not implement Relu6 natively. This function converts Relu6 op
2884   // to available TensorRT ops: Relu6(x) = min(Relu(x), 6)
2885   // ***************************************************************************
2886 
2887   // Input Tensor
2888   const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
2889 
2890   // Relu operation i.e. Relu(x) = max(0, x)
2891   nvinfer1::IActivationLayer* relu_layer =
2892       params->converter->network()->addActivation(
2893           *const_cast<nvinfer1::ITensor*>(tensor),
2894           nvinfer1::ActivationType::kRELU);
2895   TFTRT_RETURN_ERROR_IF_NULLPTR(relu_layer, node_def.name());
2896 
2897   // Large range of relu is problematic during quantization in INT8 precision
2898   // mode. Setting dynamic range of relu = [0.f, 6.0f] helps with quantization.
2899   // TRT only uses dynamic ranges in INT8 precision mode,
2900   // and this does not affect the FP32 path.
2901   params->converter->ProvideQuantizationRange(relu_layer->getOutput(0), 0.0f,
2902                                               6.0f);
2903 
2904   // Create a constant layer to store the floating point weight i.e. 6.0f
2905   const nvinfer1::ITensor* const6_tensor = nullptr;
2906   TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant(
2907       params, 6.0f, relu_layer->getOutput(0)->getDimensions(), &const6_tensor));
2908 
2909   // ElementWise Min Operation
2910   // Min op is a nop for INT8 execution path, as the input tensor
2911   // to this layer will only have values in range [0.f, 6.0f].
2912   nvinfer1::IElementWiseLayer* relu6_layer =
2913       params->converter->network()->addElementWise(
2914           *const_cast<nvinfer1::ITensor*>(relu_layer->getOutput(0)),
2915           *const_cast<nvinfer1::ITensor*>(const6_tensor),
2916           nvinfer1::ElementWiseOperation::kMIN);
2917   TFTRT_RETURN_ERROR_IF_NULLPTR(relu6_layer, node_def.name());
2918   nvinfer1::ITensor* output_tensor = relu6_layer->getOutput(0);
2919   params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 6.0f);
2920 
2921   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
2922   return Status::OK();
2923 }
2924 
ConvertBiasAdd(OpConverterParams * params)2925 Status ConvertBiasAdd(OpConverterParams* params) {
2926   const auto& inputs = params->inputs;
2927   const auto& node_def = params->node_def;
2928   TF_RETURN_IF_ERROR(
2929       CheckInputsWeights(*params, {{"value", false}, {"bias", true}}));
2930   TF_RETURN_IF_ERROR(
2931       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
2932   if (params->validation_only) return Status::OK();
2933 
2934   nvinfer1::ITensor* tensor =
2935       const_cast<nvinfer1::ITensor*>(inputs.at(0).tensor());
2936   const nvinfer1::Dims original_dims = tensor->getDimensions();
2937   TFAttrs attrs(node_def);
2938   const string data_format = attrs.get<string>("data_format");
2939   const int channel_index =
2940       (data_format == "NHWC" ? original_dims.nbDims - 1 : 0);
2941 
2942   nvinfer1::Permutation permutation;
2943   if (channel_index != 0) {
2944     // Permute the dimensions so that the channel dimension is the first
2945     // dimension.
2946     for (int i = 0; i < original_dims.nbDims; ++i) {
2947       permutation.order[i] = i;
2948     }
2949     permutation.order[0] = channel_index;
2950     permutation.order[channel_index] = 0;
2951     VLOG(1) << "ConvertBiasAdd permutation: "
2952             << DebugString(permutation, original_dims.nbDims);
2953   }
2954 
2955   // TensorRT addScale requires input to be of rank 3, we need to apply
2956   // transpose as well as reshape.
2957   // TODO(laigd): this doesn't match what the TRT doc says, fix the doc?
2958   if (channel_index != 0 || original_dims.nbDims != 3) {
2959     nvinfer1::IShuffleLayer* shuffle_layer =
2960         params->converter->network()->addShuffle(*tensor);
2961     TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name());
2962     params->converter->MarkQuantizationRangesAsInferrable(
2963         tensor, shuffle_layer->getOutput(0));
2964 
2965     // NOTE(laigd): for some reason we need to apply the reshape
2966     // unconditionally. The default shape has nbDims==-1 and it seems the
2967     // behavior is undefined in some cases.
2968     nvinfer1::Dims reshape_dims;
2969     reshape_dims.nbDims = 3;
2970     // 0 means copying from input; -1 means inferring from the rest.
2971     reshape_dims.d[0] = 0;
2972     reshape_dims.d[1] = original_dims.nbDims >= 2 ? 0 : 1;
2973     reshape_dims.d[2] = original_dims.nbDims >= 3 ? -1 : 1;
2974     shuffle_layer->setReshapeDimensions(reshape_dims);
2975 
2976     if (channel_index != 0) {
2977       shuffle_layer->setFirstTranspose(permutation);
2978     }
2979     tensor = shuffle_layer->getOutput(0);
2980   }
2981 
2982   TRT_ShapedWeights weights = inputs.at(1).weights();
2983   if (params->converter->precision_mode() == TrtPrecisionMode::FP16) {
2984     weights = ConvertFP32ToFP16(params->weight_store, weights);
2985   }
2986   nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL;
2987   if (weights.shape_.d[0] == 1) {
2988     mode = nvinfer1::ScaleMode::kUNIFORM;
2989   }
2990 
2991   TRT_ShapedWeights empty_weights(weights.type_);
2992   nvinfer1::IScaleLayer* layer = params->converter->network()->addScale(
2993       *tensor, mode, weights.GetTrtWeights(), empty_weights.GetTrtWeights(),
2994       empty_weights.GetTrtWeights());
2995   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
2996 
2997   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
2998 
2999   // Restore transpose & reshape.
3000   if (channel_index != 0 || original_dims.nbDims != 3) {
3001     nvinfer1::IShuffleLayer* shuffle_layer =
3002         params->converter->network()->addShuffle(*output_tensor);
3003     TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name());
3004     // NOTE: for same reason as mentioned above we need to apply the reshape
3005     // unconditionally.
3006     nvinfer1::Dims reshape_dims = original_dims;
3007     if (channel_index != 0) {
3008       // NOTE: according to NVIDIA dimension types are deprecated, so we don't
3009       // need to copy them back.
3010       reshape_dims.d[channel_index] = original_dims.d[0];
3011       reshape_dims.d[0] = original_dims.d[channel_index];
3012     }
3013     shuffle_layer->setReshapeDimensions(reshape_dims);
3014 
3015     if (channel_index != 0) {
3016       shuffle_layer->setSecondTranspose(permutation);
3017     }
3018     params->converter->MarkQuantizationRangesAsInferrable(
3019         output_tensor, shuffle_layer->getOutput(0));
3020     output_tensor = shuffle_layer->getOutput(0);
3021   }
3022 
3023   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3024   return Status::OK();
3025 }
3026 
GetTensorDimsWithProtoShape(const Tensor & tensor,nvinfer1::Dims * dims)3027 void GetTensorDimsWithProtoShape(const Tensor& tensor, nvinfer1::Dims* dims) {
3028   if (tensor.dims() > 0) {
3029     *dims = GetTrtDimsForTensor(tensor);
3030   } else {
3031     dims->nbDims = 1;
3032     // No dimension provided. Flatten it.
3033     dims->d[0] = tensor.NumElements();
3034     dims->type[0] = nvinfer1::DimensionType::kSPATIAL;
3035     for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; ++i) {
3036       dims->d[i] = 0;
3037     }
3038   }
3039 }
3040 
TfTensorToTrtWeights(const Tensor & tensor,TrtWeightStore * weight_store,TRT_ShapedWeights * weights)3041 Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store,
3042                             TRT_ShapedWeights* weights) {
3043   const DataType dtype = tensor.dtype();
3044 
3045   // We always convert the integer constants to INT32, since TRT INT8 is for
3046   // quantized inference.
3047   //
3048   // TODO(aaroey): FP16 will remain in half format and is not converted to
3049   // FP32, but the converter currently uses all float weights as FP32. Fix
3050   // this.
3051   const DataType converted_dtype =
3052       (dtype == DT_INT16 || dtype == DT_INT8 || dtype == DT_UINT8 ? DT_INT32
3053                                                                   : dtype);
3054 
3055   // Verify that the dtype is supported by TensorRT. Otherwise, return an error.
3056   nvinfer1::DataType trt_dtype;
3057   TF_RETURN_IF_ERROR(ConvertDType(converted_dtype, &trt_dtype));
3058 
3059   if (tensor.NumElements() == 0) {
3060     // Return empty weights having converted dtype.
3061     *weights = TRT_ShapedWeights(converted_dtype);
3062     return Status::OK();
3063   }
3064 
3065   nvinfer1::Dims weight_dims;
3066   GetTensorDimsWithProtoShape(tensor, &weight_dims);
3067   *weights = weight_store->GetTempWeights(converted_dtype, weight_dims);
3068 
3069   // Copy the tensor directly if the tensor does not require cast to the
3070   // supported type.
3071   if (converted_dtype == dtype) {
3072     char* dst = static_cast<char*>(const_cast<void*>(weights->GetValues()));
3073     memcpy(dst, tensor.tensor_data().data(), tensor.TotalBytes());
3074     return Status::OK();
3075   }
3076 
3077   // Copy tensor elements after casting them to the converted DataType.
3078   int32* dst = static_cast<int32*>(const_cast<void*>(weights->GetValues()));
3079   if (dtype == DT_INT16) {
3080     const int16* src = tensor.flat<int16>().data();
3081     std::copy(src, src + tensor.NumElements(), dst);
3082   } else if (dtype == DT_INT8) {
3083     const int8* src = tensor.flat<int8>().data();
3084     std::copy(src, src + tensor.NumElements(), dst);
3085   } else {
3086     // dtype can only be DT_UINT8 at this point.
3087     TFTRT_CHECK_EQ_TYPE(dtype, DT_UINT8);
3088     const uint8* src = tensor.flat<uint8>().data();
3089     std::copy(src, src + tensor.NumElements(), dst);
3090   }
3091   return Status::OK();
3092 }
3093 
3094 // Convert a Const NodeDef to TRT_ShapedWeights. This is a special converter, it
3095 // always ignores the params->validation_only parameter but adds the converted
3096 // weights to params->outputs. We did this since TrtNodeValidator needs the
3097 // weights as input to other nodes, and use it to determine whether those nodes
3098 // are supported by TRT.
ConvertConst(OpConverterParams * params)3099 Status ConvertConst(OpConverterParams* params) {
3100   const auto& inputs = params->inputs;
3101   const auto& node_def = params->node_def;
3102   if (!inputs.empty()) {
3103     return errors::InvalidArgument(
3104         "Constant node is expected to have empty input list: ",
3105         node_def.name());
3106   }
3107 
3108   // Create shaped weights as output
3109   const auto& tensor_proto = node_def.attr().at("value").tensor();
3110   Tensor tensor;
3111   if (!tensor.FromProto(tensor_proto)) {
3112     return errors::Internal("Cannot parse weight tensor proto: ",
3113                             node_def.name());
3114   }
3115 
3116   TFAttrs attrs(node_def);
3117   const DataType dtype = attrs.get<DataType>("dtype");
3118   if (dtype != tensor.dtype()) {
3119     return errors::InvalidArgument("DataType mismatch between attr (",
3120                                    DataTypeString(dtype), ") and tensor (",
3121                                    DataTypeString(tensor.dtype()), ")");
3122   }
3123 
3124   TRT_ShapedWeights weights;
3125   TF_RETURN_IF_ERROR(
3126       TfTensorToTrtWeights(tensor, params->weight_store, &weights));
3127 
3128   if (params->outputs != nullptr) {
3129     params->outputs->push_back(TRT_TensorOrWeights(weights));
3130   }
3131   return Status::OK();
3132 }
3133 
ConvertIdentity(OpConverterParams * params)3134 Status ConvertIdentity(OpConverterParams* params) {
3135   // TODO(tmorris): TRT's Identity layer does not get optimized away as of TRT
3136   // 5.0, however once we know that it does it would be nice to use that
3137   // instead.
3138   if (params->validation_only) return Status::OK();
3139   params->outputs->push_back(params->inputs.at(0));
3140   return Status::OK();
3141 }
3142 
ConvertBinary(OpConverterParams * params)3143 Status ConvertBinary(OpConverterParams* params) {
3144   const auto& inputs = params->inputs;
3145   const auto& node_def = params->node_def;
3146   // TODO(tmorris): Enable once false is updated to mean either tensor or weight
3147   // TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y",
3148   // false}}));
3149   if (inputs.size() != 2) {
3150     return errors::InvalidArgument(node_def.op(), " got ", inputs.size(),
3151                                    " inputs but expected 2, at ",
3152                                    node_def.name());
3153   }
3154   TF_RETURN_IF_ERROR(
3155       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3156 
3157   // Constant folding should have been done by TensorFlow
3158   if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) {
3159     return errors::Unimplemented(
3160         "Constant folding is falled back to TensorFlow, binary op received "
3161         "both input as constant at: ",
3162         node_def.name());
3163   }
3164 
3165   // TODO(tmorris): TRT plans to deprecate IScaleLayer and will replace it with
3166   // IElementwiseLayer. At that point, we can remove BinaryTensorOpWeight. For
3167   // now, the performance will be slightly better with IScaleLayer because it
3168   // can be fused in more situations. However, most of the benefits of
3169   // IScaleLayer are when the layer performs both a shift and a scale, which we
3170   // don't do except for convolutions.
3171   //
3172   // Try to convert into Scale layer first (for better performance).
3173   // Since scale layer supports restricted broadcast policy and op types, we
3174   // allow failure and try to handle it through Elementwise op
3175   // (BinaryTensorOpTensor).
3176   Status status = Status::OK();
3177   if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) {
3178     status = BinaryTensorOpWeight(params, inputs.at(0).tensor(),
3179                                   inputs.at(1).weights(), false);
3180   } else if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) {
3181     status = BinaryTensorOpWeight(params, inputs.at(1).tensor(),
3182                                   inputs.at(0).weights(), true);
3183   }
3184   // If both input are tensors, or one of them is weights but the conversion
3185   // above failed, try the conversion using BinaryTensorOpTensor.
3186   if ((inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) || !status.ok()) {
3187     if (!status.ok()) VLOG(2) << status;
3188     status = BinaryTensorOpTensor(params, inputs.at(0), inputs.at(1));
3189   }
3190   return status;
3191 }
3192 
ConvertRsqrt(OpConverterParams * params)3193 Status ConvertRsqrt(OpConverterParams* params) {
3194   const auto& inputs = params->inputs;
3195   const auto& node_def = params->node_def;
3196   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
3197   TF_RETURN_IF_ERROR(
3198       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3199   if (params->validation_only) return Status::OK();
3200 
3201   // TODO(tmorris): params->converter is null during validation. Allow
3202   // precision_mode and use_calibration to be accessed during validation and
3203   // include this check in validation.
3204   // We will need a quantization range for intermediate tensor if not using
3205   // calibration.
3206   //
3207   //   x -> [Sqrt] -> sqrt(x) -> [Recip] -> 1/sqrt(x)
3208   //                     ^
3209   //               need range here
3210   if (params->converter->precision_mode() == TrtPrecisionMode::INT8 &&
3211       !params->converter->use_calibration()) {
3212     return errors::Unimplemented(
3213         "Intermediate quantization range cannot be determined without"
3214         " calibration for Rsqrt, consider replacing with "
3215         "Sqrt -> FakeQuant -> Reciprocal ops, at ",
3216         node_def.name());
3217   }
3218   // Start conversion.
3219   const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
3220   // Sqrt
3221   nvinfer1::IUnaryLayer* sqrt_layer = params->converter->network()->addUnary(
3222       *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::UnaryOperation::kSQRT);
3223   TFTRT_RETURN_ERROR_IF_NULLPTR(sqrt_layer, node_def.name());
3224   // Recip
3225   nvinfer1::IUnaryLayer* recip_layer = params->converter->network()->addUnary(
3226       *sqrt_layer->getOutput(0), nvinfer1::UnaryOperation::kRECIP);
3227   TFTRT_RETURN_ERROR_IF_NULLPTR(recip_layer, node_def.name());
3228   params->outputs->push_back(TRT_TensorOrWeights(recip_layer->getOutput(0)));
3229   return Status::OK();
3230 }
3231 
3232 const std::unordered_map<string, nvinfer1::UnaryOperation>*
UnaryOperationMap()3233 UnaryOperationMap() {
3234   static auto* const m =
3235       new std::unordered_map<string, nvinfer1::UnaryOperation>({
3236         {"Neg", nvinfer1::UnaryOperation::kNEG},
3237             {"Exp", nvinfer1::UnaryOperation::kEXP},
3238             {"Log", nvinfer1::UnaryOperation::kLOG},
3239             {"Sqrt", nvinfer1::UnaryOperation::kSQRT},
3240             {"Abs", nvinfer1::UnaryOperation::kABS},
3241             {"Reciprocal", nvinfer1::UnaryOperation::kRECIP},
3242 #if IS_TRT_VERSION_GE(5, 1, 0)
3243             {"Sin", nvinfer1::UnaryOperation::kSIN},
3244             {"Cos", nvinfer1::UnaryOperation::kCOS},
3245             {"Tan", nvinfer1::UnaryOperation::kTAN},
3246             {"Sinh", nvinfer1::UnaryOperation::kSINH},
3247             {"Cosh", nvinfer1::UnaryOperation::kCOSH},
3248             {"Asin", nvinfer1::UnaryOperation::kASIN},
3249             {"Acos", nvinfer1::UnaryOperation::kACOS},
3250             {"Atan", nvinfer1::UnaryOperation::kATAN},
3251             {"Asinh", nvinfer1::UnaryOperation::kASINH},
3252             {"Acosh", nvinfer1::UnaryOperation::kACOSH},
3253             {"Atanh", nvinfer1::UnaryOperation::kATANH},
3254             {"Ceil", nvinfer1::UnaryOperation::kCEIL},
3255             {"Floor", nvinfer1::UnaryOperation::kFLOOR},
3256 #endif
3257       });
3258   return m;
3259 }
3260 
ConvertUnary(OpConverterParams * params)3261 Status ConvertUnary(OpConverterParams* params) {
3262   const auto& inputs = params->inputs;
3263   const auto& node_def = params->node_def;
3264   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
3265   TF_RETURN_IF_ERROR(
3266       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3267   auto op_pair = UnaryOperationMap()->find(node_def.op());
3268   if (op_pair == UnaryOperationMap()->end()) {
3269     return errors::Unimplemented("Unary op: ", node_def.op(),
3270                                  " not supported at: ", node_def.name());
3271   }
3272   if (params->validation_only) return Status::OK();
3273 
3274   // Start conversion.
3275   const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
3276   nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary(
3277       *const_cast<nvinfer1::ITensor*>(tensor), op_pair->second);
3278   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3279   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
3280 
3281   // Set quantization ranges.
3282   if (node_def.op() == "Sin" || node_def.op() == "Cos") {
3283     params->converter->ProvideQuantizationRange(output_tensor, -1.0f, 1.0f);
3284   } else if (node_def.op() == "Asin" || node_def.op() == "Atan") {
3285     params->converter->ProvideQuantizationRange(output_tensor, -M_PI_2, M_PI_2);
3286   } else if (node_def.op() == "Acos") {
3287     params->converter->ProvideQuantizationRange(output_tensor, 0.0f, M_PI);
3288   } else if (node_def.op() == "Neg" || node_def.op() == "Abs") {
3289     // Neg and Abs will have same range as input since TRT uses symmetric
3290     // quantization.
3291     // TODO(tmorris): Should we infer ranges for Ceil and Floor as well?
3292     params->converter->MarkQuantizationRangesAsInferrable(
3293         const_cast<nvinfer1::ITensor*>(tensor), output_tensor);
3294   }
3295   params->outputs->push_back(
3296       TRT_TensorOrWeights(const_cast<nvinfer1::ITensor*>(output_tensor)));
3297   return Status::OK();
3298 }
3299 
ConvertSquare(OpConverterParams * params)3300 Status ConvertSquare(OpConverterParams* params) {
3301   const auto& inputs = params->inputs;
3302   const auto& node_def = params->node_def;
3303   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
3304   TF_RETURN_IF_ERROR(
3305       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3306   if (params->validation_only) return Status::OK();
3307 
3308   // Constant 2 with same rank as input
3309   const nvinfer1::ITensor* const2_tensor = nullptr;
3310   TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant(
3311       params, 2.0f, inputs.at(0).GetTrtDims(), &const2_tensor));
3312 
3313   // ElementWise Pow Operation
3314   nvinfer1::IElementWiseLayer* layer =
3315       params->converter->network()->addElementWise(
3316           *const_cast<nvinfer1::ITensor*>(inputs.at(0).tensor()),
3317           *const_cast<nvinfer1::ITensor*>(const2_tensor),
3318           nvinfer1::ElementWiseOperation::kPOW);
3319   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3320   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
3321 
3322   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3323   return Status::OK();
3324 }
3325 
ConvertReduce(OpConverterParams * params)3326 Status ConvertReduce(OpConverterParams* params) {
3327   const auto& inputs = params->inputs;
3328   const auto& node_def = params->node_def;
3329   TF_RETURN_IF_ERROR(
3330       CheckInputsWeights(*params, {{"input", false}, {"axis", true}}));
3331   TF_RETURN_IF_ERROR(
3332       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3333 
3334   const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
3335   TRT_ShapedWeights index_list = inputs.at(1).weights();
3336 
3337   TFAttrs attrs(node_def);
3338   // Only expect to handle INT32 as attributes for now
3339   if (attrs.get<DataType>("Tidx") != DataType::DT_INT32) {
3340     return errors::Unimplemented("Tidx supports only DT_INT32");
3341   }
3342 
3343   int axes = 0;
3344   if (index_list.count() == 0) {
3345     return errors::InvalidArgument(
3346         "TRT cannot support reduce on all (batch) dimensions, at",
3347         node_def.name());
3348   } else {
3349     auto index_list_data =
3350         static_cast<int*>(const_cast<void*>(index_list.GetValues()));
3351     for (int i = 0; i < index_list.count(); i++) {
3352       int axis = index_list_data[i];
3353       if (axis < 0) axis += tensor->getDimensions().nbDims + 1;
3354       if (axis == 0) {
3355         return errors::InvalidArgument(
3356             "TRT cannot reduce at batch dimension, at", node_def.name());
3357       }
3358       axes |= (1 << (axis - 1));
3359     }
3360   }
3361 
3362   nvinfer1::ReduceOperation reduce_operation;
3363   if (node_def.op() == "Sum") {
3364     reduce_operation = nvinfer1::ReduceOperation::kSUM;
3365   } else if (node_def.op() == "Prod") {
3366     reduce_operation = nvinfer1::ReduceOperation::kPROD;
3367   } else if (node_def.op() == "Max") {
3368     reduce_operation = nvinfer1::ReduceOperation::kMAX;
3369   } else if (node_def.op() == "Min") {
3370     reduce_operation = nvinfer1::ReduceOperation::kMIN;
3371   } else if (node_def.op() == "Mean") {
3372     reduce_operation = nvinfer1::ReduceOperation::kAVG;
3373   } else {
3374     return errors::Unimplemented("Op not supported ", node_def.op(), ", at ",
3375                                  node_def.name());
3376   }
3377   if (params->validation_only) return Status::OK();
3378 
3379   const auto keep_dims = attrs.get<bool>("keep_dims");
3380   nvinfer1::ILayer* layer = params->converter->network()->addReduce(
3381       *const_cast<nvinfer1::ITensor*>(tensor), reduce_operation, axes,
3382       keep_dims);
3383   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3384 
3385   params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
3386   return Status::OK();
3387 }
3388 
ConvertPad(OpConverterParams * params)3389 Status ConvertPad(OpConverterParams* params) {
3390   const auto& inputs = params->inputs;
3391   const auto& node_def = params->node_def;
3392   TF_RETURN_IF_ERROR(
3393       CheckInputsWeights(*params, {{"tensor", false}, {"paddings", true}}));
3394   TF_RETURN_IF_ERROR(
3395       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3396 
3397   // Implement tensor binaryOp weight [channel wise] for now;
3398   const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
3399   const auto dims = tensor->getDimensions();
3400   // Restore implicit batch dimension
3401   const int nb_dims = dims.nbDims + 1;
3402 
3403   TRT_ShapedWeights pads = inputs.at(1).weights();
3404 
3405   TFAttrs attrs(node_def);
3406   // Padding type here is done through TF type
3407   //   so I can leverage their EnumToDataType for my cast
3408   auto padding_type = attrs.get<DataType>("Tpaddings");
3409   // TODO(jie): handle data type conversion for TRT?
3410 
3411   if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2) {
3412     return errors::InvalidArgument(
3413         "Pad only supports explicit padding on 4 dimensional tensor, at ",
3414         node_def.name());
3415   }
3416 
3417   // Only expect to handle INT32 as attributes for now
3418   if (padding_type != DataType::DT_INT32) {
3419     return errors::Unimplemented("Tpaddings supports only DT_INT32");
3420   }
3421   auto pad_data = static_cast<int*>(const_cast<void*>(pads.GetValues()));
3422 
3423   std::vector<int32_t> pad_index;
3424   for (int i = 0; i < nb_dims; i++) {
3425     if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0) {
3426       pad_index.push_back(i);
3427     }
3428   }
3429 
3430   // No padding at all, we should exit
3431   if (pad_index.empty()) {
3432     params->outputs->push_back(inputs.at(0));
3433     return Status::OK();
3434   }
3435 
3436   // Only supports padding on less than 2 axis GIE-2579
3437   if (pad_index.size() > 2) {
3438     return errors::InvalidArgument(
3439         "Padding layer does not support padding on > 2");
3440   }
3441 
3442   // Padding on batch dimension is not supported
3443   if (pad_index[0] == 0) {
3444     return errors::InvalidArgument(
3445         "Padding layer does not support padding on batch dimension");
3446   }
3447 
3448   // Not doing the legit thing here. ignoring padding on dim 1 and 3;
3449   // TODO(jie): implement pad as uff parser
3450   if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3) {
3451     return errors::Unimplemented(
3452         "Padding layer does not support padding on dimension 1 and 3 yet");
3453   }
3454   if (params->validation_only) return Status::OK();
3455 
3456   bool legit_pad = true;
3457   nvinfer1::DimsHW pre_padding(0, 0);
3458   nvinfer1::DimsHW post_padding(0, 0);
3459 
3460   std::vector<int32_t> permuted_pad_index(pad_index);
3461   if (pad_index[0] == 1) {
3462     legit_pad = false;
3463     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3464         const_cast<nvinfer1::ITensor*>(tensor), {0, 3, 2, 1}, &tensor));
3465     permuted_pad_index[0] = 3;
3466   }
3467 
3468   for (size_t i = 0; i < pad_index.size(); i++) {
3469     int index = pad_index[i];
3470     if (permuted_pad_index[i] == 2) {
3471       pre_padding.h() = pad_data[index * 2];
3472       post_padding.h() = pad_data[index * 2 + 1];
3473     } else if (permuted_pad_index[i] == 3) {
3474       pre_padding.w() = pad_data[index * 2];
3475       post_padding.w() = pad_data[index * 2 + 1];
3476     }
3477   }
3478 
3479   nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding(
3480       *const_cast<nvinfer1::ITensor*>(tensor), pre_padding, post_padding);
3481   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3482   const nvinfer1::ITensor* output_tensor = layer->getOutput(0);
3483 
3484   if (!legit_pad) {
3485     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3486         const_cast<nvinfer1::ITensor*>(output_tensor), {0, 3, 2, 1},
3487         &output_tensor));
3488   }
3489 
3490   params->outputs->push_back(
3491       TRT_TensorOrWeights(const_cast<nvinfer1::ITensor*>(output_tensor)));
3492   return Status::OK();
3493 }
3494 
ConvertConcat(OpConverterParams * params)3495 Status ConvertConcat(OpConverterParams* params) {
3496   const auto& inputs = params->inputs;
3497   const auto& node_def = params->node_def;
3498   // TODO(tmorris): There is a bug with Concat and INT32 in TRT - it is supposed
3499   // to be supported.
3500   TF_RETURN_IF_ERROR(
3501       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3502   // not including the last input (axis) here
3503   int input_size = static_cast<int>(inputs.size()) - 1;
3504 
3505   if (!inputs.at(0).is_tensor()) {
3506     return errors::InvalidArgument(
3507         "Concat in TRT support only Tensor input, at ", node_def.name());
3508   }
3509 
3510   // We are retrieving the axis
3511   TRT_ShapedWeights axis = inputs.at(input_size).weights();
3512 
3513   TFAttrs attrs(node_def);
3514   auto index_type = attrs.get<DataType>("Tidx");
3515 
3516   // TODO(jie): handle data type
3517   // Only expect to handle INT32 as index attributes for now
3518   if (index_type != DataType::DT_INT32)
3519     return errors::Unimplemented("Tidx supports only DT_INT32, at ",
3520                                  node_def.name());
3521 
3522   int index = *(static_cast<int*>(const_cast<void*>(axis.GetValues())));
3523 
3524   // TODO(jie): early termination with no-op (attr_size==1)
3525 
3526   auto dim = inputs.at(0).tensor()->getDimensions();
3527   // dimension check
3528   if (index > dim.nbDims + 1) {
3529     return errors::InvalidArgument(
3530         "Concatenate on axis out of dimension range, at ", node_def.name());
3531   }
3532   if (index == 0) {
3533     return errors::InvalidArgument(
3534         "Concatenate on batch dimension not supported, at ", node_def.name());
3535   }
3536   if (index < 0) {
3537     index = dim.nbDims + index + 1;
3538   }
3539 
3540   std::vector<nvinfer1::ITensor const*> inputs_vec;
3541   // Shap chack (all input tensor should have same shape)
3542   // starting from 0 since we are probably also doing transpose here;
3543   for (int i = 0; i < input_size; i++) {
3544     auto tensor_i = inputs.at(i).tensor();
3545     auto dim_i = tensor_i->getDimensions();
3546     if (dim_i.nbDims != dim.nbDims) {
3547       return errors::InvalidArgument(
3548           "Concatenate receives inputs with inconsistent dimensions, at ",
3549           node_def.name());
3550     }
3551     for (int j = 0; j < dim.nbDims; j++) {
3552       // check dimension consistency on non-concatenate axis
3553       if (j != index - 1 && dim_i.d[j] != dim.d[j]) {
3554         return errors::InvalidArgument(
3555             "Concatenate receives inputs with inconsistent shape, at",
3556             node_def.name());
3557       }
3558     }
3559 
3560     inputs_vec.push_back(tensor_i);
3561   }
3562   if (params->validation_only) return Status::OK();
3563 
3564   // nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
3565   nvinfer1::IConcatenationLayer* layer =
3566       params->converter->network()->addConcatenation(
3567           const_cast<nvinfer1::ITensor* const*>(inputs_vec.data()),
3568           inputs_vec.size());
3569   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3570   layer->setAxis(index - 1);
3571   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
3572   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3573   return Status::OK();
3574 }
3575 
ConvertFusedBatchNorm(OpConverterParams * params)3576 Status ConvertFusedBatchNorm(OpConverterParams* params) {
3577   const auto& inputs = params->inputs;
3578   const auto& node_def = params->node_def;
3579   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false},
3580                                                   {"scale", true},
3581                                                   {"offset", true},
3582                                                   {"mean", true},
3583                                                   {"variance", true}}));
3584   TF_RETURN_IF_ERROR(
3585       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3586   TFAttrs attrs(node_def);
3587   float epsilon = attrs.get<float>("epsilon");
3588   auto data_format = attrs.get<string>("data_format");
3589   if (data_format != "NCHW") {
3590     return errors::Unimplemented(
3591         node_def.op(), " only supports data_format=NCHW, at ", node_def.name());
3592   }
3593   bool is_training = attrs.get<bool>("is_training");
3594   if (is_training) {
3595     // Trying to use batchnorm in training mode is a very common problem.
3596     // Because the error message will only be printed in VLOG(1) by the
3597     // segmenter, we issue a special warning so that users will actually see it.
3598     LOG(WARNING) << node_def.op() << " only supports is_training=false. If you "
3599                  << "are using Keras, please call "
3600                  << "keras.backend.set_learning_phase(0) before constructing "
3601                  << "your model. At " << node_def.name();
3602     return errors::Unimplemented(node_def.op(),
3603                                  " only supports is_training=false, at ",
3604                                  node_def.name());
3605   }
3606   nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
3607 
3608   //  Check parameter types
3609   auto parameter_type = inputs.at(1).weights().type_;
3610   if ((parameter_type != DataType::DT_FLOAT) &&
3611       (parameter_type != DataType::DT_HALF)) {
3612     return errors::Unimplemented(
3613         "only float32 or float16 weight data type is supported, for node " +
3614         node_def.name() + " got " + DataTypeString(parameter_type));
3615   }
3616   for (int i = 1; i < 5; i++) {
3617     if (inputs.at(i).weights().type_ != parameter_type) {
3618       return errors::Unimplemented(
3619           "Inconsistent parameter type for batchnorm is not supported, at: " +
3620           node_def.name());
3621     }
3622   }
3623 
3624   TRT_ShapedWeights dummy_power_weights(parameter_type);
3625   size_t nweight = 0;
3626   for (int i = 1; i < 5; i++) {
3627     nweight = std::max<size_t>(nweight, inputs.at(i).weights().count());
3628   }
3629   TRT_ShapedWeights* ptr_shape_weights = nullptr;
3630   for (int i = 1; i < 5; i++) {
3631     if (inputs.at(i).weights().count() == nweight) {
3632       ptr_shape_weights =
3633           const_cast<TRT_ShapedWeights*>(&(inputs.at(i).weights()));
3634     } else if (inputs.at(i).weights().count() != 1) {
3635       return errors::InvalidArgument(
3636           "Inconsistent batchnorm parameter count, at: " + node_def.name());
3637     }
3638   }
3639   if (params->validation_only) return Status::OK();
3640 
3641   //  We could technically have two weights with different shape.
3642   //  that requires two addScale op, arguably less performant
3643   TRT_ShapedWeights combined_scale_weights =
3644       params->weight_store->GetTempWeights(*ptr_shape_weights);
3645   TRT_ShapedWeights combined_offset_weights =
3646       params->weight_store->GetTempWeights(*ptr_shape_weights);
3647 
3648   const Eigen::half* cast_vals_array[4];
3649   const float* vals_array[4];
3650   for (int j = 0; j < 4; j++) {
3651     cast_vals_array[j] =
3652         static_cast<Eigen::half const*>(inputs.at(j + 1).weights().GetValues());
3653     vals_array[j] =
3654         static_cast<float const*>(inputs.at(j + 1).weights().GetValues());
3655   }
3656   Eigen::half* cast_combined_scale_vals = const_cast<Eigen::half*>(
3657       static_cast<Eigen::half const*>(combined_scale_weights.GetValues()));
3658   Eigen::half* cast_combined_offset_vals = const_cast<Eigen::half*>(
3659       static_cast<Eigen::half const*>(combined_offset_weights.GetValues()));
3660   float* combined_scale_vals = const_cast<float*>(
3661       static_cast<float const*>(combined_scale_weights.GetValues()));
3662   float* combined_offset_vals = const_cast<float*>(
3663       static_cast<float const*>(combined_offset_weights.GetValues()));
3664 
3665   for (size_t i = 0; i < nweight; ++i) {
3666     float batchnorm_data[4];
3667     for (int j = 0; j < 4; j++) {
3668       if (inputs.at(j + 1).weights().count() != 1) {
3669         if (parameter_type == DT_FLOAT) {
3670           batchnorm_data[j] = vals_array[j][i];
3671         } else if (parameter_type == DT_HALF) {
3672           batchnorm_data[j] =
3673               Eigen::half_impl::half_to_float(cast_vals_array[j][i]);
3674         }
3675       } else {
3676         if (parameter_type == DT_FLOAT) {
3677           batchnorm_data[j] = vals_array[j][0];
3678         } else if (parameter_type == DT_HALF) {
3679           batchnorm_data[j] =
3680               Eigen::half_impl::half_to_float(cast_vals_array[j][0]);
3681         }
3682       }
3683     }
3684     float scale = batchnorm_data[0];
3685     float offset = batchnorm_data[1];
3686     float mean = batchnorm_data[2];
3687     float variance = batchnorm_data[3];
3688     float combined_scale_val = scale / sqrtf(variance + epsilon);
3689     float combined_offset_val = offset - mean * combined_scale_val;
3690     if (parameter_type == DT_FLOAT) {
3691       combined_scale_vals[i] = combined_scale_val;
3692       combined_offset_vals[i] = combined_offset_val;
3693     } else if (parameter_type == DT_HALF) {
3694       cast_combined_scale_vals[i] = Eigen::half(combined_scale_val);
3695       cast_combined_offset_vals[i] = Eigen::half(combined_offset_val);
3696     }
3697   }
3698 
3699   nvinfer1::ScaleMode mode = nweight == 1 ? nvinfer1::ScaleMode::kUNIFORM
3700                                           : nvinfer1::ScaleMode::kCHANNEL;
3701   nvinfer1::IScaleLayer* layer = params->converter->network()->addScale(
3702       *const_cast<nvinfer1::ITensor*>(tensor), mode,
3703       combined_offset_weights.GetTrtWeights(),
3704       combined_scale_weights.GetTrtWeights(),
3705       dummy_power_weights.GetTrtWeights());
3706   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3707   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
3708   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3709   return Status::OK();
3710 }
3711 
ConvertGather(OpConverterParams * params)3712 Status ConvertGather(OpConverterParams* params) {
3713   const auto& inputs = params->inputs;
3714   const auto& node_def = params->node_def;
3715   TF_RETURN_IF_ERROR(CheckInputsWeights(
3716       *params, {{"params", false}, {"indices", false}, {"axis", true}}));
3717   TF_RETURN_IF_ERROR(AllowDataTypes(
3718       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32},
3719       /*dtype_attr_name=*/"Tparams"));
3720   absl::Span<const int> axis = inputs.at(2).weights().GetSpan<int>();
3721   if (axis.size() != 1) {
3722     return errors::InvalidArgument("Axis for GatherV2 must be a scalar, at ",
3723                                    node_def.name());
3724   }
3725   int trt_axis = 0;
3726   TF_RETURN_IF_ERROR(ConvertAxis(axis[0], inputs.at(0).GetTrtDims().nbDims,
3727                                  node_def.name(), &trt_axis));
3728   TRT_TensorOrWeights params_tensor = inputs.at(0);
3729   TRT_TensorOrWeights indices_tensor = inputs.at(1);
3730   if (indices_tensor.batch_size() != 1) {
3731     return errors::InvalidArgument("Only indices with batch 1 are supported.");
3732   }
3733   // Both input are tensors, and the TF gather result will have rank:
3734   // (params.nbDims + 1) + (indices.nbDims + 1) - 1,
3735   // where "+ 1" adds the batch dim.
3736   const int tf_gather_output_rank = params_tensor.GetTrtDims().nbDims +
3737                                     indices_tensor.GetTrtDims().nbDims + 1;
3738   if (tf_gather_output_rank > nvinfer1::Dims::MAX_DIMS + 1) {
3739     return errors::InvalidArgument(
3740         "Result of gather has dimension greater than ",
3741         nvinfer1::Dims::MAX_DIMS + 1);
3742   }
3743   if (params->validation_only) return Status::OK();
3744 
3745   // Note on how IGatherLayer works: if both the data and indices tensors have
3746   // a batch size dimension of size N, it performs:
3747   // for batchid in xrange(N):
3748   //   output[batchid, a0, ..., an, i, ..., j, b0, ..., bn] = (
3749   //       data[batchid, a0, ..., an, indices[batchid, i, ..., j] b0, ..., bn])
3750   nvinfer1::IGatherLayer* layer = params->converter->network()->addGather(
3751       *const_cast<nvinfer1::ITensor*>(params_tensor.tensor()),
3752       *const_cast<nvinfer1::ITensor*>(indices_tensor.tensor()), trt_axis);
3753   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3754 
3755   nvinfer1::ITensor* gather_output = layer->getOutput(0);
3756   nvinfer1::Dims trt_gather_output_dims = gather_output->getDimensions();
3757   // Note for the "- 2": one is for the output batch dim encapsulated by TF-TRT,
3758   // and the other is for the output dimension that is squeezed by IGatherLayer
3759   // because of the implicit batch dim in the indices (see the above note).
3760   if (trt_gather_output_dims.nbDims != tf_gather_output_rank - 2) {
3761     return errors::Internal(
3762         "Get unexpected output dimensions of IGatherLayer. Expect nbDims: ",
3763         tf_gather_output_rank - 2,
3764         ", actual nbDims: ", trt_gather_output_dims.nbDims);
3765   }
3766   // Reshape the output so after adding the implicit batch dim it'll match the
3767   // output shape of TF GatherV2.
3768   for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) {
3769     trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1];
3770   }
3771   trt_gather_output_dims.d[trt_axis] = 1;
3772   ++trt_gather_output_dims.nbDims;
3773 
3774   const nvinfer1::ITensor* output_tensor = nullptr;
3775   TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
3776       TRT_TensorOrWeights(gather_output), trt_gather_output_dims,
3777       /*validation_only=*/false, &output_tensor));
3778 
3779   params->outputs->push_back(
3780       TRT_TensorOrWeights(const_cast<nvinfer1::ITensor*>(output_tensor)));
3781   return Status::OK();
3782 }
3783 
ConvertMatMulHelper(OpConverterParams * params,TRT_TensorOrWeights tensor_input,TRT_ShapedWeights weights_raw,bool transpose_weight,string node_name)3784 Status ConvertMatMulHelper(OpConverterParams* params,
3785                            TRT_TensorOrWeights tensor_input,
3786                            TRT_ShapedWeights weights_raw, bool transpose_weight,
3787                            string node_name) {
3788   nvinfer1::ITensor* output_tensor;
3789   if (!tensor_input.is_tensor()) {
3790     return errors::InvalidArgument("Input 0 expects tensor");
3791   }
3792   const nvinfer1::ITensor* tensor = tensor_input.tensor();
3793 
3794   TRT_ShapedWeights weights(weights_raw.type_);
3795   if (transpose_weight) {
3796     weights = weights_raw;
3797   } else {
3798     weights = params->weight_store->GetTempWeights(weights_raw);
3799     ReorderCKtoKC(weights_raw, &weights);
3800   }
3801   TRT_ShapedWeights biases(weights.type_);
3802 
3803   int noutput = weights.shape_.d[0];
3804 
3805   auto input_dim = tensor->getDimensions();
3806   while (input_dim.nbDims != 3) {
3807     input_dim.d[input_dim.nbDims++] = 1;
3808   }
3809   TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
3810       tensor_input, input_dim, /*validation_only=*/false, &tensor));
3811 
3812   nvinfer1::IFullyConnectedLayer* layer =
3813       params->converter->network()->addFullyConnected(
3814           *const_cast<nvinfer1::ITensor*>(tensor), noutput,
3815           weights.GetTrtWeights(), biases.GetTrtWeights());
3816   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_name);
3817   output_tensor = layer->getOutput(0);
3818 
3819   const nvinfer1::ITensor* temp_tensor = nullptr;
3820   auto output_dim = output_tensor->getDimensions();
3821   output_dim.nbDims = 1;
3822   TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
3823       TRT_TensorOrWeights(output_tensor), output_dim, /*validation_only=*/false,
3824       &temp_tensor));
3825   output_tensor = const_cast<nvinfer1::ITensor*>(temp_tensor);
3826   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3827   return Status::OK();
3828 }
3829 
3830 // inputs are both two dimensional (ops::MatMul)
ConvertMatMul(OpConverterParams * params)3831 Status ConvertMatMul(OpConverterParams* params) {
3832   const auto& inputs = params->inputs;
3833   const auto& node_def = params->node_def;
3834   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"a", false}, {"b", true}}));
3835   TF_RETURN_IF_ERROR(
3836       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3837 
3838   TFAttrs attrs(node_def);
3839   bool transpose_a = attrs.get<bool>("transpose_a");
3840   bool transpose_b = attrs.get<bool>("transpose_b");
3841 
3842   // FullyConnected:
3843   if (transpose_a) {
3844     return errors::InvalidArgument(
3845         "transpose_a is not supported for TensorRT FullyConnected (op: ",
3846         node_def.op(), "), at: ", node_def.name());
3847   }
3848   if (params->validation_only) return Status::OK();
3849   return ConvertMatMulHelper(params, inputs.at(0), inputs.at(1).weights(),
3850                              transpose_b, node_def.name());
3851 }
3852 
ConvertBatchMatMul(OpConverterParams * params)3853 Status ConvertBatchMatMul(OpConverterParams* params) {
3854   const auto& inputs = params->inputs;
3855   const auto& node_def = params->node_def;
3856   // TODO(tmorris): Enable once false is updated to mean either tensor or weight
3857   // TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y",
3858   // false}}));
3859   TF_RETURN_IF_ERROR(
3860       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3861   if (inputs.size() != 2) {
3862     return errors::InvalidArgument(node_def.op(), " got ", inputs.size(),
3863                                    " inputs but expected 2, at ",
3864                                    node_def.name());
3865   }
3866   if (inputs[0].is_weights() && inputs[1].is_weights()) {
3867     return errors::InvalidArgument(
3868         "All inputs are weights, but Grappler is expected to fold them.");
3869   }
3870   TFAttrs attrs(node_def);
3871   const bool transpose_a = attrs.get<bool>("adj_x");
3872   const bool transpose_b = attrs.get<bool>("adj_y");
3873   const auto dims = inputs.at(0).GetTrtDims();
3874   if (dims.nbDims == 1) {  // NC * CK is only supported through fully connected
3875     if (transpose_a == false && inputs.at(0).is_tensor() &&
3876         inputs.at(1).is_weights()) {
3877       return ConvertMatMulHelper(params, inputs.at(0), inputs.at(1).weights(),
3878                                  transpose_b, node_def.name());
3879     } else {
3880       return errors::InvalidArgument("Invalid configuration for MatMul, at: ",
3881                                      node_def.name());
3882     }
3883   }
3884 
3885   auto get_tensor_with_proper_dims = [params](
3886                                          const TRT_TensorOrWeights& input,
3887                                          const nvinfer1::ITensor** tensor) {
3888     auto dims = input.GetTrtDims();
3889     if (input.is_weights()) {
3890       // The other operand must be a tensor, this is ensured by earlier checks.
3891       // Checks that the batch dimension is not changed by broadcasting.
3892       if (dims.d[0] != 1) {
3893         return errors::InvalidArgument(
3894             "Input weight attempts to broadcast across batch dimension for "
3895             "BatchMatMul, at ",
3896             params->node_def.name());
3897       }
3898       // Remove the batch dimension from the weights.
3899       TF_RETURN_IF_ERROR(RemoveBatchDimension(&dims));
3900     }
3901     // Create tensor and reshape if necessary.
3902     TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
3903         input, dims, params->validation_only, tensor));
3904     return Status::OK();
3905   };
3906   const nvinfer1::ITensor* tensor_l;
3907   const nvinfer1::ITensor* tensor_r;
3908   TF_RETURN_IF_ERROR(get_tensor_with_proper_dims(inputs.at(0), &tensor_l));
3909   TF_RETURN_IF_ERROR(get_tensor_with_proper_dims(inputs.at(1), &tensor_r));
3910   if (params->validation_only) return Status::OK();
3911 
3912   nvinfer1::IMatrixMultiplyLayer* layer =
3913       params->converter->network()->addMatrixMultiply(
3914           *const_cast<nvinfer1::ITensor*>(tensor_l), transpose_a,
3915           *const_cast<nvinfer1::ITensor*>(tensor_r), transpose_b);
3916   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3917   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
3918   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3919   return Status::OK();
3920 }
3921 
ConvertSoftmax(OpConverterParams * params)3922 Status ConvertSoftmax(OpConverterParams* params) {
3923   const auto& inputs = params->inputs;
3924   const auto& node_def = params->node_def;
3925   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"logits", false}}));
3926   TF_RETURN_IF_ERROR(
3927       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3928   const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
3929 
3930   int nbDims = tensor->getDimensions().nbDims;
3931   if (nbDims == 0) {
3932     return errors::InvalidArgument(
3933         "TensorRT Softmax cannot apply on batch dimension, at",
3934         node_def.name());
3935   }
3936   if (params->validation_only) return Status::OK();
3937 
3938   nvinfer1::ISoftMaxLayer* layer = params->converter->network()->addSoftMax(
3939       *const_cast<nvinfer1::ITensor*>(tensor));
3940   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3941   // Tensorflow SoftMax assumes applying softmax on the last dimension.
3942   layer->setAxes(1 << (nbDims - 1));
3943 
3944   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
3945   // Quantization range for SoftMax is always (0, 1)
3946   params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 1.0f);
3947   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3948   return Status::OK();
3949 }
3950 
ConvertTopK(OpConverterParams * params)3951 Status ConvertTopK(OpConverterParams* params) {
3952   const auto& inputs = params->inputs;
3953   const auto& node_def = params->node_def;
3954   TF_RETURN_IF_ERROR(
3955       CheckInputsWeights(*params, {{"input", false}, {"k", true}}));
3956   TF_RETURN_IF_ERROR(AllowDataTypes(
3957       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
3958   const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
3959   const int num_dims = tensor->getDimensions().nbDims;
3960   if (num_dims == 0) {
3961     return errors::InvalidArgument(
3962         "TensorRT TopK cannot apply on batch dimension, at", node_def.name());
3963   }
3964 
3965   TRT_ShapedWeights k_w = inputs.at(1).weights();
3966   if (k_w.count() != 1) {
3967     return errors::InvalidArgument("k value of TopK should be a scalar, at",
3968                                    node_def.name());
3969   }
3970   // Note that ITopKLayer always have sorted outputs, so we don't need to handle
3971   // the 'sorted' attribute of the node.
3972   if (params->validation_only) return Status::OK();
3973 
3974   const nvinfer1::TopKOperation op = nvinfer1::TopKOperation::kMAX;
3975   const int k = *(static_cast<int*>(const_cast<void*>(k_w.GetValues())));
3976   const uint32_t reduce_axes = 1 << (num_dims - 1);
3977   nvinfer1::ITopKLayer* layer = params->converter->network()->addTopK(
3978       *const_cast<nvinfer1::ITensor*>(tensor), op, k, reduce_axes);
3979   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3980 
3981   nvinfer1::ITensor* output_value_tensor = layer->getOutput(0);
3982   nvinfer1::ITensor* output_indices_tensor = layer->getOutput(1);
3983   params->outputs->push_back(TRT_TensorOrWeights(output_value_tensor));
3984   params->outputs->push_back(TRT_TensorOrWeights(output_indices_tensor));
3985   return Status::OK();
3986 }
3987 
RegisterValidatableOpConverters(std::unordered_map<string,OpConverter> * registration)3988 static void RegisterValidatableOpConverters(
3989     std::unordered_map<string, OpConverter>* registration) {
3990   (*registration)["BiasAdd"] = ConvertBiasAdd;
3991   (*registration)["ConcatV2"] = ConvertConcat;
3992   (*registration)["Const"] = ConvertConst;
3993   (*registration)["Conv2D"] = ConvertConv2D;
3994   (*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput;
3995   (*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
3996   (*registration)["ExpandDims"] = ConvertExpandDims;
3997   (*registration)["GatherV2"] = ConvertGather;
3998   (*registration)["LeakyRelu"] = ConvertLeakyRelu;
3999   (*registration)["MatMul"] = ConvertMatMul;
4000   (*registration)["Pad"] = ConvertPad;
4001   (*registration)["Relu6"] = ConvertRelu6;
4002   (*registration)["Reshape"] = ConvertReshape;
4003   (*registration)["Rsqrt"] = ConvertRsqrt;
4004   (*registration)["Slice"] = ConvertSlice;
4005   (*registration)["Square"] = ConvertSquare;
4006   (*registration)["Squeeze"] = ConvertSqueeze;
4007   (*registration)["StridedSlice"] = ConvertStridedSlice;
4008   (*registration)["Transpose"] = ConvertTranspose;
4009   (*registration)["TopKV2"] = ConvertTopK;
4010 
4011   // TODO(ben,jie): this is a temp hack.
4012   (*registration)["Identity"] = ConvertIdentity;  // Identity should be removed
4013   (*registration)["Snapshot"] = ConvertIdentity;  // Snapshot should be removed
4014 
4015   (*registration)["Sum"] = ConvertReduce;
4016   (*registration)["Prod"] = ConvertReduce;
4017   (*registration)["Max"] = ConvertReduce;
4018   (*registration)["Min"] = ConvertReduce;
4019   (*registration)["Mean"] = ConvertReduce;
4020   (*registration)["Softmax"] = ConvertSoftmax;
4021   (*registration)["BatchMatMul"] = ConvertBatchMatMul;
4022 
4023   for (auto quantization_op_type :
4024        {"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3",
4025         "FakeQuantWithMinMaxVars", "FakeQuantWithMinMaxArgs"}) {
4026     (*registration)[quantization_op_type] = ConvertQuantize;
4027   }
4028   for (auto binary_op_type :
4029        {"Add", "Mul", "Sub", "Div", "RealDiv", "Maximum", "Minimum", "Pow"}) {
4030     (*registration)[binary_op_type] = ConvertBinary;
4031   }
4032   for (auto activation_op_type : {"Relu", "Sigmoid", "Tanh"}) {
4033     (*registration)[activation_op_type] = ConvertActivation;
4034   }
4035   for (auto pool_op_type : {"AvgPool", "MaxPool"}) {
4036     (*registration)[pool_op_type] = ConvertPool;
4037   }
4038   for (auto normalization_op_type : {"FusedBatchNorm", "FusedBatchNormV2"}) {
4039     (*registration)[normalization_op_type] = ConvertFusedBatchNorm;
4040   }
4041   for (auto unary_op_pair : *UnaryOperationMap()) {
4042     (*registration)[unary_op_pair.first] = ConvertUnary;
4043   }
4044 }
4045 
RegisterOpValidators()4046 void TrtNodeValidator::RegisterOpValidators() {
4047   RegisterValidatableOpConverters(&op_validators_);
4048 }
4049 
RegisterOpConverters()4050 void Converter::RegisterOpConverters() {
4051   RegisterValidatableOpConverters(&op_registry_);
4052   plugin_converter_ = ConvertPlugin;
4053 }
4054 
ConvertGraphDefToEngine(const GraphDef & gdef,TrtPrecisionMode precision_mode,int max_batch_size,size_t max_workspace_size_bytes,const std::vector<PartialTensorShape> & input_shapes,Logger * logger,nvinfer1::IGpuAllocator * allocator,TRTInt8Calibrator * calibrator,TrtUniquePtrType<nvinfer1::ICudaEngine> * engine,bool use_calibration,bool * convert_successfully)4055 Status ConvertGraphDefToEngine(
4056     const GraphDef& gdef, TrtPrecisionMode precision_mode, int max_batch_size,
4057     size_t max_workspace_size_bytes,
4058     const std::vector<PartialTensorShape>& input_shapes, Logger* logger,
4059     nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator,
4060     TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration,
4061     bool* convert_successfully) {
4062   engine->reset();
4063   if (convert_successfully) *convert_successfully = false;
4064 
4065   // Create the builder.
4066   TrtUniquePtrType<nvinfer1::IBuilder> builder(
4067       nvinfer1::createInferBuilder(*logger));
4068   builder->setMaxBatchSize(max_batch_size);
4069   builder->setMaxWorkspaceSize(max_workspace_size_bytes);
4070   builder->setGpuAllocator(allocator);
4071   if (precision_mode == TrtPrecisionMode::FP16) {
4072     builder->setFp16Mode(true);
4073   } else if (precision_mode == TrtPrecisionMode::INT8) {
4074     // Setting FP16 mode as well allows TRT to also consider FP16 kernels and
4075     // use them in situations where they are faster than INT8 or where INT8 is
4076     // not supported for a given layer.
4077     builder->setFp16Mode(true);
4078     builder->setInt8Mode(true);
4079     if (use_calibration) {
4080       builder->setInt8Calibrator(calibrator);
4081     } else {
4082       builder->setInt8Calibrator(nullptr);
4083     }
4084   }
4085 
4086   // Create the network.
4087   auto trt_network =
4088       TrtUniquePtrType<nvinfer1::INetworkDefinition>(builder->createNetwork());
4089   if (!trt_network) {
4090     return errors::Internal("Failed to create TensorRT network object");
4091   }
4092 
4093   // Build the network
4094   VLOG(1) << "Starting engine conversion ";
4095   Converter converter(trt_network.get(), precision_mode, use_calibration);
4096   std::vector<Converter::EngineOutputInfo> output_tensors;
4097   // Graph nodes are already topologically sorted during construction
4098   for (const auto& node_def : gdef.node()) {
4099     string node_name = node_def.name();
4100     VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op();
4101     if (IsEngineInput(node_name) && (node_def.op() == "Placeholder")) {
4102       int32 slot_number = -1;
4103       if (!strings::safe_strto32(  // non-absl ok
4104               node_name.c_str() + strlen(kInputPHName), &slot_number)) {
4105         return errors::InvalidArgument("Failed to parse slot number from ",
4106                                        node_name);
4107       }
4108       nvinfer1::DataType trt_dtype;
4109       nvinfer1::Dims trt_dims;
4110       int batch_size = -1;
4111       auto shape = input_shapes.at(slot_number);
4112       auto status = ValidateTensorProperties(
4113           node_def.op(), node_def.attr().at("dtype").type(), shape,
4114           /*validation_only=*/false, &trt_dtype, &trt_dims, &batch_size);
4115       if (!status.ok()) {
4116         const string error_message =
4117             StrCat("Validation failed for ", node_name, " and input slot ",
4118                    slot_number, ": ", status.error_message());
4119         LOG(WARNING) << error_message;
4120         return Status(status.code(), error_message);
4121       }
4122       VLOG(2) << "Adding engine input tensor " << node_name << " with shape "
4123               << DebugString(trt_dims);
4124       // TODO(laigd): the conversion should always happen at runtime where all
4125       // the shapes are known, and we can provide a mode to generate the
4126       // engines offline, by calling sess.run() and cache/serialize the engines.
4127       TF_RETURN_IF_ERROR(
4128           converter.AddInputTensor(node_name, trt_dtype, trt_dims, batch_size));
4129     } else if (IsEngineOutput(node_name) && (node_def.op() == "Identity")) {
4130       int32 slot_number = -1;
4131       if (!strings::safe_strto32(  // non-absl ok
4132               node_name.c_str() + strlen(kOutputPHName), &slot_number)) {
4133         return errors::InvalidArgument("Failed to parse slot number from ",
4134                                        node_name);
4135       }
4136       // Get output type that TensorFlow expects
4137       TFAttrs attrs(node_def);
4138       DataType tf_dtype = attrs.get<DataType>("T");
4139       nvinfer1::DataType trt_dtype;
4140       TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype));
4141       if (output_tensors.size() <= slot_number) {
4142         output_tensors.resize(slot_number + 1);
4143       }
4144       output_tensors.at(slot_number) = {node_def.input(0), node_name,
4145                                         trt_dtype};
4146     } else {
4147       VLOG(2) << "Converting node: " << node_def.name() << " , "
4148               << node_def.op();
4149       TF_RETURN_IF_ERROR(converter.ConvertNode(node_def));
4150     }
4151   }
4152   TF_RETURN_IF_ERROR(converter.RenameAndMarkOutputTensors(output_tensors));
4153   if (convert_successfully) *convert_successfully = true;
4154 
4155   // Apply user provided quantization ranges to tensors
4156   converter.MaybeApplyQuantizationRanges();
4157 
4158   // Build the engine.
4159   VLOG(1) << "Starting engine creation";
4160   engine->reset(builder->buildCudaEngine(*converter.network()));
4161   if (engine->get() == nullptr) {
4162     return errors::Internal("Failed to build TensorRT engine");
4163   }
4164   VLOG(1) << "Finished conversion";
4165   return Status::OK();
4166 }
4167 
ConvertSegmentToGraphDef(const Graph * graph,const grappler::GraphProperties & graph_properties,const std::vector<const Node * > & subgraph_nodes,std::vector<EngineConnection> * connections,GraphDef * segment_def,string * scope_name)4168 Status ConvertSegmentToGraphDef(
4169     const Graph* graph, const grappler::GraphProperties& graph_properties,
4170     const std::vector<const Node*>& subgraph_nodes,  // In topological order
4171     std::vector<EngineConnection>* connections, GraphDef* segment_def,
4172     string* scope_name) {
4173   std::set<string> marker_nodes;
4174   // Update connection shapes/data types and add corresponding input/output
4175   // nodes in the segment graphdef.
4176   for (size_t i = 0; i < connections->size(); ++i) {
4177     auto& connection = connections->at(i);
4178     if (connection.is_control_edge()) continue;
4179     auto outside_node = graph->FindNodeId(connection.outside_id);
4180     if (!outside_node) {
4181       // This should never happen, unless the original graph is problematic.
4182       return errors::NotFound("Cannot find node with id ",
4183                               connection.outside_id, " in the graph.");
4184     }
4185     // Updates the shape and data types of input/output connections.
4186     DataType dtype;
4187     PartialTensorShape partial_shape;
4188     if (connection.is_input_edge) {
4189       GetOutputProperties(graph_properties,
4190                           graph->FindNodeId(connection.outside_id),
4191                           connection.outside_port, &partial_shape, &dtype);
4192       connection.outside_shape = partial_shape;
4193     } else {
4194       GetInputProperties(graph_properties,
4195                          graph->FindNodeId(connection.outside_id),
4196                          connection.outside_port, &partial_shape, &dtype);
4197       connection.inside_shape = partial_shape;
4198     }
4199     connection.connection_type = dtype;
4200 
4201     // Add dummy input/output nodes to the segment graphdef.
4202     if (connection.is_input_edge) {
4203       const string node_name = StrCat(kInputPHName, connection.port_number);
4204       if (marker_nodes.count(node_name)) {
4205         VLOG(1) << "Reusing input " << node_name << " for the edge "
4206                 << connection.outside_node_name << ":"
4207                 << connection.outside_port << " -> "
4208                 << connection.inside_node_name << ":" << connection.inside_port;
4209         continue;
4210       }
4211       marker_nodes.insert(node_name);
4212       auto seg_node = segment_def->add_node();
4213       NodeDefBuilder builder(node_name, "Placeholder");
4214       auto status = builder.Attr("shape", partial_shape)
4215                         .Attr("dtype", dtype)
4216                         .Finalize(seg_node);
4217       VLOG(1) << "Constructing input " << node_name << " for the edge "
4218               << connection.outside_node_name << ":" << connection.outside_port
4219               << " -> " << connection.inside_node_name << ":"
4220               << connection.inside_port;
4221     } else {
4222       const string node_name = StrCat(kOutputPHName, connection.port_number);
4223       if (marker_nodes.count(node_name)) {
4224         VLOG(1) << "Reusing output " << node_name << " for the edge "
4225                 << connection.inside_node_name << ":" << connection.inside_port
4226                 << " -> " << connection.outside_node_name << ":"
4227                 << connection.outside_port;
4228         continue;
4229       }
4230       marker_nodes.insert(node_name);
4231       auto seg_node = segment_def->add_node();
4232       NodeDefBuilder builder(node_name, "Identity");
4233       auto status =
4234           builder
4235               .Input(connection.inside_node_name, connection.inside_port, dtype)
4236               .Finalize(seg_node);
4237       VLOG(1) << "Constructing output " << node_name << " for the edge "
4238               << connection.inside_node_name << ":" << connection.inside_port
4239               << " -> " << connection.outside_node_name << ":"
4240               << connection.outside_port;
4241     }
4242   }  // for each connection.
4243 
4244   std::unordered_map<int, int> old_to_new_id_map;
4245   // Copy internal nodes to new graphdef
4246   string local_scope = subgraph_nodes.front()->name();
4247   for (const Node* node : subgraph_nodes) {
4248     local_scope = GetCommonNameScope(local_scope, node->name());
4249     old_to_new_id_map[node->id()] = segment_def->node_size();
4250     auto snode = segment_def->add_node();
4251     *snode = node->def();
4252     VLOG(2) << "Copying " << snode->name() << " to subgraph";
4253   }
4254   // Update the inputs of the new input nodes to point to placeholder nodes.
4255   for (int i = 0; i < connections->size(); ++i) {
4256     auto& connection = connections->at(i);
4257     if (connection.is_control_edge() || !connection.is_input_edge) continue;
4258     auto snode =
4259         segment_def->mutable_node(old_to_new_id_map[connection.inside_id]);
4260     const string placeholder_name =
4261         StrCat(kInputPHName, connection.port_number);
4262     VLOG(1) << "Updating " << snode->name() << ":" << connection.inside_port
4263             << " from " << snode->input(connection.inside_port) << " to "
4264             << placeholder_name;
4265     snode->set_input(connection.inside_port, placeholder_name);
4266   }
4267   std::set<string> subgraph_node_names;
4268   for (const Node* node : subgraph_nodes) {
4269     subgraph_node_names.insert(node->name());
4270   }
4271 
4272   // Remove control inputs that are not inside the segment.
4273   for (int i = 0; i < segment_def->node_size(); ++i) {
4274     auto snode = segment_def->mutable_node(i);
4275     const int input_size = snode->input_size();
4276     int input_idx = 0;
4277     int actual_input_idx = 0;
4278     while (input_idx < input_size) {
4279       TensorId input = ParseTensorName(snode->input(input_idx));
4280       if (!subgraph_node_names.count(
4281               string(input.first.data(), input.first.size())) &&
4282           !IsEngineInput(input.first)) {
4283         if (input.second == Graph::kControlSlot) {
4284           VLOG(1) << "... removing control inputs " << input.first
4285                   << " from subgraph.";
4286           ++input_idx;
4287           continue;
4288         } else {
4289           return errors::InvalidArgument(
4290               "Found non control input outside the segment that is not an "
4291               "engine connection to ",
4292               snode->name(), ": ", input.first);
4293         }
4294       }
4295       if (actual_input_idx != input_idx) {
4296         snode->set_input(actual_input_idx, snode->input(input_idx));
4297       }
4298       ++input_idx;
4299       ++actual_input_idx;
4300     }
4301     for (int remove = input_size - actual_input_idx; remove > 0; --remove) {
4302       snode->mutable_input()->RemoveLast();
4303     }
4304   }
4305   *scope_name = local_scope;
4306   return Status::OK();
4307 }
4308 
operator ()(const Edge * out_edge) const4309 bool OutputEdgeValidator::operator()(const Edge* out_edge) const {
4310   if (out_edge->IsControlEdge()) return true;
4311   if (out_edge->src()->type_string() == "Const") {
4312     VLOG(1) << "--> Need to remove output node " << out_edge->src()->name()
4313             << " which is a Const.";
4314     return false;
4315   }
4316   return true;
4317 }
4318 
4319 }  // namespace convert
4320 }  // namespace tensorrt
4321 }  // namespace tensorflow
4322 
4323 #endif  // GOOGLE_TENSORRT
4324 #endif  // GOOGLE_CUDA
4325