1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17 #include <vector>
18
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include "tensorflow/lite/context.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/kernels/test_util.h"
24 #include "tensorflow/lite/model.h"
25 #include "tensorflow/lite/profiling/profile_summarizer.h"
26 #include "tensorflow/lite/testing/util.h"
27 #include "tensorflow/lite/version.h"
28
29 namespace tflite {
30 namespace profiling {
31
32 namespace {
33
34 const char* kOpName = "SimpleOpEval";
35
36 #ifdef TFLITE_PROFILING_ENABLED
SimpleOpEval(TfLiteContext * context,TfLiteNode * node)37 TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) {
38 const TfLiteTensor* input1 = tflite::GetInput(context, node, /*index=*/0);
39 const TfLiteTensor* input2 = tflite::GetInput(context, node, /*index=*/1);
40
41 TfLiteTensor* output = GetOutput(context, node, /*index=*/0);
42
43 int32_t* output_data = output->data.i32;
44 *output_data = *(input1->data.i32) + *(input2->data.i32);
45 return kTfLiteOk;
46 }
47
SimpleOpProfilingString(const TfLiteContext * context,const TfLiteNode * node)48 const char* SimpleOpProfilingString(const TfLiteContext* context,
49 const TfLiteNode* node) {
50 return "Profile";
51 }
52
RegisterSimpleOp()53 TfLiteRegistration* RegisterSimpleOp() {
54 static TfLiteRegistration registration = {
55 nullptr, nullptr, nullptr,
56 SimpleOpEval, nullptr, tflite::BuiltinOperator_CUSTOM,
57 "SimpleOpEval", 1};
58 return ®istration;
59 }
60
RegisterSimpleOpWithProfilingDetails()61 TfLiteRegistration* RegisterSimpleOpWithProfilingDetails() {
62 static TfLiteRegistration registration = {nullptr,
63 nullptr,
64 nullptr,
65 SimpleOpEval,
66 SimpleOpProfilingString,
67 tflite::BuiltinOperator_CUSTOM,
68 kOpName,
69 1};
70 return ®istration;
71 }
72 #endif
73
74 class SimpleOpModel : public SingleOpModel {
75 public:
76 void Init(const std::function<TfLiteRegistration*()>& registration);
GetInterpreter()77 tflite::Interpreter* GetInterpreter() { return interpreter_.get(); }
SetInputs(int32_t x,int32_t y)78 void SetInputs(int32_t x, int32_t y) {
79 PopulateTensor(inputs_[0], {x});
80 PopulateTensor(inputs_[1], {y});
81 }
GetOutput()82 int32_t GetOutput() { return ExtractVector<int32_t>(output_)[0]; }
83
84 private:
85 int inputs_[2];
86 int output_;
87 };
88
Init(const std::function<TfLiteRegistration * ()> & registration)89 void SimpleOpModel::Init(
90 const std::function<TfLiteRegistration*()>& registration) {
91 inputs_[0] = AddInput({TensorType_INT32, {1}});
92 inputs_[1] = AddInput({TensorType_INT32, {1}});
93 output_ = AddOutput({TensorType_INT32, {}});
94 SetCustomOp(kOpName, {}, registration);
95 BuildInterpreter({GetShape(inputs_[0]), GetShape(inputs_[1])});
96 }
97
TEST(ProfileSummarizerTest,Empty)98 TEST(ProfileSummarizerTest, Empty) {
99 ProfileSummarizer summarizer;
100 std::string output = summarizer.GetOutputString();
101 EXPECT_GT(output.size(), 0);
102 }
103
104 #ifdef TFLITE_PROFILING_ENABLED
TEST(ProfileSummarizerTest,Interpreter)105 TEST(ProfileSummarizerTest, Interpreter) {
106 Profiler profiler;
107 SimpleOpModel m;
108 m.Init(RegisterSimpleOp);
109 auto interpreter = m.GetInterpreter();
110 interpreter->SetProfiler(&profiler);
111 profiler.StartProfiling();
112 m.SetInputs(1, 2);
113 m.Invoke();
114 // 3 = 1 + 2
115 EXPECT_EQ(m.GetOutput(), 3);
116 profiler.StopProfiling();
117 ProfileSummarizer summarizer;
118 auto events = profiler.GetProfileEvents();
119 EXPECT_EQ(1, events.size());
120 summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter);
121 auto output = summarizer.GetOutputString();
122 // TODO(shashishekhar): Add a better test here.
123 ASSERT_TRUE(output.find("SimpleOpEval") != std::string::npos) << output;
124 }
125
TEST(ProfileSummarizerTest,InterpreterPlusProfilingDetails)126 TEST(ProfileSummarizerTest, InterpreterPlusProfilingDetails) {
127 Profiler profiler;
128 SimpleOpModel m;
129 m.Init(RegisterSimpleOpWithProfilingDetails);
130 auto interpreter = m.GetInterpreter();
131 interpreter->SetProfiler(&profiler);
132 profiler.StartProfiling();
133 m.SetInputs(1, 2);
134 m.Invoke();
135 // 3 = 1 + 2
136 EXPECT_EQ(m.GetOutput(), 3);
137 profiler.StopProfiling();
138 ProfileSummarizer summarizer;
139 auto events = profiler.GetProfileEvents();
140 EXPECT_EQ(1, events.size());
141 summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter);
142 auto output = summarizer.GetOutputString();
143 // TODO(shashishekhar): Add a better test here.
144 ASSERT_TRUE(output.find("SimpleOpEval:Profile") != std::string::npos)
145 << output;
146 }
147
148 #endif
149
150 } // namespace
151 } // namespace profiling
152 } // namespace tflite
153
main(int argc,char ** argv)154 int main(int argc, char** argv) {
155 ::tflite::LogToStderr();
156 ::testing::InitGoogleTest(&argc, argv);
157 return RUN_ALL_TESTS();
158 }
159