• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/subgraph_test_util.h"
17 
18 #include <stdint.h>
19 
20 #include <memory>
21 #include <vector>
22 
23 #include <gtest/gtest.h>
24 #include "tensorflow/lite/interpreter.h"
25 #include "tensorflow/lite/kernels/kernel_util.h"
26 #include "tensorflow/lite/testing/util.h"
27 
28 namespace tflite {
29 
30 namespace subgraph_test_util {
31 
32 namespace {
33 
34 class SubgraphBuilderTest : public ::testing::Test {
35  public:
SubgraphBuilderTest()36   SubgraphBuilderTest()
37       : interpreter_(new Interpreter), builder_(new SubgraphBuilder) {}
38 
~SubgraphBuilderTest()39   ~SubgraphBuilderTest() override {
40     interpreter_.reset();
41     builder_.reset();
42   }
43 
44  protected:
TestAccumulateLoopBody(int input1,int input2,int output1,int output2)45   void TestAccumulateLoopBody(int input1, int input2, int output1,
46                               int output2) {
47     interpreter_.reset(new Interpreter);
48     builder_->BuildAccumulateLoopBodySubgraph(
49         &interpreter_->primary_subgraph());
50 
51     interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
52     interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1});
53     ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
54 
55     FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {input1});
56     FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {input2});
57 
58     ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
59     TfLiteTensor* output_tensor1 =
60         interpreter_->tensor(interpreter_->outputs()[0]);
61     CheckIntTensor(output_tensor1, {1}, {output1});
62     TfLiteTensor* output_tensor2 =
63         interpreter_->tensor(interpreter_->outputs()[1]);
64     CheckIntTensor(output_tensor2, {1}, {output2});
65   }
66 
67   std::unique_ptr<Interpreter> interpreter_;
68   std::unique_ptr<SubgraphBuilder> builder_;
69 };
70 
TEST_F(SubgraphBuilderTest,TestBuildAddSubgraph)71 TEST_F(SubgraphBuilderTest, TestBuildAddSubgraph) {
72   builder_->BuildAddSubgraph(&interpreter_->primary_subgraph());
73 
74   interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2});
75   interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2});
76   ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
77 
78   FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7});
79   FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2});
80   ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
81 
82   TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
83   CheckIntTensor(output, {1, 2}, {6, 9});
84 }
85 
TEST_F(SubgraphBuilderTest,TestBuildMulSubgraph)86 TEST_F(SubgraphBuilderTest, TestBuildMulSubgraph) {
87   builder_->BuildMulSubgraph(&interpreter_->primary_subgraph());
88 
89   interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2});
90   interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2});
91   ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
92 
93   FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7});
94   FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2});
95   ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
96 
97   TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
98   CheckIntTensor(output, {1, 2}, {5, 14});
99 }
100 
TEST_F(SubgraphBuilderTest,TestBuildPadSubgraph)101 TEST_F(SubgraphBuilderTest, TestBuildPadSubgraph) {
102   builder_->BuildPadSubgraph(&interpreter_->primary_subgraph());
103 
104   interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2});
105   interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2});
106   ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
107 
108   FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7});
109   FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2});
110   ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
111 
112   TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
113   CheckIntTensor(output, {5}, {0, 5, 7, 0, 0});
114 }
115 
TEST_F(SubgraphBuilderTest,TestBuildDynamicPadSubgraph)116 TEST_F(SubgraphBuilderTest, TestBuildDynamicPadSubgraph) {
117   builder_->BuildPadSubgraph(&interpreter_->primary_subgraph());
118 
119   interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2});
120   interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2});
121   ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
122 
123   FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7});
124   FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2});
125   ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
126 
127   TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
128   EXPECT_TRUE(IsDynamicTensor(output));
129   CheckIntTensor(output, {5}, {0, 5, 7, 0, 0});
130 }
131 
TEST_F(SubgraphBuilderTest,TestBuildLessEqualCondSubgraph)132 TEST_F(SubgraphBuilderTest, TestBuildLessEqualCondSubgraph) {
133   builder_->BuildLessEqualCondSubgraph(&interpreter_->primary_subgraph(), 3);
134 
135   interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {5});
136   interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {10, 10});
137   ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
138 
139   // Test [1, 2, 3, 4, 5] <= 3 == [true, true, true, false, false]
140   // (with broadcasting).
141   FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]),
142                 {1, 2, 3, 4, 5});
143   ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
144   TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
145   CheckBoolTensor(output, {5}, {true, true, true, false, false});
146 }
147 
TEST_F(SubgraphBuilderTest,TestBuildAccumulateLoopBodySubgraph)148 TEST_F(SubgraphBuilderTest, TestBuildAccumulateLoopBodySubgraph) {
149   TestAccumulateLoopBody(1, 1, 2, 3);
150   TestAccumulateLoopBody(2, 3, 3, 6);
151   TestAccumulateLoopBody(3, 6, 4, 10);
152 }
153 
TEST_F(SubgraphBuilderTest,TestBuildPadLoopBodySubgraph)154 TEST_F(SubgraphBuilderTest, TestBuildPadLoopBodySubgraph) {
155   builder_->BuildPadLoopBodySubgraph(&interpreter_->primary_subgraph(), {1, 2});
156 
157   interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
158   interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {5});
159   ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
160 
161   FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1});
162   FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]),
163                 {0, 5, 7, 0, 0});
164 
165   ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
166   TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]);
167   CheckIntTensor(output1, {1}, {2});
168   TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]);
169   CheckIntTensor(output2, {8}, {0, 0, 5, 7, 0, 0, 0, 0});
170 }
171 
172 }  // namespace
173 }  // namespace subgraph_test_util
174 }  // namespace tflite
175