• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <vector>
18 
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include "tensorflow/lite/context.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/kernels/test_util.h"
24 #include "tensorflow/lite/model.h"
25 #include "tensorflow/lite/profiling/profile_summarizer.h"
26 #include "tensorflow/lite/testing/util.h"
27 #include "tensorflow/lite/version.h"
28 
29 namespace tflite {
30 namespace profiling {
31 
32 namespace {
33 
34 const char* kOpName = "SimpleOpEval";
35 
36 #ifdef TFLITE_PROFILING_ENABLED
SimpleOpEval(TfLiteContext * context,TfLiteNode * node)37 TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) {
38   const TfLiteTensor* input1 = tflite::GetInput(context, node, /*index=*/0);
39   const TfLiteTensor* input2 = tflite::GetInput(context, node, /*index=*/1);
40 
41   TfLiteTensor* output = GetOutput(context, node, /*index=*/0);
42 
43   int32_t* output_data = output->data.i32;
44   *output_data = *(input1->data.i32) + *(input2->data.i32);
45   return kTfLiteOk;
46 }
47 
SimpleOpProfilingString(const TfLiteContext * context,const TfLiteNode * node)48 const char* SimpleOpProfilingString(const TfLiteContext* context,
49                                     const TfLiteNode* node) {
50   return "Profile";
51 }
52 
RegisterSimpleOp()53 TfLiteRegistration* RegisterSimpleOp() {
54   static TfLiteRegistration registration = {
55       nullptr,        nullptr, nullptr,
56       SimpleOpEval,   nullptr, tflite::BuiltinOperator_CUSTOM,
57       "SimpleOpEval", 1};
58   return &registration;
59 }
60 
RegisterSimpleOpWithProfilingDetails()61 TfLiteRegistration* RegisterSimpleOpWithProfilingDetails() {
62   static TfLiteRegistration registration = {nullptr,
63                                             nullptr,
64                                             nullptr,
65                                             SimpleOpEval,
66                                             SimpleOpProfilingString,
67                                             tflite::BuiltinOperator_CUSTOM,
68                                             kOpName,
69                                             1};
70   return &registration;
71 }
72 #endif
73 
74 class SimpleOpModel : public SingleOpModel {
75  public:
76   void Init(const std::function<TfLiteRegistration*()>& registration);
GetInterpreter()77   tflite::Interpreter* GetInterpreter() { return interpreter_.get(); }
SetInputs(int32_t x,int32_t y)78   void SetInputs(int32_t x, int32_t y) {
79     PopulateTensor(inputs_[0], {x});
80     PopulateTensor(inputs_[1], {y});
81   }
GetOutput()82   int32_t GetOutput() { return ExtractVector<int32_t>(output_)[0]; }
83 
84  private:
85   int inputs_[2];
86   int output_;
87 };
88 
Init(const std::function<TfLiteRegistration * ()> & registration)89 void SimpleOpModel::Init(
90     const std::function<TfLiteRegistration*()>& registration) {
91   inputs_[0] = AddInput({TensorType_INT32, {1}});
92   inputs_[1] = AddInput({TensorType_INT32, {1}});
93   output_ = AddOutput({TensorType_INT32, {}});
94   SetCustomOp(kOpName, {}, registration);
95   BuildInterpreter({GetShape(inputs_[0]), GetShape(inputs_[1])});
96 }
97 
TEST(ProfileSummarizerTest,Empty)98 TEST(ProfileSummarizerTest, Empty) {
99   ProfileSummarizer summarizer;
100   std::string output = summarizer.GetOutputString();
101   EXPECT_GT(output.size(), 0);
102 }
103 
104 #ifdef TFLITE_PROFILING_ENABLED
TEST(ProfileSummarizerTest,Interpreter)105 TEST(ProfileSummarizerTest, Interpreter) {
106   Profiler profiler;
107   SimpleOpModel m;
108   m.Init(RegisterSimpleOp);
109   auto interpreter = m.GetInterpreter();
110   interpreter->SetProfiler(&profiler);
111   profiler.StartProfiling();
112   m.SetInputs(1, 2);
113   m.Invoke();
114   // 3 = 1 + 2
115   EXPECT_EQ(m.GetOutput(), 3);
116   profiler.StopProfiling();
117   ProfileSummarizer summarizer;
118   auto events = profiler.GetProfileEvents();
119   EXPECT_EQ(1, events.size());
120   summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter);
121   auto output = summarizer.GetOutputString();
122   // TODO(shashishekhar): Add a better test here.
123   ASSERT_TRUE(output.find("SimpleOpEval") != std::string::npos) << output;
124 }
125 
TEST(ProfileSummarizerTest,InterpreterPlusProfilingDetails)126 TEST(ProfileSummarizerTest, InterpreterPlusProfilingDetails) {
127   Profiler profiler;
128   SimpleOpModel m;
129   m.Init(RegisterSimpleOpWithProfilingDetails);
130   auto interpreter = m.GetInterpreter();
131   interpreter->SetProfiler(&profiler);
132   profiler.StartProfiling();
133   m.SetInputs(1, 2);
134   m.Invoke();
135   // 3 = 1 + 2
136   EXPECT_EQ(m.GetOutput(), 3);
137   profiler.StopProfiling();
138   ProfileSummarizer summarizer;
139   auto events = profiler.GetProfileEvents();
140   EXPECT_EQ(1, events.size());
141   summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter);
142   auto output = summarizer.GetOutputString();
143   // TODO(shashishekhar): Add a better test here.
144   ASSERT_TRUE(output.find("SimpleOpEval:Profile") != std::string::npos)
145       << output;
146 }
147 
148 #endif
149 
150 }  // namespace
151 }  // namespace profiling
152 }  // namespace tflite
153 
main(int argc,char ** argv)154 int main(int argc, char** argv) {
155   ::tflite::LogToStderr();
156   ::testing::InitGoogleTest(&argc, argv);
157   return RUN_ALL_TESTS();
158 }
159