• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_PROFILING_MEMORY_USAGE_MONITOR_H_
17 #define TENSORFLOW_LITE_PROFILING_MEMORY_USAGE_MONITOR_H_
18 
19 #include <memory>
20 #include <thread>  // NOLINT(build/c++11)
21 
22 #include "absl/synchronization/notification.h"
23 #include "absl/time/clock.h"
24 #include "absl/time/time.h"
25 #include "tensorflow/lite/profiling/memory_info.h"
26 
27 namespace tflite {
28 namespace profiling {
29 namespace memory {
30 
31 // This class could help to tell the peak memory footprint of a running program.
32 // It achieves this by spawning a thread to check the memory usage periodically
33 // at a pre-defined frequency.
34 class MemoryUsageMonitor {
35  public:
36   // A helper class that does memory usage sampling. This allows injecting an
37   // external dependency for the sake of testing or providing platform-specific
38   // implementations.
39   class Sampler {
40    public:
~Sampler()41     virtual ~Sampler() {}
IsSupported()42     virtual bool IsSupported() { return MemoryUsage::IsSupported(); }
GetMemoryUsage()43     virtual MemoryUsage GetMemoryUsage() {
44       return tflite::profiling::memory::GetMemoryUsage();
45     }
SleepFor(const absl::Duration & duration)46     virtual void SleepFor(const absl::Duration& duration) {
47       absl::SleepFor(duration);
48     }
49   };
50 
51   static constexpr float kInvalidMemUsageMB = -1.0f;
52 
53   explicit MemoryUsageMonitor(int sampling_interval_ms = 50)
MemoryUsageMonitor(sampling_interval_ms,std::unique_ptr<Sampler> (new Sampler ()))54       : MemoryUsageMonitor(sampling_interval_ms,
55                            std::unique_ptr<Sampler>(new Sampler())) {}
56   MemoryUsageMonitor(int sampling_interval_ms,
57                      std::unique_ptr<Sampler> sampler);
~MemoryUsageMonitor()58   ~MemoryUsageMonitor() { StopInternal(); }
59 
60   void Start();
61   void Stop();
62 
63   // For simplicity, we will return kInvalidMemUsageMB for the either following
64   // conditions:
65   // 1. getting memory usage isn't supported on the platform.
66   // 2. the memory usage is being monitored (i.e. we've created the
67   // 'check_memory_thd_'.
GetPeakMemUsageInMB()68   float GetPeakMemUsageInMB() const {
69     if (!is_supported_ || check_memory_thd_ != nullptr) {
70       return kInvalidMemUsageMB;
71     }
72     return peak_max_rss_kb_ / 1024.0;
73   }
74 
75   MemoryUsageMonitor(MemoryUsageMonitor&) = delete;
76   MemoryUsageMonitor& operator=(const MemoryUsageMonitor&) = delete;
77   MemoryUsageMonitor(MemoryUsageMonitor&&) = delete;
78   MemoryUsageMonitor& operator=(const MemoryUsageMonitor&&) = delete;
79 
80  private:
81   void StopInternal();
82 
83   std::unique_ptr<Sampler> sampler_ = nullptr;
84   bool is_supported_ = false;
85   std::unique_ptr<absl::Notification> stop_signal_ = nullptr;
86   absl::Duration sampling_interval_;
87   std::unique_ptr<std::thread> check_memory_thd_ = nullptr;
88   int64_t peak_max_rss_kb_ = static_cast<int64_t>(kInvalidMemUsageMB * 1024);
89 };
90 
91 }  // namespace memory
92 }  // namespace profiling
93 }  // namespace tflite
94 
95 #endif  // TENSORFLOW_LITE_PROFILING_MEMORY_USAGE_MONITOR_H_
96