• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/profiling/root_profiler.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/core/api/profiler.h"
24 
25 using ::testing::_;
26 using ::testing::StrictMock;
27 
28 namespace tflite {
29 namespace profiling {
30 
31 namespace {
32 
33 constexpr char kTag[] = "tag";
34 
35 class MockProfiler : public Profiler {
36  public:
37   MOCK_METHOD(uint32_t, BeginEvent,
38               (const char* tag, EventType event_type, int64_t event_metadata1,
39                int64_t event_metadata2),
40               (override));
41   MOCK_METHOD(void, EndEvent, (uint32_t event_handle), (override));
42   MOCK_METHOD(void, EndEvent,
43               (uint32_t event_handle, int64_t event_metadata1,
44                int64_t event_metadata2),
45               (override));
46   MOCK_METHOD(void, AddEvent,
47               (const char* tag, EventType event_type, uint64_t metric,
48                int64_t event_metadata1, int64_t event_metadata2),
49               (override));
50   MOCK_METHOD(void, AddEventWithData,
51               (const char* tag, EventType event_type, const void* data),
52               (override));
53 };
54 
55 using MockProfilerT = StrictMock<MockProfiler>;
56 
TEST(RootProfilerTest,ChildProfilerTest)57 TEST(RootProfilerTest, ChildProfilerTest) {
58   auto mock_profiler = std::make_unique<MockProfilerT>();
59   auto* mock = mock_profiler.get();
60   RootProfiler root;
61   root.AddProfiler(mock_profiler.get());
62 
63   ON_CALL(*mock, BeginEvent(_, _, _, _)).WillByDefault(testing::Return(42));
64 
65   EXPECT_CALL(*mock, BeginEvent(kTag, Profiler::EventType::DEFAULT, 1, 2));
66   EXPECT_CALL(*mock, EndEvent(42, 3, 4));
67   EXPECT_CALL(*mock, AddEvent(kTag, Profiler::EventType::OPERATOR_INVOKE_EVENT,
68                               5, 6, 7));
69   EXPECT_CALL(*mock, AddEventWithData(kTag, Profiler::EventType::DEFAULT, _));
70 
71   // Calls each method sequentially.
72   auto begin = root.BeginEvent(kTag, Profiler::EventType::DEFAULT, 1, 2);
73   root.EndEvent(begin, 3, 4);
74   root.AddEvent(kTag, Profiler::EventType::OPERATOR_INVOKE_EVENT, 5, 6, 7);
75   root.AddEventWithData(kTag, Profiler::EventType::DEFAULT, nullptr);
76 }
77 
TEST(RootProfilerTest,OwnedProfilerTest)78 TEST(RootProfilerTest, OwnedProfilerTest) {
79   auto mock_profiler = std::make_unique<MockProfilerT>();
80   auto* mock = mock_profiler.get();
81   RootProfiler root;
82   root.AddProfiler(std::move(mock_profiler));
83 
84   ON_CALL(*mock, BeginEvent(_, _, _, _)).WillByDefault(testing::Return(42));
85 
86   EXPECT_CALL(*mock, BeginEvent(kTag, Profiler::EventType::DEFAULT, 1, 2));
87   EXPECT_CALL(*mock, EndEvent(42));
88   EXPECT_CALL(*mock, AddEvent(kTag, Profiler::EventType::OPERATOR_INVOKE_EVENT,
89                               3, 4, 5));
90 
91   // Calls each method sequentially.
92   auto begin = root.BeginEvent(kTag, Profiler::EventType::DEFAULT, 1, 2);
93   root.EndEvent(begin);
94   root.AddEvent(kTag, Profiler::EventType::OPERATOR_INVOKE_EVENT, 3, 4, 5);
95 }
96 
TEST(RootProfilerTest,MultipleProfilerTest)97 TEST(RootProfilerTest, MultipleProfilerTest) {
98   auto mock_profiler0 = std::make_unique<MockProfilerT>();
99   auto* mock0 = mock_profiler0.get();
100   auto mock_profiler1 = std::make_unique<MockProfilerT>();
101   auto* mock1 = mock_profiler1.get();
102   RootProfiler root;
103   root.AddProfiler(std::move(mock_profiler0));
104   root.AddProfiler(std::move(mock_profiler1));
105 
106   // Different child profilers might return different event id.
107   ON_CALL(*mock0, BeginEvent(_, _, _, _)).WillByDefault(testing::Return(42));
108   ON_CALL(*mock1, BeginEvent(_, _, _, _)).WillByDefault(testing::Return(24));
109 
110   EXPECT_CALL(*mock0, BeginEvent(kTag, Profiler::EventType::DEFAULT, 1, 2));
111   EXPECT_CALL(*mock0, EndEvent(42));
112   EXPECT_CALL(*mock1, BeginEvent(kTag, Profiler::EventType::DEFAULT, 1, 2));
113   EXPECT_CALL(*mock1, EndEvent(24));
114 
115   // Calls each method sequentially.
116   auto begin = root.BeginEvent(kTag, Profiler::EventType::DEFAULT, 1, 2);
117   root.EndEvent(begin);
118 }
119 
120 }  // namespace
121 }  // namespace profiling
122 }  // namespace tflite
123