• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stdint.h>
17 
18 #include <memory>
19 #include <vector>
20 
21 #include <gtest/gtest.h>
22 #include "tensorflow/lite/interpreter.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 #include "tensorflow/lite/kernels/subgraph_test_util.h"
25 
26 namespace tflite {
27 
28 using subgraph_test_util::ControlFlowOpTest;
29 
30 namespace {
31 
32 class CallOnceTest : public ControlFlowOpTest {
33  protected:
SetUp()34   void SetUp() override {
35     AddSubgraphs(2);
36     builder_->BuildCallOnceAndReadVariableSubgraph(
37         &interpreter_->primary_subgraph());
38     builder_->BuildAssignRandomValueToVariableSubgraph(
39         interpreter_->subgraph(1));
40     builder_->BuildCallOnceAndReadVariablePlusOneSubgraph(
41         interpreter_->subgraph(2));
42 
43     ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
44     ASSERT_EQ(interpreter_->subgraph(2)->AllocateTensors(), kTfLiteOk);
45   }
46 };
47 
TEST_F(CallOnceTest,TestSimple)48 TEST_F(CallOnceTest, TestSimple) {
49   ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
50 
51   TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
52   ASSERT_EQ(output->dims->size, 1);
53   ASSERT_EQ(output->dims->data[0], 1);
54   ASSERT_EQ(output->type, kTfLiteInt32);
55   ASSERT_EQ(NumElements(output), 1);
56 
57   // The value of the variable must be non-zero, which will be assigned by the
58   // initialization subgraph.
59   EXPECT_GT(output->data.i32[0], 0);
60 }
61 
TEST_F(CallOnceTest,TestInvokeMultipleTimes)62 TEST_F(CallOnceTest, TestInvokeMultipleTimes) {
63   ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
64 
65   TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
66   ASSERT_EQ(output->dims->size, 1);
67   ASSERT_EQ(output->dims->data[0], 1);
68   ASSERT_EQ(output->type, kTfLiteInt32);
69   ASSERT_EQ(NumElements(output), 1);
70 
71   // The value of the variable must be non-zero, which will be assigned by the
72   // initialization subgraph.
73   int value = output->data.i32[0];
74   EXPECT_GT(value, 0);
75 
76   for (int i = 0; i < 3; ++i) {
77     // Make sure that no more random value assignment in the initialization
78     // subgraph.
79     ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
80     TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
81     ASSERT_EQ(output->dims->size, 1);
82     ASSERT_EQ(output->dims->data[0], 1);
83     ASSERT_EQ(output->type, kTfLiteInt32);
84     ASSERT_EQ(NumElements(output), 1);
85     ASSERT_EQ(output->data.i32[0], value);
86   }
87 }
88 
TEST_F(CallOnceTest,TestInvokeOnceAcrossMultipleEntryPoints)89 TEST_F(CallOnceTest, TestInvokeOnceAcrossMultipleEntryPoints) {
90   ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
91 
92   TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
93   ASSERT_EQ(output->dims->size, 1);
94   ASSERT_EQ(output->dims->data[0], 1);
95   ASSERT_EQ(output->type, kTfLiteInt32);
96   ASSERT_EQ(NumElements(output), 1);
97 
98   // The value of the variable must be non-zero, which will be assigned by the
99   // initialization subgraph.
100   int value = output->data.i32[0];
101   EXPECT_GT(value, 0);
102 
103   // Make sure that no more random value assignment in the initialization
104   // subgraph while invoking the other subgraph, which has the CallOnce op.
105   ASSERT_EQ(interpreter_->subgraph(2)->Invoke(), kTfLiteOk);
106   output = interpreter_->subgraph(2)->tensor(
107       interpreter_->subgraph(2)->outputs()[0]);
108   ASSERT_EQ(output->dims->size, 1);
109   ASSERT_EQ(output->dims->data[0], 1);
110   ASSERT_EQ(output->type, kTfLiteInt32);
111   ASSERT_EQ(NumElements(output), 1);
112   ASSERT_EQ(output->data.i32[0], value + 1);
113 }
114 
115 }  // namespace
116 }  // namespace tflite
117