1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <stdint.h>
17
18 #include <memory>
19 #include <vector>
20
21 #include <gtest/gtest.h>
22 #include "tensorflow/lite/interpreter.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 #include "tensorflow/lite/kernels/subgraph_test_util.h"
25
26 namespace tflite {
27
28 using subgraph_test_util::ControlFlowOpTest;
29
30 namespace {
31
32 class CallOnceTest : public ControlFlowOpTest {
33 protected:
SetUp()34 void SetUp() override {
35 AddSubgraphs(2);
36 builder_->BuildCallOnceAndReadVariableSubgraph(
37 &interpreter_->primary_subgraph());
38 builder_->BuildAssignRandomValueToVariableSubgraph(
39 interpreter_->subgraph(1));
40 builder_->BuildCallOnceAndReadVariablePlusOneSubgraph(
41 interpreter_->subgraph(2));
42
43 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
44 ASSERT_EQ(interpreter_->subgraph(2)->AllocateTensors(), kTfLiteOk);
45 }
46 };
47
TEST_F(CallOnceTest,TestSimple)48 TEST_F(CallOnceTest, TestSimple) {
49 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
50
51 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
52 ASSERT_EQ(output->dims->size, 1);
53 ASSERT_EQ(output->dims->data[0], 1);
54 ASSERT_EQ(output->type, kTfLiteInt32);
55 ASSERT_EQ(NumElements(output), 1);
56
57 // The value of the variable must be non-zero, which will be assigned by the
58 // initialization subgraph.
59 EXPECT_GT(output->data.i32[0], 0);
60 }
61
TEST_F(CallOnceTest,TestInvokeMultipleTimes)62 TEST_F(CallOnceTest, TestInvokeMultipleTimes) {
63 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
64
65 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
66 ASSERT_EQ(output->dims->size, 1);
67 ASSERT_EQ(output->dims->data[0], 1);
68 ASSERT_EQ(output->type, kTfLiteInt32);
69 ASSERT_EQ(NumElements(output), 1);
70
71 // The value of the variable must be non-zero, which will be assigned by the
72 // initialization subgraph.
73 int value = output->data.i32[0];
74 EXPECT_GT(value, 0);
75
76 for (int i = 0; i < 3; ++i) {
77 // Make sure that no more random value assignment in the initialization
78 // subgraph.
79 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
80 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
81 ASSERT_EQ(output->dims->size, 1);
82 ASSERT_EQ(output->dims->data[0], 1);
83 ASSERT_EQ(output->type, kTfLiteInt32);
84 ASSERT_EQ(NumElements(output), 1);
85 ASSERT_EQ(output->data.i32[0], value);
86 }
87 }
88
TEST_F(CallOnceTest,TestInvokeOnceAcrossMultipleEntryPoints)89 TEST_F(CallOnceTest, TestInvokeOnceAcrossMultipleEntryPoints) {
90 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
91
92 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
93 ASSERT_EQ(output->dims->size, 1);
94 ASSERT_EQ(output->dims->data[0], 1);
95 ASSERT_EQ(output->type, kTfLiteInt32);
96 ASSERT_EQ(NumElements(output), 1);
97
98 // The value of the variable must be non-zero, which will be assigned by the
99 // initialization subgraph.
100 int value = output->data.i32[0];
101 EXPECT_GT(value, 0);
102
103 // Make sure that no more random value assignment in the initialization
104 // subgraph while invoking the other subgraph, which has the CallOnce op.
105 ASSERT_EQ(interpreter_->subgraph(2)->Invoke(), kTfLiteOk);
106 output = interpreter_->subgraph(2)->tensor(
107 interpreter_->subgraph(2)->outputs()[0]);
108 ASSERT_EQ(output->dims->size, 1);
109 ASSERT_EQ(output->dims->data[0], 1);
110 ASSERT_EQ(output->type, kTfLiteInt32);
111 ASSERT_EQ(NumElements(output), 1);
112 ASSERT_EQ(output->data.i32[0], value + 1);
113 }
114
115 } // namespace
116 } // namespace tflite
117