1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA && GOOGLE_TENSORRT
17
18 #include "tensorflow/compiler/tf2tensorrt/utils/trt_testutils.h"
19
20 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
21 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
22 #include "third_party/tensorrt/NvInfer.h"
23
24 namespace tensorflow {
25 namespace tensorrt {
26 namespace convert {
27
28 using ::testing::AllOf;
29 using ::testing::AnyOf;
30 using ::testing::Eq;
31 using ::testing::Not;
32
TEST(TrtDimsMatcher,ParameterizedMatchers)33 TEST(TrtDimsMatcher, ParameterizedMatchers) {
34 EXPECT_THAT(nvinfer1::Dims4(1, 2, 3, 4), DimsAreArray({1, 2, 3, 4}));
35 // Check empty dims.
36 EXPECT_THAT(nvinfer1::Dims{}, Not(DimsAreArray({1, 2})));
37 std::vector<int> empty_dims;
38 EXPECT_THAT(nvinfer1::Dims{}, DimsAreArray(empty_dims));
39 // Check mismatching values.
40 EXPECT_THAT(nvinfer1::Dims4(1, 2, 3, 4), Not(DimsAreArray({1, 2, 3, 5})));
41 // Check mismatching number of arguments.
42 EXPECT_THAT(nvinfer1::Dims4(1, 2, 3, 4), Not(DimsAreArray({1, 2, 5})));
43 }
44
TEST(TrtDimsMatcher,EqualityMatcher)45 TEST(TrtDimsMatcher, EqualityMatcher) {
46 EXPECT_THAT(nvinfer1::Dims4(1, 2, 3, 4), Eq(nvinfer1::Dims4(1, 2, 3, 4)));
47 // Check empty dims.
48 EXPECT_THAT(nvinfer1::Dims{}, Eq(nvinfer1::Dims()));
49 // Check empty Dims is not equal to DimsHW, since their sizes differ.
50 EXPECT_THAT(nvinfer1::Dims{}, Not(Eq(nvinfer1::DimsHW())));
51 // Check mismatching values.
52 EXPECT_THAT(nvinfer1::Dims4(1, 2, 3, 4),
53 Not(Eq(nvinfer1::Dims4(1, 2, 3, 3))));
54 // Check mismatching number of arguments.
55 EXPECT_THAT(nvinfer1::Dims4(1, 2, 3, 4), Not(Eq(nvinfer1::Dims2(1, 2))));
56 }
57
TEST(INetworkDefinitionMatchers,CorrectlyMatch)58 TEST(INetworkDefinitionMatchers, CorrectlyMatch) {
59 Logger& logger = *Logger::GetLogger();
60 TrtUniquePtrType<nvinfer1::IBuilder> builder(
61 nvinfer1::createInferBuilder(logger));
62 TrtUniquePtrType<nvinfer1::INetworkDefinition> network(
63 builder->createNetworkV2(0L));
64
65 // Empty network checks.
66 EXPECT_THAT(network.get(), AllOf(Not(LayerNamesAreArray({"some layer"})),
67 LayerNamesNonEmpty()));
68
69 // Add the input and FC layers.
70 nvinfer1::Weights weights;
71 weights.type = nvinfer1::DataType::kFLOAT;
72 std::array<float, 1> vals;
73 weights.values = vals.data();
74 weights.count = 1;
75 auto input = network->addInput("input-tensor", nvinfer1::DataType::kFLOAT,
76 nvinfer1::Dims3{1, 1, 1});
77 ASSERT_NE(input, nullptr);
78
79 const char* fc_layer_name = "my-fc-layer";
80 auto layer = network->addFullyConnected(*input, 1, weights, weights);
81 ASSERT_NE(layer, nullptr);
82 layer->setName(fc_layer_name);
83
84 // Check layer names.
85 EXPECT_THAT(network.get(),
86 AllOf(LayerNamesNonEmpty(), LayerNamesAreArray({fc_layer_name})));
87
88 // Add layer with default name and check layer name.
89 layer = network->addFullyConnected(*input, 1, weights, weights);
90 EXPECT_THAT(network.get(), AllOf(LayerNamesNonEmpty(),
91 Not(LayerNamesAreArray({fc_layer_name}))));
92 }
93
94 } // namespace convert
95
96 } // namespace tensorrt
97 } // namespace tensorflow
98
99 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
100