1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
17
18 #include <algorithm>
19 #include <list>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26
27 #include "tensorflow/core/framework/node_def_builder.h"
28 #include "tensorflow/core/framework/tensor_shape.pb.h" // NOLINT
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/graph/algorithm.h"
31 #include "tensorflow/core/graph/graph.h"
32 #include "tensorflow/core/graph/graph_constructor.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/tensor_coding.h"
38 #include "tensorflow/core/platform/types.h"
39
40 #if GOOGLE_CUDA
41 #if GOOGLE_TENSORRT
42 #include "tensorflow/contrib/tensorrt/log/trt_logger.h"
43 #include "tensorrt/include/NvInfer.h"
44
45 // Check if the types are equal. Cast to int first so that failure log message
46 // would work!
47 #define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2)
48
49 namespace tensorflow {
50 namespace tensorrt {
51 namespace convert {
52
53 namespace {
54
ConvertDType(tensorflow::DataType tf_dtype,nvinfer1::DataType * trt_dtype)55 inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype,
56 nvinfer1::DataType* trt_dtype) {
57 switch (tf_dtype) {
58 case tensorflow::DataType::DT_FLOAT:
59 *trt_dtype = nvinfer1::DataType::kFLOAT;
60 break;
61 case tensorflow::DataType::DT_INT8:
62 *trt_dtype = nvinfer1::DataType::kINT8;
63 break;
64 case tensorflow::DataType::DT_HALF:
65 *trt_dtype = nvinfer1::DataType::kHALF;
66 break;
67 default:
68 return tensorflow::errors::InvalidArgument("Unsupported data type");
69 }
70 return tensorflow::Status::OK();
71 }
72
GetTensorShape(const tensorflow::Tensor & tensor)73 inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) {
74 nvinfer1::Dims dims;
75 dims.nbDims = tensor.dims();
76 for (int i = 0; i < dims.nbDims; i++) {
77 dims.d[i] = tensor.dim_size(i);
78 }
79 return dims;
80 }
81
GetShapeSize(nvinfer1::Dims shape)82 inline int64_t GetShapeSize(nvinfer1::Dims shape) {
83 // Returns total number of elements in shape
84 int64_t count = 1;
85 for (int d = 0; d < shape.nbDims; ++d) {
86 count *= shape.d[d];
87 }
88 return count;
89 }
90
CreateSamePadding(const nvinfer1::DimsHW & stride,const nvinfer1::DimsHW & kernel,const std::vector<int64_t> & input_dims)91 static std::vector<std::pair<int, int>> CreateSamePadding(
92 const nvinfer1::DimsHW& stride, const nvinfer1::DimsHW& kernel,
93 const std::vector<int64_t>& input_dims) {
94 std::vector<std::pair<int, int>> padding(input_dims.size());
95 CHECK_EQ((size_t)stride.nbDims, input_dims.size()); // TODO(jie): N+C? NC+?
96
97 for (size_t i = 0; i < input_dims.size(); ++i) {
98 // Formula to calculate the padding
99 int p = ((input_dims[i] - 1) / stride.d[i]) * stride.d[i] + kernel.d[i] -
100 input_dims[i];
101 p = (p > 0) ? p : 0;
102
103 // Right precedence padding, like in TensorFlow
104 int left = p / 2;
105 int right = p - left;
106
107 VLOG(2) << "PADDING_" << i << " pre: " << left << ", post: " << right
108 << "paras: " << input_dims[i] << ", " << stride.d[i] << ", "
109 << "kernel: " << kernel.d[i];
110 padding[i] = {left, right};
111 }
112 return padding;
113 }
114
115 class TRT_ShapedWeights {
116 public:
TRT_ShapedWeights(tensorflow::DataType type,const void * values,nvinfer1::Dims shape)117 TRT_ShapedWeights(tensorflow::DataType type, const void* values,
118 nvinfer1::Dims shape)
119 : shape_(shape), type_(type), values_(values), empty_weight_flag_(false) {
120 // Note: this->shape.type[] is not used
121 }
122
TRT_ShapedWeights(tensorflow::DataType type)123 explicit TRT_ShapedWeights(tensorflow::DataType type)
124 : shape_(), type_(type), values_(nullptr), empty_weight_flag_(true) {}
125
TRT_ShapedWeights(const TRT_ShapedWeights & rhs)126 TRT_ShapedWeights(const TRT_ShapedWeights& rhs)
127 : shape_(rhs.shape_),
128 type_(rhs.type_),
129 values_(rhs.values_),
130 empty_weight_flag_(rhs.empty_weight_flag_) {}
131
count() const132 int64_t count() const {
133 int64_t c = 1;
134 for (int i = 0; i < shape_.nbDims; i++) c *= shape_.d[i];
135 return c;
136 }
137
GetWeightsForTRT() const138 nvinfer1::Weights GetWeightsForTRT() const {
139 nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT);
140 TF_CHECK_OK(ConvertDType(type_, &trt_type));
141 if (empty_weight_flag_) return nvinfer1::Weights{trt_type, nullptr, 0};
142
143 // Note: this->shape.type[] is not used
144 return nvinfer1::Weights{trt_type, GetValues(), GetShapeSize(shape_)};
145 }
146
GetValues() const147 const void* GetValues() const { return values_; }
148
SetValues(const void * values)149 void SetValues(const void* values) { values_ = values; }
150
size_bytes() const151 size_t size_bytes() const {
152 int type_size = tensorflow::DataTypeSize(this->type_);
153 return this->count() * type_size;
154 }
155
156 // Default converter
operator nvinfer1::Weights() const157 operator nvinfer1::Weights() const { return GetWeightsForTRT(); }
158
159 nvinfer1::Dims shape_;
160 tensorflow::DataType type_;
161
162 private:
163 const void* values_;
164 bool empty_weight_flag_;
165 };
166
167 class TRT_TensorOrWeights {
168 public:
TRT_TensorOrWeights(nvinfer1::ITensor * tensor)169 explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor)
170 : tensor_(tensor), weights_(DT_FLOAT), variant_(TRT_NODE_TENSOR) {}
TRT_TensorOrWeights(const TRT_ShapedWeights & weights)171 explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights)
172 : tensor_(nullptr), weights_(weights), variant_(TRT_NODE_WEIGHTS) {}
TRT_TensorOrWeights(const TRT_TensorOrWeights & rhs)173 TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs)
174 : tensor_(rhs.tensor_), weights_(rhs.weights_), variant_(rhs.variant_) {}
~TRT_TensorOrWeights()175 ~TRT_TensorOrWeights() {}
176
is_tensor() const177 bool is_tensor() const { return variant_ == TRT_NODE_TENSOR; }
is_weights() const178 bool is_weights() const { return variant_ == TRT_NODE_WEIGHTS; }
179
tensor()180 nvinfer1::ITensor* tensor() {
181 CHECK_EQ(is_tensor(), true);
182 return tensor_;
183 }
tensor() const184 const nvinfer1::ITensor* tensor() const {
185 CHECK_EQ(is_tensor(), true);
186 return tensor_;
187 }
weights()188 TRT_ShapedWeights& weights() {
189 CHECK_EQ(is_weights(), true);
190 return weights_;
191 }
weights() const192 const TRT_ShapedWeights& weights() const {
193 CHECK_EQ(is_weights(), true);
194 return weights_;
195 }
shape() const196 nvinfer1::Dims shape() const {
197 if (is_tensor()) {
198 return tensor()->getDimensions();
199 } else {
200 return weights().shape_;
201 }
202 }
203
204 private:
205 nvinfer1::ITensor* tensor_;
206 TRT_ShapedWeights weights_;
207 enum { TRT_NODE_TENSOR, TRT_NODE_WEIGHTS } variant_;
208 };
209
210 class TFAttrs {
211 public:
TFAttrs(const tensorflow::NodeDef & tf_node)212 explicit TFAttrs(const tensorflow::NodeDef& tf_node) {
213 for (const auto& attr : tf_node.attr()) {
214 attrs_.insert({attr.first, &attr.second});
215 }
216 }
count(string key) const217 bool count(string key) const { return attrs_.count(key); }
at(string key) const218 tensorflow::AttrValue const* at(string key) const {
219 if (!attrs_.count(key)) {
220 LOG(FATAL) << "Attribute not found: " << key;
221 }
222 return attrs_.at(key);
223 }
224 template <typename T>
225 T get(string key) const;
226 template <typename T>
get(string key,const T & default_value) const227 T get(string key, const T& default_value) const {
228 return attrs_.count(key) ? this->get<T>(key) : default_value;
229 }
230
231 private:
232 typedef std::map<string, tensorflow::AttrValue const*> AttrMap;
233 AttrMap attrs_;
234 };
235
236 template <>
get(string key) const237 string TFAttrs::get<string>(string key) const {
238 return this->at(key)->s();
239 }
240
241 template <>
get(string key) const242 std::vector<int> TFAttrs::get<std::vector<int>>(string key) const {
243 auto attr = this->at(key)->list().i();
244 return std::vector<int>(attr.begin(), attr.end());
245 }
246
247 template <>
get(string key) const248 nvinfer1::Dims TFAttrs::get<nvinfer1::Dims>(string key) const {
249 auto values = this->get<std::vector<int>>(key);
250 nvinfer1::Dims dims;
251 dims.nbDims = values.size();
252 std::copy(values.begin(), values.end(), dims.d);
253 // Note: No dimension type information is included
254 return dims;
255 }
256
257 template <>
get(string key) const258 nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(string key) const {
259 nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
260 TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype));
261 return trt_dtype;
262 }
263
264 template <>
get(string key) const265 tensorflow::DataType TFAttrs::get<tensorflow::DataType>(string key) const {
266 return this->at(key)->type();
267 }
268
269 template <typename T>
Reorder4(nvinfer1::DimsNCHW shape,const T * idata,nvinfer1::DimsNCHW istrides,T * odata,nvinfer1::DimsNCHW ostrides)270 void Reorder4(nvinfer1::DimsNCHW shape, const T* idata,
271 nvinfer1::DimsNCHW istrides, T* odata,
272 nvinfer1::DimsNCHW ostrides) {
273 for (int n = 0; n < shape.n(); ++n) {
274 for (int c = 0; c < shape.c(); ++c) {
275 for (int h = 0; h < shape.h(); ++h) {
276 for (int w = 0; w < shape.w(); ++w) {
277 odata[n * ostrides.n() + c * ostrides.c() + h * ostrides.h() +
278 w * ostrides.w()] = idata[n * istrides.n() + c * istrides.c() +
279 h * istrides.h() + w * istrides.w()];
280 }
281 }
282 }
283 }
284 }
285
ReorderRSCKToKCRS(const TRT_ShapedWeights & iweights,TRT_ShapedWeights * oweights)286 void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
287 TRT_ShapedWeights* oweights) {
288 CHECK_EQ(iweights.type_, oweights->type_);
289 CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
290 int r = iweights.shape_.d[0];
291 int s = iweights.shape_.d[1];
292 int c = iweights.shape_.d[2];
293 int k = iweights.shape_.d[3];
294 oweights->shape_.d[0] = k;
295 oweights->shape_.d[1] = c;
296 oweights->shape_.d[2] = r;
297 oweights->shape_.d[3] = s;
298 nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
299 nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1};
300 switch (iweights.type_) {
301 case tensorflow::DataType::DT_FLOAT:
302 Reorder4({k, c, r, s}, static_cast<float const*>(iweights.GetValues()),
303 istrides,
304 static_cast<float*>(const_cast<void*>(oweights->GetValues())),
305 ostrides);
306 break;
307 default:
308 LOG(FATAL) << "!!!!!!!!!!!!!!!!!!!!!!!!broke!!!!!!!!!!!!";
309 }
310 }
311
312 struct InferDeleter {
313 template <typename T>
operator ()tensorflow::tensorrt::convert::__anonfd151e3f0111::InferDeleter314 void operator()(T* obj) const {
315 if (obj) {
316 obj->destroy();
317 }
318 }
319 };
320
321 template <typename T>
infer_object(T * obj)322 inline std::shared_ptr<T> infer_object(T* obj) {
323 return std::shared_ptr<T>(obj, InferDeleter());
324 }
325
326 // Logger for GIE info/warning/errors
327 class Converter;
328
329 using OpConverter =
330 std::function<tensorflow::Status(Converter&, const tensorflow::NodeDef&,
331 std::vector<TRT_TensorOrWeights> const&,
332 std::vector<TRT_TensorOrWeights>*)>;
333
334 class Converter {
335 std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
336 std::unordered_map<string, OpConverter> op_registry_;
337 nvinfer1::INetworkDefinition* trt_network_;
338 std::list<std::vector<uint8_t>> temp_bufs_;
339
340 void register_op_converters();
341
get_inputs(const tensorflow::NodeDef & node_def)342 std::vector<TRT_TensorOrWeights> get_inputs(
343 const tensorflow::NodeDef& node_def) {
344 std::vector<TRT_TensorOrWeights> inputs;
345 for (const auto& input_name : node_def.input()) {
346 VLOG(2) << "Retrieve input: " << input_name;
347 inputs.push_back(trt_tensors_.at(input_name));
348 }
349 return inputs;
350 }
351
352 public:
Converter(nvinfer1::INetworkDefinition * trt_network)353 explicit Converter(nvinfer1::INetworkDefinition* trt_network)
354 : trt_network_(trt_network) {
355 this->register_op_converters();
356 }
357
get_temp_weights(tensorflow::DataType type,nvinfer1::Dims shape)358 TRT_ShapedWeights get_temp_weights(tensorflow::DataType type,
359 nvinfer1::Dims shape) {
360 TRT_ShapedWeights weights(type, nullptr, shape);
361 // TODO(jie): check weights size_bytes. 0 means type error
362 temp_bufs_.push_back(std::vector<uint8_t>(weights.size_bytes()));
363 weights.SetValues(temp_bufs_.back().data());
364 return weights;
365 }
366
get_temp_weights_like(const TRT_ShapedWeights & weights)367 TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) {
368 return this->get_temp_weights(weights.type_, weights.shape_);
369 }
370
convert_node(const tensorflow::NodeDef & node_def)371 tensorflow::Status convert_node(const tensorflow::NodeDef& node_def) {
372 std::vector<TRT_TensorOrWeights> inputs = this->get_inputs(node_def);
373 string op = node_def.op();
374 if (!op_registry_.count(op)) {
375 return tensorflow::errors::Unimplemented(
376 "No converter registered for op: " + op);
377 }
378 OpConverter op_converter = op_registry_.at(op);
379 std::vector<TRT_TensorOrWeights> outputs;
380 TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs));
381 for (size_t i = 0; i < outputs.size(); ++i) {
382 TRT_TensorOrWeights output = outputs.at(i);
383 // TODO(jie): tf protobuf seems to be omitting the :0 suffix
384 string output_name = node_def.name();
385 if (i != 0) output_name = output_name + ":" + std::to_string(i);
386 if (output.is_tensor()) {
387 output.tensor()->setName(output_name.c_str());
388 }
389 VLOG(2) << "Write out tensor: " << output_name;
390 if (!trt_tensors_.insert({output_name, output}).second) {
391 return tensorflow::errors::AlreadyExists(
392 "Output tensor already exists for op: " + op);
393 }
394 }
395 return tensorflow::Status::OK();
396 }
397
network()398 nvinfer1::INetworkDefinition* network() { return trt_network_; }
399
get_tensor(string name)400 TRT_TensorOrWeights get_tensor(string name) {
401 if (!trt_tensors_.count(name)) {
402 return TRT_TensorOrWeights(nullptr);
403 }
404 return trt_tensors_.at(name);
405 }
406
insert_input_tensor(string name,nvinfer1::ITensor * tensor)407 bool insert_input_tensor(string name, nvinfer1::ITensor* tensor) {
408 return trt_tensors_.insert({name, TRT_TensorOrWeights(tensor)}).second;
409 }
410
TransposeTensor(nvinfer1::ITensor * input_tensor,std::vector<int> order)411 nvinfer1::ITensor* TransposeTensor(nvinfer1::ITensor* input_tensor,
412 std::vector<int> order) {
413 auto dims = input_tensor->getDimensions();
414
415 // TODO(jie): change the return to status and properly exit
416 if (order.size() - 1 != size_t(dims.nbDims))
417 LOG(ERROR) << "Dimension does not match, fail gracefully";
418
419 nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor);
420 nvinfer1::Permutation permutation;
421 for (int32_t i = 0; i < dims.nbDims; ++i) {
422 permutation.order[i] = order[i + 1] - 1;
423 }
424 layer->setFirstTranspose(permutation);
425
426 nvinfer1::Dims reshape_dims;
427 reshape_dims.nbDims = dims.nbDims;
428 for (int32_t i = 0; i < reshape_dims.nbDims; ++i) {
429 reshape_dims.d[i] = 0;
430 reshape_dims.type[i] = dims.type[i];
431 }
432 layer->setReshapeDimensions(reshape_dims);
433 return layer->getOutput(0);
434 }
435 };
436
437 // ****************************************************************************
438 // Constant folding functions
439 // TODO(jie): once optimizer kicks in, we should have done constant folding
440 // there.
441 //*****************************************************************************/
442 struct LambdaFactory {
443 enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB };
444 OP_CATEGORY op;
445
446 template <typename T>
unarytensorflow::tensorrt::convert::__anonfd151e3f0111::LambdaFactory447 std::function<T(T)> unary() {
448 switch (op) {
449 case OP_CATEGORY::RSQRT: {
450 VLOG(2) << "RSQRT GETS DONE";
451 return [](T t) -> T { return 1.0 / std::sqrt(t); };
452 }
453 case OP_CATEGORY::NEG:
454 return [](T t) -> T { return -t; };
455 default:
456 VLOG(2) << "Not supported op for unary: " << static_cast<int>(op);
457 return nullptr;
458 }
459 }
460
461 template <typename T>
binarytensorflow::tensorrt::convert::__anonfd151e3f0111::LambdaFactory462 std::function<T(T, T)> binary() {
463 switch (op) {
464 case OP_CATEGORY::ADD:
465 return [](T l, T r) -> T { return l + r; };
466 case OP_CATEGORY::SUB:
467 return [](T l, T r) -> T { return l - r; };
468 case OP_CATEGORY::MUL:
469 return [](T l, T r) -> T { return l * r; };
470 default:
471 LOG(WARNING) << "Not supported op for binary: " << static_cast<int>(op);
472 }
473 return [](T l, T r) -> T {
474 LOG(FATAL) << "Unsupported op type ";
475 return l;
476 };
477 }
478
479 template <typename T>
broadcast_rtensorflow::tensorrt::convert::__anonfd151e3f0111::LambdaFactory480 std::function<T(T)> broadcast_r(T val) {
481 VLOG(2) << "LAMBDA VAL : " << val;
482 switch (op) {
483 case OP_CATEGORY::ADD:
484 return [val](T l) -> T {
485 VLOG(2) << "LAMBDA VAL : " << val;
486 return l + val;
487 };
488 // Return [val](T l)-> T {return l+val;};
489 case OP_CATEGORY::SUB:
490 return [val](T l) -> T {
491 VLOG(2) << "LAMBDA VAL : " << val;
492 return l - val;
493 };
494 case OP_CATEGORY::MUL:
495 return [val](T l) -> T {
496 VLOG(2) << "LAMBDA VAL : " << val;
497 return l * val;
498 };
499 default:
500 LOG(WARNING) << "Not supported op for binary: " << static_cast<int>(op);
501 }
502 return [val](T l) -> T {
503 LOG(FATAL) << "Unsupported op type ";
504 return l;
505 };
506 }
507
508 template <typename T>
broadcast_ltensorflow::tensorrt::convert::__anonfd151e3f0111::LambdaFactory509 std::function<T(T)> broadcast_l(T val) {
510 VLOG(2) << "LAMBDA VAL : " << val;
511 switch (op) {
512 case OP_CATEGORY::ADD:
513 return [val](T l) -> T {
514 VLOG(2) << "LAMBDA VAL : " << val;
515 return val + l;
516 };
517 case OP_CATEGORY::SUB:
518 return [val](T l) -> T {
519 VLOG(2) << "LAMBDA VAL : " << val;
520 return val - l;
521 };
522 case OP_CATEGORY::MUL:
523 return [val](T l) -> T {
524 VLOG(2) << "LAMBDA VAL : " << val;
525 return val * l;
526 };
527 default:
528 LOG(ERROR) << "Not supported op for binary: " << static_cast<int>(op);
529 }
530 return [val](T l) -> T {
531 LOG(FATAL) << "Unsupported op type ";
532 return l;
533 };
534 }
535 };
536
UnaryCompute(const TRT_ShapedWeights & iweights,TRT_ShapedWeights * oweights,LambdaFactory unary_op)537 tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights,
538 TRT_ShapedWeights* oweights,
539 LambdaFactory unary_op) {
540 CHECK_EQ(iweights.type_, oweights->type_);
541 switch (iweights.type_) {
542 case tensorflow::DataType::DT_FLOAT: {
543 auto inp = static_cast<float const*>(iweights.GetValues());
544 auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues()));
545 std::transform(inp, inp + iweights.count(), oup, unary_op.unary<float>());
546 break;
547 }
548 default:
549 return tensorflow::errors::Unimplemented(
550 "Data type not supported: " +
551 tensorflow::DataTypeString(iweights.type_));
552 }
553 return tensorflow::Status::OK();
554 }
555
BinaryCompute(const TRT_ShapedWeights & iweights_l,const TRT_ShapedWeights & iweights_r,TRT_ShapedWeights * oweights,LambdaFactory binary_op)556 tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l,
557 const TRT_ShapedWeights& iweights_r,
558 TRT_ShapedWeights* oweights,
559 LambdaFactory binary_op) {
560 // Assume iweights_l.type == iweight_r.type
561 CHECK_EQ(iweights_l.type_, oweights->type_);
562 CHECK_EQ(iweights_r.type_, oweights->type_);
563 VLOG(2) << "SANITY CHECK!";
564
565 switch (iweights_l.type_) {
566 case tensorflow::DataType::DT_FLOAT: {
567 auto inp_l = static_cast<const float*>(iweights_l.GetValues());
568 auto inp_r = static_cast<const float*>(iweights_r.GetValues());
569 auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues()));
570
571 if (iweights_l.count() != iweights_r.count()) {
572 // We only supports broadcast of RankZero
573 if (iweights_l.count() == 1) {
574 VLOG(2) << "I bet it is not working!" << (*inp_l);
575 std::transform(inp_r, inp_r + iweights_r.count(), oup,
576 binary_op.broadcast_l<float>(*inp_l));
577 } else if (iweights_r.count() == 1) {
578 VLOG(2) << "I bet it is not working!" << (*inp_r);
579 std::transform(inp_l, inp_l + iweights_l.count(), oup,
580 binary_op.broadcast_r<float>(*inp_r));
581 } else {
582 return tensorflow::errors::Unimplemented(
583 "Binary op with non-rankZero broadcast not supported");
584 }
585 } else {
586 std::transform(inp_l, inp_l + iweights_l.count(), inp_r, oup,
587 binary_op.binary<float>());
588 }
589 break;
590 }
591 default:
592 return tensorflow::errors::Unimplemented(
593 "Data type not supported: " +
594 tensorflow::DataTypeString(iweights_l.type_));
595 }
596
597 return tensorflow::Status::OK();
598 }
599
ConstantFoldUnary(Converter & ctx,const tensorflow::NodeDef & node_def,std::vector<TRT_TensorOrWeights> const & inputs,std::vector<TRT_TensorOrWeights> * outputs)600 tensorflow::Status ConstantFoldUnary(
601 Converter& ctx, const tensorflow::NodeDef& node_def,
602 std::vector<TRT_TensorOrWeights> const& inputs,
603 std::vector<TRT_TensorOrWeights>* outputs) {
604 TRT_ShapedWeights weights_input = inputs.at(0).weights();
605
606 // Allocate output weights
607 TRT_ShapedWeights weights_output = ctx.get_temp_weights_like(weights_input);
608
609 // FIXME assume type matches input weights
610 // Get trt type & shape
611 // Maybe this part has to be moved into the block of rsqrt later
612 // Check type consistency
613 CHECK_EQ(weights_input.type_,
614 TFAttrs(node_def).get<tensorflow::DataType>("T"));
615
616 // Maybe I should do a switch
617 LambdaFactory unary_op;
618 if (node_def.op() == "Rsqrt") {
619 // Compute rsqrt
620 unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT;
621 auto ret = UnaryCompute(weights_input, &weights_output, unary_op);
622 // PAss the output
623 if (ret == tensorflow::Status::OK()) {
624 outputs->push_back(TRT_TensorOrWeights(weights_output));
625 }
626 return ret;
627 } else {
628 return tensorflow::errors::Unimplemented("Binary op not supported: " +
629 node_def.op());
630 }
631 }
632
633 // TODO(jie,ben) broadcast is needed yet not implemented
634 // Let's get the simple stuff working first. Maybe we should fall bakc to TF
635 // approach for constant folding
ConstantFoldBinary(Converter & ctx,const tensorflow::NodeDef & node_def,std::vector<TRT_TensorOrWeights> const & inputs,std::vector<TRT_TensorOrWeights> * outputs)636 tensorflow::Status ConstantFoldBinary(
637 Converter& ctx, const tensorflow::NodeDef& node_def,
638 std::vector<TRT_TensorOrWeights> const& inputs,
639 std::vector<TRT_TensorOrWeights>* outputs) {
640 TRT_ShapedWeights weights_input_l = inputs.at(0).weights();
641 TRT_ShapedWeights weights_input_r = inputs.at(1).weights();
642
643 // Check type consistency
644 CHECK_EQ(weights_input_l.type_, weights_input_r.type_);
645
646 if (weights_input_l.shape_.nbDims != weights_input_r.shape_.nbDims)
647 return tensorflow::errors::Unimplemented(
648 "Binary op implicit broadcast not supported: " + node_def.op());
649
650 // TODO(jie): constant fold should really fall back to TF.
651 int nb_dims = weights_input_l.shape_.nbDims;
652 nvinfer1::Dims output_shape;
653 output_shape.nbDims = nb_dims;
654 VLOG(2) << "nb_dims: " << nb_dims
655 << ", the other: " << weights_input_r.shape_.nbDims;
656 for (int i = 0; i < nb_dims; i++) {
657 if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) {
658 output_shape.d[i] = weights_input_l.shape_.d[i];
659 } else if (weights_input_l.shape_.d[i] == 1 ||
660 weights_input_r.shape_.d[i] == 1) {
661 output_shape.d[i] =
662 std::max(weights_input_l.shape_.d[i], weights_input_r.shape_.d[i]);
663 } else {
664 return tensorflow::errors::Unimplemented(
665 "Binary op with incompatible shape at, " + node_def.op());
666 }
667 VLOG(2) << "left: " << weights_input_l.shape_.d[i]
668 << "right: " << weights_input_r.shape_.d[i]
669 << "output: " << output_shape.d[i];
670 }
671
672 // FIXME assume type matches input weights
673 // Get trt type & shape
674 TFAttrs attrs(node_def);
675 // Maybe this part has to be moved into the block of rsqrt later
676 tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("T");
677
678 // Allocate output weights
679 TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape);
680
681 // Maybe I should do a switch
682 LambdaFactory binary_op;
683 if (node_def.op() == "Sub") {
684 binary_op.op = LambdaFactory::OP_CATEGORY::SUB;
685 } else if (node_def.op() == "Mul") {
686 binary_op.op = LambdaFactory::OP_CATEGORY::MUL;
687 } else if (node_def.op() == "Add") {
688 binary_op.op = LambdaFactory::OP_CATEGORY::ADD;
689 } else {
690 return tensorflow::errors::Unimplemented("Binary op not supported: " +
691 node_def.op());
692 }
693 auto ret = BinaryCompute(weights_input_l, weights_input_r, &weights_output,
694 binary_op);
695
696 // Pass the output
697 if (ret == tensorflow::Status::OK()) {
698 outputs->push_back(TRT_TensorOrWeights(weights_output));
699 }
700
701 return ret;
702 }
703
704 // TODO(jie): broadcast is needed yet not implemented.
705 // Only implemented channel wise for the time being
BinaryTensorOpWeight(Converter & ctx,const tensorflow::NodeDef & node_def,const nvinfer1::ITensor * tensor,TRT_ShapedWeights weights,std::vector<TRT_TensorOrWeights> * outputs)706 tensorflow::Status BinaryTensorOpWeight(
707 Converter& ctx, const tensorflow::NodeDef& node_def,
708 const nvinfer1::ITensor* tensor, TRT_ShapedWeights weights,
709 std::vector<TRT_TensorOrWeights>* outputs) {
710 // FIXME assume type matches input weights
711 // Get trt type & shape
712 // Maybe this part has to be moved into the block of rsqrt later
713
714 // Check type consistency
715 auto dtype = TFAttrs(node_def).get<nvinfer1::DataType>("T");
716 CHECK_EQ_TYPE(tensor->getType(), dtype); // Cast to int for error messages
717 nvinfer1::DataType ttype;
718 TF_CHECK_OK(ConvertDType(weights.type_, &ttype));
719 CHECK_EQ_TYPE(ttype, dtype); // Cast to int for error message
720
721 // Check scale mode
722 auto dims_w = weights.shape_;
723 auto dims_t = tensor->getDimensions();
724
725 // Default to channel-wise
726 auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
727
728 if (weights.count() == 1) {
729 VLOG(2) << "UNIFORM";
730 scale_mode = nvinfer1::ScaleMode::kUNIFORM;
731 } else {
732 // No broadcasting on Batch dimension;
733 assert(dims_w.d[0] == 1);
734
735 // Broadcasting on Channel dimension only allowed in kUNIFORM
736 assert(dims_w.d[1] == dims_t.d[0]);
737 assert(dims_w.nbDims == dims_t.nbDims);
738
739 // Default is element;
740 for (int i = 2; i < dims_w.nbDims; i++) {
741 if (dims_w.d[i] != dims_t.d[i - 1]) {
742 scale_mode = nvinfer1::ScaleMode::kCHANNEL;
743 break;
744 }
745 }
746 if (scale_mode == nvinfer1::ScaleMode::kELEMENTWISE) {
747 scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
748 for (int i = 2; i < dims_w.nbDims; i++) {
749 if (dims_w.d[i] != 1)
750 return tensorflow::errors::InvalidArgument(
751 "Weight shape not compatible at, " + node_def.name());
752 }
753 }
754 }
755
756 // Prepare weights
757 TRT_ShapedWeights shift_weights(weights.type_);
758 TRT_ShapedWeights scale_weights(weights.type_);
759 TRT_ShapedWeights power_weights(weights.type_);
760
761 // Maybe I should do a switch
762 if (node_def.op() == "Sub") {
763 TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights);
764 LambdaFactory unary_op;
765 unary_op.op = LambdaFactory::OP_CATEGORY::NEG;
766 TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op));
767 shift_weights = neg_weights;
768 } else if (node_def.op() == "Mul") {
769 scale_weights = weights;
770 } else if (node_def.op() == "Add") {
771 shift_weights = weights;
772 } else {
773 return tensorflow::errors::Unimplemented("Binary op not supported: " +
774 node_def.op());
775 }
776
777 nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
778 *const_cast<nvinfer1::ITensor*>(tensor), scale_mode, shift_weights,
779 scale_weights, power_weights);
780
781 nvinfer1::ITensor* output_tensor = layer->getOutput(0);
782
783 // Pass the output
784 outputs->push_back(TRT_TensorOrWeights(output_tensor));
785 return tensorflow::Status::OK();
786 }
787
BinaryTensorOpTensor(Converter & ctx,const tensorflow::NodeDef & node_def,const nvinfer1::ITensor * tensor_l,const nvinfer1::ITensor * tensor_r,std::vector<TRT_TensorOrWeights> * outputs)788 tensorflow::Status BinaryTensorOpTensor(
789 Converter& ctx, const tensorflow::NodeDef& node_def,
790 const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r,
791 std::vector<TRT_TensorOrWeights>* outputs) {
792 static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
793 {"Add", nvinfer1::ElementWiseOperation::kSUM},
794 {"Mul", nvinfer1::ElementWiseOperation::kPROD},
795 // {"max", nvinfer1::ElementWiseOperation::kMAX},
796 // {"min", nvinfer1::ElementWiseOperation::kMIN},
797 {"Sub", nvinfer1::ElementWiseOperation::kSUB},
798 {"Div", nvinfer1::ElementWiseOperation::kDIV},
799 };
800
801 // FIXME assume type matches input weights
802 // Get trt type & shape
803 TFAttrs attrs(node_def);
804 // Maybe this part has to be moved into the block of rsqrt later
805 nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
806
807 // Check type consistency
808 CHECK_EQ_TYPE(tensor_l->getType(), dtype);
809 CHECK_EQ_TYPE(tensor_r->getType(), dtype);
810 auto op_pair = ops.find(node_def.op());
811 if (op_pair == ops.end())
812 return tensorflow::errors::Unimplemented("binary op: " + node_def.op() +
813 " not supported at: " +
814 node_def.name());
815
816 nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
817 *const_cast<nvinfer1::ITensor*>(tensor_l),
818 *const_cast<nvinfer1::ITensor*>(tensor_r), op_pair->second);
819
820 nvinfer1::ITensor* output_tensor = layer->getOutput(0);
821
822 // Pass the output
823 outputs->push_back(TRT_TensorOrWeights(output_tensor));
824 return tensorflow::Status::OK();
825 }
826
ConvertPlaceholder(Converter & ctx,const tensorflow::NodeDef & node_def,std::vector<TRT_TensorOrWeights> const & inputs,std::vector<TRT_TensorOrWeights> * outputs)827 tensorflow::Status ConvertPlaceholder(
828 Converter& ctx, const tensorflow::NodeDef& node_def,
829 std::vector<TRT_TensorOrWeights> const& inputs,
830 std::vector<TRT_TensorOrWeights>* outputs) {
831 VLOG(2) << "Placeholder should have been replace already";
832 return tensorflow::errors::Unimplemented(", cannot convert Placeholder op");
833 // OK this make sense since we are supposed to replace it with input
834 TFAttrs attrs(node_def);
835 nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("dtype");
836 nvinfer1::Dims dims = attrs.get<nvinfer1::Dims>("shape");
837
838 dims.nbDims--;
839 for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1];
840
841 nvinfer1::ITensor* output =
842 ctx.network()->addInput(node_def.name().c_str(), dtype, dims);
843 if (!output) {
844 return tensorflow::errors::InvalidArgument("Failed to create Input layer");
845 }
846 outputs->push_back(TRT_TensorOrWeights(output));
847 return tensorflow::Status::OK();
848 }
849
ConvertConv2D(Converter & ctx,const tensorflow::NodeDef & node_def,const std::vector<TRT_TensorOrWeights> & inputs,std::vector<TRT_TensorOrWeights> * outputs)850 tensorflow::Status ConvertConv2D(Converter& ctx,
851 const tensorflow::NodeDef& node_def,
852 const std::vector<TRT_TensorOrWeights>& inputs,
853 std::vector<TRT_TensorOrWeights>* outputs) {
854 nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
855 // TODO(jie): handle NHWC/NCHW transpose;
856 TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
857 TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck);
858 ReorderRSCKToKCRS(weights_rsck, &weights);
859 TRT_ShapedWeights biases(weights.type_);
860 int noutput = weights.shape_.d[0];
861 nvinfer1::DimsHW kernel_size;
862 kernel_size.h() = weights.shape_.d[2];
863 kernel_size.w() = weights.shape_.d[3];
864 TFAttrs attrs(node_def);
865
866 int h_index = 2;
867 int w_index = 3;
868 auto data_format = attrs.get<string>("data_format");
869 if (data_format == "NHWC") {
870 tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
871 {0, 3, 1, 2});
872 h_index = 1;
873 w_index = 2;
874 // TODO(jie): transpose it
875 }
876
877 // TODO(jie): stride. (NHWC/NCHW)
878 auto tf_stride = attrs.get<std::vector<int>>("strides");
879 nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
880
881 auto tensor_dim = tensor->getDimensions();
882 std::vector<std::pair<int, int>> padding;
883 // TODO(jie): padding.
884 if (attrs.get<string>("padding") == "SAME") {
885 // This is NCHW tensor with no batch dimension.
886 // 1 -> h
887 // 2 -> w
888 padding = CreateSamePadding(
889 stride, kernel_size,
890 {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
891 } else {
892 padding = {{0, 0}, {0, 0}};
893 }
894
895 if (padding[0].first != padding[0].second ||
896 padding[1].first != padding[1].second) {
897 // TODO(jie): handle asymmetric padding
898 VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
899 << padding[1].first << padding[1].second;
900
901 auto dim_before = tensor->getDimensions();
902 VLOG(2) << "TENSOR before: " << dim_before.d[0] << ", " << dim_before.d[1]
903 << dim_before.d[2] << ", " << dim_before.d[3];
904 auto pad_layer = ctx.network()->addPadding(
905 *const_cast<nvinfer1::ITensor*>(tensor),
906 nvinfer1::DimsHW(padding[0].first, padding[1].first),
907 nvinfer1::DimsHW(padding[0].second, padding[1].second));
908 padding = {{0, 0}, {0, 0}};
909 tensor = pad_layer->getOutput(0);
910 auto dim_after = tensor->getDimensions();
911 VLOG(2) << "TENSOR after: " << dim_after.d[0] << ", " << dim_after.d[1]
912 << dim_after.d[2] << ", " << dim_after.d[3];
913 }
914
915 nvinfer1::IConvolutionLayer* layer =
916 ctx.network()->addConvolution(*const_cast<nvinfer1::ITensor*>(tensor),
917 noutput, kernel_size, weights, biases);
918
919 layer->setStride(stride);
920 layer->setPadding({padding[0].first, padding[1].first});
921 layer->setName(node_def.name().c_str());
922 nvinfer1::ITensor* output_tensor = layer->getOutput(0);
923
924 auto dim_after = output_tensor->getDimensions();
925 VLOG(2) << "TENSOR out: " << dim_after.d[0] << ", " << dim_after.d[1]
926 << dim_after.d[2] << ", " << dim_after.d[3];
927
928 if (data_format == "NHWC") {
929 // TODO(jie): transpose it back!
930 output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
931 } else {
932 VLOG(2) << "NCHW !!!!";
933 }
934 outputs->push_back(TRT_TensorOrWeights(output_tensor));
935 return tensorflow::Status::OK();
936 }
937
ConvertPool(Converter & ctx,const tensorflow::NodeDef & node_def,std::vector<TRT_TensorOrWeights> const & inputs,std::vector<TRT_TensorOrWeights> * outputs)938 tensorflow::Status ConvertPool(Converter& ctx,
939 const tensorflow::NodeDef& node_def,
940 std::vector<TRT_TensorOrWeights> const& inputs,
941 std::vector<TRT_TensorOrWeights>* outputs) {
942 nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
943 TFAttrs attrs(node_def);
944
945 int h_index = 2;
946 int w_index = 3;
947 auto data_format = attrs.get<string>("data_format");
948 if (data_format == "NHWC") {
949 h_index = 1;
950 w_index = 2;
951 tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
952 {0, 3, 1, 2});
953 } else {
954 VLOG(2) << "NCHW !!!!";
955 }
956 nvinfer1::PoolingType type;
957 // TODO(jie): support other pooling type
958 if (node_def.op() == "MaxPool")
959 type = nvinfer1::PoolingType::kMAX;
960 else
961 return tensorflow::errors::Unimplemented("Only supports Max pool");
962
963 // TODO(jie): NCHW
964 auto tf_stride = attrs.get<std::vector<int>>("strides");
965 nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
966
967 auto tf_kernel = attrs.get<std::vector<int>>("ksize");
968 nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
969
970 auto tensor_dim = tensor->getDimensions();
971 std::vector<std::pair<int, int>> padding;
972 // TODO(jie): padding.
973 if (attrs.get<string>("padding") == "SAME") {
974 // This is NCHW tensor with no batch dimension.
975 // 1 -> h
976 // 2 -> w
977 padding = CreateSamePadding(
978 stride, ksize,
979 {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
980 } else if (attrs.get<string>("padding") == "VALID") {
981 // No padding for valid padding here
982 VLOG(2) << "No padding added for VALID padding in pool" << node_def.name();
983 padding = {{0, 0}, {0, 0}};
984 } else {
985 return tensorflow::errors::Unimplemented(
986 "Current MaxPool cannot support padding other than SAME");
987 }
988
989 if (padding[0].first != padding[0].second ||
990 padding[1].first != padding[1].second) {
991 // TODO(jie): handle asymmetric padding
992 VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
993 << padding[1].first << padding[1].second;
994 auto pad_layer = ctx.network()->addPadding(
995 *const_cast<nvinfer1::ITensor*>(tensor),
996 nvinfer1::DimsHW(padding[0].first, padding[1].first),
997 nvinfer1::DimsHW(padding[0].second, padding[1].second));
998 padding = {{0, 0}, {0, 0}};
999 tensor = pad_layer->getOutput(0);
1000 }
1001
1002 nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling(
1003 *const_cast<nvinfer1::ITensor*>(tensor), type, ksize);
1004
1005 layer->setStride(stride);
1006 layer->setPadding({padding[0].first, padding[1].first});
1007 layer->setName(node_def.name().c_str());
1008 nvinfer1::ITensor* output_tensor = layer->getOutput(0);
1009
1010 if (data_format == "NHWC") {
1011 // TODO(jie): transpose it back!
1012 output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
1013 } else {
1014 VLOG(2) << "NCHW !!!!";
1015 }
1016 outputs->push_back(TRT_TensorOrWeights(output_tensor));
1017 return tensorflow::Status::OK();
1018 }
1019
ConvertActivation(Converter & ctx,const tensorflow::NodeDef & node_def,std::vector<TRT_TensorOrWeights> const & inputs,std::vector<TRT_TensorOrWeights> * outputs)1020 tensorflow::Status ConvertActivation(
1021 Converter& ctx, const tensorflow::NodeDef& node_def,
1022 std::vector<TRT_TensorOrWeights> const& inputs,
1023 std::vector<TRT_TensorOrWeights>* outputs) {
1024 nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
1025 nvinfer1::IActivationLayer* layer = ctx.network()->addActivation(
1026 *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ActivationType::kRELU);
1027 nvinfer1::ITensor* output_tensor = layer->getOutput(0);
1028 outputs->push_back(TRT_TensorOrWeights(output_tensor));
1029 return tensorflow::Status::OK();
1030 }
1031
ConvertScale(Converter & ctx,const tensorflow::NodeDef & node_def,std::vector<TRT_TensorOrWeights> const & inputs,std::vector<TRT_TensorOrWeights> * outputs)1032 tensorflow::Status ConvertScale(Converter& ctx,
1033 const tensorflow::NodeDef& node_def,
1034 std::vector<TRT_TensorOrWeights> const& inputs,
1035 std::vector<TRT_TensorOrWeights>* outputs) {
1036 if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
1037 !inputs.at(1).is_weights())
1038 return tensorflow::errors::Unimplemented(
1039 "Only supports tensor op weight for now, at " + node_def.name());
1040 // Implement tensor binaryOp weight [channel wise] for now;
1041 nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
1042
1043 // TODO(jie): handle NHWC/NCHW transpose;
1044 TRT_ShapedWeights weights = inputs.at(1).weights();
1045 TRT_ShapedWeights empty_weights(weights.type_);
1046
1047 TFAttrs attrs(node_def);
1048
1049 // Transpose NHWC
1050 auto data_format = attrs.get<string>("data_format");
1051 if (data_format == "NHWC") {
1052 tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
1053 {0, 3, 1, 2});
1054 // TODO(jie): transpose it
1055 } else {
1056 VLOG(2) << "NCHW !!!!";
1057 }
1058 nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
1059 *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ScaleMode::kCHANNEL,
1060 weights, empty_weights, empty_weights);
1061
1062 nvinfer1::ITensor* output_tensor = layer->getOutput(0);
1063 if (data_format == "NHWC") {
1064 // TODO(jie): transpose it back!
1065 output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
1066 } else {
1067 VLOG(2) << "NCHW !!!!";
1068 }
1069 outputs->push_back(TRT_TensorOrWeights(output_tensor));
1070 return tensorflow::Status::OK();
1071 }
1072
ConvertConst(Converter & ctx,const tensorflow::NodeDef & node_def,std::vector<TRT_TensorOrWeights> const & inputs,std::vector<TRT_TensorOrWeights> * outputs)1073 tensorflow::Status ConvertConst(Converter& ctx,
1074 const tensorflow::NodeDef& node_def,
1075 std::vector<TRT_TensorOrWeights> const& inputs,
1076 std::vector<TRT_TensorOrWeights>* outputs) {
1077 const auto& weights_tensor = node_def.attr().at("value").tensor();
1078
1079 // Get trt type & shape
1080 TFAttrs attrs(node_def);
1081 const tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("dtype");
1082
1083 // Create shaped weights as output
1084 tensorflow::Tensor tensor;
1085 if (!tensor.FromProto(weights_tensor))
1086 return tensorflow::errors::Internal("Cannot parse weight tensor proto: " +
1087 node_def.name());
1088
1089 TRT_ShapedWeights weights(dtype);
1090 if (!weights_tensor.float_val().empty()) {
1091 VLOG(2) << "SCALAR!!!" << node_def.name();
1092 nvinfer1::Dims scalar_shape;
1093 if (tensor.dims() > 0) {
1094 VLOG(2) << "Dimensions: " << tensor.dims();
1095 weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
1096 GetTensorShape(tensor));
1097 } else {
1098 VLOG(2) << "Dimensions: " << tensor.dims();
1099 scalar_shape.nbDims = 1;
1100 scalar_shape.d[0] = 1;
1101 scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
1102 for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) {
1103 scalar_shape.d[i] = 0;
1104 scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL;
1105 }
1106 weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
1107 scalar_shape);
1108 }
1109 } else if (!weights_tensor.tensor_content().empty()) {
1110 VLOG(2) << "TENSOR!!!" << node_def.name();
1111 const auto& content = weights_tensor.tensor_content();
1112
1113 weights = ctx.get_temp_weights(dtype, GetTensorShape(tensor));
1114 if (content.size() > 0) {
1115 const int dtype_size = tensorflow::DataTypeSize(dtype);
1116 CHECK_EQ(0, content.size() % dtype_size)
1117 << "Tensor content size (" << content.size()
1118 << ") is not a multiple of " << dtype_size;
1119 port::CopyToArray(
1120 content, static_cast<char*>(const_cast<void*>(weights.GetValues())));
1121 }
1122 } else {
1123 return tensorflow::errors::Unimplemented(
1124 "Not supported constant type, at " + node_def.name());
1125 }
1126 // Pass the output
1127 outputs->push_back(TRT_TensorOrWeights(weights));
1128 return tensorflow::Status::OK();
1129 }
1130
ConvertIdentity(Converter & ctx,const tensorflow::NodeDef & node_def,std::vector<TRT_TensorOrWeights> const & inputs,std::vector<TRT_TensorOrWeights> * outputs)1131 tensorflow::Status ConvertIdentity(
1132 Converter& ctx, const tensorflow::NodeDef& node_def,
1133 std::vector<TRT_TensorOrWeights> const& inputs,
1134 std::vector<TRT_TensorOrWeights>* outputs) {
1135 outputs->push_back(inputs.at(0));
1136 return tensorflow::Status::OK();
1137 }
1138
ConvertBinary(Converter & ctx,const tensorflow::NodeDef & node_def,std::vector<TRT_TensorOrWeights> const & inputs,std::vector<TRT_TensorOrWeights> * outputs)1139 tensorflow::Status ConvertBinary(Converter& ctx,
1140 const tensorflow::NodeDef& node_def,
1141 std::vector<TRT_TensorOrWeights> const& inputs,
1142 std::vector<TRT_TensorOrWeights>* outputs) {
1143 if (inputs.size() != 2)
1144 return tensorflow::errors::FailedPrecondition(
1145 "Binary ops require two tensor input, at " + node_def.name());
1146
1147 if (inputs.at(0).is_weights() && inputs.at(1).is_weights())
1148 return ConstantFoldBinary(ctx, node_def, inputs, outputs);
1149
1150 if (inputs.at(0).is_tensor() && inputs.at(1).is_weights())
1151 return BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(),
1152 inputs.at(1).weights(), outputs);
1153
1154 if (inputs.at(0).is_weights() && inputs.at(1).is_tensor())
1155 return BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(),
1156 inputs.at(0).weights(), outputs);
1157
1158 if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor())
1159 return BinaryTensorOpTensor(ctx, node_def, inputs.at(0).tensor(),
1160 inputs.at(1).tensor(), outputs);
1161
1162 return tensorflow::errors::Unknown("Binary op input error, at " +
1163 node_def.name());
1164 }
1165
ConvertUnary(Converter & ctx,const tensorflow::NodeDef & node_def,std::vector<TRT_TensorOrWeights> const & inputs,std::vector<TRT_TensorOrWeights> * outputs)1166 tensorflow::Status ConvertUnary(Converter& ctx,
1167 const tensorflow::NodeDef& node_def,
1168 std::vector<TRT_TensorOrWeights> const& inputs,
1169 std::vector<TRT_TensorOrWeights>* outputs) {
1170 if (inputs.size() != 1)
1171 return tensorflow::errors::FailedPrecondition(
1172 "Unary ops require single tensor input, at " + node_def.name());
1173
1174 if (inputs.at(0).is_weights())
1175 return ConstantFoldUnary(ctx, node_def, inputs, outputs);
1176 else if (inputs.at(0).is_tensor())
1177 return tensorflow::errors::Unimplemented(
1178 "Unary op for tensor not supported, at " + node_def.name());
1179
1180 return tensorflow::errors::Unknown("Binary op input error, at " +
1181 node_def.name());
1182 }
1183
ConvertReduce(Converter & ctx,const tensorflow::NodeDef & node_def,std::vector<TRT_TensorOrWeights> const & inputs,std::vector<TRT_TensorOrWeights> * outputs)1184 tensorflow::Status ConvertReduce(Converter& ctx,
1185 const tensorflow::NodeDef& node_def,
1186 std::vector<TRT_TensorOrWeights> const& inputs,
1187 std::vector<TRT_TensorOrWeights>* outputs) {
1188 if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
1189 !inputs.at(1).is_weights())
1190 return tensorflow::errors::InvalidArgument(
1191 "Input expects tensor and weights, at" + node_def.name());
1192
1193 // Implement tensor binaryOp weight [channel wise] for now;
1194 nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
1195 auto dims = tensor->getDimensions();
1196 // Restore implicit batch dimension
1197 int nb_dims = dims.nbDims + 1;
1198
1199 TRT_ShapedWeights index_list = inputs.at(1).weights();
1200
1201 TFAttrs attrs(node_def);
1202 // TODO(jie): handle data type.
1203 // Index type here is done through TF type, so I can leverage their
1204 // EnumToDataType for my cast
1205 auto index_type = attrs.get<tensorflow::DataType>("Tidx");
1206
1207 // Only expect to handle INT32 as attributes for now
1208 if (index_type != tensorflow::DataType::DT_INT32)
1209 return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32");
1210 auto index_list_data =
1211 static_cast<int*>(const_cast<void*>(index_list.GetValues()));
1212
1213 // Hack warning: have to fall back to pool layer since reduce is not in public
1214 // TRT yet.
1215 if (nb_dims != 4)
1216 return tensorflow::errors::InvalidArgument(
1217 "TRT only support reduce on 4 dimensional tensors, at" +
1218 node_def.name());
1219 if (index_list.count() > 2)
1220 return tensorflow::errors::InvalidArgument(
1221 "TRT cannot support reduce on more than 2 dimensions, at" +
1222 node_def.name());
1223
1224 std::set<int> idx_set;
1225 // We cannot operate on Channel. permutation flag used to transpose tensor
1226 int permuted_index = -1;
1227 for (int i = 0; i < index_list.count(); i++) {
1228 if (index_list_data[i] == 0)
1229 return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" +
1230 node_def.name());
1231 if (index_list_data[i] == 1) permuted_index = 1;
1232 idx_set.emplace(index_list_data[i]);
1233 }
1234
1235 std::vector<int> permutation_order(nb_dims);
1236 nvinfer1::DimsHW pool_kernel;
1237 if (permuted_index == 1) {
1238 for (int i = 2; i < nb_dims; i++) {
1239 if (idx_set.count(i)) {
1240 permuted_index = i;
1241 break;
1242 }
1243 }
1244 for (int i = 0; i < nb_dims; i++) permutation_order[i] = i;
1245
1246 permutation_order[permuted_index] = 1;
1247 permutation_order[1] = permuted_index;
1248
1249 // Apply permutation before extracting dimension for pool_kernel
1250 tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
1251 permutation_order);
1252 }
1253
1254 // Apply permutation before extracting dimension for pool_kernel
1255 pool_kernel.d[0] = (idx_set.count(2) || permuted_index == 2) ? dims.d[1] : 1;
1256 pool_kernel.d[1] = (idx_set.count(3) || permuted_index == 3) ? dims.d[2] : 1;
1257
1258 nvinfer1::ITensor* output_tensor;
1259
1260 if (node_def.op() == "Mean") {
1261 nvinfer1::IPoolingLayer* layer =
1262 ctx.network()->addPooling(*const_cast<nvinfer1::ITensor*>(tensor),
1263 nvinfer1::PoolingType::kAVERAGE, pool_kernel);
1264 output_tensor = layer->getOutput(0);
1265 } else {
1266 return tensorflow::errors::Unimplemented(
1267 "Op not supported " + node_def.op() + " , at " + node_def.name());
1268 }
1269 if (permuted_index != -1) {
1270 // Apply permutation before extracting dimension for pool_kernel
1271 output_tensor = ctx.TransposeTensor(
1272 const_cast<nvinfer1::ITensor*>(output_tensor), permutation_order);
1273 }
1274 return tensorflow::Status::OK();
1275 }
1276
ConvertPad(Converter & ctx,const tensorflow::NodeDef & node_def,std::vector<TRT_TensorOrWeights> const & inputs,std::vector<TRT_TensorOrWeights> * outputs)1277 tensorflow::Status ConvertPad(Converter& ctx,
1278 const tensorflow::NodeDef& node_def,
1279 std::vector<TRT_TensorOrWeights> const& inputs,
1280 std::vector<TRT_TensorOrWeights>* outputs) {
1281 if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
1282 !inputs.at(1).is_weights())
1283 return tensorflow::errors::InvalidArgument(
1284 "Input expects tensor and weights, at" + node_def.name());
1285
1286 // Implement tensor binaryOp weight [channel wise] for now;
1287 nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
1288 auto dims = tensor->getDimensions();
1289 // Restore implicit batch dimension
1290 int nb_dims = dims.nbDims + 1;
1291
1292 TRT_ShapedWeights pads = inputs.at(1).weights();
1293
1294 TFAttrs attrs(node_def);
1295 // Padding type here is done through TF type
1296 // so I can leverage their EnumToDataType for my cast
1297 auto padding_type = attrs.get<tensorflow::DataType>("Tpaddings");
1298 // TODO(jie): handle data type conversion for TRT?
1299
1300 if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2)
1301 return tensorflow::errors::InvalidArgument(
1302 "Pad only supports explicit padding on 4 dimensional tensor, at " +
1303 node_def.name());
1304
1305 // Only expect to handle INT32 as attributes for now
1306 if (padding_type != tensorflow::DataType::DT_INT32)
1307 return tensorflow::errors::Unimplemented(
1308 "Tpaddings supports only DT_INT32");
1309 auto pad_data = static_cast<int*>(const_cast<void*>(pads.GetValues()));
1310
1311 std::vector<int32_t> pad_index;
1312 for (int i = 0; i < nb_dims; i++) {
1313 if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0)
1314 pad_index.push_back(i);
1315 }
1316
1317 // No padding at all, we should exit
1318 if (pad_index.size() == 0) {
1319 outputs->push_back(inputs.at(0));
1320 return tensorflow::Status::OK();
1321 }
1322
1323 // Only supports padding on less than 2 axis GIE-2579
1324 if (pad_index.size() > 2)
1325 return tensorflow::errors::InvalidArgument(
1326 "Padding layer does not support padding on > 2");
1327
1328 // Padding on batch dimension is not supported
1329 if (pad_index[0] == 0)
1330 return tensorflow::errors::InvalidArgument(
1331 "Padding layer does not support padding on batch dimension");
1332
1333 // Not doing the legit thing here. ignoring padding on dim 1 and 3;
1334 // TODO(jie): implement pad as uff parser
1335 if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3)
1336 return tensorflow::errors::Unimplemented(
1337 "Padding layer does not support padding on dimension 1 and 3 yet");
1338
1339 bool legit_pad = true;
1340 nvinfer1::DimsHW pre_padding(0, 0);
1341 nvinfer1::DimsHW post_padding(0, 0);
1342
1343 std::vector<int32_t> permuted_pad_index(pad_index);
1344 if (pad_index[0] == 1) {
1345 legit_pad = false;
1346 tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
1347 {0, 3, 2, 1});
1348 permuted_pad_index[0] = 3;
1349 }
1350
1351 for (size_t i = 0; i < pad_index.size(); i++) {
1352 int index = pad_index[i];
1353 if (permuted_pad_index[i] == 2) {
1354 pre_padding.h() = pad_data[index * 2];
1355 post_padding.h() = pad_data[index * 2 + 1];
1356 } else if (permuted_pad_index[i] == 3) {
1357 pre_padding.w() = pad_data[index * 2];
1358 post_padding.w() = pad_data[index * 2 + 1];
1359 }
1360 }
1361
1362 nvinfer1::IPaddingLayer* layer = ctx.network()->addPadding(
1363 *const_cast<nvinfer1::ITensor*>(tensor), pre_padding, post_padding);
1364 nvinfer1::ITensor* output_tensor = layer->getOutput(0);
1365
1366 if (!legit_pad)
1367 output_tensor = ctx.TransposeTensor(
1368 const_cast<nvinfer1::ITensor*>(output_tensor), {0, 3, 2, 1});
1369
1370 outputs->push_back(TRT_TensorOrWeights(output_tensor));
1371 return tensorflow::Status::OK();
1372 }
1373
register_op_converters()1374 void Converter::register_op_converters() {
1375 // vgg_16 slim implementation
1376 op_registry_["Placeholder"] = ConvertPlaceholder;
1377 op_registry_["Conv2D"] = ConvertConv2D;
1378 op_registry_["Relu"] = ConvertActivation;
1379 op_registry_["MaxPool"] = ConvertPool;
1380 // This could be really handled as ConvertBinary
1381 op_registry_["BiasAdd"] = ConvertScale;
1382 op_registry_["Const"] = ConvertConst;
1383 // op_registry_["MatMul"] = ConvertFullyConnected; // Not used in vgg
1384 // TODO(ben,jie): this is a temp hack.
1385 op_registry_["Identity"] = ConvertIdentity; // Identity should be removed
1386 // op_registry_["AvgPool"] = ConvertPool;
1387
1388 // resnet_50_v1 slim implementation
1389 op_registry_["Add"] = ConvertBinary;
1390 op_registry_["Mul"] = ConvertBinary;
1391 op_registry_["Sub"] = ConvertBinary;
1392 op_registry_["Rsqrt"] = ConvertUnary;
1393 op_registry_["Mean"] = ConvertReduce;
1394 op_registry_["Pad"] = ConvertPad;
1395 // TODO(ben,jie): Add more ops
1396 }
1397
1398 } // namespace
1399
ConvertSubGraphToTensorRTNodeDef(const tensorflow::Graph & graph,const std::set<int> & subgraph_node_ids,const std::vector<std::pair<int,int>> & input_inds,const std::vector<std::pair<int,int>> & output_inds,size_t max_batch_size,size_t max_workspace_size_bytes,const tensorflow::grappler::GraphProperties & graph_properties,tensorflow::NodeDef * trt_node)1400 tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
1401 const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
1402 const std::vector<std::pair<int, int>>& input_inds,
1403 const std::vector<std::pair<int, int>>& output_inds, size_t max_batch_size,
1404 size_t max_workspace_size_bytes,
1405 const tensorflow::grappler::GraphProperties& graph_properties,
1406 tensorflow::NodeDef* trt_node) {
1407 // Visit nodes in reverse topological order and construct the TRT network.
1408
1409 // Toposort
1410 std::vector<tensorflow::Node*> order_vec;
1411 tensorflow::GetPostOrder(graph, &order_vec);
1412 // Select just the subgraph
1413 std::list<tensorflow::Node*> order;
1414 for (tensorflow::Node* node : order_vec) {
1415 if (subgraph_node_ids.count(node->id())) {
1416 // We want topological order to contstruct the
1417 // network layer by layer
1418 order.push_front(node);
1419 }
1420 }
1421 // Topological order is needed to build TRT network
1422
1423 tensorflow::tensorrt::Logger trt_logger;
1424
1425 auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger));
1426 if (!trt_builder) {
1427 return tensorflow::errors::Internal(
1428 "Failed to create TensorRT builder object");
1429 }
1430
1431 auto trt_network = infer_object(trt_builder->createNetwork());
1432 if (!trt_network) {
1433 return tensorflow::errors::Internal(
1434 "Failed to create TensorRT network object");
1435 }
1436
1437 // Build the network
1438 Converter converter(trt_network.get());
1439
1440 std::vector<string> input_names;
1441 std::vector<tensorflow::DataType> input_dtypes;
1442 for (std::pair<int, int> const& input : input_inds) {
1443 int node_id = input.first;
1444 int output_idx = input.second;
1445 tensorflow::Node* node = graph.FindNodeId(node_id);
1446 auto node_name = node->name();
1447 input_names.push_back(node_name); // Insert original node name without port
1448 // TODO(jie): alternative :)
1449 if (!graph_properties.HasOutputProperties(node_name))
1450 return tensorflow::errors::Internal("Failed to find input node: " +
1451 node_name);
1452
1453 auto op_info_vec = graph_properties.GetOutputProperties(node_name);
1454 if (static_cast<int>(op_info_vec.size()) < output_idx)
1455 return tensorflow::errors::Internal(
1456 "Accessing output index of: " + std::to_string(output_idx) +
1457 ", at node: " + node_name + " with output entry from shape_map: " +
1458 std::to_string(op_info_vec.size()));
1459
1460 auto op_info = op_info_vec.at(output_idx);
1461
1462 tensorflow::DataType tf_dtype = op_info.dtype();
1463 input_dtypes.push_back(tf_dtype);
1464
1465 nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
1466 TF_CHECK_OK(ConvertDType(tf_dtype, &dtype));
1467
1468 VLOG(2) << "Accessing output index of: " << std::to_string(output_idx)
1469 << ", at node: " << node_name
1470 << " with output entry from shape_map: "
1471 << std::to_string(op_info_vec.size());
1472
1473 // TODO(ben,jie): update TRT input format/dimension
1474 nvinfer1::DimsCHW input_dim_psuedo_chw;
1475 for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1;
1476
1477 for (int i = 1; i < op_info.shape().dim_size(); i++) {
1478 VLOG(2) << "dimension: " << i
1479 << " , size: " << op_info.shape().dim(i).size();
1480 input_dim_psuedo_chw.d[i - 1] = op_info.shape().dim(i).size();
1481 }
1482
1483 // TODO(ben,jie): proper way to restore input tensor name?
1484 auto input_tensor_name = node_name;
1485 if (output_idx != 0)
1486 input_tensor_name = node_name + ":" + std::to_string(output_idx);
1487
1488 nvinfer1::ITensor* input_tensor = converter.network()->addInput(
1489 input_tensor_name.c_str(), dtype, input_dim_psuedo_chw);
1490
1491 if (!input_tensor)
1492 return tensorflow::errors::InvalidArgument(
1493 "Failed to create Input layer");
1494 VLOG(2) << "Input tensor name :" << input_tensor_name;
1495
1496 if (!converter.insert_input_tensor(input_tensor_name, input_tensor))
1497 return tensorflow::errors::AlreadyExists(
1498 "Output tensor already exists for op: " + input_tensor_name);
1499 }
1500
1501 VLOG(2) << "Finished sorting";
1502
1503 for (const tensorflow::Node* node : order) {
1504 const tensorflow::NodeDef& node_def = node->def();
1505 VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op();
1506 TF_RETURN_IF_ERROR(converter.convert_node(node_def));
1507 }
1508
1509 VLOG(2) << "Finished conversion";
1510
1511 // Gather output metadata
1512 std::vector<string> output_names;
1513 std::vector<tensorflow::DataType> output_dtypes;
1514 for (std::pair<int, int> const& output : output_inds) {
1515 int node_id = output.first;
1516 int output_idx = output.second;
1517 tensorflow::Node* node = graph.FindNodeId(node_id);
1518 string op_name = node->name();
1519 string tensor_name = op_name;
1520 if (output_idx != 0)
1521 tensor_name = tensor_name + ":" + std::to_string(output_idx);
1522 VLOG(2) << "Output tensor name: " << tensor_name;
1523 output_names.push_back(tensor_name);
1524 auto tensor_or_weights = converter.get_tensor(tensor_name);
1525 if (!tensor_or_weights.is_tensor()) {
1526 return tensorflow::errors::InvalidArgument(
1527 "Output node is weights not tensor");
1528 }
1529 nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
1530 if (!tensor) {
1531 return tensorflow::errors::NotFound("Output tensor not found: " +
1532 tensor_name);
1533 }
1534 converter.network()->markOutput(*tensor);
1535 tensorflow::DataType tf_dtype = node->output_type(output_idx);
1536 output_dtypes.push_back(tf_dtype);
1537 nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT;
1538 TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype));
1539 tensor->setType(trt_dtype);
1540 }
1541
1542 VLOG(2) << "Finished output";
1543 // TODO(jie): static_id is not thread safe.
1544 static int static_id = 0;
1545
1546 // Build the engine
1547 trt_builder->setMaxBatchSize(max_batch_size);
1548 trt_builder->setMaxWorkspaceSize(max_workspace_size_bytes);
1549 VLOG(0) << "Starting build engine " << static_id;
1550 // TODO(ben,jie): half2 and int8 mode support
1551 string engine_plan_string;
1552 {
1553 auto trt_engine =
1554 infer_object(trt_builder->buildCudaEngine(*converter.network()));
1555 VLOG(0) << "Built network";
1556 auto engine_plan = infer_object(trt_engine->serialize());
1557 VLOG(0) << "Serialized engine";
1558 const char* engine_plan_data =
1559 static_cast<const char*>(engine_plan->data());
1560 engine_plan_string =
1561 string(engine_plan_data, engine_plan_data + engine_plan->size());
1562 }
1563
1564 VLOG(0) << "Finished engine";
1565
1566 // Build the TRT op
1567 // TODO(sami,ben,jie): proper naming!
1568 tensorflow::NodeDefBuilder op_builder(
1569 tensorflow::strings::StrCat("my_trt_op", static_id++), "TRTEngineOp");
1570 std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
1571 for (size_t i = 0; i < input_names.size(); ++i) {
1572 int output_idx = input_inds.at(i).second;
1573 // We wired up the input here already, it is redundant to do it again in
1574 // ConvertSubGraphToTensorRT(convert_graph.cc)
1575 auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut(
1576 input_names.at(i), output_idx, input_dtypes.at(i));
1577 income_edges.push_back(incoming_edge);
1578 }
1579 tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list(
1580 income_edges);
1581 op_builder.Input(input_list);
1582
1583 VLOG(0) << "Finished op preparation";
1584
1585 auto status = op_builder.Attr("serialized_engine", engine_plan_string)
1586 .Attr("input_nodes", input_names)
1587 .Attr("output_nodes", output_names)
1588 .Attr("OutT", output_dtypes)
1589 .Finalize(trt_node);
1590
1591 VLOG(0) << status.ToString() << " finished op building";
1592
1593 return tensorflow::Status::OK();
1594 }
1595
1596 } // namespace convert
1597 } // namespace tensorrt
1598 } // namespace tensorflow
1599
1600 #endif // GOOGLE_TENSORRT
1601 #endif // GOOGLE_CUDA
1602