1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CORE_UTILS_PROFILE_H_
18 #define MINDSPORE_CORE_UTILS_PROFILE_H_
19
20 #include <utility>
21 #include <map>
22 #include <string>
23 #include <vector>
24 #include <fstream>
25 #include <iomanip>
26 #include <sstream>
27 #include "utils/log_adapter.h"
28 #include "utils/compile_config.h"
29
30 namespace mindspore {
EnabledProfilePtr()31 inline bool *EnabledProfilePtr() {
32 static auto enabled_profile = common::GetCompileConfig("COMPILE_PROFILE") == "1";
33 return &enabled_profile;
34 }
35
SetEnabledProfile(bool enabled)36 inline void SetEnabledProfile(bool enabled) {
37 if (EnabledProfilePtr() != nullptr) {
38 *EnabledProfilePtr() = enabled;
39 }
40 }
41
EnabledProfile()42 inline bool EnabledProfile() {
43 #ifdef ENABLE_PROFILE
44 return true;
45 #else
46 return *EnabledProfilePtr();
47 #endif
48 }
49
50 struct TimeInfo;
51 using TimeInfoMap = std::map<std::string, const TimeInfo *>;
52
53 MS_CORE_API double GetTime();
54
55 class ProfileBase;
56
57 struct TimeInfo {
time_TimeInfo58 explicit TimeInfo(double time = -1.0) : time_(time), dict_(nullptr), actionNum_(0) {}
59 TimeInfo(const TimeInfo &) = delete;
60 TimeInfo &operator=(const TimeInfo &) = delete;
61 ~TimeInfo();
62
63 double time_;
64 TimeInfoMap *dict_;
65 size_t actionNum_;
66 };
67
68 // Utility class for Profile.
69 class MS_CORE_API ProfContext {
70 friend class Profile;
71 friend class ProfileBase;
72 friend class ProfTransaction;
73
74 public:
75 ProfContext(const std::string &name, ProfileBase *const prof);
76 ~ProfContext();
77
78 ProfContext(const ProfContext &) = delete;
79 ProfContext &operator=(const ProfContext &) = delete;
80
81 void SetTime(double time) noexcept;
82 void Insert(const std::string &name, const TimeInfo *time) noexcept;
83 bool IsTopContext() const noexcept;
84
85 // Used for Execute break.
set_start_time(double start_time)86 void set_start_time(double start_time) { start_time_ = start_time; }
start_time()87 double start_time() { return start_time_; }
88
89 private:
90 std::string name_;
91 ProfileBase *prof_;
92 ProfContext *parent_;
93 TimeInfo *time_info_;
94 double start_time_{-1.0}; // Used for Execute break.
95 };
96
97 class MS_CORE_API ProfileBase {
98 friend class ProfContext;
99 friend class ProfTransaction;
100
101 public:
102 ProfileBase();
103 virtual ~ProfileBase();
104
Print()105 virtual void Print() {}
Step(const std::string &)106 virtual ProfContext *Step(const std::string &) { return nullptr; }
Lap(int)107 virtual ProfContext *Lap(int) { return nullptr; }
Pop()108 virtual void Pop() {}
109
110 // top level profile context
111 ProfContext context_;
112 // profile context pointer, act as a stack pointer
113 ProfContext *ctx_ptr_ = nullptr;
114 };
115
116 class MS_CORE_API Profile : public ProfileBase {
117 public:
118 Profile() = default;
119 ~Profile() override = default;
120 Profile(const Profile &) = delete;
121 Profile &operator=(const Profile &) = delete;
122
123 void Print() override;
124 ProfContext *Step(const std::string &name) override;
125 ProfContext *Lap(int count) override;
126 void Pop() noexcept override;
127 };
128
129 class MS_CORE_API ProfTransaction {
130 public:
131 explicit ProfTransaction(const ProfileBase *prof);
132 explicit ProfTransaction(ProfContext *const ctx);
133 ProfTransaction(const ProfTransaction &) = delete;
134 ProfTransaction &operator=(const ProfTransaction &) = delete;
135 ~ProfTransaction();
136
137 template <class Function>
Execute(const Function & func)138 void Execute(const Function &func) {
139 if (ctx_ == nullptr) {
140 func();
141 return;
142 }
143
144 double start_time = GetTime();
145 // Set for Execute break.
146 ctx_->set_start_time(start_time);
147 func();
148 double end_time = GetTime();
149 ctx_->SetTime(end_time - start_time);
150 }
151
152 private:
153 ProfContext *ctx_ = nullptr;
154 };
155
156 class NoProfTransaction {
157 public:
NoProfTransaction(ProfileBase *)158 explicit NoProfTransaction(ProfileBase *) {}
NoProfTransaction(ProfContext *)159 explicit NoProfTransaction(ProfContext *) {}
160 ~NoProfTransaction() = default;
161
162 template <class Function>
Execute(const Function & func)163 void Execute(const Function &func) const {
164 func();
165 }
166 };
167
168 class MS_CORE_API DumpTime {
169 public:
170 DumpTime(const DumpTime &) = delete;
171 DumpTime &operator=(const DumpTime &) = delete;
172 ~DumpTime();
173 static DumpTime &GetInstance();
set_file_path(const std::string & save_path)174 void set_file_path(const std::string &save_path) { file_path_ = save_path; }
175 void Record(const std::string &step_name, const double time, const bool is_start);
176 void Save();
177
178 private:
179 DumpTime() = default;
180 std::stringstream file_ss_;
181 std::ofstream file_out_;
182 std::string file_path_ = "./timeline.json";
183 };
184
185 struct TimeStat {
TimeStatTimeStat186 TimeStat() : time_(0.0), count_(0) {}
187 ~TimeStat() = default;
188
189 void operator+=(double t) {
190 time_ += t;
191 count_ += 1;
192 }
193
194 TimeStat operator+(double t) {
195 TimeStat ts = *this;
196 ts += t;
197 return ts;
198 }
199
200 double time_;
201 int count_;
202 };
203
204 class MS_CORE_API MsProfile {
205 public:
206 ~MsProfile();
207
208 static void Reset();
209 static ProfileBase *GetProfile();
210 static void StatTime(const std::string &id, double time);
211 static void Print();
212
213 private:
214 MsProfile() = default;
215
216 static MsProfile &GetSingleton();
217
218 void Clear();
219
220 std::map<std::string, TimeStat> time_stat_; // record time and count info from some activity
221 ProfileBase *profile_ = nullptr; // record hierarchical profile info
222 };
223
224 template <typename Function>
ProfileExecute(ProfileBase * profile,const Function & func)225 void ProfileExecute(ProfileBase *profile, const Function &func) {
226 if (EnabledProfile()) {
227 ProfTransaction(profile).Execute(func);
228 } else {
229 NoProfTransaction(profile).Execute(func);
230 }
231 }
232
ProfileExecuteBreak(ProfileBase * profile)233 inline void ProfileExecuteBreak(ProfileBase *profile) {
234 if (!EnabledProfile()) {
235 return;
236 }
237
238 auto ctx = profile->ctx_ptr_;
239 if (ctx != nullptr && ctx->start_time() != -1.0) {
240 double end_time = GetTime();
241 ctx->SetTime(end_time - ctx->start_time());
242 }
243 }
244
245 template <typename Function>
ProfileExecute(ProfContext * profile_ctx,const Function & func)246 void ProfileExecute(ProfContext *profile_ctx, const Function &func) {
247 if (EnabledProfile()) {
248 ProfTransaction(profile_ctx).Execute(func);
249 } else {
250 NoProfTransaction(profile_ctx).Execute(func);
251 }
252 }
253 class MsProfileStatGuard {
254 public:
MsProfileStatGuard(std::string && state_name)255 explicit MsProfileStatGuard(std::string &&state_name) {
256 if (!EnabledProfile()) {
257 return;
258 }
259 start_ = GetTime();
260 state_name_ = std::move(state_name);
261 }
262
~MsProfileStatGuard()263 ~MsProfileStatGuard() {
264 if (!EnabledProfile()) {
265 return;
266 }
267 if (interrupted_) {
268 return;
269 }
270 auto end = GetTime();
271 MsProfile::StatTime(state_name_, end - start_);
272 }
273
Interrupt()274 void Interrupt() { interrupted_ = true; }
275
276 private:
277 std::string state_name_;
278 double start_;
279 bool interrupted_{false};
280 };
281
282 struct MemoryInfo {
283 std::string name{""};
284 int64_t start_memory{-1};
285 int64_t end_memory{-1};
286 size_t depth{0};
287 };
288
289 class MS_CORE_API ProcessStatus {
290 public:
291 ~ProcessStatus() = default;
292 static ProcessStatus &GetInstance();
293 // Get current process status by a key. Only useful for Linux.
294 int64_t GetMemoryCost(const std::string &key) const;
295 // Start to record memory increase info. It must be used with RecordEnd().
296 // If previous record not end, the next record will have indent when printed.
297 void RecordStart(const std::string &step_name);
298 // End to record memory increase info. It must be used with RecordStart().
299 void RecordEnd();
300 // Print memory increase info which are recorded.
301 void Print();
302 // Clear all records.
303 void Clear();
304
305 private:
306 ProcessStatus() = default;
307 std::vector<MemoryInfo> stack_;
308 std::vector<MemoryInfo> memory_used_;
309 };
310
311 } // namespace mindspore
312 #endif // MINDSPORE_CORE_UTILS_PROFILE_H_
313