• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/utils.h"
16 
17 #include <string>
18 #include <vector>
19 
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "tensorflow/lite/c/common.h"
23 
24 namespace tflite {
25 namespace delegates {
26 namespace {
27 
TEST(UtilsTest,CreateNewTensorWithDifferentTypeTest)28 TEST(UtilsTest, CreateNewTensorWithDifferentTypeTest) {
29   std::vector<TfLiteTensor> tensors(2);
30   // Data about original tensor.
31   // The same shape should be reflected in tensors[1] later.
32   tensors[0].dims = TfLiteIntArrayCreate(2);
33   tensors[0].dims->data[0] = 2;
34   tensors[0].dims->data[1] = 3;
35   tensors[0].type = kTfLiteFloat32;
36   // To simulate a valid TFLite Context.
37   TfLiteContext context;
38   context.AddTensors = [](struct TfLiteContext*, int tensors_to_add,
39                           int* first_new_tensor_index) {
40     // The util should be adding exactly one tensor to the graph.
41     if (tensors_to_add != 1) {
42       return kTfLiteError;
43     }
44     // This ensures that the 'new tensor' is the second tensor in the vector
45     // above.
46     *first_new_tensor_index = 1;
47     return kTfLiteOk;
48   };
49   context.ResizeTensor = [](struct TfLiteContext*, TfLiteTensor* tensor,
50                             TfLiteIntArray* new_size) {
51     // Ensure dimensions are the same as the original tensor.
52     if (new_size->size != 2 || new_size->data[0] != 2 || new_size->data[1] != 3)
53       return kTfLiteError;
54     tensor->dims = new_size;
55     return kTfLiteOk;
56   };
57   context.tensors = tensors.data();
58 
59   TfLiteTensor* new_tensor = nullptr;
60   int new_tensor_index = -1;
61   EXPECT_EQ(CreateNewTensorWithDifferentType(
62                 &context, /**original_tensor_index**/ 0,
63                 /**new_type**/ kTfLiteUInt8, &new_tensor, &new_tensor_index),
64             kTfLiteOk);
65   EXPECT_EQ(new_tensor_index, 1);
66   EXPECT_NE(new_tensor, nullptr);
67   EXPECT_NE(new_tensor->dims, nullptr);
68   EXPECT_EQ(new_tensor->type, kTfLiteUInt8);
69   EXPECT_EQ(new_tensor->allocation_type, kTfLiteArenaRw);
70 
71   // Cleanup.
72   TfLiteIntArrayFree(tensors[0].dims);
73   TfLiteIntArrayFree(tensors[1].dims);
74 }
75 
76 // A mock TfLiteContext to be used for GraphPartitionHelperTest.
77 class MockTfLiteContext : public TfLiteContext {
78  public:
MockTfLiteContext()79   MockTfLiteContext() : TfLiteContext({0}) {
80     // Simply create a 10-node execution plan.
81     exec_plan_ = TfLiteIntArrayCreate(10);
82     for (int i = 0; i < 10; ++i) exec_plan_->data[i] = i;
83 
84     // Create {1}, {0,3,7,8}, {2,4,9}, {5,6} 4 partitions.
85     TfLiteDelegateParams params1({nullptr});
86     params1.nodes_to_replace = TfLiteIntArrayCreate(1);
87     params1.nodes_to_replace->data[0] = 1;
88     delegate_params_.emplace_back(params1);
89 
90     TfLiteDelegateParams params2({nullptr});
91     params2.nodes_to_replace = TfLiteIntArrayCreate(4);
92     params2.nodes_to_replace->data[0] = 0;
93     params2.nodes_to_replace->data[1] = 3;
94     params2.nodes_to_replace->data[2] = 7;
95     params2.nodes_to_replace->data[3] = 8;
96     delegate_params_.emplace_back(params2);
97 
98     TfLiteDelegateParams params3({nullptr});
99     params3.nodes_to_replace = TfLiteIntArrayCreate(3);
100     params3.nodes_to_replace->data[0] = 2;
101     params3.nodes_to_replace->data[1] = 4;
102     params3.nodes_to_replace->data[2] = 9;
103     delegate_params_.emplace_back(params3);
104 
105     TfLiteDelegateParams params4({nullptr});
106     params4.nodes_to_replace = TfLiteIntArrayCreate(2);
107     params4.nodes_to_replace->data[0] = 5;
108     params4.nodes_to_replace->data[1] = 6;
109     delegate_params_.emplace_back(params4);
110 
111     // We need to mock the following 3 functions inside TfLiteContext object
112     // that are used by GraphPartitionHelper implementation.
113     this->GetExecutionPlan = MockGetExecutionPlan;
114     this->GetNodeAndRegistration = MockGetNodeAndRegistration;
115     this->PreviewDelegatePartitioning = MockPreviewDelegatePartitioning;
116   }
~MockTfLiteContext()117   ~MockTfLiteContext() {
118     TfLiteIntArrayFree(exec_plan_);
119     for (auto params : delegate_params_) {
120       TfLiteIntArrayFree(params.nodes_to_replace);
121       TfLiteIntArrayFree(params.input_tensors);
122       TfLiteIntArrayFree(params.output_tensors);
123     }
124   }
125 
exec_plan() const126   TfLiteIntArray* exec_plan() const { return exec_plan_; }
node()127   TfLiteNode* node() { return &node_; }
registration()128   TfLiteRegistration* registration() { return &registration_; }
delegate_params()129   TfLiteDelegateParams* delegate_params() { return &delegate_params_.front(); }
num_delegate_params()130   int num_delegate_params() { return delegate_params_.size(); }
131 
132  private:
MockGetExecutionPlan(TfLiteContext * context,TfLiteIntArray ** execution_plan)133   static TfLiteStatus MockGetExecutionPlan(TfLiteContext* context,
134                                            TfLiteIntArray** execution_plan) {
135     MockTfLiteContext* mock = reinterpret_cast<MockTfLiteContext*>(context);
136     *execution_plan = mock->exec_plan();
137     return kTfLiteOk;
138   }
139 
MockGetNodeAndRegistration(TfLiteContext * context,int node_index,TfLiteNode ** node,TfLiteRegistration ** registration)140   static TfLiteStatus MockGetNodeAndRegistration(
141       TfLiteContext* context, int node_index, TfLiteNode** node,
142       TfLiteRegistration** registration) {
143     MockTfLiteContext* mock = reinterpret_cast<MockTfLiteContext*>(context);
144     *node = mock->node();
145     *registration = mock->registration();
146     return kTfLiteOk;
147   }
148 
MockPreviewDelegatePartitioning(TfLiteContext * context,const TfLiteIntArray * nodes_to_replace,TfLiteDelegateParams ** partition_params_array,int * num_partitions)149   static TfLiteStatus MockPreviewDelegatePartitioning(
150       TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
151       TfLiteDelegateParams** partition_params_array, int* num_partitions) {
152     MockTfLiteContext* mock = reinterpret_cast<MockTfLiteContext*>(context);
153     *partition_params_array = mock->delegate_params();
154     *num_partitions = mock->num_delegate_params();
155     return kTfLiteOk;
156   }
157 
158   // The execution plan of this mocked TfLiteContext object.
159   TfLiteIntArray* exec_plan_;
160 
161   // For simplicity, the mocked graph has only type of node and one
162   // registration.
163   TfLiteNode node_;
164   TfLiteRegistration registration_;
165 
166   // The TfLiteDelegateParams object that's manually populated inside the mocked
167   // TfLiteContext::PreviewDelegatePartitioning.
168   std::vector<TfLiteDelegateParams> delegate_params_;
169 };
170 
IsNodeSupported(TfLiteContext * context,TfLiteNode * node,TfLiteRegistration * registration,std::string * unsupported_details)171 bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
172                      TfLiteRegistration* registration,
173                      std::string* unsupported_details) {
174   return true;
175 }
176 
GetNodesToReplaceFromPartitions(const std::vector<TfLiteDelegateParams * > & partitions)177 std::vector<int> GetNodesToReplaceFromPartitions(
178     const std::vector<TfLiteDelegateParams*>& partitions) {
179   std::vector<int> nodes;
180   for (const auto p : partitions) {
181     nodes.insert(nodes.end(), p->nodes_to_replace->data,
182                  p->nodes_to_replace->data + p->nodes_to_replace->size);
183   }
184   return nodes;
185 }
186 
TEST(GraphPartitionHelper,CheckPartitions)187 TEST(GraphPartitionHelper, CheckPartitions) {
188   // The mocked TfLiteContext has 4 partitions: {1}, {0,3,7,8}, {2,4,9}, {5,6}.
189   MockTfLiteContext mocked_context;
190   GraphPartitionHelper helper(&mocked_context, IsNodeSupported);
191   EXPECT_EQ(kTfLiteOk, helper.Partition(nullptr));
192   EXPECT_EQ(10, helper.num_total_nodes());
193   EXPECT_EQ(4, helper.num_partitions());
194 
195   auto partitions = helper.GetFirstNLargestPartitions(1, 0);
196   EXPECT_EQ(1, partitions.size());
197   auto nodes = GetNodesToReplaceFromPartitions(partitions);
198   EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8}));
199 
200   // Get the largest partition but requiring at least 5 nodes, so empty result.
201   partitions = helper.GetFirstNLargestPartitions(1, 5);
202   EXPECT_TRUE(partitions.empty());
203 
204   partitions = helper.GetFirstNLargestPartitions(10, 3);
205   EXPECT_EQ(2, partitions.size());
206   EXPECT_EQ(4, partitions[0]->nodes_to_replace->size);
207   EXPECT_EQ(3, partitions[1]->nodes_to_replace->size);
208   nodes = GetNodesToReplaceFromPartitions(partitions);
209   EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8, 2, 4, 9}));
210 }
211 
ErrorGetExecutionPlan(TfLiteContext * context,TfLiteIntArray ** execution_plan)212 TfLiteStatus ErrorGetExecutionPlan(TfLiteContext* context,
213                                    TfLiteIntArray** execution_plan) {
214   return kTfLiteError;
215 }
216 
EmptyReportError(TfLiteContext * context,const char * format,...)217 void EmptyReportError(TfLiteContext* context, const char* format, ...) {}
218 
TEST(GraphPartitionHelper,CheckPrepareErrors)219 TEST(GraphPartitionHelper, CheckPrepareErrors) {
220   TfLiteContext error_context({0});
221   error_context.GetExecutionPlan = ErrorGetExecutionPlan;
222   error_context.ReportError = EmptyReportError;
223   GraphPartitionHelper helper(&error_context, IsNodeSupported);
224   EXPECT_EQ(kTfLiteError, helper.Partition(nullptr));
225 }
226 
TEST(GraphPartitionHelper,CheckPartitionsWithSupportedNodeList)227 TEST(GraphPartitionHelper, CheckPartitionsWithSupportedNodeList) {
228   // The mocked TfLiteContext has 4 partitions: {1}, {0,3,7,8}, {2,4,9}, {5,6}.
229   // So, we simply create a list of supported nodes as {0,1,2,...,8,9}
230   MockTfLiteContext mocked_context;
231   std::vector<int> supported_nodes = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
232   GraphPartitionHelper helper(&mocked_context, supported_nodes);
233   EXPECT_EQ(kTfLiteOk, helper.Partition(nullptr));
234   EXPECT_EQ(10, helper.num_total_nodes());
235   EXPECT_EQ(4, helper.num_partitions());
236 
237   auto partitions = helper.GetFirstNLargestPartitions(1, 0);
238   EXPECT_EQ(1, partitions.size());
239   auto nodes = GetNodesToReplaceFromPartitions(partitions);
240   EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8}));
241 
242   // Get the largest partition but requiring at least 5 nodes, so empty result.
243   partitions = helper.GetFirstNLargestPartitions(1, 5);
244   EXPECT_TRUE(partitions.empty());
245 
246   partitions = helper.GetFirstNLargestPartitions(10, 3);
247   EXPECT_EQ(2, partitions.size());
248   EXPECT_EQ(4, partitions[0]->nodes_to_replace->size);
249   EXPECT_EQ(3, partitions[1]->nodes_to_replace->size);
250   nodes = GetNodesToReplaceFromPartitions(partitions);
251   EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8, 2, 4, 9}));
252 }
253 
254 }  // namespace
255 }  // namespace delegates
256 }  // namespace tflite
257