• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/flex/training/training_delegate.h"
17 
18 #include <cstdint>
19 #include <thread>  // NOLINT(build/c++11)
20 #include <vector>
21 
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "tensorflow/lite/delegates/flex/test_util.h"
25 #include "tensorflow/lite/shared_library.h"
26 
27 namespace tflite {
28 namespace flex {
29 namespace testing {
30 
31 using ::testing::ElementsAre;
32 
33 class TrainingDelegateTest : public testing::FlexModelTest {
34  protected:
SetUp()35   void SetUp() override {
36     delegate_ = absl::make_unique<TrainingFlexDelegate>();
37     interpreter_ = absl::make_unique<Interpreter>(&error_reporter_);
38     interpreter_->SetCancellationFunction(delegate_.get(),
39                                           TrainingFlexDelegate::ShouldCancel);
40   }
41 
TearDown()42   void TearDown() override {
43     // The delegate needs to be destructed after the interpreter because the
44     // interpreter references data contained in the delegate.
45     interpreter_.reset();
46     delegate_.reset();
47   }
48 
49  public:
TrainingDelegateTest()50   TrainingDelegateTest() : delegate_(nullptr) {}
51 
ConfigureDelegate()52   void ConfigureDelegate() {
53     ASSERT_EQ(
54         interpreter_->ModifyGraphWithDelegate(delegate_->GetTfLiteDelegate()),
55         kTfLiteOk);
56   }
57 
Cancel()58   void Cancel() { delegate_->Cancel(); }
59 
60  private:
61   std::unique_ptr<TrainingFlexDelegate> delegate_;
62 };
63 
TEST_F(TrainingDelegateTest,TestFullGraph)64 TEST_F(TrainingDelegateTest, TestFullGraph) {
65   AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2});
66 
67   AddTfOp(testing::kAdd, {0, 1}, {2});
68 
69   ConfigureDelegate();
70 
71   SetShape(0, {2, 2});
72   SetTypedValues<int>(0, {1, 2, 3, 4});
73   SetShape(1, {2, 2});
74   SetTypedValues<int>(1, {4, 3, 2, 1});
75 
76   ASSERT_TRUE(Invoke());
77 
78   ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
79   ASSERT_THAT(GetTypedValues<int>(2), ElementsAre(5, 5, 5, 5));
80   ASSERT_EQ(GetType(2), kTfLiteInt32);
81 }
82 
TEST_F(TrainingDelegateTest,TestCancellation1)83 TEST_F(TrainingDelegateTest, TestCancellation1) {
84   AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2});
85 
86   AddTfOp(testing::kAdd, {0, 1}, {2});
87 
88   ConfigureDelegate();
89 
90   SetShape(0, {2, 2});
91   SetTypedValues<int>(0, {1, 2, 3, 4});
92   SetShape(1, {2, 2});
93   SetTypedValues<int>(1, {4, 3, 2, 1});
94 
95   ASSERT_TRUE(Invoke());
96 
97   ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
98   ASSERT_THAT(GetTypedValues<int>(2), ElementsAre(5, 5, 5, 5));
99   ASSERT_EQ(GetType(2), kTfLiteInt32);
100 
101   Cancel();
102   // Op should be cancelled.
103   ASSERT_FALSE(Invoke());
104 }
105 
TEST_F(TrainingDelegateTest,TestCancellation2)106 TEST_F(TrainingDelegateTest, TestCancellation2) {
107   // Define the graph.
108   AddTensors(2, {0}, {1}, kTfLiteBool, {1});
109 
110   // We need an op that checks the CancellationManager status.
111   AddTfOp(testing::kLoopCond, {0}, {1});
112 
113   // Apply the delegate.
114   ConfigureDelegate();
115 
116   // Define inputs.
117   SetShape(0, {1});
118 
119   ASSERT_TRUE(Invoke());
120 
121   Cancel();
122   // Op should be cancelled.
123   ASSERT_FALSE(Invoke());
124 }
125 
TEST_F(TrainingDelegateTest,TestCancellationTwoThreads)126 TEST_F(TrainingDelegateTest, TestCancellationTwoThreads) {
127   AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2});
128 
129   AddTfOp(testing::kAdd, {0, 1}, {2});
130 
131   ConfigureDelegate();
132 
133   SetShape(0, {2, 2});
134   SetTypedValues<int>(0, {1, 2, 3, 4});
135   SetShape(1, {2, 2});
136   SetTypedValues<int>(1, {4, 3, 2, 1});
137 
138   std::thread invoke_thread([this]() {
139     bool result = true;
140     result = this->Invoke();
141     std::this_thread::sleep_for(std::chrono::milliseconds(1000));
142     result = this->Invoke();
143     ASSERT_FALSE(result);
144   });
145 
146   std::thread cancel_thread([this]() { this->Cancel(); });
147 
148   invoke_thread.join();
149   cancel_thread.join();
150 }
151 
152 // TODO(b/179048124): Add test case with ReduceDataset op.
153 // TODO(b/179048124): Add integration test with real models.
154 // TODO(b/179048124): Add proper test with multiple threads.
155 
156 }  // namespace testing
157 }  // namespace flex
158 }  // namespace tflite
159