1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/profiling/root_profiler.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/core/api/profiler.h"
24
25 using ::testing::_;
26 using ::testing::StrictMock;
27
28 namespace tflite {
29 namespace profiling {
30
31 namespace {
32
33 constexpr char kTag[] = "tag";
34
35 class MockProfiler : public Profiler {
36 public:
37 MOCK_METHOD(uint32_t, BeginEvent,
38 (const char* tag, EventType event_type, int64_t event_metadata1,
39 int64_t event_metadata2),
40 (override));
41 MOCK_METHOD(void, EndEvent, (uint32_t event_handle), (override));
42 MOCK_METHOD(void, EndEvent,
43 (uint32_t event_handle, int64_t event_metadata1,
44 int64_t event_metadata2),
45 (override));
46 MOCK_METHOD(void, AddEvent,
47 (const char* tag, EventType event_type, uint64_t metric,
48 int64_t event_metadata1, int64_t event_metadata2),
49 (override));
50 MOCK_METHOD(void, AddEventWithData,
51 (const char* tag, EventType event_type, const void* data),
52 (override));
53 };
54
55 using MockProfilerT = StrictMock<MockProfiler>;
56
TEST(RootProfilerTest,ChildProfilerTest)57 TEST(RootProfilerTest, ChildProfilerTest) {
58 auto mock_profiler = std::make_unique<MockProfilerT>();
59 auto* mock = mock_profiler.get();
60 RootProfiler root;
61 root.AddProfiler(mock_profiler.get());
62
63 ON_CALL(*mock, BeginEvent(_, _, _, _)).WillByDefault(testing::Return(42));
64
65 EXPECT_CALL(*mock, BeginEvent(kTag, Profiler::EventType::DEFAULT, 1, 2));
66 EXPECT_CALL(*mock, EndEvent(42, 3, 4));
67 EXPECT_CALL(*mock, AddEvent(kTag, Profiler::EventType::OPERATOR_INVOKE_EVENT,
68 5, 6, 7));
69 EXPECT_CALL(*mock, AddEventWithData(kTag, Profiler::EventType::DEFAULT, _));
70
71 // Calls each method sequentially.
72 auto begin = root.BeginEvent(kTag, Profiler::EventType::DEFAULT, 1, 2);
73 root.EndEvent(begin, 3, 4);
74 root.AddEvent(kTag, Profiler::EventType::OPERATOR_INVOKE_EVENT, 5, 6, 7);
75 root.AddEventWithData(kTag, Profiler::EventType::DEFAULT, nullptr);
76 }
77
TEST(RootProfilerTest,OwnedProfilerTest)78 TEST(RootProfilerTest, OwnedProfilerTest) {
79 auto mock_profiler = std::make_unique<MockProfilerT>();
80 auto* mock = mock_profiler.get();
81 RootProfiler root;
82 root.AddProfiler(std::move(mock_profiler));
83
84 ON_CALL(*mock, BeginEvent(_, _, _, _)).WillByDefault(testing::Return(42));
85
86 EXPECT_CALL(*mock, BeginEvent(kTag, Profiler::EventType::DEFAULT, 1, 2));
87 EXPECT_CALL(*mock, EndEvent(42));
88 EXPECT_CALL(*mock, AddEvent(kTag, Profiler::EventType::OPERATOR_INVOKE_EVENT,
89 3, 4, 5));
90
91 // Calls each method sequentially.
92 auto begin = root.BeginEvent(kTag, Profiler::EventType::DEFAULT, 1, 2);
93 root.EndEvent(begin);
94 root.AddEvent(kTag, Profiler::EventType::OPERATOR_INVOKE_EVENT, 3, 4, 5);
95 }
96
TEST(RootProfilerTest,MultipleProfilerTest)97 TEST(RootProfilerTest, MultipleProfilerTest) {
98 auto mock_profiler0 = std::make_unique<MockProfilerT>();
99 auto* mock0 = mock_profiler0.get();
100 auto mock_profiler1 = std::make_unique<MockProfilerT>();
101 auto* mock1 = mock_profiler1.get();
102 RootProfiler root;
103 root.AddProfiler(std::move(mock_profiler0));
104 root.AddProfiler(std::move(mock_profiler1));
105
106 // Different child profilers might return different event id.
107 ON_CALL(*mock0, BeginEvent(_, _, _, _)).WillByDefault(testing::Return(42));
108 ON_CALL(*mock1, BeginEvent(_, _, _, _)).WillByDefault(testing::Return(24));
109
110 EXPECT_CALL(*mock0, BeginEvent(kTag, Profiler::EventType::DEFAULT, 1, 2));
111 EXPECT_CALL(*mock0, EndEvent(42));
112 EXPECT_CALL(*mock1, BeginEvent(kTag, Profiler::EventType::DEFAULT, 1, 2));
113 EXPECT_CALL(*mock1, EndEvent(24));
114
115 // Calls each method sequentially.
116 auto begin = root.BeginEvent(kTag, Profiler::EventType::DEFAULT, 1, 2);
117 root.EndEvent(begin);
118 }
119
120 } // namespace
121 } // namespace profiling
122 } // namespace tflite
123