1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/profiling/profile_summarizer.h"
17
18 #include <string>
19 #include <vector>
20
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/context.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/kernels/subgraph_test_util.h"
26 #include "tensorflow/lite/kernels/test_util.h"
27 #include "tensorflow/lite/model.h"
28 #include "tensorflow/lite/profiling/buffered_profiler.h"
29 #include "tensorflow/lite/version.h"
30
31 namespace tflite {
32 namespace profiling {
33
34 namespace {
35
36 const char* kOpName = "SimpleOpEval";
37
SimpleOpEval(TfLiteContext * context,TfLiteNode * node)38 TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) {
39 const TfLiteTensor* input1;
40 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, /*index=*/0, &input1));
41 const TfLiteTensor* input2;
42 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, /*index=*/1, &input2));
43
44 TfLiteTensor* output;
45 TF_LITE_ENSURE_OK(context,
46 GetOutputSafe(context, node, /*index=*/0, &output));
47
48 int32_t* output_data = output->data.i32;
49 *output_data = *(input1->data.i32) + *(input2->data.i32);
50 return kTfLiteOk;
51 }
52
SimpleOpProfilingString(const TfLiteContext * context,const TfLiteNode * node)53 const char* SimpleOpProfilingString(const TfLiteContext* context,
54 const TfLiteNode* node) {
55 return "Profile";
56 }
57
RegisterSimpleOp()58 TfLiteRegistration* RegisterSimpleOp() {
59 static TfLiteRegistration registration = {
60 nullptr, nullptr, nullptr,
61 SimpleOpEval, nullptr, tflite::BuiltinOperator_CUSTOM,
62 "SimpleOpEval", 1};
63 return ®istration;
64 }
65
RegisterSimpleOpWithProfilingDetails()66 TfLiteRegistration* RegisterSimpleOpWithProfilingDetails() {
67 static TfLiteRegistration registration = {nullptr,
68 nullptr,
69 nullptr,
70 SimpleOpEval,
71 SimpleOpProfilingString,
72 tflite::BuiltinOperator_CUSTOM,
73 kOpName,
74 1};
75 return ®istration;
76 }
77
78 class SimpleOpModel : public SingleOpModel {
79 public:
80 void Init(const std::function<TfLiteRegistration*()>& registration);
GetInterpreter()81 tflite::Interpreter* GetInterpreter() { return interpreter_.get(); }
SetInputs(int32_t x,int32_t y)82 void SetInputs(int32_t x, int32_t y) {
83 PopulateTensor(inputs_[0], {x});
84 PopulateTensor(inputs_[1], {y});
85 }
GetOutput()86 int32_t GetOutput() { return ExtractVector<int32_t>(output_)[0]; }
87
88 private:
89 int inputs_[2];
90 int output_;
91 };
92
Init(const std::function<TfLiteRegistration * ()> & registration)93 void SimpleOpModel::Init(
94 const std::function<TfLiteRegistration*()>& registration) {
95 inputs_[0] = AddInput({TensorType_INT32, {1}});
96 inputs_[1] = AddInput({TensorType_INT32, {1}});
97 output_ = AddOutput({TensorType_INT32, {}});
98 SetCustomOp(kOpName, {}, registration);
99 BuildInterpreter({GetShape(inputs_[0]), GetShape(inputs_[1])});
100 }
101
TEST(ProfileSummarizerTest,Empty)102 TEST(ProfileSummarizerTest, Empty) {
103 ProfileSummarizer summarizer;
104 std::string output = summarizer.GetOutputString();
105 EXPECT_GT(output.size(), 0);
106 }
107
TEST(ProfileSummarizerTest,Interpreter)108 TEST(ProfileSummarizerTest, Interpreter) {
109 BufferedProfiler profiler(1024);
110 SimpleOpModel m;
111 m.Init(RegisterSimpleOp);
112 auto interpreter = m.GetInterpreter();
113 interpreter->SetProfiler(&profiler);
114 profiler.StartProfiling();
115 m.SetInputs(1, 2);
116 ASSERT_EQ(m.Invoke(), kTfLiteOk);
117 // 3 = 1 + 2
118 EXPECT_EQ(m.GetOutput(), 3);
119 profiler.StopProfiling();
120 ProfileSummarizer summarizer;
121 auto events = profiler.GetProfileEvents();
122 EXPECT_EQ(2, events.size());
123 summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter);
124 auto output = summarizer.GetOutputString();
125 // TODO(shashishekhar): Add a better test here.
126 ASSERT_TRUE(output.find("SimpleOpEval") != std::string::npos) << output;
127 ASSERT_TRUE(output.find("Invoke") == std::string::npos) << output; // NOLINT
128 }
129
TEST(ProfileSummarizerTest,InterpreterPlusProfilingDetails)130 TEST(ProfileSummarizerTest, InterpreterPlusProfilingDetails) {
131 BufferedProfiler profiler(1024);
132 SimpleOpModel m;
133 m.Init(RegisterSimpleOpWithProfilingDetails);
134 auto interpreter = m.GetInterpreter();
135 interpreter->SetProfiler(&profiler);
136 profiler.StartProfiling();
137 m.SetInputs(1, 2);
138 ASSERT_EQ(m.Invoke(), kTfLiteOk);
139 // 3 = 1 + 2
140 EXPECT_EQ(m.GetOutput(), 3);
141 profiler.StopProfiling();
142 ProfileSummarizer summarizer;
143 auto events = profiler.GetProfileEvents();
144 EXPECT_EQ(2, events.size());
145 summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter);
146 auto output = summarizer.GetOutputString();
147 // TODO(shashishekhar): Add a better test here.
148 ASSERT_TRUE(output.find("SimpleOpEval/Profile") != std::string::npos)
149 << output;
150 }
151
152 // A simple test that performs `ADD` if condition is true, and `MUL` otherwise.
153 // The computation is: `cond ? a + b : a * b`.
154 class ProfileSummarizerIfOpTest : public subgraph_test_util::ControlFlowOpTest {
155 protected:
SetUp()156 void SetUp() override {
157 AddSubgraphs(2);
158 builder_->BuildAddSubgraph(interpreter_->subgraph(1));
159 builder_->BuildMulSubgraph(interpreter_->subgraph(2));
160 builder_->BuildIfSubgraph(&interpreter_->primary_subgraph());
161
162 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
163 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2});
164 interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1, 2});
165 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
166
167 subgraph_test_util::FillIntTensor(
168 interpreter_->tensor(interpreter_->inputs()[1]), {5, 7});
169 subgraph_test_util::FillIntTensor(
170 interpreter_->tensor(interpreter_->inputs()[2]), {1, 2});
171 }
172 };
173
TEST_F(ProfileSummarizerIfOpTest,TestIfTrue)174 TEST_F(ProfileSummarizerIfOpTest, TestIfTrue) {
175 BufferedProfiler profiler(1024);
176 interpreter_->SetProfiler(&profiler);
177
178 interpreter_->typed_input_tensor<bool>(0)[0] = true;
179 profiler.StartProfiling();
180 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
181 profiler.StopProfiling();
182 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
183 subgraph_test_util::CheckIntTensor(output, {1, 2}, {6, 9});
184
185 auto events = profiler.GetProfileEvents();
186 EXPECT_EQ(4, events.size());
187 int event_count_of_subgraph_zero = std::count_if(
188 events.begin(), events.end(),
189 [](auto event) { return event->extra_event_metadata == 0; });
190 int event_count_of_subgraph_one = std::count_if(
191 events.begin(), events.end(),
192 [](auto event) { return event->extra_event_metadata == 1; });
193 int event_count_of_subgraph_two = std::count_if(
194 events.begin(), events.end(),
195 [](auto event) { return event->extra_event_metadata == 2; });
196 EXPECT_EQ(2, event_count_of_subgraph_zero);
197 EXPECT_EQ(2, event_count_of_subgraph_one);
198 EXPECT_EQ(0, event_count_of_subgraph_two);
199 }
200
TEST_F(ProfileSummarizerIfOpTest,TestIfFalse)201 TEST_F(ProfileSummarizerIfOpTest, TestIfFalse) {
202 BufferedProfiler profiler(1024);
203 interpreter_->SetProfiler(&profiler);
204
205 interpreter_->typed_input_tensor<bool>(0)[0] = false;
206 profiler.StartProfiling();
207 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
208 profiler.StopProfiling();
209 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
210 subgraph_test_util::CheckIntTensor(output, {1, 2}, {5, 14});
211
212 auto events = profiler.GetProfileEvents();
213 EXPECT_EQ(4, events.size());
214 int event_count_of_subgraph_zero = std::count_if(
215 events.begin(), events.end(),
216 [](auto event) { return event->extra_event_metadata == 0; });
217 int event_count_of_subgraph_one = std::count_if(
218 events.begin(), events.end(),
219 [](auto event) { return event->extra_event_metadata == 1; });
220 int event_count_of_subgraph_two = std::count_if(
221 events.begin(), events.end(),
222 [](auto event) { return event->extra_event_metadata == 2; });
223 EXPECT_EQ(2, event_count_of_subgraph_zero);
224 EXPECT_EQ(0, event_count_of_subgraph_one);
225 EXPECT_EQ(2, event_count_of_subgraph_two);
226 }
227
228 } // namespace
229 } // namespace profiling
230 } // namespace tflite
231