1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/kernels/subgraph_test_util.h"
17
18 #include <stdint.h>
19
20 #include <memory>
21 #include <vector>
22
23 #include <gtest/gtest.h>
24 #include "tensorflow/lite/interpreter.h"
25 #include "tensorflow/lite/kernels/kernel_util.h"
26 #include "tensorflow/lite/testing/util.h"
27
28 namespace tflite {
29
30 namespace subgraph_test_util {
31
32 namespace {
33
34 class SubgraphBuilderTest : public ::testing::Test {
35 public:
SubgraphBuilderTest()36 SubgraphBuilderTest()
37 : interpreter_(new Interpreter), builder_(new SubgraphBuilder) {}
38
~SubgraphBuilderTest()39 ~SubgraphBuilderTest() override {
40 interpreter_.reset();
41 builder_.reset();
42 }
43
44 protected:
TestAccumulateLoopBody(int input1,int input2,int output1,int output2)45 void TestAccumulateLoopBody(int input1, int input2, int output1,
46 int output2) {
47 interpreter_.reset(new Interpreter);
48 builder_->BuildAccumulateLoopBodySubgraph(
49 &interpreter_->primary_subgraph());
50
51 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
52 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1});
53 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
54
55 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {input1});
56 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {input2});
57
58 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
59 TfLiteTensor* output_tensor1 =
60 interpreter_->tensor(interpreter_->outputs()[0]);
61 CheckIntTensor(output_tensor1, {1}, {output1});
62 TfLiteTensor* output_tensor2 =
63 interpreter_->tensor(interpreter_->outputs()[1]);
64 CheckIntTensor(output_tensor2, {1}, {output2});
65 }
66
67 std::unique_ptr<Interpreter> interpreter_;
68 std::unique_ptr<SubgraphBuilder> builder_;
69 };
70
TEST_F(SubgraphBuilderTest,TestBuildAddSubgraph)71 TEST_F(SubgraphBuilderTest, TestBuildAddSubgraph) {
72 builder_->BuildAddSubgraph(&interpreter_->primary_subgraph());
73
74 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2});
75 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2});
76 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
77
78 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7});
79 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2});
80 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
81
82 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
83 CheckIntTensor(output, {1, 2}, {6, 9});
84 }
85
TEST_F(SubgraphBuilderTest,TestBuildMulSubgraph)86 TEST_F(SubgraphBuilderTest, TestBuildMulSubgraph) {
87 builder_->BuildMulSubgraph(&interpreter_->primary_subgraph());
88
89 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2});
90 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2});
91 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
92
93 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7});
94 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2});
95 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
96
97 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
98 CheckIntTensor(output, {1, 2}, {5, 14});
99 }
100
TEST_F(SubgraphBuilderTest,TestBuildPadSubgraph)101 TEST_F(SubgraphBuilderTest, TestBuildPadSubgraph) {
102 builder_->BuildPadSubgraph(&interpreter_->primary_subgraph());
103
104 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2});
105 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2});
106 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
107
108 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7});
109 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2});
110 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
111
112 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
113 CheckIntTensor(output, {5}, {0, 5, 7, 0, 0});
114 }
115
TEST_F(SubgraphBuilderTest,TestBuildDynamicPadSubgraph)116 TEST_F(SubgraphBuilderTest, TestBuildDynamicPadSubgraph) {
117 builder_->BuildPadSubgraph(&interpreter_->primary_subgraph());
118
119 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2});
120 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2});
121 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
122
123 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7});
124 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2});
125 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
126
127 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
128 EXPECT_TRUE(IsDynamicTensor(output));
129 CheckIntTensor(output, {5}, {0, 5, 7, 0, 0});
130 }
131
TEST_F(SubgraphBuilderTest,TestBuildLessEqualCondSubgraph)132 TEST_F(SubgraphBuilderTest, TestBuildLessEqualCondSubgraph) {
133 builder_->BuildLessEqualCondSubgraph(&interpreter_->primary_subgraph(), 3);
134
135 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {5});
136 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {10, 10});
137 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
138
139 // Test [1, 2, 3, 4, 5] <= 3 == [true, true, true, false, false]
140 // (with broadcasting).
141 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]),
142 {1, 2, 3, 4, 5});
143 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
144 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
145 CheckBoolTensor(output, {5}, {true, true, true, false, false});
146 }
147
TEST_F(SubgraphBuilderTest,TestBuildAccumulateLoopBodySubgraph)148 TEST_F(SubgraphBuilderTest, TestBuildAccumulateLoopBodySubgraph) {
149 TestAccumulateLoopBody(1, 1, 2, 3);
150 TestAccumulateLoopBody(2, 3, 3, 6);
151 TestAccumulateLoopBody(3, 6, 4, 10);
152 }
153
TEST_F(SubgraphBuilderTest,TestBuildPadLoopBodySubgraph)154 TEST_F(SubgraphBuilderTest, TestBuildPadLoopBodySubgraph) {
155 builder_->BuildPadLoopBodySubgraph(&interpreter_->primary_subgraph(), {1, 2});
156
157 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
158 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {5});
159 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
160
161 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1});
162 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]),
163 {0, 5, 7, 0, 0});
164
165 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
166 TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]);
167 CheckIntTensor(output1, {1}, {2});
168 TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]);
169 CheckIntTensor(output2, {8}, {0, 0, 5, 7, 0, 0, 0, 0});
170 }
171
172 } // namespace
173 } // namespace subgraph_test_util
174 } // namespace tflite
175