• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>  // For std::unique_ptr.
17 #include <thread>  // NOLINT(build/c++11)
18 #include <utility>
19 #include <vector>
20 
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/delegates/xnnpack/conv_2d_tester.h"
24 #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
25 #include "tensorflow/lite/interpreter.h"
26 #include "tensorflow/lite/kernels/register.h"
27 #include "tensorflow/lite/model_builder.h"
28 
29 namespace tflite {
30 namespace xnnpack {
31 
TEST(XNNPACK_WEIGHTS_CACHE,WithSize)32 TEST(XNNPACK_WEIGHTS_CACHE, WithSize) {
33   std::vector<char> buffer = Conv2DTester().CreateTfLiteModel();
34   const Model* model = GetModel(buffer.data());
35   ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
36 
37   std::unique_ptr<Interpreter> interpreter;
38   ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter));
39   ASSERT_EQ(kTfLiteOk, interpreter->AllocateTensors());
40 
41   size_t four_mb = 4194304;
42   std::unique_ptr<TfLiteXNNPackDelegateWeightsCache,
43                   decltype(&TfLiteXNNPackDelegateWeightsCacheDelete)>
44       weights_cache(TfLiteXNNPackDelegateWeightsCacheCreateWithSize(four_mb),
45                     TfLiteXNNPackDelegateWeightsCacheDelete);
46 
47   TfLiteXNNPackDelegateOptions delegate_options =
48       TfLiteXNNPackDelegateOptionsDefault();
49   delegate_options.weights_cache = weights_cache.get();
50 
51   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
52       delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
53                TfLiteXNNPackDelegateDelete);
54 
55   ASSERT_EQ(kTfLiteOk, interpreter->ModifyGraphWithDelegate(delegate.get()));
56 
57   ASSERT_TRUE(
58       TfLiteXNNPackDelegateWeightsCacheFinalizeHard(weights_cache.get()));
59 
60   ASSERT_EQ(kTfLiteOk, interpreter->Invoke());
61 }
62 
TEST(XNNPACK_WEIGHTS_CACHE,InvokeBeforeFinalization)63 TEST(XNNPACK_WEIGHTS_CACHE, InvokeBeforeFinalization) {
64   std::vector<char> buffer = Conv2DTester().CreateTfLiteModel();
65   const Model* model = GetModel(buffer.data());
66   ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
67 
68   std::unique_ptr<Interpreter> interpreter;
69   ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter));
70   ASSERT_EQ(kTfLiteOk, interpreter->AllocateTensors());
71 
72   std::unique_ptr<TfLiteXNNPackDelegateWeightsCache,
73                   decltype(&TfLiteXNNPackDelegateWeightsCacheDelete)>
74       weights_cache(TfLiteXNNPackDelegateWeightsCacheCreate(),
75                     TfLiteXNNPackDelegateWeightsCacheDelete);
76 
77   TfLiteXNNPackDelegateOptions delegate_options =
78       TfLiteXNNPackDelegateOptionsDefault();
79   delegate_options.weights_cache = weights_cache.get();
80 
81   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
82       delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
83                TfLiteXNNPackDelegateDelete);
84 
85   ASSERT_EQ(kTfLiteOk, interpreter->ModifyGraphWithDelegate(delegate.get()));
86 
87   // Invoking before finalization fails.
88   ASSERT_NE(kTfLiteOk, interpreter->Invoke());
89 }
90 
TEST(XNNPACK_WEIGHTS_CACHE,HardFinalization)91 TEST(XNNPACK_WEIGHTS_CACHE, HardFinalization) {
92   std::vector<char> buffer = Conv2DTester().CreateTfLiteModel();
93   const Model* model = GetModel(buffer.data());
94   ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
95 
96   std::unique_ptr<Interpreter> interpreter1;
97   ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter1));
98   ASSERT_EQ(kTfLiteOk, interpreter1->AllocateTensors());
99 
100   std::unique_ptr<TfLiteXNNPackDelegateWeightsCache,
101                   decltype(&TfLiteXNNPackDelegateWeightsCacheDelete)>
102       weights_cache(TfLiteXNNPackDelegateWeightsCacheCreate(),
103                     TfLiteXNNPackDelegateWeightsCacheDelete);
104 
105   TfLiteXNNPackDelegateOptions delegate_options =
106       TfLiteXNNPackDelegateOptionsDefault();
107   delegate_options.weights_cache = weights_cache.get();
108 
109   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
110       delegate1(TfLiteXNNPackDelegateCreate(&delegate_options),
111                 TfLiteXNNPackDelegateDelete);
112   ASSERT_EQ(kTfLiteOk, interpreter1->ModifyGraphWithDelegate(delegate1.get()));
113   ASSERT_TRUE(
114       TfLiteXNNPackDelegateWeightsCacheFinalizeHard(weights_cache.get()));
115 
116   ASSERT_EQ(kTfLiteOk, interpreter1->Invoke());
117 
118   // We cannot create new instances using the same weights cache after hard
119   // finalization.
120   std::unique_ptr<Interpreter> interpreter2;
121   ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter2));
122   ASSERT_EQ(kTfLiteOk, interpreter2->AllocateTensors());
123   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
124       delegate2(TfLiteXNNPackDelegateCreate(&delegate_options),
125                 TfLiteXNNPackDelegateDelete);
126   ASSERT_NE(kTfLiteOk, interpreter2->ModifyGraphWithDelegate(delegate2.get()));
127 }
128 
TEST(XNNPACK_WEIGHTS_CACHE,SoftFinalization)129 TEST(XNNPACK_WEIGHTS_CACHE, SoftFinalization) {
130   std::vector<char> buffer = Conv2DTester().CreateTfLiteModel();
131   const Model* model = GetModel(buffer.data());
132   ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
133 
134   std::unique_ptr<TfLiteXNNPackDelegateWeightsCache,
135                   decltype(&TfLiteXNNPackDelegateWeightsCacheDelete)>
136       weights_cache(TfLiteXNNPackDelegateWeightsCacheCreate(),
137                     TfLiteXNNPackDelegateWeightsCacheDelete);
138 
139   TfLiteXNNPackDelegateOptions delegate_options =
140       TfLiteXNNPackDelegateOptionsDefault();
141   delegate_options.weights_cache = weights_cache.get();
142 
143   std::unique_ptr<Interpreter> interpreter1;
144   ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter1));
145   ASSERT_EQ(kTfLiteOk, interpreter1->AllocateTensors());
146   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
147       delegate1(TfLiteXNNPackDelegateCreate(&delegate_options),
148                 TfLiteXNNPackDelegateDelete);
149   ASSERT_EQ(kTfLiteOk, interpreter1->ModifyGraphWithDelegate(delegate1.get()));
150 
151   ASSERT_TRUE(
152       TfLiteXNNPackDelegateWeightsCacheFinalizeSoft(weights_cache.get()));
153 
154   ASSERT_EQ(kTfLiteOk, interpreter1->Invoke());
155 
156   // Build a second interpreter, it should work after soft finalization.
157   std::unique_ptr<Interpreter> interpreter2;
158   ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter2));
159   ASSERT_EQ(kTfLiteOk, interpreter2->AllocateTensors());
160   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
161       delegate2(TfLiteXNNPackDelegateCreate(&delegate_options),
162                 TfLiteXNNPackDelegateDelete);
163   ASSERT_EQ(kTfLiteOk, interpreter2->ModifyGraphWithDelegate(delegate2.get()));
164   ASSERT_EQ(kTfLiteOk, interpreter2->Invoke());
165 }
166 
167 // Dummy class to use with parameterized test.
168 class WeightsCacheTest : public testing::TestWithParam<size_t> {};
169 
TEST_P(WeightsCacheTest,SoftFinalizationMultithreaded)170 TEST_P(WeightsCacheTest, SoftFinalizationMultithreaded) {
171   std::vector<char> buffer = Conv2DTester().CreateTfLiteModel();
172   const Model* model = GetModel(buffer.data());
173   ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
174 
175   std::unique_ptr<TfLiteXNNPackDelegateWeightsCache,
176                   decltype(&TfLiteXNNPackDelegateWeightsCacheDelete)>
177       weights_cache(TfLiteXNNPackDelegateWeightsCacheCreate(),
178                     TfLiteXNNPackDelegateWeightsCacheDelete);
179 
180   TfLiteXNNPackDelegateOptions delegate_options =
181       TfLiteXNNPackDelegateOptionsDefault();
182   delegate_options.weights_cache = weights_cache.get();
183 
184   // Create the first interpreter and finalize it.
185   std::unique_ptr<Interpreter> initial_interpreter;
186   ASSERT_EQ(kTfLiteOk,
187             InterpreterBuilder(model, resolver)(&initial_interpreter));
188   ASSERT_EQ(kTfLiteOk, initial_interpreter->AllocateTensors());
189   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
190       initial_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
191                        TfLiteXNNPackDelegateDelete);
192   ASSERT_EQ(kTfLiteOk, initial_interpreter->ModifyGraphWithDelegate(
193                            initial_delegate.get()));
194 
195   ASSERT_TRUE(
196       TfLiteXNNPackDelegateWeightsCacheFinalizeSoft(weights_cache.get()));
197 
198   ASSERT_EQ(kTfLiteOk, initial_interpreter->Invoke());
199 
200   // Create multiple interpreters afterwards.
201   const size_t num_threads = GetParam();
202   if (num_threads > std::thread::hardware_concurrency()) {
203     GTEST_SKIP();
204   }
205 
206   std::vector<std::thread> threads;
207   threads.reserve(num_threads);
208   for (size_t i = 0; i < num_threads; i++) {
209     threads.emplace_back(std::thread([&] {
210       std::unique_ptr<Interpreter> interpreter;
211       ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter));
212       ASSERT_EQ(kTfLiteOk, interpreter->AllocateTensors());
213 
214       std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
215           delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
216                    TfLiteXNNPackDelegateDelete);
217 
218       ASSERT_EQ(kTfLiteOk,
219                 interpreter->ModifyGraphWithDelegate(delegate.get()));
220       ASSERT_EQ(kTfLiteOk, interpreter->Invoke());
221     }));
222   }
223 
224   for (int i = 0; i < num_threads; i++) {
225     threads[i].join();
226   }
227 }
228 
229 INSTANTIATE_TEST_SUITE_P(WeightsCacheTest, WeightsCacheTest,
230                          testing::Values(2, 4),
231                          testing::PrintToStringParamName());
232 
233 }  // namespace xnnpack
234 }  // namespace tflite
235