1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/utils.h"
16
17 #include <string>
18 #include <vector>
19
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "tensorflow/lite/c/common.h"
23
24 namespace tflite {
25 namespace delegates {
26 namespace {
27
TEST(UtilsTest,CreateNewTensorWithDifferentTypeTest)28 TEST(UtilsTest, CreateNewTensorWithDifferentTypeTest) {
29 std::vector<TfLiteTensor> tensors(2);
30 // Data about original tensor.
31 // The same shape should be reflected in tensors[1] later.
32 tensors[0].dims = TfLiteIntArrayCreate(2);
33 tensors[0].dims->data[0] = 2;
34 tensors[0].dims->data[1] = 3;
35 tensors[0].type = kTfLiteFloat32;
36 // To simulate a valid TFLite Context.
37 TfLiteContext context;
38 context.AddTensors = [](struct TfLiteContext*, int tensors_to_add,
39 int* first_new_tensor_index) {
40 // The util should be adding exactly one tensor to the graph.
41 if (tensors_to_add != 1) {
42 return kTfLiteError;
43 }
44 // This ensures that the 'new tensor' is the second tensor in the vector
45 // above.
46 *first_new_tensor_index = 1;
47 return kTfLiteOk;
48 };
49 context.ResizeTensor = [](struct TfLiteContext*, TfLiteTensor* tensor,
50 TfLiteIntArray* new_size) {
51 // Ensure dimensions are the same as the original tensor.
52 if (new_size->size != 2 || new_size->data[0] != 2 || new_size->data[1] != 3)
53 return kTfLiteError;
54 tensor->dims = new_size;
55 return kTfLiteOk;
56 };
57 context.tensors = tensors.data();
58
59 TfLiteTensor* new_tensor = nullptr;
60 int new_tensor_index = -1;
61 EXPECT_EQ(CreateNewTensorWithDifferentType(
62 &context, /**original_tensor_index**/ 0,
63 /**new_type**/ kTfLiteUInt8, &new_tensor, &new_tensor_index),
64 kTfLiteOk);
65 EXPECT_EQ(new_tensor_index, 1);
66 EXPECT_NE(new_tensor, nullptr);
67 EXPECT_NE(new_tensor->dims, nullptr);
68 EXPECT_EQ(new_tensor->type, kTfLiteUInt8);
69 EXPECT_EQ(new_tensor->allocation_type, kTfLiteArenaRw);
70
71 // Cleanup.
72 TfLiteIntArrayFree(tensors[0].dims);
73 TfLiteIntArrayFree(tensors[1].dims);
74 }
75
76 // A mock TfLiteContext to be used for GraphPartitionHelperTest.
77 class MockTfLiteContext : public TfLiteContext {
78 public:
MockTfLiteContext()79 MockTfLiteContext() : TfLiteContext({0}) {
80 // Simply create a 10-node execution plan.
81 exec_plan_ = TfLiteIntArrayCreate(10);
82 for (int i = 0; i < 10; ++i) exec_plan_->data[i] = i;
83
84 // Create {1}, {0,3,7,8}, {2,4,9}, {5,6} 4 partitions.
85 TfLiteDelegateParams params1({nullptr});
86 params1.nodes_to_replace = TfLiteIntArrayCreate(1);
87 params1.nodes_to_replace->data[0] = 1;
88 delegate_params_.emplace_back(params1);
89
90 TfLiteDelegateParams params2({nullptr});
91 params2.nodes_to_replace = TfLiteIntArrayCreate(4);
92 params2.nodes_to_replace->data[0] = 0;
93 params2.nodes_to_replace->data[1] = 3;
94 params2.nodes_to_replace->data[2] = 7;
95 params2.nodes_to_replace->data[3] = 8;
96 delegate_params_.emplace_back(params2);
97
98 TfLiteDelegateParams params3({nullptr});
99 params3.nodes_to_replace = TfLiteIntArrayCreate(3);
100 params3.nodes_to_replace->data[0] = 2;
101 params3.nodes_to_replace->data[1] = 4;
102 params3.nodes_to_replace->data[2] = 9;
103 delegate_params_.emplace_back(params3);
104
105 TfLiteDelegateParams params4({nullptr});
106 params4.nodes_to_replace = TfLiteIntArrayCreate(2);
107 params4.nodes_to_replace->data[0] = 5;
108 params4.nodes_to_replace->data[1] = 6;
109 delegate_params_.emplace_back(params4);
110
111 // We need to mock the following 3 functions inside TfLiteContext object
112 // that are used by GraphPartitionHelper implementation.
113 this->GetExecutionPlan = MockGetExecutionPlan;
114 this->GetNodeAndRegistration = MockGetNodeAndRegistration;
115 this->PreviewDelegatePartitioning = MockPreviewDelegatePartitioning;
116 }
~MockTfLiteContext()117 ~MockTfLiteContext() {
118 TfLiteIntArrayFree(exec_plan_);
119 for (auto params : delegate_params_) {
120 TfLiteIntArrayFree(params.nodes_to_replace);
121 TfLiteIntArrayFree(params.input_tensors);
122 TfLiteIntArrayFree(params.output_tensors);
123 }
124 }
125
exec_plan() const126 TfLiteIntArray* exec_plan() const { return exec_plan_; }
node()127 TfLiteNode* node() { return &node_; }
registration()128 TfLiteRegistration* registration() { return ®istration_; }
delegate_params()129 TfLiteDelegateParams* delegate_params() { return &delegate_params_.front(); }
num_delegate_params()130 int num_delegate_params() { return delegate_params_.size(); }
131
132 private:
MockGetExecutionPlan(TfLiteContext * context,TfLiteIntArray ** execution_plan)133 static TfLiteStatus MockGetExecutionPlan(TfLiteContext* context,
134 TfLiteIntArray** execution_plan) {
135 MockTfLiteContext* mock = reinterpret_cast<MockTfLiteContext*>(context);
136 *execution_plan = mock->exec_plan();
137 return kTfLiteOk;
138 }
139
MockGetNodeAndRegistration(TfLiteContext * context,int node_index,TfLiteNode ** node,TfLiteRegistration ** registration)140 static TfLiteStatus MockGetNodeAndRegistration(
141 TfLiteContext* context, int node_index, TfLiteNode** node,
142 TfLiteRegistration** registration) {
143 MockTfLiteContext* mock = reinterpret_cast<MockTfLiteContext*>(context);
144 *node = mock->node();
145 *registration = mock->registration();
146 return kTfLiteOk;
147 }
148
MockPreviewDelegatePartitioning(TfLiteContext * context,const TfLiteIntArray * nodes_to_replace,TfLiteDelegateParams ** partition_params_array,int * num_partitions)149 static TfLiteStatus MockPreviewDelegatePartitioning(
150 TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
151 TfLiteDelegateParams** partition_params_array, int* num_partitions) {
152 MockTfLiteContext* mock = reinterpret_cast<MockTfLiteContext*>(context);
153 *partition_params_array = mock->delegate_params();
154 *num_partitions = mock->num_delegate_params();
155 return kTfLiteOk;
156 }
157
158 // The execution plan of this mocked TfLiteContext object.
159 TfLiteIntArray* exec_plan_;
160
161 // For simplicity, the mocked graph has only type of node and one
162 // registration.
163 TfLiteNode node_;
164 TfLiteRegistration registration_;
165
166 // The TfLiteDelegateParams object that's manually populated inside the mocked
167 // TfLiteContext::PreviewDelegatePartitioning.
168 std::vector<TfLiteDelegateParams> delegate_params_;
169 };
170
IsNodeSupported(TfLiteContext * context,TfLiteNode * node,TfLiteRegistration * registration,std::string * unsupported_details)171 bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
172 TfLiteRegistration* registration,
173 std::string* unsupported_details) {
174 return true;
175 }
176
GetNodesToReplaceFromPartitions(const std::vector<TfLiteDelegateParams * > & partitions)177 std::vector<int> GetNodesToReplaceFromPartitions(
178 const std::vector<TfLiteDelegateParams*>& partitions) {
179 std::vector<int> nodes;
180 for (const auto p : partitions) {
181 nodes.insert(nodes.end(), p->nodes_to_replace->data,
182 p->nodes_to_replace->data + p->nodes_to_replace->size);
183 }
184 return nodes;
185 }
186
TEST(GraphPartitionHelper,CheckPartitions)187 TEST(GraphPartitionHelper, CheckPartitions) {
188 // The mocked TfLiteContext has 4 partitions: {1}, {0,3,7,8}, {2,4,9}, {5,6}.
189 MockTfLiteContext mocked_context;
190 GraphPartitionHelper helper(&mocked_context, IsNodeSupported);
191 EXPECT_EQ(kTfLiteOk, helper.Partition(nullptr));
192 EXPECT_EQ(10, helper.num_total_nodes());
193 EXPECT_EQ(4, helper.num_partitions());
194
195 auto partitions = helper.GetFirstNLargestPartitions(1, 0);
196 EXPECT_EQ(1, partitions.size());
197 auto nodes = GetNodesToReplaceFromPartitions(partitions);
198 EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8}));
199
200 // Get the largest partition but requiring at least 5 nodes, so empty result.
201 partitions = helper.GetFirstNLargestPartitions(1, 5);
202 EXPECT_TRUE(partitions.empty());
203
204 partitions = helper.GetFirstNLargestPartitions(10, 3);
205 EXPECT_EQ(2, partitions.size());
206 EXPECT_EQ(4, partitions[0]->nodes_to_replace->size);
207 EXPECT_EQ(3, partitions[1]->nodes_to_replace->size);
208 nodes = GetNodesToReplaceFromPartitions(partitions);
209 EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8, 2, 4, 9}));
210 }
211
ErrorGetExecutionPlan(TfLiteContext * context,TfLiteIntArray ** execution_plan)212 TfLiteStatus ErrorGetExecutionPlan(TfLiteContext* context,
213 TfLiteIntArray** execution_plan) {
214 return kTfLiteError;
215 }
216
EmptyReportError(TfLiteContext * context,const char * format,...)217 void EmptyReportError(TfLiteContext* context, const char* format, ...) {}
218
TEST(GraphPartitionHelper,CheckPrepareErrors)219 TEST(GraphPartitionHelper, CheckPrepareErrors) {
220 TfLiteContext error_context({0});
221 error_context.GetExecutionPlan = ErrorGetExecutionPlan;
222 error_context.ReportError = EmptyReportError;
223 GraphPartitionHelper helper(&error_context, IsNodeSupported);
224 EXPECT_EQ(kTfLiteError, helper.Partition(nullptr));
225 }
226
TEST(GraphPartitionHelper,CheckPartitionsWithSupportedNodeList)227 TEST(GraphPartitionHelper, CheckPartitionsWithSupportedNodeList) {
228 // The mocked TfLiteContext has 4 partitions: {1}, {0,3,7,8}, {2,4,9}, {5,6}.
229 // So, we simply create a list of supported nodes as {0,1,2,...,8,9}
230 MockTfLiteContext mocked_context;
231 std::vector<int> supported_nodes = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
232 GraphPartitionHelper helper(&mocked_context, supported_nodes);
233 EXPECT_EQ(kTfLiteOk, helper.Partition(nullptr));
234 EXPECT_EQ(10, helper.num_total_nodes());
235 EXPECT_EQ(4, helper.num_partitions());
236
237 auto partitions = helper.GetFirstNLargestPartitions(1, 0);
238 EXPECT_EQ(1, partitions.size());
239 auto nodes = GetNodesToReplaceFromPartitions(partitions);
240 EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8}));
241
242 // Get the largest partition but requiring at least 5 nodes, so empty result.
243 partitions = helper.GetFirstNLargestPartitions(1, 5);
244 EXPECT_TRUE(partitions.empty());
245
246 partitions = helper.GetFirstNLargestPartitions(10, 3);
247 EXPECT_EQ(2, partitions.size());
248 EXPECT_EQ(4, partitions[0]->nodes_to_replace->size);
249 EXPECT_EQ(3, partitions[1]->nodes_to_replace->size);
250 nodes = GetNodesToReplaceFromPartitions(partitions);
251 EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8, 2, 4, 9}));
252 }
253
254 } // namespace
255 } // namespace delegates
256 } // namespace tflite
257