• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/profiling/profile_summarizer.h"
17 
18 #include <string>
19 #include <vector>
20 
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/context.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/kernels/subgraph_test_util.h"
26 #include "tensorflow/lite/kernels/test_util.h"
27 #include "tensorflow/lite/model.h"
28 #include "tensorflow/lite/profiling/buffered_profiler.h"
29 #include "tensorflow/lite/version.h"
30 
31 namespace tflite {
32 namespace profiling {
33 
34 namespace {
35 
36 const char* kOpName = "SimpleOpEval";
37 
SimpleOpEval(TfLiteContext * context,TfLiteNode * node)38 TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) {
39   const TfLiteTensor* input1;
40   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, /*index=*/0, &input1));
41   const TfLiteTensor* input2;
42   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, /*index=*/1, &input2));
43 
44   TfLiteTensor* output;
45   TF_LITE_ENSURE_OK(context,
46                     GetOutputSafe(context, node, /*index=*/0, &output));
47 
48   int32_t* output_data = output->data.i32;
49   *output_data = *(input1->data.i32) + *(input2->data.i32);
50   return kTfLiteOk;
51 }
52 
SimpleOpProfilingString(const TfLiteContext * context,const TfLiteNode * node)53 const char* SimpleOpProfilingString(const TfLiteContext* context,
54                                     const TfLiteNode* node) {
55   return "Profile";
56 }
57 
RegisterSimpleOp()58 TfLiteRegistration* RegisterSimpleOp() {
59   static TfLiteRegistration registration = {
60       nullptr,        nullptr, nullptr,
61       SimpleOpEval,   nullptr, tflite::BuiltinOperator_CUSTOM,
62       "SimpleOpEval", 1};
63   return &registration;
64 }
65 
RegisterSimpleOpWithProfilingDetails()66 TfLiteRegistration* RegisterSimpleOpWithProfilingDetails() {
67   static TfLiteRegistration registration = {nullptr,
68                                             nullptr,
69                                             nullptr,
70                                             SimpleOpEval,
71                                             SimpleOpProfilingString,
72                                             tflite::BuiltinOperator_CUSTOM,
73                                             kOpName,
74                                             1};
75   return &registration;
76 }
77 
78 class SimpleOpModel : public SingleOpModel {
79  public:
80   void Init(const std::function<TfLiteRegistration*()>& registration);
GetInterpreter()81   tflite::Interpreter* GetInterpreter() { return interpreter_.get(); }
SetInputs(int32_t x,int32_t y)82   void SetInputs(int32_t x, int32_t y) {
83     PopulateTensor(inputs_[0], {x});
84     PopulateTensor(inputs_[1], {y});
85   }
GetOutput()86   int32_t GetOutput() { return ExtractVector<int32_t>(output_)[0]; }
87 
88  private:
89   int inputs_[2];
90   int output_;
91 };
92 
Init(const std::function<TfLiteRegistration * ()> & registration)93 void SimpleOpModel::Init(
94     const std::function<TfLiteRegistration*()>& registration) {
95   inputs_[0] = AddInput({TensorType_INT32, {1}});
96   inputs_[1] = AddInput({TensorType_INT32, {1}});
97   output_ = AddOutput({TensorType_INT32, {}});
98   SetCustomOp(kOpName, {}, registration);
99   BuildInterpreter({GetShape(inputs_[0]), GetShape(inputs_[1])});
100 }
101 
TEST(ProfileSummarizerTest,Empty)102 TEST(ProfileSummarizerTest, Empty) {
103   ProfileSummarizer summarizer;
104   std::string output = summarizer.GetOutputString();
105   EXPECT_GT(output.size(), 0);
106 }
107 
TEST(ProfileSummarizerTest,Interpreter)108 TEST(ProfileSummarizerTest, Interpreter) {
109   BufferedProfiler profiler(1024);
110   SimpleOpModel m;
111   m.Init(RegisterSimpleOp);
112   auto interpreter = m.GetInterpreter();
113   interpreter->SetProfiler(&profiler);
114   profiler.StartProfiling();
115   m.SetInputs(1, 2);
116   ASSERT_EQ(m.Invoke(), kTfLiteOk);
117   // 3 = 1 + 2
118   EXPECT_EQ(m.GetOutput(), 3);
119   profiler.StopProfiling();
120   ProfileSummarizer summarizer;
121   auto events = profiler.GetProfileEvents();
122   EXPECT_EQ(2, events.size());
123   summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter);
124   auto output = summarizer.GetOutputString();
125   // TODO(shashishekhar): Add a better test here.
126   ASSERT_TRUE(output.find("SimpleOpEval") != std::string::npos) << output;
127   ASSERT_TRUE(output.find("Invoke") == std::string::npos) << output;  // NOLINT
128 }
129 
TEST(ProfileSummarizerTest,InterpreterPlusProfilingDetails)130 TEST(ProfileSummarizerTest, InterpreterPlusProfilingDetails) {
131   BufferedProfiler profiler(1024);
132   SimpleOpModel m;
133   m.Init(RegisterSimpleOpWithProfilingDetails);
134   auto interpreter = m.GetInterpreter();
135   interpreter->SetProfiler(&profiler);
136   profiler.StartProfiling();
137   m.SetInputs(1, 2);
138   ASSERT_EQ(m.Invoke(), kTfLiteOk);
139   // 3 = 1 + 2
140   EXPECT_EQ(m.GetOutput(), 3);
141   profiler.StopProfiling();
142   ProfileSummarizer summarizer;
143   auto events = profiler.GetProfileEvents();
144   EXPECT_EQ(2, events.size());
145   summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter);
146   auto output = summarizer.GetOutputString();
147   // TODO(shashishekhar): Add a better test here.
148   ASSERT_TRUE(output.find("SimpleOpEval/Profile") != std::string::npos)
149       << output;
150 }
151 
152 // A simple test that performs `ADD` if condition is true, and `MUL` otherwise.
153 // The computation is: `cond ? a + b : a * b`.
154 class ProfileSummarizerIfOpTest : public subgraph_test_util::ControlFlowOpTest {
155  protected:
SetUp()156   void SetUp() override {
157     AddSubgraphs(2);
158     builder_->BuildAddSubgraph(interpreter_->subgraph(1));
159     builder_->BuildMulSubgraph(interpreter_->subgraph(2));
160     builder_->BuildIfSubgraph(&interpreter_->primary_subgraph());
161 
162     interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
163     interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2});
164     interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1, 2});
165     ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
166 
167     subgraph_test_util::FillIntTensor(
168         interpreter_->tensor(interpreter_->inputs()[1]), {5, 7});
169     subgraph_test_util::FillIntTensor(
170         interpreter_->tensor(interpreter_->inputs()[2]), {1, 2});
171   }
172 };
173 
TEST_F(ProfileSummarizerIfOpTest,TestIfTrue)174 TEST_F(ProfileSummarizerIfOpTest, TestIfTrue) {
175   BufferedProfiler profiler(1024);
176   interpreter_->SetProfiler(&profiler);
177 
178   interpreter_->typed_input_tensor<bool>(0)[0] = true;
179   profiler.StartProfiling();
180   ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
181   profiler.StopProfiling();
182   TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
183   subgraph_test_util::CheckIntTensor(output, {1, 2}, {6, 9});
184 
185   auto events = profiler.GetProfileEvents();
186   EXPECT_EQ(4, events.size());
187   int event_count_of_subgraph_zero = std::count_if(
188       events.begin(), events.end(),
189       [](auto event) { return event->extra_event_metadata == 0; });
190   int event_count_of_subgraph_one = std::count_if(
191       events.begin(), events.end(),
192       [](auto event) { return event->extra_event_metadata == 1; });
193   int event_count_of_subgraph_two = std::count_if(
194       events.begin(), events.end(),
195       [](auto event) { return event->extra_event_metadata == 2; });
196   EXPECT_EQ(2, event_count_of_subgraph_zero);
197   EXPECT_EQ(2, event_count_of_subgraph_one);
198   EXPECT_EQ(0, event_count_of_subgraph_two);
199 }
200 
TEST_F(ProfileSummarizerIfOpTest,TestIfFalse)201 TEST_F(ProfileSummarizerIfOpTest, TestIfFalse) {
202   BufferedProfiler profiler(1024);
203   interpreter_->SetProfiler(&profiler);
204 
205   interpreter_->typed_input_tensor<bool>(0)[0] = false;
206   profiler.StartProfiling();
207   ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
208   profiler.StopProfiling();
209   TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
210   subgraph_test_util::CheckIntTensor(output, {1, 2}, {5, 14});
211 
212   auto events = profiler.GetProfileEvents();
213   EXPECT_EQ(4, events.size());
214   int event_count_of_subgraph_zero = std::count_if(
215       events.begin(), events.end(),
216       [](auto event) { return event->extra_event_metadata == 0; });
217   int event_count_of_subgraph_one = std::count_if(
218       events.begin(), events.end(),
219       [](auto event) { return event->extra_event_metadata == 1; });
220   int event_count_of_subgraph_two = std::count_if(
221       events.begin(), events.end(),
222       [](auto event) { return event->extra_event_metadata == 2; });
223   EXPECT_EQ(2, event_count_of_subgraph_zero);
224   EXPECT_EQ(0, event_count_of_subgraph_one);
225   EXPECT_EQ(2, event_count_of_subgraph_two);
226 }
227 
228 }  // namespace
229 }  // namespace profiling
230 }  // namespace tflite
231