• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #if GOOGLE_CUDA && GOOGLE_TENSORRT
16 
17 #include "tensorflow/compiler/tf2tensorrt/convert/ops/quantization_ops.h"
18 
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include "absl/strings/str_format.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/cc/framework/ops.h"
24 #include "tensorflow/cc/framework/scope.h"
25 #include "tensorflow/cc/ops/array_ops.h"
26 #include "tensorflow/cc/ops/const_op.h"
27 #include "tensorflow/cc/ops/linalg_ops.h"
28 #include "tensorflow/cc/ops/math_ops.h"
29 #include "tensorflow/cc/ops/nn_ops.h"
30 #include "tensorflow/compiler/jit/shape_inference.h"
31 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
32 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
33 #include "tensorflow/compiler/tf2tensorrt/trt_convert_api.h"
34 #include "tensorflow/core/framework/tensor_testutil.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/lib/strings/str_util.h"
37 #include "tensorflow/core/platform/status_matchers.h"
38 #include "tensorflow/core/protobuf/error_codes.pb.h"
39 
40 #if IS_TRT_VERSION_GE(8, 0, 0, 0)
41 
42 namespace tensorflow {
43 namespace tensorrt {
44 namespace convert {
45 
46 namespace ops = ::tensorflow::ops;
47 using ::tensorflow::testing::StatusIs;
48 
49 // This anonymous namespace contains helper functions for instatiating small TF
50 // building blocks. These are used below to construct specific graph patterns
51 // which test end-to-end conversion of the TF graph to an explciit-precision
52 // enabled TensorRT network.
53 namespace {
54 
55 enum class ConvEpilogueType {
56   kNone,
57   kReLU,
58   kBatchNorm,
59   kReLUBatchnorm,
60   kBatchnormReLU
61 };
62 
operator <<(std::ostream & os,ConvEpilogueType epilogue)63 std::ostream& operator<<(std::ostream& os, ConvEpilogueType epilogue) {
64   switch (epilogue) {
65     case ConvEpilogueType::kNone:
66       return os << "None";
67     case ConvEpilogueType::kReLU:
68       return os << "ReLU only";
69     case ConvEpilogueType::kBatchNorm:
70       return os << "BatchNorm Only";
71     case ConvEpilogueType::kReLUBatchnorm:
72       return os << "ReLU+Batchnorm";
73     case ConvEpilogueType::kBatchnormReLU:
74       return os << "BatchNorm+ReLU";
75   }
76 }
77 
DebugString(ConvEpilogueType epilogue)78 std::string DebugString(ConvEpilogueType epilogue) {
79   std::stringstream ss;
80   ss << epilogue;
81   return ss.str();
82 }
83 
84 // Adds a 2D 3x3, single channel input with specified data_format. data_format
85 // must be NHWC,NCHW or NHW.
AddInput(Scope scope,int input_idx,const std::string data_format,std::array<int,3> size_chw={1, 3, 3})86 ops::Placeholder AddInput(Scope scope, int input_idx,
87                           const std::string data_format,
88                           std::array<int, 3> size_chw = {1, 3, 3}) {
89   PartialTensorShape input_shape;
90   if (data_format == "NCHW") {
91     input_shape =
92         PartialTensorShape({1, size_chw[0], size_chw[1], size_chw[2]});
93   } else if (data_format == "NHWC") {
94     input_shape =
95         PartialTensorShape({1, size_chw[1], size_chw[2], size_chw[0]});
96   } else if (data_format == "NHW") {
97     input_shape = PartialTensorShape({1, size_chw[1], size_chw[2]});
98   } else {
99     LOG(FATAL) << "Unknown input shape type " << data_format;
100   }
101   auto input_attrs = ops::Placeholder::Attrs().Shape(input_shape);
102   return ops::Placeholder(scope.WithOpName(absl::StrCat("input_", input_idx)),
103                           DT_FLOAT, input_attrs);
104 }
105 
106 // Adds QDQ op with min = -1.0f, max = 1.0f.
AddQDQV2(Scope scope,Input input)107 Output AddQDQV2(Scope scope, Input input) {
108   // Create scaling factors.
109   auto input_min =
110       ops::Const<float>(scope.WithOpName("in_min"), -1.0f, TensorShape{});
111   auto input_max =
112       ops::Const<float>(scope.WithOpName("in_max"), 1.0f, TensorShape{});
113   return ops::QuantizeAndDequantizeV2(scope.WithOpName("qdq"), input, input_min,
114                                       input_max);
115 }
116 
AddOutput(Scope scope,Output input,int idx,bool add_qdq)117 Output AddOutput(Scope scope, Output input, int idx, bool add_qdq) {
118   Output out = input;
119   if (add_qdq) {
120     out = AddQDQV2(scope, input);
121   }
122   return ops::Identity(scope.WithOpName(StrCat("output_", idx)), out);
123 }
124 
125 // Adds a 3x3x1x1 Conv2D op and optional bias weights, followed by ReLU
126 // activation. Puts QDQ between (weights, op). Puts QDQ between (input, op)
127 // when qdq_on_output=false. Otherwise, puts QDQ between (op, output).
AddConv2D(Scope scope,Input input,int in_channels,int out_channels,std::array<int,2> filter_size={1, 1},std::array<int,2> stride={1, 1},const std::string & data_format="NCHW",bool with_bias=true,ConvEpilogueType epilogue=ConvEpilogueType::kBatchnormReLU,bool qdq_on_output=false)128 Output AddConv2D(Scope scope, Input input, int in_channels, int out_channels,
129                  std::array<int, 2> filter_size = {1, 1},
130                  std::array<int, 2> stride = {1, 1},
131                  const std::string& data_format = "NCHW", bool with_bias = true,
132                  ConvEpilogueType epilogue = ConvEpilogueType::kBatchnormReLU,
133                  bool qdq_on_output = false) {
134   // Create 3x3 non-quantized weights weights.
135   auto weights_const = ops::Const(
136       scope.WithOpName("weights"), 1.0f,
137       TensorShape({filter_size[0], filter_size[1], in_channels, out_channels}));
138 
139   // Add QDQ to input if we don't add QDQ to output.
140   auto conv_input =
141       !qdq_on_output ? AddQDQV2(scope.WithOpName("qdq_input"), input) : input;
142 
143   std::array<int, 4> strides =
144       data_format == "NCHW" ? std::array<int, 4>{1, 1, stride[0], stride[1]}
145                             : std::array<int, 4>{1, stride[0], stride[1], 1};
146   Output result = ops::Conv2D(
147       scope.WithOpName("conv2d"), conv_input, AddQDQV2(scope, weights_const),
148       /*strides=*/{1, 1, 1, 1},
149       /*padding=*/"SAME", ops::Conv2D::Attrs().DataFormat(data_format));
150 
151   if (with_bias) {
152     auto bias_const = ops::Const(scope.WithOpName("bias_weights"), 1.0f,
153                                  TensorShape({
154                                      out_channels,
155                                  }));
156     result = ops::BiasAdd(scope.WithOpName("bias"), result, bias_const,
157                           ops::BiasAdd::Attrs().DataFormat(data_format));
158   }
159 
160   auto add_bn = [scope, data_format](Input input,
__anon4f49c2cd0202(Input input, const int channels) 161                                      const int channels) -> Output {
162     TensorShape constant_shape = TensorShape({channels});
163     auto bn_scale =
164         ops::Const(scope.WithOpName("bn_scale"), 1.0f, constant_shape);
165     auto bn_offset =
166         ops::Const(scope.WithOpName("bn_offset"), 1.0f, constant_shape);
167     auto bn_mean =
168         ops::Const(scope.WithOpName("bn_mean"), 0.1f, TensorShape({channels}));
169     auto bn_var =
170         ops::Const(scope.WithOpName("bn_var"), 1.0f, TensorShape({channels}));
171     Input conv_bn_input = IS_TRT_VERSION_GE(8, 0, 1, 0)
172                               ? input
173                               : AddQDQV2(scope.WithOpName("qdq_input"), input);
174     return ops::FusedBatchNormV3(
175                scope.WithOpName("bn"), conv_bn_input, bn_scale, bn_offset,
176                bn_mean, bn_var,
177                ops::FusedBatchNormV3::Attrs().IsTraining(false).DataFormat(
178                    data_format))
179         .y;
180   };
181 
182   switch (epilogue) {
183     case ConvEpilogueType::kBatchNorm: {
184       result = add_bn(result, out_channels);
185       break;
186     }
187     case ConvEpilogueType::kReLU: {
188       result = ops::Relu(scope.WithOpName("relu"), result);
189       break;
190     }
191     case ConvEpilogueType::kReLUBatchnorm: {
192       result = ops::Relu(scope.WithOpName("relu"), result);
193       result = add_bn(result, out_channels);
194       break;
195     }
196     case ConvEpilogueType::kBatchnormReLU: {
197       result = add_bn(result, out_channels);
198       result = ops::Relu(scope.WithOpName("relu"), result);
199       break;
200     }
201     case ConvEpilogueType::kNone:
202       break;
203   }
204 
205   if (qdq_on_output) {
206     result = AddQDQV2(scope.WithOpName("qdq_out"), result);
207   }
208   return result;
209 }
210 
211 // Adds a batch matrix multiplication V2 operation, which commonly appears in
212 // fully connected layers. Puts QDQ between (input, op) as well as between
213 // (weights, op).
AddMatMul(Scope scope,const std::string & name,Input input)214 ops::BatchMatMulV2 AddMatMul(Scope scope, const std::string& name,
215                              Input input) {
216   // Add QDQ to input.
217   auto input_qdq = AddQDQV2(scope, input);
218 
219   // Add 3x3 weights with QDQ.
220   auto weights_const =
221       ops::Const(scope.WithOpName(name + "_weights"),
222                  {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f},
223                  TensorShape({3, 3}));
224   auto weights_qdq = AddQDQV2(scope.WithOpName("weights_qdq"), weights_const);
225   return ops::BatchMatMulV2(scope.WithOpName(name), input_qdq, weights_qdq);
226 }
227 }  // namespace
228 
229 struct QDQTestOptions {
230   bool conv_has_bias{true};
231 
232   // TRT7 may have issues with optimizing redundant transpose operations between
233   // QDQ and Op introduced by TF-TRT when format is not "NCHW". This allows to
234   // test both cases as well as WAR feasibility.
235   std::string data_format{"NCHW"};
236 
237   // Tests whether placing QDQ on outputs rather than inputs is handled
238   // correctly.
239   bool qdq_on_output{false};
240 
241   // Option for testing whether TRT build succeeds without a final QDQ before
242   // the output.
243   bool final_qdq{true};
244 
245   // Whether to add activations (relu) to conv operations
246   ConvEpilogueType conv_epilogue;
247 
248   // TF-TRT API Options
249   TfTrtConversionParams conversion_params{};
250 };
251 
operator <<(std::ostream & os,const QDQTestOptions opts)252 std::ostream& operator<<(std::ostream& os, const QDQTestOptions opts) {
253   return os << absl::StrCat(
254              "QDQTestOptions(conv_has_bias=",
255              static_cast<int>(opts.conv_has_bias),
256              ", qdq_on_output=", static_cast<int>(opts.qdq_on_output),
257              ", data_format=", opts.data_format,
258              ", conv_epilogue=", DebugString(opts.conv_epilogue),
259              ", final_qdq=", opts.final_qdq, ")");
260 }
261 
EnumerateQDQTestOptions()262 std::vector<QDQTestOptions> EnumerateQDQTestOptions() {
263   std::vector<QDQTestOptions> result;
264   for (const auto* data_format : {"NCHW", "NHWC"}) {
265     for (auto use_bias : {true, false}) {
266       for (auto qdq_on_output : {false, true}) {
267         // For now, always append a QDQ before output. For small single-op tests
268         // (besides QDQ), TensorRT7 sometimes has trouble.
269         for (auto final_qdq : {true, false}) {
270           for (auto conv_epilogue :
271                {ConvEpilogueType::kReLU, ConvEpilogueType::kNone,
272                 ConvEpilogueType::kBatchnormReLU}) {
273             // Currently batch norm converter only supports NHWC.
274             if (data_format == "NHWC" &&
275                 (conv_epilogue == ConvEpilogueType::kBatchnormReLU ||
276                  conv_epilogue == ConvEpilogueType::kBatchNorm ||
277                  conv_epilogue == ConvEpilogueType::kBatchnormReLU)) {
278               continue;
279             }
280             QDQTestOptions opts{};
281             opts.conv_has_bias = use_bias;
282             opts.data_format = data_format;
283             opts.qdq_on_output = qdq_on_output;
284             opts.final_qdq = final_qdq;
285             opts.conv_epilogue = conv_epilogue;
286             result.push_back(opts);
287           }
288         }
289       }
290     }
291   }
292   return result;
293 }
294 
295 // This class is a test fixture for running graph conversion and evaluating
296 // numerical results.
297 class QDQExplicitTest : public ::testing::Test,
298                         public ::testing::WithParamInterface<QDQTestOptions> {
299  public:
GetShape(const std::string & name,const GraphShapeInfo & shapes)300   static StatusOr<PartialTensorShape> GetShape(const std::string& name,
301                                                const GraphShapeInfo& shapes) {
302     TRT_ENSURE(shapes.find(name) != shapes.end());
303     TRT_ENSURE(shapes.at(name).size() == 1);
304     return shapes.at(name)[0].shape;
305   }
306 
GetModel(const GraphDef & graph_def,const std::vector<const NodeDef * > & inputs,const std::vector<const NodeDef * > & outputs,const GraphShapeInfo & shapes)307   StatusOr<MetaGraphDef> GetModel(const GraphDef& graph_def,
308                                   const std::vector<const NodeDef*>& inputs,
309                                   const std::vector<const NodeDef*>& outputs,
310                                   const GraphShapeInfo& shapes) {
311     TRT_ENSURE(!inputs.empty());
312     TRT_ENSURE(!outputs.empty());
313 
314     MetaGraphDef out;
315     out.mutable_graph_def()->CopyFrom(graph_def);
316 
317     SignatureDef signature_def;
318     auto& mutable_inputs = *signature_def.mutable_inputs();
319     for (int i = 0; i < inputs.size(); i++) {
320       std::string input_name = inputs[i]->name();
321       auto& input = mutable_inputs[input_name];
322       input.set_name(input_name);
323       input.set_dtype(DT_FLOAT);
324       TRT_ENSURE(shapes.find(input_name) != shapes.end());
325       TRT_ENSURE(shapes.at(input_name).size() == 1);
326       PartialTensorShape input_shape = shapes.at(input_name)[0].shape;
327       input_shape.AsProto(input.mutable_tensor_shape());
328     }
329 
330     auto& mutable_outputs = *signature_def.mutable_outputs();
331     for (int i = 0; i < outputs.size(); i++) {
332       std::string output_name = outputs[i]->name();
333       auto& output = mutable_outputs[output_name];
334       output.set_name(output_name);
335       output.set_dtype(DT_FLOAT);
336       TRT_ENSURE(shapes.find(output_name) != shapes.end());
337       TRT_ENSURE(shapes.at(output_name).size() == 1);
338       PartialTensorShape output_shape = shapes.at(output_name)[0].shape;
339       output_shape.AsProto(output.mutable_tensor_shape());
340     }
341 
342     (*out.mutable_signature_def())["serving_default"] = signature_def;
343     return out;
344   }
345 
346   // Confirms that we have a TRT node with the correct attributes.
CheckTrtNode(const GraphDef & converted_graph_def)347   static Status CheckTrtNode(const GraphDef& converted_graph_def) {
348     int n_trt_ops = 0;
349     string op_name{"TRTEngineOp"};
350     for (const auto& node : converted_graph_def.node()) {
351       if (op_name == node.op()) {
352         n_trt_ops++;
353         const auto& attr = node.attr();
354         TRT_ENSURE(attr.at("static_engine").b());
355         VLOG(2) << "Found serialized segment with size "
356                 << attr.at("serialized_segment").s().size();
357         TRT_ENSURE(!attr.at("serialized_segment").s().empty());
358       }
359     }
360     TRT_ENSURE(n_trt_ops == 1);
361     return Status::OK();
362   }
363 
ConvertAndRun(Scope * scope)364   Status ConvertAndRun(Scope* scope) {
365     std::vector<const NodeDef*> inputs;
366     std::vector<const NodeDef*> outputs;
367 
368     GraphDef gdef;
369     scope->ToGraphDef(&gdef);
370 
371     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
372     TF_RETURN_IF_ERROR(scope->ToGraph(graph.get()));
373 
374     GraphShapeInfo shape_info;
375     TF_RETURN_IF_ERROR(InferShapes(graph.get(), /*arg_shapes=*/{},
376                                    /*fnlib_def=*/nullptr, &shape_info));
377 
378     for (const NodeDef& node : gdef.node()) {
379       if (absl::StartsWith(node.name(), "input_")) {
380         inputs.push_back(&node);
381       } else if (absl::StartsWith(node.name(), "output_")) {
382         outputs.push_back(&node);
383       }
384     }
385 
386     StatusOr<MetaGraphDef> meta_graph_def =
387         GetModel(gdef, inputs, outputs, shape_info);
388     TRT_ENSURE_OK(meta_graph_def);
389 
390     // Create a list of input tensors, they will be used to build the engines.
391     std::vector<Tensor> input_tensors;
392     std::vector<std::string> input_names;
393     for (const auto& input : inputs) {
394       input_names.push_back(input->name());
395 
396       StatusOr<PartialTensorShape> input_shape =
397           GetShape(input->name(), shape_info);
398       TRT_ENSURE_OK(input_shape);
399 
400       TensorShape shape;
401       input_shape->AsTensorShape(&shape);
402       Tensor tensor(DT_FLOAT, shape);
403       test::FillIota(&tensor, 1.0f);
404       input_tensors.push_back(tensor);
405     }
406 
407     std::vector<std::string> output_names;
408     for (const auto& output : outputs) {
409       output_names.push_back(output->name());
410     }
411 
412     TfTrtConversionParams conversion_params;
413     conversion_params.allow_build_at_runtime = true;
414     conversion_params.precision_mode = TrtPrecisionMode::INT8;
415     conversion_params.use_calibration = false;
416     conversion_params.convert_to_static_engine = true;
417     TRT_ENSURE(input_names.size() == input_tensors.size());
418     StatusOr<GraphDef> converted_gdef = tensorrt::ConvertAndBuild(
419         meta_graph_def->graph_def(), input_names, output_names, {input_tensors},
420         conversion_params);
421     TRT_ENSURE_OK(converted_gdef);
422     return CheckTrtNode(*converted_gdef);
423   }
424 
425  protected:
426   TfTrtConversionParams params_;
427   TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
428 };
429 
430 class TestQDQSuite : public QDQExplicitTest {};
431 
432 #define EXPECT_QDQ_ON_OUTPUT_FAILURE(params, scope)                  \
433   if ((params).qdq_on_output) {                                      \
434     EXPECT_THAT(ConvertAndRun(&(scope)), StatusIs(error::INTERNAL)); \
435     return;                                                          \
436   }
437 #define EXPECT_NO_FINAL_QDQ_FAILURE(params, scope)                   \
438   if (!(params).final_qdq) {                                         \
439     EXPECT_THAT(ConvertAndRun(&(scope)), StatusIs(error::INTERNAL)); \
440     return;                                                          \
441   }
442 
443 #define EXPECT_BUILD_OK(scope) TF_EXPECT_OK(ConvertAndRun(&(scope)))
444 
445 #define POLICY_TRT7(params, scope)               \
446   if (!IS_TRT_VERSION_GE(8, 0, 0, 0)) {          \
447     EXPECT_QDQ_ON_OUTPUT_FAILURE(params, scope); \
448     EXPECT_NO_FINAL_QDQ_FAILURE(params, scope);  \
449     EXPECT_BUILD_OK(scope);                      \
450   }
451 
452 #define POLICY_TRT8(params, scope)                                          \
453   if (IS_TRT_VERSION_GE(8, 0, 0, 0)) {                                      \
454     if (((params).conv_epilogue == ConvEpilogueType::kBatchNorm ||          \
455          (params).conv_epilogue == ConvEpilogueType::kBatchnormReLU ||      \
456          (params).conv_epilogue == ConvEpilogueType::kReLUBatchnorm) &&     \
457         (params).data_format == "NHWC") {                                   \
458       EXPECT_THAT(ConvertAndRun(&(scope)), StatusIs(error::UNIMPLEMENTED)); \
459       return;                                                               \
460     }                                                                       \
461     EXPECT_BUILD_OK(scope);                                                 \
462   }
463 
464 #define SKIP_TRT7(x)                           \
465   if (!IS_TRT_VERSION_GE(8, 0, 0, 0) && (x)) { \
466     GTEST_SKIP();                              \
467   }
468 
469 // Tests single convolution operation conversion.
TEST_P(TestQDQSuite,TestConv2DBasic)470 TEST_P(TestQDQSuite, TestConv2DBasic) {
471   SKIP_TRT7(GetParam().qdq_on_output);
472   SKIP_TRT7(GetParam().data_format != "NCHW");
473   SKIP_TRT7(!GetParam().final_qdq);
474 
475   Scope scope = Scope::NewRootScope();
476   auto input = AddInput(scope, 0, GetParam().data_format, {3, 28, 28});
477 
478   Output out = input;
479   const int num_conv = 1;
480   std::array<int, 2> in_channels = {3, 16};
481   std::array<int, 2> out_channels = {16, 32};
482   for (int i = 0; i < num_conv; i++) {
483     out = AddConv2D(scope.WithOpName(absl::StrCat("conv_", i)), out,
484                     in_channels[i], out_channels[i], /*filter_size=*/{3, 3},
485                     /*stride=*/{1, 1}, GetParam().data_format,
486                     GetParam().conv_has_bias, GetParam().conv_epilogue,
487                     GetParam().qdq_on_output);
488   }
489   out = AddOutput(scope, out, 0, GetParam().final_qdq);
490   POLICY_TRT7(GetParam(), scope);
491   POLICY_TRT8(GetParam(), scope);
492 }
493 
494 // Tests single convolution operation conversion.
TEST_P(TestQDQSuite,TestMatMulBasic)495 TEST_P(TestQDQSuite, TestMatMulBasic) {
496   // Some param's don't apply, so pick one combination and skip otherwise.
497   if (GetParam().data_format != "NCHW" || !GetParam().conv_has_bias ||
498       GetParam().qdq_on_output ||
499       GetParam().conv_epilogue != ConvEpilogueType::kReLU) {
500     GTEST_SKIP();
501   }
502   Scope scope = Scope::NewRootScope();
503   auto input = AddInput(scope, 0, "NHW");
504   auto matmul_op = AddMatMul(scope, "matmul", input);
505   auto out = AddOutput(scope, matmul_op, 0, GetParam().final_qdq);
506 
507   TF_EXPECT_OK(ConvertAndRun(&scope));
508 }
509 
510 // A single input goes through two different Conv2D. Outputs of Conv2D are
511 // added together, with QQQ on both branches of ADD.
TEST_P(TestQDQSuite,AddBothBranchesQDQConvSingleInput)512 TEST_P(TestQDQSuite, AddBothBranchesQDQConvSingleInput) {
513   SKIP_TRT7(!GetParam().final_qdq);
514   SKIP_TRT7(GetParam().data_format != "NCHW");
515 
516   Scope scope = Scope::NewRootScope();
517   auto input1 = AddInput(scope, 0, GetParam().data_format,
518                          /*size_chw=*/{3, 28, 28});
519 
520   auto conv1 =
521       AddConv2D(scope, input1, 3, 16, /*filter_size=*/{3, 3}, /*stride=*/{1, 1},
522                 GetParam().data_format, GetParam().conv_has_bias,
523                 GetParam().conv_epilogue, GetParam().qdq_on_output);
524 
525   auto conv2 =
526       AddConv2D(scope, input1, 3, 16, /*filter_size=*/{3, 3}, /*stride=*/
527                 {1, 1}, GetParam().data_format, GetParam().conv_has_bias,
528                 GetParam().conv_epilogue, GetParam().qdq_on_output);
529 
530   // In the case of "qdq on output", we don't need to add QDQ.
531   auto add =
532       ops::Add(scope.WithOpName("add"),
533                !GetParam().qdq_on_output ? AddQDQV2(scope, conv1) : conv1,
534                !GetParam().qdq_on_output ? AddQDQV2(scope, conv2) : conv2);
535 
536   auto conv3 =
537       AddConv2D(scope.WithOpName("conv3"), conv2, 16, 16, {1, 1}, {1, 1},
538                 GetParam().data_format, GetParam().conv_has_bias,
539                 GetParam().conv_epilogue, GetParam().qdq_on_output);
540 
541   auto out =
542       AddOutput(scope.WithOpName("output"), conv3, 0, GetParam().final_qdq);
543 
544   POLICY_TRT7(GetParam(), scope);
545   POLICY_TRT8(GetParam(), scope);
546 }
547 
548 // Tests adding a single tensor to itself, with QQQ on both branches of ADD.
TEST_P(TestQDQSuite,AddBothBranchesQDQMultipleInput)549 TEST_P(TestQDQSuite, AddBothBranchesQDQMultipleInput) {
550   // TRT7 QDQ optimizer makes single-input restriction.
551   SKIP_TRT7(true);
552 
553   Scope scope = Scope::NewRootScope();
554   auto input1 = AddInput(scope, 0, GetParam().data_format);
555   auto input2 = AddInput(scope, 1, GetParam().data_format);
556   auto add =
557       ops::Add(scope.WithOpName("add"),
558                !GetParam().qdq_on_output ? AddQDQV2(scope, input1) : input1,
559                !GetParam().qdq_on_output ? AddQDQV2(scope, input2) : input2);
560   auto output = AddOutput(scope, add, 0, true);
561   TF_EXPECT_OK(ConvertAndRun(&scope));
562 }
563 
564 // Tests Conv-MaxPool combination
TEST_P(TestQDQSuite,TestConvMaxpool)565 TEST_P(TestQDQSuite, TestConvMaxpool) {
566   SKIP_TRT7(!GetParam().final_qdq);
567   SKIP_TRT7(GetParam().data_format != "NCHW");
568 
569   Scope scope = Scope::NewRootScope();
570   auto input = AddInput(scope, 0, GetParam().data_format,
571                         /*size_chw=*/{3, 28, 28});
572   auto conv1 =
573       AddConv2D(scope, input, 3, 16, /*filter_size=*/{3, 3}, /*stride=*/{1, 1},
574                 GetParam().data_format, GetParam().conv_has_bias,
575                 GetParam().conv_epilogue, GetParam().qdq_on_output);
576   ops::MaxPool maxpool =
577       ops::MaxPool(scope.WithOpName("maxpool"),
578                    AddQDQV2(scope.WithOpName("mp_qdq_in"), conv1), {1, 1, 1, 1},
579                    {1, 1, 1, 1}, "SAME",
580                    ops::MaxPool::Attrs().DataFormat(GetParam().data_format));
581   auto output =
582       AddOutput(scope.WithOpName("output"), maxpool, 0, GetParam().final_qdq);
583   POLICY_TRT7(GetParam(), scope);
584   POLICY_TRT8(GetParam(), scope);
585 }
586 
587 // Tests QDQ(Conv(QDQ(MaxPool(Conv(QDQ(x))))))
TEST_P(TestQDQSuite,TestConvMaxpoolConv)588 TEST_P(TestQDQSuite, TestConvMaxpoolConv) {
589   SKIP_TRT7(!GetParam().final_qdq);
590   SKIP_TRT7(GetParam().data_format != "NCHW");
591 
592   Scope scope = Scope::NewRootScope();
593   auto input = AddInput(scope, 0, GetParam().data_format,
594                         /*size_chw=*/{3, 28, 28});
595   auto conv1 =
596       AddConv2D(scope, input, 3, 16, /*filter_size=*/{3, 3}, /*stride=*/{1, 1},
597                 GetParam().data_format, GetParam().conv_has_bias,
598                 GetParam().conv_epilogue, GetParam().qdq_on_output);
599   ops::MaxPool maxpool =
600       ops::MaxPool(scope.WithOpName("maxpool"),
601                    AddQDQV2(scope.WithOpName("mp_qdq_in"), conv1), {1, 1, 1, 1},
602                    {1, 1, 1, 1}, "SAME",
603                    ops::MaxPool::Attrs().DataFormat(GetParam().data_format));
604   auto conv2 = AddConv2D(scope, maxpool, 16, 16, {3, 3}, {1, 1},
605                          GetParam().data_format, GetParam().conv_has_bias,
606                          GetParam().conv_epilogue, GetParam().qdq_on_output);
607   auto output =
608       AddOutput(scope.WithOpName("out"), conv2, 0, GetParam().final_qdq);
609   POLICY_TRT7(GetParam(), scope);
610   POLICY_TRT8(GetParam(), scope);
611 }
612 
613 INSTANTIATE_TEST_SUITE_P(TestQDQSuiteInst, TestQDQSuite,
614                          ::testing::ValuesIn(EnumerateQDQTestOptions()));
615 
616 }  // namespace convert
617 }  // namespace tensorrt
618 }  // namespace tensorflow
619 
620 #endif  // IS_TRT_VERSION_GE(8, 0, 0, 0)
621 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
622