• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <gtest/gtest.h>
17 #include "tensorflow/lite/core/macros.h"
18 #include "tensorflow/lite/interpreter.h"
19 #include "tensorflow/lite/interpreter_builder.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 #include "tensorflow/lite/kernels/register.h"
22 #include "tensorflow/lite/model_builder.h"
23 #include "tensorflow/lite/tools/logging.h"
24 
25 namespace tflite {
26 
TestMemoryThreshold(const std::string & model_path,size_t threshold_in_kb)27 void TestMemoryThreshold(const std::string& model_path,
28                          size_t threshold_in_kb) {
29   // The Im2Col optimization is only applied on mobile platforms, so only
30   // validate on such platforms.
31   if (!IsMobilePlatform()) {
32     return;
33   }
34 
35   // The model has a conv op will require a huge temporary tensor if
36   // im2col is performed and it's possible to cause OOM on devices. To prevent
37   // this from happening, a size cap (i.e. kMaxIm2colBufferSizeMobile) of
38   // to-be-allocated im2col data is used to determine whether to disable
39   // im2col. This test will check the memory footprint before/after
40   // interpreter Invoke to ensure the size cap is correctly enforced on mobile
41   // platforms.
42   auto model = FlatBufferModel::BuildFromFile(model_path.c_str());
43   ASSERT_TRUE(model);
44   std::unique_ptr<Interpreter> interpreter;
45 
46   // Note that we explicitly set 1 thread here to avoid extra memory footprint
47   // caused by multithreading, which will make the memory usage threshold
48   // check later more reliable.
49   ASSERT_EQ(InterpreterBuilder(*model, ops::builtin::BuiltinOpResolver())(
50                 &interpreter, /*num_threads*/ 1),
51             kTfLiteOk);
52   ASSERT_TRUE(interpreter);
53   ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
54 
55   // Memory required for all tensors should be smaller than the  threshold.
56   int64_t accumulate_tensor_memory = 0;
57   for (int i = 0; i < interpreter->tensors_size(); ++i) {
58     accumulate_tensor_memory += interpreter->tensor(i)->bytes;
59   }
60   EXPECT_LE(accumulate_tensor_memory, threshold_in_kb * 1024);
61 }
62 
TEST(ConvMemUsage,HugeIm2ColData)63 TEST(ConvMemUsage, HugeIm2ColData) {
64   TestMemoryThreshold(
65       // The model has a conv op will require a temporary tensor of ~3.5GB if
66       // im2col is performed.
67       "tensorflow/lite/testdata/conv_huge_im2col.bin",
68       /*threshold_in_kb=*/3 * 1024 * 1024);
69 }
70 
TEST(Conv3DMemUsage,HugeIm2ColData)71 TEST(Conv3DMemUsage, HugeIm2ColData) {
72   TestMemoryThreshold(
73       // The model has a Conv3D op will require a temporary tensor of ~1.3GB if
74       // im2col is performed.If not, it will use about 450MB.
75       "tensorflow/lite/testdata/conv3d_huge_im2col.bin",
76       /*threshold_in_kb=*/1 * 1024 * 1024);
77 }
78 
79 }  // namespace tflite
80