1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_PROFILING_BUFFERED_PROFILER_H_ 16 #define TENSORFLOW_LITE_PROFILING_BUFFERED_PROFILER_H_ 17 18 #include <cstdint> 19 #include <vector> 20 21 #include "tensorflow/lite/core/api/profiler.h" 22 #include "tensorflow/lite/profiling/profile_buffer.h" 23 24 namespace tflite { 25 namespace profiling { 26 27 // Controls whether profiling is enabled or disabled and collects profiles. 28 // TFLite is used on platforms that don't have posix threads, so the profiler is 29 // kept as simple as possible. It is designed to be used only on a single 30 // thread. 31 // 32 // Profiles are collected using Scoped*Profile objects that begin and end a 33 // profile event. 34 // An example usage is shown in the example below: 35 // 36 // Say Worker class has a DoWork method and we are interested in profiling 37 // the overall execution time for DoWork and time spent in Task1 and Task2 38 // functions. 39 // 40 // class Worker { 41 // public: 42 // void DoWork() { 43 // ScopedProfile(&controller, "DoWork"); 44 // Task1(); 45 // Task2(); 46 // ..... 47 // } 48 // 49 // void Task1() { 50 // ScopedProfile(&controller, "Task1"); 51 // .... 52 // } 53 // 54 // void Task2() { 55 // ScopedProfile(&controller, "Task2"); 56 // } 57 // 58 // Profiler profiler; 59 // } 60 // 61 // We instrument the functions that need to be profiled. 62 // 63 // Profile can be collected by enable profiling and then getting profile 64 // events. 65 // 66 // void ProfileWorker() { 67 // Worker worker; 68 // worker.profiler.EnableProfiling(); 69 // worker.DoWork(); 70 // worker.profiler.DisableProfiling(); 71 // // Profiling is complete, extract profiles. 72 // auto profile_events = worker.profiler.GetProfiles(); 73 // } 74 // 75 // 76 class BufferedProfiler : public tflite::Profiler { 77 public: BufferedProfiler(uint32_t max_num_entries)78 explicit BufferedProfiler(uint32_t max_num_entries) 79 : buffer_(max_num_entries, false), 80 supported_event_types_(~static_cast<uint64_t>( 81 EventType::GENERAL_RUNTIME_INSTRUMENTATION_EVENT)) {} 82 BeginEvent(const char * tag,EventType event_type,int64_t event_metadata1,int64_t event_metadata2)83 uint32_t BeginEvent(const char* tag, EventType event_type, 84 int64_t event_metadata1, 85 int64_t event_metadata2) override { 86 if (!ShouldAddEvent(event_type)) return kInvalidEventHandle; 87 return buffer_.BeginEvent(tag, event_type, event_metadata1, 88 event_metadata2); 89 } 90 EndEvent(uint32_t event_handle)91 void EndEvent(uint32_t event_handle) override { 92 buffer_.EndEvent(event_handle); 93 } 94 EndEvent(uint32_t event_handle,int64_t event_metadata1,int64_t event_metadata2)95 void EndEvent(uint32_t event_handle, int64_t event_metadata1, 96 int64_t event_metadata2) override { 97 buffer_.EndEvent(event_handle, &event_metadata1, &event_metadata2); 98 } 99 AddEvent(const char * tag,EventType event_type,uint64_t start,uint64_t end,int64_t event_metadata1,int64_t event_metadata2)100 void AddEvent(const char* tag, EventType event_type, uint64_t start, 101 uint64_t end, int64_t event_metadata1, 102 int64_t event_metadata2) override { 103 if (!ShouldAddEvent(event_type)) return; 104 buffer_.AddEvent(tag, event_type, start, end, event_metadata1, 105 event_metadata2); 106 } 107 StartProfiling()108 void StartProfiling() { buffer_.SetEnabled(true); } StopProfiling()109 void StopProfiling() { buffer_.SetEnabled(false); } Reset()110 void Reset() { buffer_.Reset(); } GetProfileEvents()111 std::vector<const ProfileEvent*> GetProfileEvents() { 112 std::vector<const ProfileEvent*> profile_events; 113 profile_events.reserve(buffer_.Size()); 114 for (size_t i = 0; i < buffer_.Size(); i++) { 115 profile_events.push_back(buffer_.At(i)); 116 } 117 return profile_events; 118 } 119 120 protected: ShouldAddEvent(EventType event_type)121 bool ShouldAddEvent(EventType event_type) { 122 return (static_cast<uint64_t>(event_type) & supported_event_types_) != 0; 123 } 124 125 private: GetProfileBuffer()126 ProfileBuffer* GetProfileBuffer() { return &buffer_; } 127 ProfileBuffer buffer_; 128 const uint64_t supported_event_types_; 129 }; 130 131 } // namespace profiling 132 } // namespace tflite 133 134 #endif // TENSORFLOW_LITE_PROFILING_BUFFERED_PROFILER_H_ 135