1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/checkpoint_callback_manager.h"
16
17 #include <string>
18 #include <utility>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/core/platform/env.h"
24 #include "tensorflow/core/platform/errors.h"
25 #include "tensorflow/core/platform/mutex.h"
26 #include "tensorflow/core/platform/path.h"
27 #include "tensorflow/core/platform/regexp.h"
28 #include "tensorflow/core/platform/status.h"
29 #include "tensorflow/core/platform/statusor.h"
30 #include "tensorflow/core/platform/stringpiece.h"
31 #include "tensorflow/core/platform/types.h"
32
33 namespace tensorflow {
34 namespace checkpoint {
35
36 const absl::string_view kCheckpointCallbackManagerResourceName =
37 "checkpoint_callback_manager";
38
39 namespace {
40
41 const absl::string_view kCheckpointFileRegex = "^part-[0-9]*-of-[0-9]*$";
42 const absl::string_view kCheckpointTempDirRegex = "-[0-9]*_temp$";
43 const absl::string_view kCheckpointDirRegex = "-[0-9]*$";
44 const absl::string_view kCheckpointTempDirSuffix = "_temp";
45
TriggerSaveCallbackIfFileNotExist(absl::string_view checkpoint_id,absl::string_view checkpoint_dir,absl::string_view file_extension,SaveCallback callback)46 void TriggerSaveCallbackIfFileNotExist(absl::string_view checkpoint_id,
47 absl::string_view checkpoint_dir,
48 absl::string_view file_extension,
49 SaveCallback callback) {
50 const std::string file_path = io::JoinPath(
51 checkpoint_dir, absl::StrCat(checkpoint_id, ".", file_extension));
52
53 // If the file already exists, we are done.
54 if (Env::Default()->FileExists(file_path).ok()) {
55 return;
56 }
57 LOG(INFO) << "Calling a save callback: file_extension = " << file_extension
58 << ", checkpoint_id = " << checkpoint_id;
59 // The callback should return a string to store.
60 StatusOr<std::string> save_content = callback(checkpoint_id);
61 if (!save_content.ok()) {
62 LOG(WARNING) << save_content.status();
63 return;
64 }
65
66 // An empty string means nothing to be saved.
67 if (save_content->empty()) {
68 return;
69 }
70
71 Status write_status =
72 WriteStringToFile(Env::Default(), file_path, *save_content);
73 if (!write_status.ok()) {
74 LOG(WARNING) << write_status;
75 } else {
76 LOG(INFO) << "A CheckpointCallbackManager has been written to "
77 << file_path;
78 }
79 }
80
TriggerRestoreCallbackIfFileExists(absl::string_view checkpoint_id,absl::string_view checkpoint_dir,absl::string_view file_extension,RestoreCallback callback)81 void TriggerRestoreCallbackIfFileExists(absl::string_view checkpoint_id,
82 absl::string_view checkpoint_dir,
83 absl::string_view file_extension,
84 RestoreCallback callback) {
85 const std::string file_path = io::JoinPath(
86 checkpoint_dir, absl::StrCat(checkpoint_id, ".", file_extension));
87 if (!Env::Default()->FileExists(file_path).ok()) {
88 return;
89 }
90 std::string payload;
91 Status read_status = ReadFileToString(Env::Default(), file_path, &payload);
92 if (!read_status.ok()) {
93 LOG(WARNING) << "Failed to read: " << read_status;
94 return;
95 }
96
97 LOG(INFO) << "Calling a restore callback: file_extension = " << file_extension
98 << ", checkpoint_id = " << checkpoint_id;
99 Status callback_status = callback(checkpoint_id, payload);
100 if (!callback_status.ok()) {
101 LOG(WARNING) << callback_status;
102 }
103 }
104
105 } // namespace
106
107 // Examples:
108 // "/foo/bar/checkpoint-1_temp/part-00000-of-00001" -->
109 // ("checkpoint-1", "/foo/bar");
110 // "/foo/bar/checkpoint-2/part-00000-of-00001" -->
111 // ("checkpoint-2", "/foo/bar");
112 // "/foo/bar/checkpoint-3" --> ("checkpoint-3", "/foo/bar");
113 // "/foo/bar" --> NotFound error
114 StatusOr<std::pair<std::string, std::string>>
GetCheckpointIdAndPathFromPrefix(absl::string_view prefix)115 CheckpointCallbackManager::GetCheckpointIdAndPathFromPrefix(
116 absl::string_view prefix) {
117 for (absl::string_view path = prefix;; path = io::Dirname(path)) {
118 absl::string_view basename = io::Basename(path);
119
120 // Failed to find checkpoint_id
121 if (basename.empty()) break;
122
123 // Skip known checkpoint file: e.g., part-00000-of-00001
124 if (RE2::PartialMatch(basename, kCheckpointFileRegex)) continue;
125
126 // With _temp suffix: e.g., checkpoint-1_temp
127 if (RE2::PartialMatch(basename, kCheckpointTempDirRegex)) {
128 // Trim suffix, "_temp".
129 return std::make_pair(
130 std::string(basename.substr(
131 0, basename.length() - kCheckpointTempDirSuffix.length())),
132 std::string(io::Dirname(path)));
133 }
134
135 // Without _temp suffix: e.g., checkpoint-1
136 if (RE2::PartialMatch(basename, kCheckpointDirRegex)) {
137 return std::make_pair(std::string(basename),
138 std::string(io::Dirname(path)));
139 }
140 }
141 return errors::NotFound(
142 absl::StrCat("Failed to find a checkpoint id. prefix = ", prefix));
143 }
144
RegisterSaveCallback(absl::string_view file_extension,SaveCallback callback)145 Status CheckpointCallbackManager::RegisterSaveCallback(
146 absl::string_view file_extension, SaveCallback callback) {
147 SaveCallback lazy_callback = nullptr;
148 std::string checkpoint_id;
149 std::string checkpoint_dir;
150 {
151 mutex_lock l(mu_);
152 if (!save_callbacks_.try_emplace(file_extension, std::move(callback))
153 .second) {
154 return errors::AlreadyExists("A callback already exists.");
155 }
156
157 // If last_saved_checkpoint_id_and_dir_ is not empty,
158 // tries to trigger save callback lazily.
159 if (!last_saved_checkpoint_id_and_dir_.first.empty()) {
160 lazy_callback = save_callbacks_[file_extension];
161 checkpoint_id = last_saved_checkpoint_id_and_dir_.first;
162 checkpoint_dir = last_saved_checkpoint_id_and_dir_.second;
163 }
164 }
165
166 if (lazy_callback != nullptr) {
167 TriggerSaveCallbackIfFileNotExist(checkpoint_id, checkpoint_dir,
168 file_extension, lazy_callback);
169 }
170 return OkStatus();
171 }
172
DoesSaveCallbackExist(absl::string_view file_extension)173 bool CheckpointCallbackManager::DoesSaveCallbackExist(
174 absl::string_view file_extension) {
175 tf_shared_lock l(mu_);
176 return save_callbacks_.contains(file_extension);
177 }
178
RegisterRestoreCallback(absl::string_view file_extension,RestoreCallback callback)179 Status CheckpointCallbackManager::RegisterRestoreCallback(
180 absl::string_view file_extension, RestoreCallback callback) {
181 RestoreCallback lazy_callback = nullptr;
182 std::string checkpoint_id;
183 std::string checkpoint_dir;
184 {
185 mutex_lock l(mu_);
186 if (!restore_callbacks_.try_emplace(file_extension, std::move(callback))
187 .second) {
188 return errors::AlreadyExists("A callback already exists.");
189 }
190
191 // If last_restored_checkpoint_id_and_dir_ is not empty,
192 // tries to trigger restore callback lazily.
193 if (!last_restored_checkpoint_id_and_dir_.first.empty()) {
194 lazy_callback = restore_callbacks_[file_extension];
195 checkpoint_id = last_restored_checkpoint_id_and_dir_.first;
196 checkpoint_dir = last_restored_checkpoint_id_and_dir_.second;
197 }
198 }
199
200 if (lazy_callback != nullptr) {
201 TriggerRestoreCallbackIfFileExists(checkpoint_id, checkpoint_dir,
202 file_extension, lazy_callback);
203 }
204 return OkStatus();
205 }
206
DoesRestoreCallbackExist(absl::string_view file_extension)207 bool CheckpointCallbackManager::DoesRestoreCallbackExist(
208 absl::string_view file_extension) {
209 tf_shared_lock l(mu_);
210 return restore_callbacks_.contains(file_extension);
211 }
212
Save(absl::string_view prefix)213 void CheckpointCallbackManager::Save(absl::string_view prefix) {
214 StatusOr<std::pair<std::string, std::string>> id_and_dir =
215 GetCheckpointIdAndPathFromPrefix(prefix);
216 if (!id_and_dir.ok()) {
217 return;
218 }
219
220 // Create a copy to avoid holding lock while calling a callback.
221 absl::flat_hash_map<std::string, SaveCallback> copy_of_save_callbacks;
222 {
223 mutex_lock l(mu_);
224 last_saved_checkpoint_id_and_dir_ = *id_and_dir;
225 copy_of_save_callbacks = save_callbacks_;
226 }
227
228 for (const auto& name_and_callback : copy_of_save_callbacks) {
229 TriggerSaveCallbackIfFileNotExist(id_and_dir->first, id_and_dir->second,
230 name_and_callback.first,
231 name_and_callback.second);
232 }
233 }
234
Restore(absl::string_view prefix)235 void CheckpointCallbackManager::Restore(absl::string_view prefix) {
236 StatusOr<std::pair<std::string, std::string>> id_and_dir =
237 GetCheckpointIdAndPathFromPrefix(prefix);
238 if (!id_and_dir.ok()) {
239 return;
240 }
241
242 // Create a copy to avoid holding lock while calling a callback.
243 absl::flat_hash_map<std::string, RestoreCallback> copy_of_restore_callbacks;
244 {
245 mutex_lock l(mu_);
246 last_restored_checkpoint_id_and_dir_ = *id_and_dir;
247 copy_of_restore_callbacks = restore_callbacks_;
248 }
249
250 for (const auto& name_and_callback : copy_of_restore_callbacks) {
251 TriggerRestoreCallbackIfFileExists(id_and_dir->first, id_and_dir->second,
252 name_and_callback.first,
253 name_and_callback.second);
254 }
255 }
256
257 } // namespace checkpoint
258 } // namespace tensorflow
259