• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #if GOOGLE_CUDA && GOOGLE_TENSORRT
16 
17 #include "tensorflow/compiler/tf2tensorrt/convert/algorithm_selector.h"
18 
19 #include <memory>
20 
21 #include <gtest/gtest.h>
22 #include "third_party/tensorrt/NvInfer.h"
23 
24 namespace tensorflow {
25 namespace tensorrt {
26 namespace convert {
27 
TEST(TestAlgorithmSelector,TensorRT7_1)28 TEST(TestAlgorithmSelector, TensorRT7_1) {
29   // Verify that the algorithm selector for TRT 7.1 is not required.
30   AlgorithmSelectorImpl sel71({7, 1, 3, 4});
31   ASSERT_FALSE(sel71.IsAlgorithmSelectorRequired());
32 }
33 
TEST(TestAlgorithmSelector,TensorRT7_2)34 TEST(TestAlgorithmSelector, TensorRT7_2) {
35   // Verify that the algorithm selector for TRT 7.2 is required.
36   AlgorithmSelectorImpl sel72({7, 2, 0, 0});
37   ASSERT_TRUE(sel72.IsAlgorithmSelectorRequired());
38 
39   // Check that the correct tactics are banned.
40   auto turing_tactics = AlgorithmSelectorImpl::GetBannedTRT72TuringTactics();
41 
42   for (auto id : turing_tactics) {
43     EXPECT_TRUE(sel72.IsBannedTactic(id));
44   }
45 
46   // Check that a bad shuffle format is banned.
47   EXPECT_FALSE(sel72.AllowShuffleAlgorithm(0, nvinfer1::DataType::kFLOAT,
48                                            nvinfer1::TensorFormat::kCHW32));
49 
50   // Check that other formats are not banned.
51   EXPECT_TRUE(sel72.AllowShuffleAlgorithm(0, nvinfer1::DataType::kHALF,
52                                           nvinfer1::TensorFormat::kCHW32));
53   EXPECT_TRUE(sel72.AllowShuffleAlgorithm(0, nvinfer1::DataType::kINT32,
54                                           nvinfer1::TensorFormat::kCHW32));
55   EXPECT_TRUE(sel72.AllowShuffleAlgorithm(0, nvinfer1::DataType::kFLOAT,
56                                           nvinfer1::TensorFormat::kCHW16));
57 }
58 
TEST(TestAlgorithmSelector,TensorRT8_0)59 TEST(TestAlgorithmSelector, TensorRT8_0) {
60   // Verify that the algorithm selector for TRT 8.0 is required.
61   AlgorithmSelectorImpl sel80({8, 0, 1, 6});
62   ASSERT_TRUE(sel80.IsAlgorithmSelectorRequired());
63 
64   // Check that the turing 7.2 tactics are not banned.
65   auto turing_tactics = AlgorithmSelectorImpl::GetBannedTRT72TuringTactics();
66   for (auto id : turing_tactics) {
67     EXPECT_FALSE(sel80.IsBannedTactic(id));
68   }
69 
70   // Check that a bad shuffle format is banned.
71   EXPECT_FALSE(sel80.AllowShuffleAlgorithm(0, nvinfer1::DataType::kINT8,
72                                            nvinfer1::TensorFormat::kLINEAR));
73 
74   // Check that other formats are not banned.
75   EXPECT_TRUE(sel80.AllowShuffleAlgorithm(0, nvinfer1::DataType::kHALF,
76                                           nvinfer1::TensorFormat::kLINEAR));
77   EXPECT_TRUE(sel80.AllowShuffleAlgorithm(0, nvinfer1::DataType::kINT32,
78                                           nvinfer1::TensorFormat::kLINEAR));
79   EXPECT_TRUE(sel80.AllowShuffleAlgorithm(0, nvinfer1::DataType::kFLOAT,
80                                           nvinfer1::TensorFormat::kLINEAR));
81   EXPECT_TRUE(sel80.AllowShuffleAlgorithm(0, nvinfer1::DataType::kINT8,
82                                           nvinfer1::TensorFormat::kCHW16));
83   EXPECT_TRUE(sel80.AllowShuffleAlgorithm(0, nvinfer1::DataType::kINT8,
84                                           nvinfer1::TensorFormat::kCHW32));
85 }
86 
TEST(TestAlgorithmSelector,TensorRT8_2)87 TEST(TestAlgorithmSelector, TensorRT8_2) {
88   // Verify that the algorithm selector for TRT 8.0 is required.
89   AlgorithmSelectorImpl sel({8, 2, 0, 0});
90   ASSERT_FALSE(sel.IsAlgorithmSelectorRequired());
91 }
92 
93 }  // namespace convert
94 }  // namespace tensorrt
95 }  // namespace tensorflow
96 
97 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
98