• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/checkpoint_callback_manager.h"
16 
17 #include <string>
18 #include <utility>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/core/platform/env.h"
24 #include "tensorflow/core/platform/errors.h"
25 #include "tensorflow/core/platform/mutex.h"
26 #include "tensorflow/core/platform/path.h"
27 #include "tensorflow/core/platform/regexp.h"
28 #include "tensorflow/core/platform/status.h"
29 #include "tensorflow/core/platform/statusor.h"
30 #include "tensorflow/core/platform/stringpiece.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 namespace tensorflow {
34 namespace checkpoint {
35 
36 const absl::string_view kCheckpointCallbackManagerResourceName =
37     "checkpoint_callback_manager";
38 
39 namespace {
40 
41 const absl::string_view kCheckpointFileRegex = "^part-[0-9]*-of-[0-9]*$";
42 const absl::string_view kCheckpointTempDirRegex = "-[0-9]*_temp$";
43 const absl::string_view kCheckpointDirRegex = "-[0-9]*$";
44 const absl::string_view kCheckpointTempDirSuffix = "_temp";
45 
TriggerSaveCallbackIfFileNotExist(absl::string_view checkpoint_id,absl::string_view checkpoint_dir,absl::string_view file_extension,SaveCallback callback)46 void TriggerSaveCallbackIfFileNotExist(absl::string_view checkpoint_id,
47                                        absl::string_view checkpoint_dir,
48                                        absl::string_view file_extension,
49                                        SaveCallback callback) {
50   const std::string file_path = io::JoinPath(
51       checkpoint_dir, absl::StrCat(checkpoint_id, ".", file_extension));
52 
53   // If the file already exists, we are done.
54   if (Env::Default()->FileExists(file_path).ok()) {
55     return;
56   }
57   LOG(INFO) << "Calling a save callback: file_extension = " << file_extension
58             << ", checkpoint_id = " << checkpoint_id;
59   // The callback should return a string to store.
60   StatusOr<std::string> save_content = callback(checkpoint_id);
61   if (!save_content.ok()) {
62     LOG(WARNING) << save_content.status();
63     return;
64   }
65 
66   // An empty string means nothing to be saved.
67   if (save_content->empty()) {
68     return;
69   }
70 
71   Status write_status =
72       WriteStringToFile(Env::Default(), file_path, *save_content);
73   if (!write_status.ok()) {
74     LOG(WARNING) << write_status;
75   } else {
76     LOG(INFO) << "A CheckpointCallbackManager has been written to "
77               << file_path;
78   }
79 }
80 
TriggerRestoreCallbackIfFileExists(absl::string_view checkpoint_id,absl::string_view checkpoint_dir,absl::string_view file_extension,RestoreCallback callback)81 void TriggerRestoreCallbackIfFileExists(absl::string_view checkpoint_id,
82                                         absl::string_view checkpoint_dir,
83                                         absl::string_view file_extension,
84                                         RestoreCallback callback) {
85   const std::string file_path = io::JoinPath(
86       checkpoint_dir, absl::StrCat(checkpoint_id, ".", file_extension));
87   if (!Env::Default()->FileExists(file_path).ok()) {
88     return;
89   }
90   std::string payload;
91   Status read_status = ReadFileToString(Env::Default(), file_path, &payload);
92   if (!read_status.ok()) {
93     LOG(WARNING) << "Failed to read: " << read_status;
94     return;
95   }
96 
97   LOG(INFO) << "Calling a restore callback: file_extension = " << file_extension
98             << ", checkpoint_id = " << checkpoint_id;
99   Status callback_status = callback(checkpoint_id, payload);
100   if (!callback_status.ok()) {
101     LOG(WARNING) << callback_status;
102   }
103 }
104 
105 }  // namespace
106 
107 //  Examples:
108 //    "/foo/bar/checkpoint-1_temp/part-00000-of-00001" -->
109 //        ("checkpoint-1", "/foo/bar");
110 //    "/foo/bar/checkpoint-2/part-00000-of-00001" -->
111 //        ("checkpoint-2", "/foo/bar");
112 //    "/foo/bar/checkpoint-3" --> ("checkpoint-3", "/foo/bar");
113 //    "/foo/bar"              --> NotFound error
114 StatusOr<std::pair<std::string, std::string>>
GetCheckpointIdAndPathFromPrefix(absl::string_view prefix)115 CheckpointCallbackManager::GetCheckpointIdAndPathFromPrefix(
116     absl::string_view prefix) {
117   for (absl::string_view path = prefix;; path = io::Dirname(path)) {
118     absl::string_view basename = io::Basename(path);
119 
120     // Failed to find checkpoint_id
121     if (basename.empty()) break;
122 
123     // Skip known checkpoint file: e.g., part-00000-of-00001
124     if (RE2::PartialMatch(basename, kCheckpointFileRegex)) continue;
125 
126     // With _temp suffix: e.g., checkpoint-1_temp
127     if (RE2::PartialMatch(basename, kCheckpointTempDirRegex)) {
128       // Trim suffix, "_temp".
129       return std::make_pair(
130           std::string(basename.substr(
131               0, basename.length() - kCheckpointTempDirSuffix.length())),
132           std::string(io::Dirname(path)));
133     }
134 
135     // Without _temp suffix: e.g., checkpoint-1
136     if (RE2::PartialMatch(basename, kCheckpointDirRegex)) {
137       return std::make_pair(std::string(basename),
138                             std::string(io::Dirname(path)));
139     }
140   }
141   return errors::NotFound(
142       absl::StrCat("Failed to find a checkpoint id. prefix = ", prefix));
143 }
144 
RegisterSaveCallback(absl::string_view file_extension,SaveCallback callback)145 Status CheckpointCallbackManager::RegisterSaveCallback(
146     absl::string_view file_extension, SaveCallback callback) {
147   SaveCallback lazy_callback = nullptr;
148   std::string checkpoint_id;
149   std::string checkpoint_dir;
150   {
151     mutex_lock l(mu_);
152     if (!save_callbacks_.try_emplace(file_extension, std::move(callback))
153              .second) {
154       return errors::AlreadyExists("A callback already exists.");
155     }
156 
157     // If last_saved_checkpoint_id_and_dir_ is not empty,
158     // tries to trigger save callback lazily.
159     if (!last_saved_checkpoint_id_and_dir_.first.empty()) {
160       lazy_callback = save_callbacks_[file_extension];
161       checkpoint_id = last_saved_checkpoint_id_and_dir_.first;
162       checkpoint_dir = last_saved_checkpoint_id_and_dir_.second;
163     }
164   }
165 
166   if (lazy_callback != nullptr) {
167     TriggerSaveCallbackIfFileNotExist(checkpoint_id, checkpoint_dir,
168                                       file_extension, lazy_callback);
169   }
170   return OkStatus();
171 }
172 
DoesSaveCallbackExist(absl::string_view file_extension)173 bool CheckpointCallbackManager::DoesSaveCallbackExist(
174     absl::string_view file_extension) {
175   tf_shared_lock l(mu_);
176   return save_callbacks_.contains(file_extension);
177 }
178 
RegisterRestoreCallback(absl::string_view file_extension,RestoreCallback callback)179 Status CheckpointCallbackManager::RegisterRestoreCallback(
180     absl::string_view file_extension, RestoreCallback callback) {
181   RestoreCallback lazy_callback = nullptr;
182   std::string checkpoint_id;
183   std::string checkpoint_dir;
184   {
185     mutex_lock l(mu_);
186     if (!restore_callbacks_.try_emplace(file_extension, std::move(callback))
187              .second) {
188       return errors::AlreadyExists("A callback already exists.");
189     }
190 
191     // If last_restored_checkpoint_id_and_dir_ is not empty,
192     // tries to trigger restore callback lazily.
193     if (!last_restored_checkpoint_id_and_dir_.first.empty()) {
194       lazy_callback = restore_callbacks_[file_extension];
195       checkpoint_id = last_restored_checkpoint_id_and_dir_.first;
196       checkpoint_dir = last_restored_checkpoint_id_and_dir_.second;
197     }
198   }
199 
200   if (lazy_callback != nullptr) {
201     TriggerRestoreCallbackIfFileExists(checkpoint_id, checkpoint_dir,
202                                        file_extension, lazy_callback);
203   }
204   return OkStatus();
205 }
206 
DoesRestoreCallbackExist(absl::string_view file_extension)207 bool CheckpointCallbackManager::DoesRestoreCallbackExist(
208     absl::string_view file_extension) {
209   tf_shared_lock l(mu_);
210   return restore_callbacks_.contains(file_extension);
211 }
212 
Save(absl::string_view prefix)213 void CheckpointCallbackManager::Save(absl::string_view prefix) {
214   StatusOr<std::pair<std::string, std::string>> id_and_dir =
215       GetCheckpointIdAndPathFromPrefix(prefix);
216   if (!id_and_dir.ok()) {
217     return;
218   }
219 
220   // Create a copy to avoid holding lock while calling a callback.
221   absl::flat_hash_map<std::string, SaveCallback> copy_of_save_callbacks;
222   {
223     mutex_lock l(mu_);
224     last_saved_checkpoint_id_and_dir_ = *id_and_dir;
225     copy_of_save_callbacks = save_callbacks_;
226   }
227 
228   for (const auto& name_and_callback : copy_of_save_callbacks) {
229     TriggerSaveCallbackIfFileNotExist(id_and_dir->first, id_and_dir->second,
230                                       name_and_callback.first,
231                                       name_and_callback.second);
232   }
233 }
234 
Restore(absl::string_view prefix)235 void CheckpointCallbackManager::Restore(absl::string_view prefix) {
236   StatusOr<std::pair<std::string, std::string>> id_and_dir =
237       GetCheckpointIdAndPathFromPrefix(prefix);
238   if (!id_and_dir.ok()) {
239     return;
240   }
241 
242   // Create a copy to avoid holding lock while calling a callback.
243   absl::flat_hash_map<std::string, RestoreCallback> copy_of_restore_callbacks;
244   {
245     mutex_lock l(mu_);
246     last_restored_checkpoint_id_and_dir_ = *id_and_dir;
247     copy_of_restore_callbacks = restore_callbacks_;
248   }
249 
250   for (const auto& name_and_callback : copy_of_restore_callbacks) {
251     TriggerRestoreCallbackIfFileExists(id_and_dir->first, id_and_dir->second,
252                                        name_and_callback.first,
253                                        name_and_callback.second);
254   }
255 }
256 
257 }  // namespace checkpoint
258 }  // namespace tensorflow
259