1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #if GOOGLE_CUDA && GOOGLE_TENSORRT
16
17 #include "tensorflow/compiler/tf2tensorrt/convert/ops/quantization_ops.h"
18
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include "absl/strings/str_format.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/cc/framework/ops.h"
24 #include "tensorflow/cc/framework/scope.h"
25 #include "tensorflow/cc/ops/array_ops.h"
26 #include "tensorflow/cc/ops/const_op.h"
27 #include "tensorflow/cc/ops/linalg_ops.h"
28 #include "tensorflow/cc/ops/math_ops.h"
29 #include "tensorflow/cc/ops/nn_ops.h"
30 #include "tensorflow/compiler/jit/shape_inference.h"
31 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
32 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
33 #include "tensorflow/compiler/tf2tensorrt/trt_convert_api.h"
34 #include "tensorflow/core/framework/tensor_testutil.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/lib/strings/str_util.h"
37 #include "tensorflow/core/platform/status_matchers.h"
38 #include "tensorflow/core/protobuf/error_codes.pb.h"
39
40 #if IS_TRT_VERSION_GE(8, 0, 0, 0)
41
42 namespace tensorflow {
43 namespace tensorrt {
44 namespace convert {
45
46 namespace ops = ::tensorflow::ops;
47 using ::tensorflow::testing::StatusIs;
48
49 // This anonymous namespace contains helper functions for instatiating small TF
50 // building blocks. These are used below to construct specific graph patterns
51 // which test end-to-end conversion of the TF graph to an explciit-precision
52 // enabled TensorRT network.
53 namespace {
54
55 enum class ConvEpilogueType {
56 kNone,
57 kReLU,
58 kBatchNorm,
59 kReLUBatchnorm,
60 kBatchnormReLU
61 };
62
operator <<(std::ostream & os,ConvEpilogueType epilogue)63 std::ostream& operator<<(std::ostream& os, ConvEpilogueType epilogue) {
64 switch (epilogue) {
65 case ConvEpilogueType::kNone:
66 return os << "None";
67 case ConvEpilogueType::kReLU:
68 return os << "ReLU only";
69 case ConvEpilogueType::kBatchNorm:
70 return os << "BatchNorm Only";
71 case ConvEpilogueType::kReLUBatchnorm:
72 return os << "ReLU+Batchnorm";
73 case ConvEpilogueType::kBatchnormReLU:
74 return os << "BatchNorm+ReLU";
75 }
76 }
77
DebugString(ConvEpilogueType epilogue)78 std::string DebugString(ConvEpilogueType epilogue) {
79 std::stringstream ss;
80 ss << epilogue;
81 return ss.str();
82 }
83
84 // Adds a 2D 3x3, single channel input with specified data_format. data_format
85 // must be NHWC,NCHW or NHW.
AddInput(Scope scope,int input_idx,const std::string data_format,std::array<int,3> size_chw={1, 3, 3})86 ops::Placeholder AddInput(Scope scope, int input_idx,
87 const std::string data_format,
88 std::array<int, 3> size_chw = {1, 3, 3}) {
89 PartialTensorShape input_shape;
90 if (data_format == "NCHW") {
91 input_shape =
92 PartialTensorShape({1, size_chw[0], size_chw[1], size_chw[2]});
93 } else if (data_format == "NHWC") {
94 input_shape =
95 PartialTensorShape({1, size_chw[1], size_chw[2], size_chw[0]});
96 } else if (data_format == "NHW") {
97 input_shape = PartialTensorShape({1, size_chw[1], size_chw[2]});
98 } else {
99 LOG(FATAL) << "Unknown input shape type " << data_format;
100 }
101 auto input_attrs = ops::Placeholder::Attrs().Shape(input_shape);
102 return ops::Placeholder(scope.WithOpName(absl::StrCat("input_", input_idx)),
103 DT_FLOAT, input_attrs);
104 }
105
106 // Adds QDQ op with min = -1.0f, max = 1.0f.
AddQDQV2(Scope scope,Input input)107 Output AddQDQV2(Scope scope, Input input) {
108 // Create scaling factors.
109 auto input_min =
110 ops::Const<float>(scope.WithOpName("in_min"), -1.0f, TensorShape{});
111 auto input_max =
112 ops::Const<float>(scope.WithOpName("in_max"), 1.0f, TensorShape{});
113 return ops::QuantizeAndDequantizeV2(scope.WithOpName("qdq"), input, input_min,
114 input_max);
115 }
116
AddOutput(Scope scope,Output input,int idx,bool add_qdq)117 Output AddOutput(Scope scope, Output input, int idx, bool add_qdq) {
118 Output out = input;
119 if (add_qdq) {
120 out = AddQDQV2(scope, input);
121 }
122 return ops::Identity(scope.WithOpName(StrCat("output_", idx)), out);
123 }
124
125 // Adds a 3x3x1x1 Conv2D op and optional bias weights, followed by ReLU
126 // activation. Puts QDQ between (weights, op). Puts QDQ between (input, op)
127 // when qdq_on_output=false. Otherwise, puts QDQ between (op, output).
AddConv2D(Scope scope,Input input,int in_channels,int out_channels,std::array<int,2> filter_size={1, 1},std::array<int,2> stride={1, 1},const std::string & data_format="NCHW",bool with_bias=true,ConvEpilogueType epilogue=ConvEpilogueType::kBatchnormReLU,bool qdq_on_output=false)128 Output AddConv2D(Scope scope, Input input, int in_channels, int out_channels,
129 std::array<int, 2> filter_size = {1, 1},
130 std::array<int, 2> stride = {1, 1},
131 const std::string& data_format = "NCHW", bool with_bias = true,
132 ConvEpilogueType epilogue = ConvEpilogueType::kBatchnormReLU,
133 bool qdq_on_output = false) {
134 // Create 3x3 non-quantized weights weights.
135 auto weights_const = ops::Const(
136 scope.WithOpName("weights"), 1.0f,
137 TensorShape({filter_size[0], filter_size[1], in_channels, out_channels}));
138
139 // Add QDQ to input if we don't add QDQ to output.
140 auto conv_input =
141 !qdq_on_output ? AddQDQV2(scope.WithOpName("qdq_input"), input) : input;
142
143 std::array<int, 4> strides =
144 data_format == "NCHW" ? std::array<int, 4>{1, 1, stride[0], stride[1]}
145 : std::array<int, 4>{1, stride[0], stride[1], 1};
146 Output result = ops::Conv2D(
147 scope.WithOpName("conv2d"), conv_input, AddQDQV2(scope, weights_const),
148 /*strides=*/{1, 1, 1, 1},
149 /*padding=*/"SAME", ops::Conv2D::Attrs().DataFormat(data_format));
150
151 if (with_bias) {
152 auto bias_const = ops::Const(scope.WithOpName("bias_weights"), 1.0f,
153 TensorShape({
154 out_channels,
155 }));
156 result = ops::BiasAdd(scope.WithOpName("bias"), result, bias_const,
157 ops::BiasAdd::Attrs().DataFormat(data_format));
158 }
159
160 auto add_bn = [scope, data_format](Input input,
__anon4f49c2cd0202(Input input, const int channels) 161 const int channels) -> Output {
162 TensorShape constant_shape = TensorShape({channels});
163 auto bn_scale =
164 ops::Const(scope.WithOpName("bn_scale"), 1.0f, constant_shape);
165 auto bn_offset =
166 ops::Const(scope.WithOpName("bn_offset"), 1.0f, constant_shape);
167 auto bn_mean =
168 ops::Const(scope.WithOpName("bn_mean"), 0.1f, TensorShape({channels}));
169 auto bn_var =
170 ops::Const(scope.WithOpName("bn_var"), 1.0f, TensorShape({channels}));
171 Input conv_bn_input = IS_TRT_VERSION_GE(8, 0, 1, 0)
172 ? input
173 : AddQDQV2(scope.WithOpName("qdq_input"), input);
174 return ops::FusedBatchNormV3(
175 scope.WithOpName("bn"), conv_bn_input, bn_scale, bn_offset,
176 bn_mean, bn_var,
177 ops::FusedBatchNormV3::Attrs().IsTraining(false).DataFormat(
178 data_format))
179 .y;
180 };
181
182 switch (epilogue) {
183 case ConvEpilogueType::kBatchNorm: {
184 result = add_bn(result, out_channels);
185 break;
186 }
187 case ConvEpilogueType::kReLU: {
188 result = ops::Relu(scope.WithOpName("relu"), result);
189 break;
190 }
191 case ConvEpilogueType::kReLUBatchnorm: {
192 result = ops::Relu(scope.WithOpName("relu"), result);
193 result = add_bn(result, out_channels);
194 break;
195 }
196 case ConvEpilogueType::kBatchnormReLU: {
197 result = add_bn(result, out_channels);
198 result = ops::Relu(scope.WithOpName("relu"), result);
199 break;
200 }
201 case ConvEpilogueType::kNone:
202 break;
203 }
204
205 if (qdq_on_output) {
206 result = AddQDQV2(scope.WithOpName("qdq_out"), result);
207 }
208 return result;
209 }
210
211 // Adds a batch matrix multiplication V2 operation, which commonly appears in
212 // fully connected layers. Puts QDQ between (input, op) as well as between
213 // (weights, op).
AddMatMul(Scope scope,const std::string & name,Input input)214 ops::BatchMatMulV2 AddMatMul(Scope scope, const std::string& name,
215 Input input) {
216 // Add QDQ to input.
217 auto input_qdq = AddQDQV2(scope, input);
218
219 // Add 3x3 weights with QDQ.
220 auto weights_const =
221 ops::Const(scope.WithOpName(name + "_weights"),
222 {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f},
223 TensorShape({3, 3}));
224 auto weights_qdq = AddQDQV2(scope.WithOpName("weights_qdq"), weights_const);
225 return ops::BatchMatMulV2(scope.WithOpName(name), input_qdq, weights_qdq);
226 }
227 } // namespace
228
229 struct QDQTestOptions {
230 bool conv_has_bias{true};
231
232 // TRT7 may have issues with optimizing redundant transpose operations between
233 // QDQ and Op introduced by TF-TRT when format is not "NCHW". This allows to
234 // test both cases as well as WAR feasibility.
235 std::string data_format{"NCHW"};
236
237 // Tests whether placing QDQ on outputs rather than inputs is handled
238 // correctly.
239 bool qdq_on_output{false};
240
241 // Option for testing whether TRT build succeeds without a final QDQ before
242 // the output.
243 bool final_qdq{true};
244
245 // Whether to add activations (relu) to conv operations
246 ConvEpilogueType conv_epilogue;
247
248 // TF-TRT API Options
249 TfTrtConversionParams conversion_params{};
250 };
251
operator <<(std::ostream & os,const QDQTestOptions opts)252 std::ostream& operator<<(std::ostream& os, const QDQTestOptions opts) {
253 return os << absl::StrCat(
254 "QDQTestOptions(conv_has_bias=",
255 static_cast<int>(opts.conv_has_bias),
256 ", qdq_on_output=", static_cast<int>(opts.qdq_on_output),
257 ", data_format=", opts.data_format,
258 ", conv_epilogue=", DebugString(opts.conv_epilogue),
259 ", final_qdq=", opts.final_qdq, ")");
260 }
261
EnumerateQDQTestOptions()262 std::vector<QDQTestOptions> EnumerateQDQTestOptions() {
263 std::vector<QDQTestOptions> result;
264 for (const auto* data_format : {"NCHW", "NHWC"}) {
265 for (auto use_bias : {true, false}) {
266 for (auto qdq_on_output : {false, true}) {
267 // For now, always append a QDQ before output. For small single-op tests
268 // (besides QDQ), TensorRT7 sometimes has trouble.
269 for (auto final_qdq : {true, false}) {
270 for (auto conv_epilogue :
271 {ConvEpilogueType::kReLU, ConvEpilogueType::kNone,
272 ConvEpilogueType::kBatchnormReLU}) {
273 // Currently batch norm converter only supports NHWC.
274 if (data_format == "NHWC" &&
275 (conv_epilogue == ConvEpilogueType::kBatchnormReLU ||
276 conv_epilogue == ConvEpilogueType::kBatchNorm ||
277 conv_epilogue == ConvEpilogueType::kBatchnormReLU)) {
278 continue;
279 }
280 QDQTestOptions opts{};
281 opts.conv_has_bias = use_bias;
282 opts.data_format = data_format;
283 opts.qdq_on_output = qdq_on_output;
284 opts.final_qdq = final_qdq;
285 opts.conv_epilogue = conv_epilogue;
286 result.push_back(opts);
287 }
288 }
289 }
290 }
291 }
292 return result;
293 }
294
295 // This class is a test fixture for running graph conversion and evaluating
296 // numerical results.
297 class QDQExplicitTest : public ::testing::Test,
298 public ::testing::WithParamInterface<QDQTestOptions> {
299 public:
GetShape(const std::string & name,const GraphShapeInfo & shapes)300 static StatusOr<PartialTensorShape> GetShape(const std::string& name,
301 const GraphShapeInfo& shapes) {
302 TRT_ENSURE(shapes.find(name) != shapes.end());
303 TRT_ENSURE(shapes.at(name).size() == 1);
304 return shapes.at(name)[0].shape;
305 }
306
GetModel(const GraphDef & graph_def,const std::vector<const NodeDef * > & inputs,const std::vector<const NodeDef * > & outputs,const GraphShapeInfo & shapes)307 StatusOr<MetaGraphDef> GetModel(const GraphDef& graph_def,
308 const std::vector<const NodeDef*>& inputs,
309 const std::vector<const NodeDef*>& outputs,
310 const GraphShapeInfo& shapes) {
311 TRT_ENSURE(!inputs.empty());
312 TRT_ENSURE(!outputs.empty());
313
314 MetaGraphDef out;
315 out.mutable_graph_def()->CopyFrom(graph_def);
316
317 SignatureDef signature_def;
318 auto& mutable_inputs = *signature_def.mutable_inputs();
319 for (int i = 0; i < inputs.size(); i++) {
320 std::string input_name = inputs[i]->name();
321 auto& input = mutable_inputs[input_name];
322 input.set_name(input_name);
323 input.set_dtype(DT_FLOAT);
324 TRT_ENSURE(shapes.find(input_name) != shapes.end());
325 TRT_ENSURE(shapes.at(input_name).size() == 1);
326 PartialTensorShape input_shape = shapes.at(input_name)[0].shape;
327 input_shape.AsProto(input.mutable_tensor_shape());
328 }
329
330 auto& mutable_outputs = *signature_def.mutable_outputs();
331 for (int i = 0; i < outputs.size(); i++) {
332 std::string output_name = outputs[i]->name();
333 auto& output = mutable_outputs[output_name];
334 output.set_name(output_name);
335 output.set_dtype(DT_FLOAT);
336 TRT_ENSURE(shapes.find(output_name) != shapes.end());
337 TRT_ENSURE(shapes.at(output_name).size() == 1);
338 PartialTensorShape output_shape = shapes.at(output_name)[0].shape;
339 output_shape.AsProto(output.mutable_tensor_shape());
340 }
341
342 (*out.mutable_signature_def())["serving_default"] = signature_def;
343 return out;
344 }
345
346 // Confirms that we have a TRT node with the correct attributes.
CheckTrtNode(const GraphDef & converted_graph_def)347 static Status CheckTrtNode(const GraphDef& converted_graph_def) {
348 int n_trt_ops = 0;
349 string op_name{"TRTEngineOp"};
350 for (const auto& node : converted_graph_def.node()) {
351 if (op_name == node.op()) {
352 n_trt_ops++;
353 const auto& attr = node.attr();
354 TRT_ENSURE(attr.at("static_engine").b());
355 VLOG(2) << "Found serialized segment with size "
356 << attr.at("serialized_segment").s().size();
357 TRT_ENSURE(!attr.at("serialized_segment").s().empty());
358 }
359 }
360 TRT_ENSURE(n_trt_ops == 1);
361 return Status::OK();
362 }
363
ConvertAndRun(Scope * scope)364 Status ConvertAndRun(Scope* scope) {
365 std::vector<const NodeDef*> inputs;
366 std::vector<const NodeDef*> outputs;
367
368 GraphDef gdef;
369 scope->ToGraphDef(&gdef);
370
371 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
372 TF_RETURN_IF_ERROR(scope->ToGraph(graph.get()));
373
374 GraphShapeInfo shape_info;
375 TF_RETURN_IF_ERROR(InferShapes(graph.get(), /*arg_shapes=*/{},
376 /*fnlib_def=*/nullptr, &shape_info));
377
378 for (const NodeDef& node : gdef.node()) {
379 if (absl::StartsWith(node.name(), "input_")) {
380 inputs.push_back(&node);
381 } else if (absl::StartsWith(node.name(), "output_")) {
382 outputs.push_back(&node);
383 }
384 }
385
386 StatusOr<MetaGraphDef> meta_graph_def =
387 GetModel(gdef, inputs, outputs, shape_info);
388 TRT_ENSURE_OK(meta_graph_def);
389
390 // Create a list of input tensors, they will be used to build the engines.
391 std::vector<Tensor> input_tensors;
392 std::vector<std::string> input_names;
393 for (const auto& input : inputs) {
394 input_names.push_back(input->name());
395
396 StatusOr<PartialTensorShape> input_shape =
397 GetShape(input->name(), shape_info);
398 TRT_ENSURE_OK(input_shape);
399
400 TensorShape shape;
401 input_shape->AsTensorShape(&shape);
402 Tensor tensor(DT_FLOAT, shape);
403 test::FillIota(&tensor, 1.0f);
404 input_tensors.push_back(tensor);
405 }
406
407 std::vector<std::string> output_names;
408 for (const auto& output : outputs) {
409 output_names.push_back(output->name());
410 }
411
412 TfTrtConversionParams conversion_params;
413 conversion_params.allow_build_at_runtime = true;
414 conversion_params.precision_mode = TrtPrecisionMode::INT8;
415 conversion_params.use_calibration = false;
416 conversion_params.convert_to_static_engine = true;
417 TRT_ENSURE(input_names.size() == input_tensors.size());
418 StatusOr<GraphDef> converted_gdef = tensorrt::ConvertAndBuild(
419 meta_graph_def->graph_def(), input_names, output_names, {input_tensors},
420 conversion_params);
421 TRT_ENSURE_OK(converted_gdef);
422 return CheckTrtNode(*converted_gdef);
423 }
424
425 protected:
426 TfTrtConversionParams params_;
427 TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
428 };
429
430 class TestQDQSuite : public QDQExplicitTest {};
431
432 #define EXPECT_QDQ_ON_OUTPUT_FAILURE(params, scope) \
433 if ((params).qdq_on_output) { \
434 EXPECT_THAT(ConvertAndRun(&(scope)), StatusIs(error::INTERNAL)); \
435 return; \
436 }
437 #define EXPECT_NO_FINAL_QDQ_FAILURE(params, scope) \
438 if (!(params).final_qdq) { \
439 EXPECT_THAT(ConvertAndRun(&(scope)), StatusIs(error::INTERNAL)); \
440 return; \
441 }
442
443 #define EXPECT_BUILD_OK(scope) TF_EXPECT_OK(ConvertAndRun(&(scope)))
444
445 #define POLICY_TRT7(params, scope) \
446 if (!IS_TRT_VERSION_GE(8, 0, 0, 0)) { \
447 EXPECT_QDQ_ON_OUTPUT_FAILURE(params, scope); \
448 EXPECT_NO_FINAL_QDQ_FAILURE(params, scope); \
449 EXPECT_BUILD_OK(scope); \
450 }
451
452 #define POLICY_TRT8(params, scope) \
453 if (IS_TRT_VERSION_GE(8, 0, 0, 0)) { \
454 if (((params).conv_epilogue == ConvEpilogueType::kBatchNorm || \
455 (params).conv_epilogue == ConvEpilogueType::kBatchnormReLU || \
456 (params).conv_epilogue == ConvEpilogueType::kReLUBatchnorm) && \
457 (params).data_format == "NHWC") { \
458 EXPECT_THAT(ConvertAndRun(&(scope)), StatusIs(error::UNIMPLEMENTED)); \
459 return; \
460 } \
461 EXPECT_BUILD_OK(scope); \
462 }
463
464 #define SKIP_TRT7(x) \
465 if (!IS_TRT_VERSION_GE(8, 0, 0, 0) && (x)) { \
466 GTEST_SKIP(); \
467 }
468
469 // Tests single convolution operation conversion.
TEST_P(TestQDQSuite,TestConv2DBasic)470 TEST_P(TestQDQSuite, TestConv2DBasic) {
471 SKIP_TRT7(GetParam().qdq_on_output);
472 SKIP_TRT7(GetParam().data_format != "NCHW");
473 SKIP_TRT7(!GetParam().final_qdq);
474
475 Scope scope = Scope::NewRootScope();
476 auto input = AddInput(scope, 0, GetParam().data_format, {3, 28, 28});
477
478 Output out = input;
479 const int num_conv = 1;
480 std::array<int, 2> in_channels = {3, 16};
481 std::array<int, 2> out_channels = {16, 32};
482 for (int i = 0; i < num_conv; i++) {
483 out = AddConv2D(scope.WithOpName(absl::StrCat("conv_", i)), out,
484 in_channels[i], out_channels[i], /*filter_size=*/{3, 3},
485 /*stride=*/{1, 1}, GetParam().data_format,
486 GetParam().conv_has_bias, GetParam().conv_epilogue,
487 GetParam().qdq_on_output);
488 }
489 out = AddOutput(scope, out, 0, GetParam().final_qdq);
490 POLICY_TRT7(GetParam(), scope);
491 POLICY_TRT8(GetParam(), scope);
492 }
493
494 // Tests single convolution operation conversion.
TEST_P(TestQDQSuite,TestMatMulBasic)495 TEST_P(TestQDQSuite, TestMatMulBasic) {
496 // Some param's don't apply, so pick one combination and skip otherwise.
497 if (GetParam().data_format != "NCHW" || !GetParam().conv_has_bias ||
498 GetParam().qdq_on_output ||
499 GetParam().conv_epilogue != ConvEpilogueType::kReLU) {
500 GTEST_SKIP();
501 }
502 Scope scope = Scope::NewRootScope();
503 auto input = AddInput(scope, 0, "NHW");
504 auto matmul_op = AddMatMul(scope, "matmul", input);
505 auto out = AddOutput(scope, matmul_op, 0, GetParam().final_qdq);
506
507 TF_EXPECT_OK(ConvertAndRun(&scope));
508 }
509
510 // A single input goes through two different Conv2D. Outputs of Conv2D are
511 // added together, with QQQ on both branches of ADD.
TEST_P(TestQDQSuite,AddBothBranchesQDQConvSingleInput)512 TEST_P(TestQDQSuite, AddBothBranchesQDQConvSingleInput) {
513 SKIP_TRT7(!GetParam().final_qdq);
514 SKIP_TRT7(GetParam().data_format != "NCHW");
515
516 Scope scope = Scope::NewRootScope();
517 auto input1 = AddInput(scope, 0, GetParam().data_format,
518 /*size_chw=*/{3, 28, 28});
519
520 auto conv1 =
521 AddConv2D(scope, input1, 3, 16, /*filter_size=*/{3, 3}, /*stride=*/{1, 1},
522 GetParam().data_format, GetParam().conv_has_bias,
523 GetParam().conv_epilogue, GetParam().qdq_on_output);
524
525 auto conv2 =
526 AddConv2D(scope, input1, 3, 16, /*filter_size=*/{3, 3}, /*stride=*/
527 {1, 1}, GetParam().data_format, GetParam().conv_has_bias,
528 GetParam().conv_epilogue, GetParam().qdq_on_output);
529
530 // In the case of "qdq on output", we don't need to add QDQ.
531 auto add =
532 ops::Add(scope.WithOpName("add"),
533 !GetParam().qdq_on_output ? AddQDQV2(scope, conv1) : conv1,
534 !GetParam().qdq_on_output ? AddQDQV2(scope, conv2) : conv2);
535
536 auto conv3 =
537 AddConv2D(scope.WithOpName("conv3"), conv2, 16, 16, {1, 1}, {1, 1},
538 GetParam().data_format, GetParam().conv_has_bias,
539 GetParam().conv_epilogue, GetParam().qdq_on_output);
540
541 auto out =
542 AddOutput(scope.WithOpName("output"), conv3, 0, GetParam().final_qdq);
543
544 POLICY_TRT7(GetParam(), scope);
545 POLICY_TRT8(GetParam(), scope);
546 }
547
548 // Tests adding a single tensor to itself, with QQQ on both branches of ADD.
TEST_P(TestQDQSuite,AddBothBranchesQDQMultipleInput)549 TEST_P(TestQDQSuite, AddBothBranchesQDQMultipleInput) {
550 // TRT7 QDQ optimizer makes single-input restriction.
551 SKIP_TRT7(true);
552
553 Scope scope = Scope::NewRootScope();
554 auto input1 = AddInput(scope, 0, GetParam().data_format);
555 auto input2 = AddInput(scope, 1, GetParam().data_format);
556 auto add =
557 ops::Add(scope.WithOpName("add"),
558 !GetParam().qdq_on_output ? AddQDQV2(scope, input1) : input1,
559 !GetParam().qdq_on_output ? AddQDQV2(scope, input2) : input2);
560 auto output = AddOutput(scope, add, 0, true);
561 TF_EXPECT_OK(ConvertAndRun(&scope));
562 }
563
564 // Tests Conv-MaxPool combination
TEST_P(TestQDQSuite,TestConvMaxpool)565 TEST_P(TestQDQSuite, TestConvMaxpool) {
566 SKIP_TRT7(!GetParam().final_qdq);
567 SKIP_TRT7(GetParam().data_format != "NCHW");
568
569 Scope scope = Scope::NewRootScope();
570 auto input = AddInput(scope, 0, GetParam().data_format,
571 /*size_chw=*/{3, 28, 28});
572 auto conv1 =
573 AddConv2D(scope, input, 3, 16, /*filter_size=*/{3, 3}, /*stride=*/{1, 1},
574 GetParam().data_format, GetParam().conv_has_bias,
575 GetParam().conv_epilogue, GetParam().qdq_on_output);
576 ops::MaxPool maxpool =
577 ops::MaxPool(scope.WithOpName("maxpool"),
578 AddQDQV2(scope.WithOpName("mp_qdq_in"), conv1), {1, 1, 1, 1},
579 {1, 1, 1, 1}, "SAME",
580 ops::MaxPool::Attrs().DataFormat(GetParam().data_format));
581 auto output =
582 AddOutput(scope.WithOpName("output"), maxpool, 0, GetParam().final_qdq);
583 POLICY_TRT7(GetParam(), scope);
584 POLICY_TRT8(GetParam(), scope);
585 }
586
587 // Tests QDQ(Conv(QDQ(MaxPool(Conv(QDQ(x))))))
TEST_P(TestQDQSuite,TestConvMaxpoolConv)588 TEST_P(TestQDQSuite, TestConvMaxpoolConv) {
589 SKIP_TRT7(!GetParam().final_qdq);
590 SKIP_TRT7(GetParam().data_format != "NCHW");
591
592 Scope scope = Scope::NewRootScope();
593 auto input = AddInput(scope, 0, GetParam().data_format,
594 /*size_chw=*/{3, 28, 28});
595 auto conv1 =
596 AddConv2D(scope, input, 3, 16, /*filter_size=*/{3, 3}, /*stride=*/{1, 1},
597 GetParam().data_format, GetParam().conv_has_bias,
598 GetParam().conv_epilogue, GetParam().qdq_on_output);
599 ops::MaxPool maxpool =
600 ops::MaxPool(scope.WithOpName("maxpool"),
601 AddQDQV2(scope.WithOpName("mp_qdq_in"), conv1), {1, 1, 1, 1},
602 {1, 1, 1, 1}, "SAME",
603 ops::MaxPool::Attrs().DataFormat(GetParam().data_format));
604 auto conv2 = AddConv2D(scope, maxpool, 16, 16, {3, 3}, {1, 1},
605 GetParam().data_format, GetParam().conv_has_bias,
606 GetParam().conv_epilogue, GetParam().qdq_on_output);
607 auto output =
608 AddOutput(scope.WithOpName("out"), conv2, 0, GetParam().final_qdq);
609 POLICY_TRT7(GetParam(), scope);
610 POLICY_TRT8(GetParam(), scope);
611 }
612
613 INSTANTIATE_TEST_SUITE_P(TestQDQSuiteInst, TestQDQSuite,
614 ::testing::ValuesIn(EnumerateQDQTestOptions()));
615
616 } // namespace convert
617 } // namespace tensorrt
618 } // namespace tensorflow
619
620 #endif // IS_TRT_VERSION_GE(8, 0, 0, 0)
621 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
622