1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/platform/cloud/gcs_file_system.h"
17 #include <stdio.h>
18 #include <unistd.h>
19 #include <algorithm>
20 #include <cstdio>
21 #include <cstdlib>
22 #include <cstring>
23 #include <fstream>
24 #include <vector>
25 #ifdef _WIN32
26 #include <io.h> // for _mktemp
27 #endif
28 #include "absl/base/macros.h"
29 #include "include/json/json.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/map_util.h"
32 #include "tensorflow/core/lib/gtl/stl_util.h"
33 #include "tensorflow/core/lib/io/path.h"
34 #include "tensorflow/core/lib/strings/numbers.h"
35 #include "tensorflow/core/lib/strings/str_util.h"
36 #include "tensorflow/core/lib/strings/stringprintf.h"
37 #include "tensorflow/core/platform/cloud/curl_http_request.h"
38 #include "tensorflow/core/platform/cloud/file_block_cache.h"
39 #include "tensorflow/core/platform/cloud/google_auth_provider.h"
40 #include "tensorflow/core/platform/cloud/ram_file_block_cache.h"
41 #include "tensorflow/core/platform/cloud/retrying_utils.h"
42 #include "tensorflow/core/platform/cloud/time_util.h"
43 #include "tensorflow/core/platform/env.h"
44 #include "tensorflow/core/platform/mutex.h"
45 #include "tensorflow/core/platform/protobuf.h"
46 #include "tensorflow/core/platform/thread_annotations.h"
47
48 #ifdef _WIN32
49 #ifdef DeleteFile
50 #undef DeleteFile
51 #endif
52 #endif
53
54 namespace tensorflow {
55 namespace {
56
57 constexpr char kGcsUriBase[] = "https://www.googleapis.com/storage/v1/";
58 constexpr char kGcsUploadUriBase[] =
59 "https://www.googleapis.com/upload/storage/v1/";
60 constexpr char kStorageHost[] = "storage.googleapis.com";
61 constexpr char kBucketMetadataLocationKey[] = "location";
62 constexpr size_t kReadAppendableFileBufferSize = 1024 * 1024; // In bytes.
63 constexpr int kGetChildrenDefaultPageSize = 1000;
64 // The HTTP response code "308 Resume Incomplete".
65 constexpr uint64 HTTP_CODE_RESUME_INCOMPLETE = 308;
66 // The environment variable that overrides the size of the readahead buffer.
67 ABSL_DEPRECATED("Use GCS_READ_CACHE_BLOCK_SIZE_MB instead.")
68 constexpr char kReadaheadBufferSize[] = "GCS_READAHEAD_BUFFER_SIZE_BYTES";
69 // The environment variable that disables the GCS block cache for reads.
70 // This is the explicit alternative to setting BLOCK_SIZE or MAX_SIZE to 0, and
71 // takes precedence over either of those environment variables.
72 constexpr char kReadCacheDisabled[] = "GCS_READ_CACHE_DISABLED";
73 // The environment variable that overrides the block size for aligned reads from
74 // GCS. Specified in MB (e.g. "16" = 16 x 1024 x 1024 = 16777216 bytes).
75 constexpr char kBlockSize[] = "GCS_READ_CACHE_BLOCK_SIZE_MB";
76 constexpr size_t kDefaultBlockSize = 16 * 1024 * 1024;
77 // The environment variable that overrides the max size of the LRU cache of
78 // blocks read from GCS. Specified in MB.
79 constexpr char kMaxCacheSize[] = "GCS_READ_CACHE_MAX_SIZE_MB";
80 constexpr size_t kDefaultMaxCacheSize = kDefaultBlockSize;
81 // The environment variable that overrides the maximum staleness of cached file
82 // contents. Once any block of a file reaches this staleness, all cached blocks
83 // will be evicted on the next read.
84 constexpr char kMaxStaleness[] = "GCS_READ_CACHE_MAX_STALENESS";
85 constexpr uint64 kDefaultMaxStaleness = 0;
86 // The environment variable that overrides the maximum age of entries in the
87 // Stat cache. A value of 0 (the default) means nothing is cached.
88 constexpr char kStatCacheMaxAge[] = "GCS_STAT_CACHE_MAX_AGE";
89 constexpr uint64 kStatCacheDefaultMaxAge = 5;
90 // The environment variable that overrides the maximum number of entries in the
91 // Stat cache.
92 constexpr char kStatCacheMaxEntries[] = "GCS_STAT_CACHE_MAX_ENTRIES";
93 constexpr size_t kStatCacheDefaultMaxEntries = 1024;
94 // The environment variable that overrides the maximum age of entries in the
95 // GetMatchingPaths cache. A value of 0 (the default) means nothing is cached.
96 constexpr char kMatchingPathsCacheMaxAge[] = "GCS_MATCHING_PATHS_CACHE_MAX_AGE";
97 constexpr uint64 kMatchingPathsCacheDefaultMaxAge = 0;
98 // The environment variable that overrides the maximum number of entries in the
99 // GetMatchingPaths cache.
100 constexpr char kMatchingPathsCacheMaxEntries[] =
101 "GCS_MATCHING_PATHS_CACHE_MAX_ENTRIES";
102 constexpr size_t kMatchingPathsCacheDefaultMaxEntries = 1024;
103 // Number of bucket locations cached, most workloads wont touch more than one
104 // bucket so this limit is set fairly low
105 constexpr size_t kBucketLocationCacheMaxEntries = 10;
106 // ExpiringLRUCache doesnt support any "cache forever" option
107 constexpr size_t kCacheNeverExpire = std::numeric_limits<uint64>::max();
108 // The file statistics returned by Stat() for directories.
109 const FileStatistics DIRECTORY_STAT(0, 0, true);
110 // Some environments exhibit unreliable DNS resolution. Set this environment
111 // variable to a positive integer describing the frequency used to refresh the
112 // userspace DNS cache.
113 constexpr char kResolveCacheSecs[] = "GCS_RESOLVE_REFRESH_SECS";
114 // The environment variable to configure the http request's connection timeout.
115 constexpr char kRequestConnectionTimeout[] =
116 "GCS_REQUEST_CONNECTION_TIMEOUT_SECS";
117 // The environment variable to configure the http request's idle timeout.
118 constexpr char kRequestIdleTimeout[] = "GCS_REQUEST_IDLE_TIMEOUT_SECS";
119 // The environment variable to configure the overall request timeout for
120 // metadata requests.
121 constexpr char kMetadataRequestTimeout[] = "GCS_METADATA_REQUEST_TIMEOUT_SECS";
122 // The environment variable to configure the overall request timeout for
123 // block reads requests.
124 constexpr char kReadRequestTimeout[] = "GCS_READ_REQUEST_TIMEOUT_SECS";
125 // The environment variable to configure the overall request timeout for
126 // upload requests.
127 constexpr char kWriteRequestTimeout[] = "GCS_WRITE_REQUEST_TIMEOUT_SECS";
128 // The environment variable to configure an additional header to send with
129 // all requests to GCS (format HEADERNAME:HEADERCONTENT)
130 constexpr char kAdditionalRequestHeader[] = "GCS_ADDITIONAL_REQUEST_HEADER";
131 // The environment variable to configure the throttle (format: <int64>)
132 constexpr char kThrottleRate[] = "GCS_THROTTLE_TOKEN_RATE";
133 // The environment variable to configure the token bucket size (format: <int64>)
134 constexpr char kThrottleBucket[] = "GCS_THROTTLE_BUCKET_SIZE";
135 // The environment variable that controls the number of tokens per request.
136 // (format: <int64>)
137 constexpr char kTokensPerRequest[] = "GCS_TOKENS_PER_REQUEST";
138 // The environment variable to configure the initial tokens (format: <int64>)
139 constexpr char kInitialTokens[] = "GCS_INITIAL_TOKENS";
140
141 // The environment variable to customize which GCS bucket locations are allowed,
142 // if the list is empty defaults to using the region of the zone (format, comma
143 // delimited list). Requires 'storage.buckets.get' permission.
144 constexpr char kAllowedBucketLocations[] = "GCS_ALLOWED_BUCKET_LOCATIONS";
145 // When this value is passed as an allowed location detects the zone tensorflow
146 // is running in and restricts to buckets in that region.
147 constexpr char kDetectZoneSentinalValue[] = "auto";
148
149 // TODO: DO NOT use a hardcoded path
GetTmpFilename(string * filename)150 Status GetTmpFilename(string* filename) {
151 #ifndef _WIN32
152 char buffer[] = "/tmp/gcs_filesystem_XXXXXX";
153 int fd = mkstemp(buffer);
154 if (fd < 0) {
155 return errors::Internal("Failed to create a temporary file.");
156 }
157 close(fd);
158 #else
159 char buffer[] = "/tmp/gcs_filesystem_XXXXXX";
160 char* ret = _mktemp(buffer);
161 if (ret == nullptr) {
162 return errors::Internal("Failed to create a temporary file.");
163 }
164 #endif
165 *filename = buffer;
166 return Status::OK();
167 }
168
169 /// \brief Splits a GCS path to a bucket and an object.
170 ///
171 /// For example, "gs://bucket-name/path/to/file.txt" gets split into
172 /// "bucket-name" and "path/to/file.txt".
173 /// If fname only contains the bucket and empty_object_ok = true, the returned
174 /// object is empty.
ParseGcsPath(StringPiece fname,bool empty_object_ok,string * bucket,string * object)175 Status ParseGcsPath(StringPiece fname, bool empty_object_ok, string* bucket,
176 string* object) {
177 StringPiece scheme, bucketp, objectp;
178 io::ParseURI(fname, &scheme, &bucketp, &objectp);
179 if (scheme != "gs") {
180 return errors::InvalidArgument("GCS path doesn't start with 'gs://': ",
181 fname);
182 }
183 *bucket = string(bucketp);
184 if (bucket->empty() || *bucket == ".") {
185 return errors::InvalidArgument("GCS path doesn't contain a bucket name: ",
186 fname);
187 }
188 str_util::ConsumePrefix(&objectp, "/");
189 *object = string(objectp);
190 if (!empty_object_ok && object->empty()) {
191 return errors::InvalidArgument("GCS path doesn't contain an object name: ",
192 fname);
193 }
194 return Status::OK();
195 }
196
197 /// Appends a trailing slash if the name doesn't already have one.
MaybeAppendSlash(const string & name)198 string MaybeAppendSlash(const string& name) {
199 if (name.empty()) {
200 return "/";
201 }
202 if (name.back() != '/') {
203 return strings::StrCat(name, "/");
204 }
205 return name;
206 }
207
208 // io::JoinPath() doesn't work in cases when we want an empty subpath
209 // to result in an appended slash in order for directory markers
210 // to be processed correctly: "gs://a/b" + "" should give "gs://a/b/".
JoinGcsPath(const string & path,const string & subpath)211 string JoinGcsPath(const string& path, const string& subpath) {
212 return strings::StrCat(MaybeAppendSlash(path), subpath);
213 }
214
215 /// \brief Returns the given paths appending all their subfolders.
216 ///
217 /// For every path X in the list, every subfolder in X is added to the
218 /// resulting list.
219 /// For example:
220 /// - for 'a/b/c/d' it will append 'a', 'a/b' and 'a/b/c'
221 /// - for 'a/b/c/' it will append 'a', 'a/b' and 'a/b/c'
AddAllSubpaths(const std::vector<string> & paths)222 std::set<string> AddAllSubpaths(const std::vector<string>& paths) {
223 std::set<string> result;
224 result.insert(paths.begin(), paths.end());
225 for (const string& path : paths) {
226 StringPiece subpath = io::Dirname(path);
227 while (!subpath.empty()) {
228 result.emplace(string(subpath));
229 subpath = io::Dirname(subpath);
230 }
231 }
232 return result;
233 }
234
ParseJson(StringPiece json,Json::Value * result)235 Status ParseJson(StringPiece json, Json::Value* result) {
236 Json::Reader reader;
237 if (!reader.parse(json.data(), json.data() + json.size(), *result)) {
238 return errors::Internal("Couldn't parse JSON response from GCS.");
239 }
240 return Status::OK();
241 }
242
ParseJson(const std::vector<char> & json,Json::Value * result)243 Status ParseJson(const std::vector<char>& json, Json::Value* result) {
244 return ParseJson(StringPiece{json.data(), json.size()}, result);
245 }
246
247 /// Reads a JSON value with the given name from a parent JSON value.
GetValue(const Json::Value & parent,const char * name,Json::Value * result)248 Status GetValue(const Json::Value& parent, const char* name,
249 Json::Value* result) {
250 *result = parent.get(name, Json::Value::null);
251 if (result->isNull()) {
252 return errors::Internal("The field '", name,
253 "' was expected in the JSON response.");
254 }
255 return Status::OK();
256 }
257
258 /// Reads a string JSON value with the given name from a parent JSON value.
GetStringValue(const Json::Value & parent,const char * name,string * result)259 Status GetStringValue(const Json::Value& parent, const char* name,
260 string* result) {
261 Json::Value result_value;
262 TF_RETURN_IF_ERROR(GetValue(parent, name, &result_value));
263 if (!result_value.isString()) {
264 return errors::Internal(
265 "The field '", name,
266 "' in the JSON response was expected to be a string.");
267 }
268 *result = result_value.asString();
269 return Status::OK();
270 }
271
272 /// Reads a long JSON value with the given name from a parent JSON value.
GetInt64Value(const Json::Value & parent,const char * name,int64 * result)273 Status GetInt64Value(const Json::Value& parent, const char* name,
274 int64* result) {
275 Json::Value result_value;
276 TF_RETURN_IF_ERROR(GetValue(parent, name, &result_value));
277 if (result_value.isNumeric()) {
278 *result = result_value.asInt64();
279 return Status::OK();
280 }
281 if (result_value.isString() &&
282 strings::safe_strto64(result_value.asCString(), result)) {
283 return Status::OK();
284 }
285 return errors::Internal(
286 "The field '", name,
287 "' in the JSON response was expected to be a number.");
288 }
289
290 /// Reads a boolean JSON value with the given name from a parent JSON value.
GetBoolValue(const Json::Value & parent,const char * name,bool * result)291 Status GetBoolValue(const Json::Value& parent, const char* name, bool* result) {
292 Json::Value result_value;
293 TF_RETURN_IF_ERROR(GetValue(parent, name, &result_value));
294 if (!result_value.isBool()) {
295 return errors::Internal(
296 "The field '", name,
297 "' in the JSON response was expected to be a boolean.");
298 }
299 *result = result_value.asBool();
300 return Status::OK();
301 }
302
303 /// A GCS-based implementation of a random access file with an LRU block cache.
304 class GcsRandomAccessFile : public RandomAccessFile {
305 public:
306 using ReadFn =
307 std::function<Status(const string& filename, uint64 offset, size_t n,
308 StringPiece* result, char* scratch)>;
309
GcsRandomAccessFile(const string & filename,ReadFn read_fn)310 GcsRandomAccessFile(const string& filename, ReadFn read_fn)
311 : filename_(filename), read_fn_(std::move(read_fn)) {}
312
Name(StringPiece * result) const313 Status Name(StringPiece* result) const override {
314 *result = filename_;
315 return Status::OK();
316 }
317
318 /// The implementation of reads with an LRU block cache. Thread safe.
Read(uint64 offset,size_t n,StringPiece * result,char * scratch) const319 Status Read(uint64 offset, size_t n, StringPiece* result,
320 char* scratch) const override {
321 return read_fn_(filename_, offset, n, result, scratch);
322 }
323
324 private:
325 /// The filename of this file.
326 const string filename_;
327 /// The implementation of the read operation (provided by the GCSFileSystem).
328 const ReadFn read_fn_;
329 };
330
331 /// \brief GCS-based implementation of a writeable file.
332 ///
333 /// Since GCS objects are immutable, this implementation writes to a local
334 /// tmp file and copies it to GCS on flush/close.
335 class GcsWritableFile : public WritableFile {
336 public:
GcsWritableFile(const string & bucket,const string & object,GcsFileSystem * filesystem,GcsFileSystem::TimeoutConfig * timeouts,std::function<void ()> file_cache_erase,RetryConfig retry_config)337 GcsWritableFile(const string& bucket, const string& object,
338 GcsFileSystem* filesystem,
339 GcsFileSystem::TimeoutConfig* timeouts,
340 std::function<void()> file_cache_erase,
341 RetryConfig retry_config)
342 : bucket_(bucket),
343 object_(object),
344 filesystem_(filesystem),
345 timeouts_(timeouts),
346 file_cache_erase_(std::move(file_cache_erase)),
347 sync_needed_(true),
348 retry_config_(retry_config) {
349 // TODO: to make it safer, outfile_ should be constructed from an FD
350 if (GetTmpFilename(&tmp_content_filename_).ok()) {
351 outfile_.open(tmp_content_filename_,
352 std::ofstream::binary | std::ofstream::app);
353 }
354 }
355
356 /// \brief Constructs the writable file in append mode.
357 ///
358 /// tmp_content_filename should contain a path of an existing temporary file
359 /// with the content to be appended. The class takes onwnership of the
360 /// specified tmp file and deletes it on close.
GcsWritableFile(const string & bucket,const string & object,GcsFileSystem * filesystem,const string & tmp_content_filename,GcsFileSystem::TimeoutConfig * timeouts,std::function<void ()> file_cache_erase,RetryConfig retry_config)361 GcsWritableFile(const string& bucket, const string& object,
362 GcsFileSystem* filesystem, const string& tmp_content_filename,
363 GcsFileSystem::TimeoutConfig* timeouts,
364 std::function<void()> file_cache_erase,
365 RetryConfig retry_config)
366 : bucket_(bucket),
367 object_(object),
368 filesystem_(filesystem),
369 timeouts_(timeouts),
370 file_cache_erase_(std::move(file_cache_erase)),
371 sync_needed_(true),
372 retry_config_(retry_config) {
373 tmp_content_filename_ = tmp_content_filename;
374 outfile_.open(tmp_content_filename_,
375 std::ofstream::binary | std::ofstream::app);
376 }
377
~GcsWritableFile()378 ~GcsWritableFile() override { Close().IgnoreError(); }
379
Append(StringPiece data)380 Status Append(StringPiece data) override {
381 TF_RETURN_IF_ERROR(CheckWritable());
382 sync_needed_ = true;
383 outfile_ << data;
384 if (!outfile_.good()) {
385 return errors::Internal(
386 "Could not append to the internal temporary file.");
387 }
388 return Status::OK();
389 }
390
Close()391 Status Close() override {
392 if (outfile_.is_open()) {
393 TF_RETURN_IF_ERROR(Sync());
394 outfile_.close();
395 std::remove(tmp_content_filename_.c_str());
396 }
397 return Status::OK();
398 }
399
Flush()400 Status Flush() override { return Sync(); }
401
Name(StringPiece * result) const402 Status Name(StringPiece* result) const override {
403 return errors::Unimplemented("GCSWritableFile does not support Name()");
404 }
405
Sync()406 Status Sync() override {
407 TF_RETURN_IF_ERROR(CheckWritable());
408 if (!sync_needed_) {
409 return Status::OK();
410 }
411 Status status = SyncImpl();
412 if (status.ok()) {
413 sync_needed_ = false;
414 }
415 return status;
416 }
417
Tell(int64 * position)418 Status Tell(int64* position) override {
419 *position = outfile_.tellp();
420 if (*position == -1) {
421 return errors::Internal("tellp on the internal temporary file failed");
422 }
423 return Status::OK();
424 }
425
426 private:
427 /// Copies the current version of the file to GCS.
428 ///
429 /// This SyncImpl() uploads the object to GCS.
430 /// In case of a failure, it resumes failed uploads as recommended by the GCS
431 /// resumable API documentation. When the whole upload needs to be
432 /// restarted, Sync() returns UNAVAILABLE and relies on RetryingFileSystem.
SyncImpl()433 Status SyncImpl() {
434 outfile_.flush();
435 if (!outfile_.good()) {
436 return errors::Internal(
437 "Could not write to the internal temporary file.");
438 }
439 string session_uri;
440 TF_RETURN_IF_ERROR(CreateNewUploadSession(&session_uri));
441 uint64 already_uploaded = 0;
442 bool first_attempt = true;
443 const Status upload_status = RetryingUtils::CallWithRetries(
444 [&first_attempt, &already_uploaded, &session_uri, this]() {
445 if (!first_attempt) {
446 bool completed;
447 TF_RETURN_IF_ERROR(RequestUploadSessionStatus(
448 session_uri, &completed, &already_uploaded));
449 if (completed) {
450 // Erase the file from the file cache on every successful write.
451 file_cache_erase_();
452 // It's unclear why UploadToSession didn't return OK in the
453 // previous attempt, but GCS reports that the file is fully
454 // uploaded, so succeed.
455 return Status::OK();
456 }
457 }
458 first_attempt = false;
459 return UploadToSession(session_uri, already_uploaded);
460 },
461 retry_config_);
462 if (upload_status.code() == errors::Code::NOT_FOUND) {
463 // GCS docs recommend retrying the whole upload. We're relying on the
464 // RetryingFileSystem to retry the Sync() call.
465 return errors::Unavailable(
466 strings::StrCat("Upload to gs://", bucket_, "/", object_,
467 " failed, caused by: ", upload_status.ToString()));
468 }
469 return upload_status;
470 }
471
CheckWritable() const472 Status CheckWritable() const {
473 if (!outfile_.is_open()) {
474 return errors::FailedPrecondition(
475 "The internal temporary file is not writable.");
476 }
477 return Status::OK();
478 }
479
GetCurrentFileSize(uint64 * size)480 Status GetCurrentFileSize(uint64* size) {
481 const auto tellp = outfile_.tellp();
482 if (tellp == static_cast<std::streampos>(-1)) {
483 return errors::Internal(
484 "Could not get the size of the internal temporary file.");
485 }
486 *size = tellp;
487 return Status::OK();
488 }
489
490 /// Initiates a new resumable upload session.
CreateNewUploadSession(string * session_uri)491 Status CreateNewUploadSession(string* session_uri) {
492 uint64 file_size;
493 TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size));
494
495 std::vector<char> output_buffer;
496 std::unique_ptr<HttpRequest> request;
497 TF_RETURN_IF_ERROR(filesystem_->CreateHttpRequest(&request));
498
499 request->SetUri(strings::StrCat(
500 kGcsUploadUriBase, "b/", bucket_,
501 "/o?uploadType=resumable&name=", request->EscapeString(object_)));
502 request->AddHeader("X-Upload-Content-Length", std::to_string(file_size));
503 request->SetPostEmptyBody();
504 request->SetResultBuffer(&output_buffer);
505 request->SetTimeouts(timeouts_->connect, timeouts_->idle,
506 timeouts_->metadata);
507 TF_RETURN_WITH_CONTEXT_IF_ERROR(
508 request->Send(), " when initiating an upload to ", GetGcsPath());
509 *session_uri = request->GetResponseHeader("Location");
510 if (session_uri->empty()) {
511 return errors::Internal("Unexpected response from GCS when writing to ",
512 GetGcsPath(),
513 ": 'Location' header not returned.");
514 }
515 return Status::OK();
516 }
517
518 /// \brief Requests status of a previously initiated upload session.
519 ///
520 /// If the upload has already succeeded, sets 'completed' to true.
521 /// Otherwise sets 'completed' to false and 'uploaded' to the currently
522 /// uploaded size in bytes.
RequestUploadSessionStatus(const string & session_uri,bool * completed,uint64 * uploaded)523 Status RequestUploadSessionStatus(const string& session_uri, bool* completed,
524 uint64* uploaded) {
525 uint64 file_size;
526 TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size));
527
528 std::unique_ptr<HttpRequest> request;
529 TF_RETURN_IF_ERROR(filesystem_->CreateHttpRequest(&request));
530 request->SetUri(session_uri);
531 request->SetTimeouts(timeouts_->connect, timeouts_->idle,
532 timeouts_->metadata);
533 request->AddHeader("Content-Range", strings::StrCat("bytes */", file_size));
534 request->SetPutEmptyBody();
535 const Status& status = request->Send();
536 if (status.ok()) {
537 *completed = true;
538 return Status::OK();
539 }
540 *completed = false;
541 if (request->GetResponseCode() != HTTP_CODE_RESUME_INCOMPLETE) {
542 TF_RETURN_WITH_CONTEXT_IF_ERROR(status, " when resuming upload ",
543 GetGcsPath());
544 }
545 const string& received_range = request->GetResponseHeader("Range");
546 if (received_range.empty()) {
547 // This means GCS doesn't have any bytes of the file yet.
548 *uploaded = 0;
549 } else {
550 StringPiece range_piece(received_range);
551 str_util::ConsumePrefix(&range_piece,
552 "bytes="); // May or may not be present.
553 std::vector<int64> range_parts;
554 if (!str_util::SplitAndParseAsInts(range_piece, '-', &range_parts) ||
555 range_parts.size() != 2) {
556 return errors::Internal("Unexpected response from GCS when writing ",
557 GetGcsPath(), ": Range header '",
558 received_range, "' could not be parsed.");
559 }
560 if (range_parts[0] != 0) {
561 return errors::Internal("Unexpected response from GCS when writing to ",
562 GetGcsPath(), ": the returned range '",
563 received_range, "' does not start at zero.");
564 }
565 // If GCS returned "Range: 0-10", this means 11 bytes were uploaded.
566 *uploaded = range_parts[1] + 1;
567 }
568 return Status::OK();
569 }
570
UploadToSession(const string & session_uri,uint64 start_offset)571 Status UploadToSession(const string& session_uri, uint64 start_offset) {
572 uint64 file_size;
573 TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size));
574
575 std::unique_ptr<HttpRequest> request;
576 TF_RETURN_IF_ERROR(filesystem_->CreateHttpRequest(&request));
577 request->SetUri(session_uri);
578 if (file_size > 0) {
579 request->AddHeader("Content-Range",
580 strings::StrCat("bytes ", start_offset, "-",
581 file_size - 1, "/", file_size));
582 }
583 request->SetTimeouts(timeouts_->connect, timeouts_->idle, timeouts_->write);
584
585 TF_RETURN_IF_ERROR(
586 request->SetPutFromFile(tmp_content_filename_, start_offset));
587 TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when uploading ",
588 GetGcsPath());
589 // Erase the file from the file cache on every successful write.
590 file_cache_erase_();
591 return Status::OK();
592 }
593
GetGcsPath() const594 string GetGcsPath() const {
595 return strings::StrCat("gs://", bucket_, "/", object_);
596 }
597
598 string bucket_;
599 string object_;
600 GcsFileSystem* const filesystem_; // Not owned.
601 string tmp_content_filename_;
602 std::ofstream outfile_;
603 GcsFileSystem::TimeoutConfig* timeouts_;
604 std::function<void()> file_cache_erase_;
605 bool sync_needed_; // whether there is buffered data that needs to be synced
606 RetryConfig retry_config_;
607 };
608
609 class GcsReadOnlyMemoryRegion : public ReadOnlyMemoryRegion {
610 public:
GcsReadOnlyMemoryRegion(std::unique_ptr<char[]> data,uint64 length)611 GcsReadOnlyMemoryRegion(std::unique_ptr<char[]> data, uint64 length)
612 : data_(std::move(data)), length_(length) {}
data()613 const void* data() override { return reinterpret_cast<void*>(data_.get()); }
length()614 uint64 length() override { return length_; }
615
616 private:
617 std::unique_ptr<char[]> data_;
618 uint64 length_;
619 };
620
621 // Helper function to extract an environment variable and convert it into a
622 // value of type T.
623 template <typename T>
GetEnvVar(const char * varname,bool (* convert)(StringPiece,T *),T * value)624 bool GetEnvVar(const char* varname, bool (*convert)(StringPiece, T*),
625 T* value) {
626 const char* env_value = std::getenv(varname);
627 if (!env_value) {
628 return false;
629 }
630 return convert(env_value, value);
631 }
632
StringPieceIdentity(StringPiece str,StringPiece * value)633 bool StringPieceIdentity(StringPiece str, StringPiece* value) {
634 *value = str;
635 return true;
636 }
637
638 /// \brief Utility function to split a comma delimited list of strings to an
639 /// unordered set, lowercasing all values.
SplitByCommaToLowercaseSet(StringPiece list,std::unordered_set<string> * set)640 bool SplitByCommaToLowercaseSet(StringPiece list,
641 std::unordered_set<string>* set) {
642 std::vector<string> vector =
643 str_util::Split(tensorflow::str_util::Lowercase(list), ",");
644 *set = std::unordered_set<string>(vector.begin(), vector.end());
645 return true;
646 }
647
648 // \brief Convert Compute Engine zone to region
ZoneToRegion(string * zone)649 string ZoneToRegion(string* zone) {
650 return zone->substr(0, zone->find_last_of('-'));
651 }
652
653 } // namespace
654
GcsFileSystem()655 GcsFileSystem::GcsFileSystem() {
656 uint64 value;
657 size_t block_size = kDefaultBlockSize;
658 size_t max_bytes = kDefaultMaxCacheSize;
659 uint64 max_staleness = kDefaultMaxStaleness;
660
661 http_request_factory_ = std::make_shared<CurlHttpRequest::Factory>();
662 compute_engine_metadata_client_ =
663 std::make_shared<ComputeEngineMetadataClient>(http_request_factory_);
664 auth_provider_ = std::unique_ptr<AuthProvider>(
665 new GoogleAuthProvider(compute_engine_metadata_client_));
666 zone_provider_ = std::unique_ptr<ZoneProvider>(
667 new ComputeEngineZoneProvider(compute_engine_metadata_client_));
668
669 // Apply the sys env override for the readahead buffer size if it's provided.
670 if (GetEnvVar(kReadaheadBufferSize, strings::safe_strtou64, &value)) {
671 block_size = value;
672 }
673 // Apply the overrides for the block size (MB), max bytes (MB), and max
674 // staleness (seconds) if provided.
675 if (GetEnvVar(kBlockSize, strings::safe_strtou64, &value)) {
676 block_size = value * 1024 * 1024;
677 }
678 if (GetEnvVar(kMaxCacheSize, strings::safe_strtou64, &value)) {
679 max_bytes = value * 1024 * 1024;
680 }
681 if (GetEnvVar(kMaxStaleness, strings::safe_strtou64, &value)) {
682 max_staleness = value;
683 }
684 if (std::getenv(kReadCacheDisabled)) {
685 // Setting either to 0 disables the cache; set both for good measure.
686 block_size = max_bytes = 0;
687 }
688 VLOG(1) << "GCS cache max size = " << max_bytes << " ; "
689 << "block size = " << block_size << " ; "
690 << "max staleness = " << max_staleness;
691 file_block_cache_ = MakeFileBlockCache(block_size, max_bytes, max_staleness);
692 // Apply overrides for the stat cache max age and max entries, if provided.
693 uint64 stat_cache_max_age = kStatCacheDefaultMaxAge;
694 size_t stat_cache_max_entries = kStatCacheDefaultMaxEntries;
695 if (GetEnvVar(kStatCacheMaxAge, strings::safe_strtou64, &value)) {
696 stat_cache_max_age = value;
697 }
698 if (GetEnvVar(kStatCacheMaxEntries, strings::safe_strtou64, &value)) {
699 stat_cache_max_entries = value;
700 }
701 stat_cache_.reset(new ExpiringLRUCache<GcsFileStat>(stat_cache_max_age,
702 stat_cache_max_entries));
703 // Apply overrides for the matching paths cache max age and max entries, if
704 // provided.
705 uint64 matching_paths_cache_max_age = kMatchingPathsCacheDefaultMaxAge;
706 size_t matching_paths_cache_max_entries =
707 kMatchingPathsCacheDefaultMaxEntries;
708 if (GetEnvVar(kMatchingPathsCacheMaxAge, strings::safe_strtou64, &value)) {
709 matching_paths_cache_max_age = value;
710 }
711 if (GetEnvVar(kMatchingPathsCacheMaxEntries, strings::safe_strtou64,
712 &value)) {
713 matching_paths_cache_max_entries = value;
714 }
715 matching_paths_cache_.reset(new ExpiringLRUCache<std::vector<string>>(
716 matching_paths_cache_max_age, matching_paths_cache_max_entries));
717
718 bucket_location_cache_.reset(new ExpiringLRUCache<string>(
719 kCacheNeverExpire, kBucketLocationCacheMaxEntries));
720
721 int64 resolve_frequency_secs;
722 if (GetEnvVar(kResolveCacheSecs, strings::safe_strto64,
723 &resolve_frequency_secs)) {
724 dns_cache_.reset(new GcsDnsCache(resolve_frequency_secs));
725 VLOG(1) << "GCS DNS cache is enabled. " << kResolveCacheSecs << " = "
726 << resolve_frequency_secs;
727 } else {
728 VLOG(1) << "GCS DNS cache is disabled, because " << kResolveCacheSecs
729 << " = 0 (or is not set)";
730 }
731
732 // Get the additional header
733 StringPiece add_header_contents;
734 if (GetEnvVar(kAdditionalRequestHeader, StringPieceIdentity,
735 &add_header_contents)) {
736 size_t split = add_header_contents.find(':', 0);
737
738 if (split != StringPiece::npos) {
739 StringPiece header_name = add_header_contents.substr(0, split);
740 StringPiece header_value = add_header_contents.substr(split + 1);
741
742 if (!header_name.empty() && !header_value.empty()) {
743 additional_header_.reset(new std::pair<const string, const string>(
744 string(header_name), string(header_value)));
745
746 VLOG(1) << "GCS additional header ENABLED. "
747 << "Name: " << additional_header_->first << ", "
748 << "Value: " << additional_header_->second;
749 } else {
750 LOG(ERROR) << "GCS additional header DISABLED. Invalid contents: "
751 << add_header_contents;
752 }
753 } else {
754 LOG(ERROR) << "GCS additional header DISABLED. Invalid contents: "
755 << add_header_contents;
756 }
757 } else {
758 VLOG(1) << "GCS additional header DISABLED. No environment variable set.";
759 }
760
761 // Apply the overrides for request timeouts
762 uint32 timeout_value;
763 if (GetEnvVar(kRequestConnectionTimeout, strings::safe_strtou32,
764 &timeout_value)) {
765 timeouts_.connect = timeout_value;
766 }
767 if (GetEnvVar(kRequestIdleTimeout, strings::safe_strtou32, &timeout_value)) {
768 timeouts_.idle = timeout_value;
769 }
770 if (GetEnvVar(kMetadataRequestTimeout, strings::safe_strtou32,
771 &timeout_value)) {
772 timeouts_.metadata = timeout_value;
773 }
774 if (GetEnvVar(kReadRequestTimeout, strings::safe_strtou32, &timeout_value)) {
775 timeouts_.read = timeout_value;
776 }
777 if (GetEnvVar(kWriteRequestTimeout, strings::safe_strtou32, &timeout_value)) {
778 timeouts_.write = timeout_value;
779 }
780
781 int64 token_value;
782 if (GetEnvVar(kThrottleRate, strings::safe_strto64, &token_value)) {
783 GcsThrottleConfig config;
784 config.enabled = true;
785 config.token_rate = token_value;
786
787 if (GetEnvVar(kThrottleBucket, strings::safe_strto64, &token_value)) {
788 config.bucket_size = token_value;
789 }
790
791 if (GetEnvVar(kTokensPerRequest, strings::safe_strto64, &token_value)) {
792 config.tokens_per_request = token_value;
793 }
794
795 if (GetEnvVar(kInitialTokens, strings::safe_strto64, &token_value)) {
796 config.initial_tokens = token_value;
797 }
798 throttle_.SetConfig(config);
799 }
800
801 GetEnvVar(kAllowedBucketLocations, SplitByCommaToLowercaseSet,
802 &allowed_locations_);
803 }
804
GcsFileSystem(std::unique_ptr<AuthProvider> auth_provider,std::unique_ptr<HttpRequest::Factory> http_request_factory,std::unique_ptr<ZoneProvider> zone_provider,size_t block_size,size_t max_bytes,uint64 max_staleness,uint64 stat_cache_max_age,size_t stat_cache_max_entries,uint64 matching_paths_cache_max_age,size_t matching_paths_cache_max_entries,RetryConfig retry_config,TimeoutConfig timeouts,const std::unordered_set<string> & allowed_locations,std::pair<const string,const string> * additional_header)805 GcsFileSystem::GcsFileSystem(
806 std::unique_ptr<AuthProvider> auth_provider,
807 std::unique_ptr<HttpRequest::Factory> http_request_factory,
808 std::unique_ptr<ZoneProvider> zone_provider, size_t block_size,
809 size_t max_bytes, uint64 max_staleness, uint64 stat_cache_max_age,
810 size_t stat_cache_max_entries, uint64 matching_paths_cache_max_age,
811 size_t matching_paths_cache_max_entries, RetryConfig retry_config,
812 TimeoutConfig timeouts, const std::unordered_set<string>& allowed_locations,
813 std::pair<const string, const string>* additional_header)
814 : auth_provider_(std::move(auth_provider)),
815 http_request_factory_(std::move(http_request_factory)),
816 zone_provider_(std::move(zone_provider)),
817 file_block_cache_(
818 MakeFileBlockCache(block_size, max_bytes, max_staleness)),
819 stat_cache_(new StatCache(stat_cache_max_age, stat_cache_max_entries)),
820 matching_paths_cache_(new MatchingPathsCache(
821 matching_paths_cache_max_age, matching_paths_cache_max_entries)),
822 bucket_location_cache_(new BucketLocationCache(
823 kCacheNeverExpire, kBucketLocationCacheMaxEntries)),
824 allowed_locations_(allowed_locations),
825 timeouts_(timeouts),
826 retry_config_(retry_config),
827 additional_header_(additional_header) {}
828
NewRandomAccessFile(const string & fname,std::unique_ptr<RandomAccessFile> * result)829 Status GcsFileSystem::NewRandomAccessFile(
830 const string& fname, std::unique_ptr<RandomAccessFile>* result) {
831 string bucket, object;
832 TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
833 TF_RETURN_IF_ERROR(CheckBucketLocationConstraint(bucket));
834 result->reset(new GcsRandomAccessFile(fname, [this, bucket, object](
835 const string& fname,
836 uint64 offset, size_t n,
837 StringPiece* result,
838 char* scratch) {
839 tf_shared_lock l(block_cache_lock_);
840 if (file_block_cache_->IsCacheEnabled()) {
841 GcsFileStat stat;
842 TF_RETURN_IF_ERROR(stat_cache_->LookupOrCompute(
843 fname, &stat,
844 [this, bucket, object](const string& fname, GcsFileStat* stat) {
845 return UncachedStatForObject(fname, bucket, object, stat);
846 }));
847 if (!file_block_cache_->ValidateAndUpdateFileSignature(
848 fname, stat.generation_number)) {
849 VLOG(1)
850 << "File signature has been changed. Refreshing the cache. Path: "
851 << fname;
852 }
853 }
854 *result = StringPiece();
855 size_t bytes_transferred;
856 TF_RETURN_IF_ERROR(
857 file_block_cache_->Read(fname, offset, n, scratch, &bytes_transferred));
858 *result = StringPiece(scratch, bytes_transferred);
859 if (bytes_transferred < n) {
860 return errors::OutOfRange("EOF reached, ", result->size(),
861 " bytes were read out of ", n,
862 " bytes requested.");
863 }
864 return Status::OK();
865 }));
866 return Status::OK();
867 }
868
ResetFileBlockCache(size_t block_size_bytes,size_t max_bytes,uint64 max_staleness_secs)869 void GcsFileSystem::ResetFileBlockCache(size_t block_size_bytes,
870 size_t max_bytes,
871 uint64 max_staleness_secs) {
872 mutex_lock l(block_cache_lock_);
873 file_block_cache_ =
874 MakeFileBlockCache(block_size_bytes, max_bytes, max_staleness_secs);
875 if (stats_ != nullptr) {
876 stats_->Configure(this, &throttle_, file_block_cache_.get());
877 }
878 }
879
880 // A helper function to build a FileBlockCache for GcsFileSystem.
MakeFileBlockCache(size_t block_size,size_t max_bytes,uint64 max_staleness)881 std::unique_ptr<FileBlockCache> GcsFileSystem::MakeFileBlockCache(
882 size_t block_size, size_t max_bytes, uint64 max_staleness) {
883 std::unique_ptr<FileBlockCache> file_block_cache(new RamFileBlockCache(
884 block_size, max_bytes, max_staleness,
885 [this](const string& filename, size_t offset, size_t n, char* buffer,
886 size_t* bytes_transferred) {
887 return LoadBufferFromGCS(filename, offset, n, buffer,
888 bytes_transferred);
889 }));
890 return file_block_cache;
891 }
892
893 // A helper function to actually read the data from GCS.
LoadBufferFromGCS(const string & filename,size_t offset,size_t n,char * buffer,size_t * bytes_transferred)894 Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
895 size_t n, char* buffer,
896 size_t* bytes_transferred) {
897 *bytes_transferred = 0;
898
899 string bucket, object;
900 TF_RETURN_IF_ERROR(ParseGcsPath(filename, false, &bucket, &object));
901
902 std::unique_ptr<HttpRequest> request;
903 TF_RETURN_WITH_CONTEXT_IF_ERROR(CreateHttpRequest(&request),
904 "when reading gs://", bucket, "/", object);
905
906 request->SetUri(strings::StrCat("https://", kStorageHost, "/", bucket, "/",
907 request->EscapeString(object)));
908 request->SetRange(offset, offset + n - 1);
909 request->SetResultBufferDirect(buffer, n);
910 request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.read);
911
912 if (stats_ != nullptr) {
913 stats_->RecordBlockLoadRequest(filename, offset);
914 }
915
916 TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading gs://",
917 bucket, "/", object);
918
919 size_t bytes_read = request->GetResultBufferDirectBytesTransferred();
920 *bytes_transferred = bytes_read;
921 VLOG(1) << "Successful read of gs://" << bucket << "/" << object << " @ "
922 << offset << " of size: " << bytes_read;
923
924 if (stats_ != nullptr) {
925 stats_->RecordBlockRetrieved(filename, offset, bytes_read);
926 }
927
928 throttle_.RecordResponse(bytes_read);
929
930 if (bytes_read < n) {
931 // Check stat cache to see if we encountered an interrupted read.
932 GcsFileStat stat;
933 if (stat_cache_->Lookup(filename, &stat)) {
934 if (offset + bytes_read < stat.base.length) {
935 return errors::Internal(strings::Printf(
936 "File contents are inconsistent for file: %s @ %lu.",
937 filename.c_str(), offset));
938 }
939 VLOG(2) << "Successful integrity check for: gs://" << bucket << "/"
940 << object << " @ " << offset;
941 }
942 }
943
944 return Status::OK();
945 }
946
ClearFileCaches(const string & fname)947 void GcsFileSystem::ClearFileCaches(const string& fname) {
948 tf_shared_lock l(block_cache_lock_);
949 file_block_cache_->RemoveFile(fname);
950 stat_cache_->Delete(fname);
951 // TODO(rxsang): Remove the patterns that matche the file in
952 // MatchingPathsCache as well.
953 }
954
NewWritableFile(const string & fname,std::unique_ptr<WritableFile> * result)955 Status GcsFileSystem::NewWritableFile(const string& fname,
956 std::unique_ptr<WritableFile>* result) {
957 string bucket, object;
958 TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
959 result->reset(new GcsWritableFile(bucket, object, this, &timeouts_,
960 [this, fname]() { ClearFileCaches(fname); },
961 retry_config_));
962 return Status::OK();
963 }
964
965 // Reads the file from GCS in chunks and stores it in a tmp file,
966 // which is then passed to GcsWritableFile.
NewAppendableFile(const string & fname,std::unique_ptr<WritableFile> * result)967 Status GcsFileSystem::NewAppendableFile(const string& fname,
968 std::unique_ptr<WritableFile>* result) {
969 std::unique_ptr<RandomAccessFile> reader;
970 TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, &reader));
971 std::unique_ptr<char[]> buffer(new char[kReadAppendableFileBufferSize]);
972 Status status;
973 uint64 offset = 0;
974 StringPiece read_chunk;
975
976 // Read the file from GCS in chunks and save it to a tmp file.
977 string old_content_filename;
978 TF_RETURN_IF_ERROR(GetTmpFilename(&old_content_filename));
979 std::ofstream old_content(old_content_filename, std::ofstream::binary);
980 while (true) {
981 status = reader->Read(offset, kReadAppendableFileBufferSize, &read_chunk,
982 buffer.get());
983 if (status.ok()) {
984 old_content << read_chunk;
985 offset += kReadAppendableFileBufferSize;
986 } else if (status.code() == error::OUT_OF_RANGE) {
987 // Expected, this means we reached EOF.
988 old_content << read_chunk;
989 break;
990 } else {
991 return status;
992 }
993 }
994 old_content.close();
995
996 // Create a writable file and pass the old content to it.
997 string bucket, object;
998 TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
999 result->reset(new GcsWritableFile(
1000 bucket, object, this, old_content_filename, &timeouts_,
1001 [this, fname]() { ClearFileCaches(fname); }, retry_config_));
1002 return Status::OK();
1003 }
1004
NewReadOnlyMemoryRegionFromFile(const string & fname,std::unique_ptr<ReadOnlyMemoryRegion> * result)1005 Status GcsFileSystem::NewReadOnlyMemoryRegionFromFile(
1006 const string& fname, std::unique_ptr<ReadOnlyMemoryRegion>* result) {
1007 uint64 size;
1008 TF_RETURN_IF_ERROR(GetFileSize(fname, &size));
1009 std::unique_ptr<char[]> data(new char[size]);
1010
1011 std::unique_ptr<RandomAccessFile> file;
1012 TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, &file));
1013
1014 StringPiece piece;
1015 TF_RETURN_IF_ERROR(file->Read(0, size, &piece, data.get()));
1016
1017 result->reset(new GcsReadOnlyMemoryRegion(std::move(data), size));
1018 return Status::OK();
1019 }
1020
FileExists(const string & fname)1021 Status GcsFileSystem::FileExists(const string& fname) {
1022 string bucket, object;
1023 TF_RETURN_IF_ERROR(ParseGcsPath(fname, true, &bucket, &object));
1024 if (object.empty()) {
1025 bool result;
1026 TF_RETURN_IF_ERROR(BucketExists(bucket, &result));
1027 if (result) {
1028 return Status::OK();
1029 }
1030 }
1031
1032 // Check if the object exists.
1033 GcsFileStat stat;
1034 const Status status = StatForObject(fname, bucket, object, &stat);
1035 if (status.code() != errors::Code::NOT_FOUND) {
1036 return status;
1037 }
1038
1039 // Check if the folder exists.
1040 bool result;
1041 TF_RETURN_IF_ERROR(FolderExists(fname, &result));
1042 if (result) {
1043 return Status::OK();
1044 }
1045 return errors::NotFound("The specified path ", fname, " was not found.");
1046 }
1047
ObjectExists(const string & fname,const string & bucket,const string & object,bool * result)1048 Status GcsFileSystem::ObjectExists(const string& fname, const string& bucket,
1049 const string& object, bool* result) {
1050 GcsFileStat stat;
1051 const Status status = StatForObject(fname, bucket, object, &stat);
1052 switch (status.code()) {
1053 case errors::Code::OK:
1054 *result = !stat.base.is_directory;
1055 return Status::OK();
1056 case errors::Code::NOT_FOUND:
1057 *result = false;
1058 return Status::OK();
1059 default:
1060 return status;
1061 }
1062 }
1063
UncachedStatForObject(const string & fname,const string & bucket,const string & object,GcsFileStat * stat)1064 Status GcsFileSystem::UncachedStatForObject(const string& fname,
1065 const string& bucket,
1066 const string& object,
1067 GcsFileStat* stat) {
1068 std::vector<char> output_buffer;
1069 std::unique_ptr<HttpRequest> request;
1070 TF_RETURN_WITH_CONTEXT_IF_ERROR(CreateHttpRequest(&request),
1071 " when reading metadata of gs://", bucket,
1072 "/", object);
1073
1074 request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket, "/o/",
1075 request->EscapeString(object),
1076 "?fields=size%2Cgeneration%2Cupdated"));
1077 request->SetResultBuffer(&output_buffer);
1078 request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata);
1079
1080 if (stats_ != nullptr) {
1081 stats_->RecordStatObjectRequest();
1082 }
1083
1084 TF_RETURN_WITH_CONTEXT_IF_ERROR(
1085 request->Send(), " when reading metadata of gs://", bucket, "/", object);
1086
1087 Json::Value root;
1088 TF_RETURN_IF_ERROR(ParseJson(output_buffer, &root));
1089
1090 // Parse file size.
1091 TF_RETURN_IF_ERROR(GetInt64Value(root, "size", &stat->base.length));
1092
1093 // Parse generation number.
1094 TF_RETURN_IF_ERROR(
1095 GetInt64Value(root, "generation", &stat->generation_number));
1096
1097 // Parse file modification time.
1098 string updated;
1099 TF_RETURN_IF_ERROR(GetStringValue(root, "updated", &updated));
1100 TF_RETURN_IF_ERROR(ParseRfc3339Time(updated, &(stat->base.mtime_nsec)));
1101
1102 VLOG(1) << "Stat of: gs://" << bucket << "/" << object << " -- "
1103 << " length: " << stat->base.length
1104 << " generation: " << stat->generation_number
1105 << "; mtime_nsec: " << stat->base.mtime_nsec
1106 << "; updated: " << updated;
1107
1108 if (str_util::EndsWith(fname, "/")) {
1109 // In GCS a path can be both a directory and a file, both it is uncommon for
1110 // other file systems. To avoid the ambiguity, if a path ends with "/" in
1111 // GCS, we always regard it as a directory mark or a virtual directory.
1112 stat->base.is_directory = true;
1113 } else {
1114 stat->base.is_directory = false;
1115 }
1116 return Status::OK();
1117 }
1118
StatForObject(const string & fname,const string & bucket,const string & object,GcsFileStat * stat)1119 Status GcsFileSystem::StatForObject(const string& fname, const string& bucket,
1120 const string& object, GcsFileStat* stat) {
1121 if (object.empty()) {
1122 return errors::InvalidArgument(strings::Printf(
1123 "'object' must be a non-empty string. (File: %s)", fname.c_str()));
1124 }
1125
1126 TF_RETURN_IF_ERROR(stat_cache_->LookupOrCompute(
1127 fname, stat,
1128 [this, &bucket, &object](const string& fname, GcsFileStat* stat) {
1129 return UncachedStatForObject(fname, bucket, object, stat);
1130 }));
1131 return Status::OK();
1132 }
1133
BucketExists(const string & bucket,bool * result)1134 Status GcsFileSystem::BucketExists(const string& bucket, bool* result) {
1135 const Status status = GetBucketMetadata(bucket, nullptr);
1136 switch (status.code()) {
1137 case errors::Code::OK:
1138 *result = true;
1139 return Status::OK();
1140 case errors::Code::NOT_FOUND:
1141 *result = false;
1142 return Status::OK();
1143 default:
1144 return status;
1145 }
1146 }
1147
CheckBucketLocationConstraint(const string & bucket)1148 Status GcsFileSystem::CheckBucketLocationConstraint(const string& bucket) {
1149 if (allowed_locations_.empty()) {
1150 return Status::OK();
1151 }
1152
1153 // Avoid calling external API's in the constructor
1154 if (allowed_locations_.erase(kDetectZoneSentinalValue) == 1) {
1155 string zone;
1156 TF_RETURN_IF_ERROR(zone_provider_->GetZone(&zone));
1157 allowed_locations_.insert(ZoneToRegion(&zone));
1158 }
1159
1160 string location;
1161 TF_RETURN_IF_ERROR(GetBucketLocation(bucket, &location));
1162 if (allowed_locations_.find(location) != allowed_locations_.end()) {
1163 return Status::OK();
1164 }
1165
1166 return errors::FailedPrecondition(strings::Printf(
1167 "Bucket '%s' is in '%s' location, allowed locations are: (%s).",
1168 bucket.c_str(), location.c_str(),
1169 str_util::Join(allowed_locations_, ", ").c_str()));
1170 }
1171
GetBucketLocation(const string & bucket,string * location)1172 Status GcsFileSystem::GetBucketLocation(const string& bucket,
1173 string* location) {
1174 auto compute_func = [this](const string& bucket, string* location) {
1175 std::vector<char> result_buffer;
1176 Status status = GetBucketMetadata(bucket, &result_buffer);
1177 Json::Value result;
1178 TF_RETURN_IF_ERROR(ParseJson(result_buffer, &result));
1179 string bucket_location;
1180 TF_RETURN_IF_ERROR(
1181 GetStringValue(result, kBucketMetadataLocationKey, &bucket_location));
1182 // Lowercase the GCS location to be case insensitive for allowed locations.
1183 *location = tensorflow::str_util::Lowercase(bucket_location);
1184 return Status::OK();
1185 };
1186
1187 TF_RETURN_IF_ERROR(
1188 bucket_location_cache_->LookupOrCompute(bucket, location, compute_func));
1189
1190 return Status::OK();
1191 }
1192
GetBucketMetadata(const string & bucket,std::vector<char> * result_buffer)1193 Status GcsFileSystem::GetBucketMetadata(const string& bucket,
1194 std::vector<char>* result_buffer) {
1195 std::unique_ptr<HttpRequest> request;
1196 TF_RETURN_IF_ERROR(CreateHttpRequest(&request));
1197 request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket));
1198
1199 if (result_buffer != nullptr) {
1200 request->SetResultBuffer(result_buffer);
1201 }
1202
1203 request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata);
1204 return request->Send();
1205 }
1206
FolderExists(const string & dirname,bool * result)1207 Status GcsFileSystem::FolderExists(const string& dirname, bool* result) {
1208 StatCache::ComputeFunc compute_func = [this](const string& dirname,
1209 GcsFileStat* stat) {
1210 std::vector<string> children;
1211 TF_RETURN_IF_ERROR(
1212 GetChildrenBounded(dirname, 1, &children, true /* recursively */,
1213 true /* include_self_directory_marker */));
1214 if (!children.empty()) {
1215 stat->base = DIRECTORY_STAT;
1216 return Status::OK();
1217 } else {
1218 return errors::InvalidArgument("Not a directory!");
1219 }
1220 };
1221 GcsFileStat stat;
1222 Status s = stat_cache_->LookupOrCompute(MaybeAppendSlash(dirname), &stat,
1223 compute_func);
1224 if (s.ok()) {
1225 *result = stat.base.is_directory;
1226 return Status::OK();
1227 }
1228 if (errors::IsInvalidArgument(s)) {
1229 *result = false;
1230 return Status::OK();
1231 }
1232 return s;
1233 }
1234
GetChildren(const string & dirname,std::vector<string> * result)1235 Status GcsFileSystem::GetChildren(const string& dirname,
1236 std::vector<string>* result) {
1237 return GetChildrenBounded(dirname, UINT64_MAX, result,
1238 false /* recursively */,
1239 false /* include_self_directory_marker */);
1240 }
1241
GetMatchingPaths(const string & pattern,std::vector<string> * results)1242 Status GcsFileSystem::GetMatchingPaths(const string& pattern,
1243 std::vector<string>* results) {
1244 MatchingPathsCache::ComputeFunc compute_func =
1245 [this](const string& pattern, std::vector<string>* results) {
1246 results->clear();
1247 // Find the fixed prefix by looking for the first wildcard.
1248 const string& fixed_prefix =
1249 pattern.substr(0, pattern.find_first_of("*?[\\"));
1250 const string dir(io::Dirname(fixed_prefix));
1251 if (dir.empty()) {
1252 return errors::InvalidArgument(
1253 "A GCS pattern doesn't have a bucket name: ", pattern);
1254 }
1255 std::vector<string> all_files;
1256 TF_RETURN_IF_ERROR(GetChildrenBounded(
1257 dir, UINT64_MAX, &all_files, true /* recursively */,
1258 false /* include_self_directory_marker */));
1259
1260 const auto& files_and_folders = AddAllSubpaths(all_files);
1261
1262 // Match all obtained paths to the input pattern.
1263 for (const auto& path : files_and_folders) {
1264 const string& full_path = io::JoinPath(dir, path);
1265 if (Env::Default()->MatchPath(full_path, pattern)) {
1266 results->push_back(full_path);
1267 }
1268 }
1269 return Status::OK();
1270 };
1271 TF_RETURN_IF_ERROR(
1272 matching_paths_cache_->LookupOrCompute(pattern, results, compute_func));
1273 return Status::OK();
1274 }
1275
GetChildrenBounded(const string & dirname,uint64 max_results,std::vector<string> * result,bool recursive,bool include_self_directory_marker)1276 Status GcsFileSystem::GetChildrenBounded(const string& dirname,
1277 uint64 max_results,
1278 std::vector<string>* result,
1279 bool recursive,
1280 bool include_self_directory_marker) {
1281 if (!result) {
1282 return errors::InvalidArgument("'result' cannot be null");
1283 }
1284 string bucket, object_prefix;
1285 TF_RETURN_IF_ERROR(
1286 ParseGcsPath(MaybeAppendSlash(dirname), true, &bucket, &object_prefix));
1287
1288 string nextPageToken;
1289 uint64 retrieved_results = 0;
1290 while (true) { // A loop over multiple result pages.
1291 std::vector<char> output_buffer;
1292 std::unique_ptr<HttpRequest> request;
1293 TF_RETURN_IF_ERROR(CreateHttpRequest(&request));
1294 auto uri = strings::StrCat(kGcsUriBase, "b/", bucket, "/o");
1295 if (recursive) {
1296 uri = strings::StrCat(uri, "?fields=items%2Fname%2CnextPageToken");
1297 } else {
1298 // Set "/" as a delimiter to ask GCS to treat subfolders as children
1299 // and return them in "prefixes".
1300 uri = strings::StrCat(uri,
1301 "?fields=items%2Fname%2Cprefixes%2CnextPageToken");
1302 uri = strings::StrCat(uri, "&delimiter=%2F");
1303 }
1304 if (!object_prefix.empty()) {
1305 uri = strings::StrCat(uri,
1306 "&prefix=", request->EscapeString(object_prefix));
1307 }
1308 if (!nextPageToken.empty()) {
1309 uri = strings::StrCat(
1310 uri, "&pageToken=", request->EscapeString(nextPageToken));
1311 }
1312 if (max_results - retrieved_results < kGetChildrenDefaultPageSize) {
1313 uri =
1314 strings::StrCat(uri, "&maxResults=", max_results - retrieved_results);
1315 }
1316 request->SetUri(uri);
1317 request->SetResultBuffer(&output_buffer);
1318 request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata);
1319
1320 TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading ", dirname);
1321 Json::Value root;
1322 TF_RETURN_IF_ERROR(ParseJson(output_buffer, &root));
1323 const auto items = root.get("items", Json::Value::null);
1324 if (!items.isNull()) {
1325 if (!items.isArray()) {
1326 return errors::Internal(
1327 "Expected an array 'items' in the GCS response.");
1328 }
1329 for (size_t i = 0; i < items.size(); i++) {
1330 const auto item = items.get(i, Json::Value::null);
1331 if (!item.isObject()) {
1332 return errors::Internal(
1333 "Unexpected JSON format: 'items' should be a list of objects.");
1334 }
1335 string name;
1336 TF_RETURN_IF_ERROR(GetStringValue(item, "name", &name));
1337 // The names should be relative to the 'dirname'. That means the
1338 // 'object_prefix', which is part of 'dirname', should be removed from
1339 // the beginning of 'name'.
1340 StringPiece relative_path(name);
1341 if (!str_util::ConsumePrefix(&relative_path, object_prefix)) {
1342 return errors::Internal(strings::StrCat(
1343 "Unexpected response: the returned file name ", name,
1344 " doesn't match the prefix ", object_prefix));
1345 }
1346 if (!relative_path.empty() || include_self_directory_marker) {
1347 result->emplace_back(relative_path);
1348 }
1349 if (++retrieved_results >= max_results) {
1350 return Status::OK();
1351 }
1352 }
1353 }
1354 const auto prefixes = root.get("prefixes", Json::Value::null);
1355 if (!prefixes.isNull()) {
1356 // Subfolders are returned for the non-recursive mode.
1357 if (!prefixes.isArray()) {
1358 return errors::Internal(
1359 "'prefixes' was expected to be an array in the GCS response.");
1360 }
1361 for (size_t i = 0; i < prefixes.size(); i++) {
1362 const auto prefix = prefixes.get(i, Json::Value::null);
1363 if (prefix.isNull() || !prefix.isString()) {
1364 return errors::Internal(
1365 "'prefixes' was expected to be an array of strings in the GCS "
1366 "response.");
1367 }
1368 const string& prefix_str = prefix.asString();
1369 StringPiece relative_path(prefix_str);
1370 if (!str_util::ConsumePrefix(&relative_path, object_prefix)) {
1371 return errors::Internal(
1372 "Unexpected response: the returned folder name ", prefix_str,
1373 " doesn't match the prefix ", object_prefix);
1374 }
1375 result->emplace_back(relative_path);
1376 if (++retrieved_results >= max_results) {
1377 return Status::OK();
1378 }
1379 }
1380 }
1381 const auto token = root.get("nextPageToken", Json::Value::null);
1382 if (token.isNull()) {
1383 return Status::OK();
1384 }
1385 if (!token.isString()) {
1386 return errors::Internal(
1387 "Unexpected response: nextPageToken is not a string");
1388 }
1389 nextPageToken = token.asString();
1390 }
1391 }
1392
Stat(const string & fname,FileStatistics * stat)1393 Status GcsFileSystem::Stat(const string& fname, FileStatistics* stat) {
1394 if (!stat) {
1395 return errors::Internal("'stat' cannot be nullptr.");
1396 }
1397 string bucket, object;
1398 TF_RETURN_IF_ERROR(ParseGcsPath(fname, true, &bucket, &object));
1399 if (object.empty()) {
1400 bool is_bucket;
1401 TF_RETURN_IF_ERROR(BucketExists(bucket, &is_bucket));
1402 if (is_bucket) {
1403 *stat = DIRECTORY_STAT;
1404 return Status::OK();
1405 }
1406 return errors::NotFound("The specified bucket ", fname, " was not found.");
1407 }
1408
1409 GcsFileStat gcs_stat;
1410 const Status status = StatForObject(fname, bucket, object, &gcs_stat);
1411 if (status.ok()) {
1412 *stat = gcs_stat.base;
1413 return Status::OK();
1414 }
1415 if (status.code() != errors::Code::NOT_FOUND) {
1416 return status;
1417 }
1418 bool is_folder;
1419 TF_RETURN_IF_ERROR(FolderExists(fname, &is_folder));
1420 if (is_folder) {
1421 *stat = DIRECTORY_STAT;
1422 return Status::OK();
1423 }
1424 return errors::NotFound("The specified path ", fname, " was not found.");
1425 }
1426
DeleteFile(const string & fname)1427 Status GcsFileSystem::DeleteFile(const string& fname) {
1428 string bucket, object;
1429 TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
1430
1431 std::unique_ptr<HttpRequest> request;
1432 TF_RETURN_IF_ERROR(CreateHttpRequest(&request));
1433 request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket, "/o/",
1434 request->EscapeString(object)));
1435 request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata);
1436 request->SetDeleteRequest();
1437
1438 TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when deleting ", fname);
1439 ClearFileCaches(fname);
1440 return Status::OK();
1441 }
1442
CreateDir(const string & dirname)1443 Status GcsFileSystem::CreateDir(const string& dirname) {
1444 string bucket, object;
1445 TF_RETURN_IF_ERROR(ParseGcsPath(dirname, true, &bucket, &object));
1446 if (object.empty()) {
1447 bool is_bucket;
1448 TF_RETURN_IF_ERROR(BucketExists(bucket, &is_bucket));
1449 return is_bucket ? Status::OK()
1450 : errors::NotFound("The specified bucket ", dirname,
1451 " was not found.");
1452 }
1453
1454 const string dirname_with_slash = MaybeAppendSlash(dirname);
1455
1456 if (FileExists(dirname_with_slash).ok()) {
1457 return errors::AlreadyExists(dirname);
1458 }
1459
1460 // Create a zero-length directory marker object.
1461 std::unique_ptr<WritableFile> file;
1462 TF_RETURN_IF_ERROR(NewWritableFile(dirname_with_slash, &file));
1463 TF_RETURN_IF_ERROR(file->Close());
1464 return Status::OK();
1465 }
1466
1467 // Checks that the directory is empty (i.e no objects with this prefix exist).
1468 // Deletes the GCS directory marker if it exists.
DeleteDir(const string & dirname)1469 Status GcsFileSystem::DeleteDir(const string& dirname) {
1470 std::vector<string> children;
1471 // A directory is considered empty either if there are no matching objects
1472 // with the corresponding name prefix or if there is exactly one matching
1473 // object and it is the directory marker. Therefore we need to retrieve
1474 // at most two children for the prefix to detect if a directory is empty.
1475 TF_RETURN_IF_ERROR(
1476 GetChildrenBounded(dirname, 2, &children, true /* recursively */,
1477 true /* include_self_directory_marker */));
1478
1479 if (children.size() > 1 || (children.size() == 1 && !children[0].empty())) {
1480 return errors::FailedPrecondition("Cannot delete a non-empty directory.");
1481 }
1482 if (children.size() == 1 && children[0].empty()) {
1483 // This is the directory marker object. Delete it.
1484 return DeleteFile(MaybeAppendSlash(dirname));
1485 }
1486 return Status::OK();
1487 }
1488
GetFileSize(const string & fname,uint64 * file_size)1489 Status GcsFileSystem::GetFileSize(const string& fname, uint64* file_size) {
1490 if (!file_size) {
1491 return errors::Internal("'file_size' cannot be nullptr.");
1492 }
1493
1494 // Only validate the name.
1495 string bucket, object;
1496 TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
1497
1498 FileStatistics stat;
1499 TF_RETURN_IF_ERROR(Stat(fname, &stat));
1500 *file_size = stat.length;
1501 return Status::OK();
1502 }
1503
RenameFile(const string & src,const string & target)1504 Status GcsFileSystem::RenameFile(const string& src, const string& target) {
1505 if (!IsDirectory(src).ok()) {
1506 return RenameObject(src, target);
1507 }
1508 // Rename all individual objects in the directory one by one.
1509 std::vector<string> children;
1510 TF_RETURN_IF_ERROR(
1511 GetChildrenBounded(src, UINT64_MAX, &children, true /* recursively */,
1512 true /* include_self_directory_marker */));
1513 for (const string& subpath : children) {
1514 TF_RETURN_IF_ERROR(
1515 RenameObject(JoinGcsPath(src, subpath), JoinGcsPath(target, subpath)));
1516 }
1517 return Status::OK();
1518 }
1519
1520 // Uses a GCS API command to copy the object and then deletes the old one.
RenameObject(const string & src,const string & target)1521 Status GcsFileSystem::RenameObject(const string& src, const string& target) {
1522 string src_bucket, src_object, target_bucket, target_object;
1523 TF_RETURN_IF_ERROR(ParseGcsPath(src, false, &src_bucket, &src_object));
1524 TF_RETURN_IF_ERROR(
1525 ParseGcsPath(target, false, &target_bucket, &target_object));
1526
1527 std::unique_ptr<HttpRequest> request;
1528 TF_RETURN_IF_ERROR(CreateHttpRequest(&request));
1529 request->SetUri(strings::StrCat(kGcsUriBase, "b/", src_bucket, "/o/",
1530 request->EscapeString(src_object),
1531 "/rewriteTo/b/", target_bucket, "/o/",
1532 request->EscapeString(target_object)));
1533 request->SetPostEmptyBody();
1534 request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata);
1535 std::vector<char> output_buffer;
1536 request->SetResultBuffer(&output_buffer);
1537 TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when renaming ", src,
1538 " to ", target);
1539 // Flush the target from the caches. The source will be flushed in the
1540 // DeleteFile call below.
1541 ClearFileCaches(target);
1542 Json::Value root;
1543 TF_RETURN_IF_ERROR(ParseJson(output_buffer, &root));
1544 bool done;
1545 TF_RETURN_IF_ERROR(GetBoolValue(root, "done", &done));
1546 if (!done) {
1547 // If GCS didn't complete rewrite in one call, this means that a large file
1548 // is being copied to a bucket with a different storage class or location,
1549 // which requires multiple rewrite calls.
1550 // TODO(surkov): implement multi-step rewrites.
1551 return errors::Unimplemented(
1552 "Couldn't rename ", src, " to ", target,
1553 ": moving large files between buckets with different "
1554 "locations or storage classes is not supported.");
1555 }
1556
1557 // In case the delete API call failed, but the deletion actually happened
1558 // on the server side, we can't just retry the whole RenameFile operation
1559 // because the source object is already gone.
1560 return RetryingUtils::DeleteWithRetries(
1561 [this, &src]() { return DeleteFile(src); }, retry_config_);
1562 }
1563
IsDirectory(const string & fname)1564 Status GcsFileSystem::IsDirectory(const string& fname) {
1565 string bucket, object;
1566 TF_RETURN_IF_ERROR(ParseGcsPath(fname, true, &bucket, &object));
1567 if (object.empty()) {
1568 bool is_bucket;
1569 TF_RETURN_IF_ERROR(BucketExists(bucket, &is_bucket));
1570 if (is_bucket) {
1571 return Status::OK();
1572 }
1573 return errors::NotFound("The specified bucket gs://", bucket,
1574 " was not found.");
1575 }
1576 bool is_folder;
1577 TF_RETURN_IF_ERROR(FolderExists(fname, &is_folder));
1578 if (is_folder) {
1579 return Status::OK();
1580 }
1581 bool is_object;
1582 TF_RETURN_IF_ERROR(ObjectExists(fname, bucket, object, &is_object));
1583 if (is_object) {
1584 return errors::FailedPrecondition("The specified path ", fname,
1585 " is not a directory.");
1586 }
1587 return errors::NotFound("The specified path ", fname, " was not found.");
1588 }
1589
DeleteRecursively(const string & dirname,int64 * undeleted_files,int64 * undeleted_dirs)1590 Status GcsFileSystem::DeleteRecursively(const string& dirname,
1591 int64* undeleted_files,
1592 int64* undeleted_dirs) {
1593 if (!undeleted_files || !undeleted_dirs) {
1594 return errors::Internal(
1595 "'undeleted_files' and 'undeleted_dirs' cannot be nullptr.");
1596 }
1597 *undeleted_files = 0;
1598 *undeleted_dirs = 0;
1599 if (!IsDirectory(dirname).ok()) {
1600 *undeleted_dirs = 1;
1601 return Status(
1602 error::NOT_FOUND,
1603 strings::StrCat(dirname, " doesn't exist or not a directory."));
1604 }
1605 std::vector<string> all_objects;
1606 // Get all children in the directory recursively.
1607 TF_RETURN_IF_ERROR(GetChildrenBounded(
1608 dirname, UINT64_MAX, &all_objects, true /* recursively */,
1609 true /* include_self_directory_marker */));
1610 for (const string& object : all_objects) {
1611 const string& full_path = JoinGcsPath(dirname, object);
1612 // Delete all objects including directory markers for subfolders.
1613 // Since DeleteRecursively returns OK if individual file deletions fail,
1614 // and therefore RetryingFileSystem won't pay attention to the failures,
1615 // we need to make sure these failures are properly retried.
1616 const auto& delete_file_status = RetryingUtils::DeleteWithRetries(
1617 [this, &full_path]() { return DeleteFile(full_path); }, retry_config_);
1618 if (!delete_file_status.ok()) {
1619 if (IsDirectory(full_path).ok()) {
1620 // The object is a directory marker.
1621 (*undeleted_dirs)++;
1622 } else {
1623 (*undeleted_files)++;
1624 }
1625 }
1626 }
1627 return Status::OK();
1628 }
1629
1630 // Flushes all caches for filesystem metadata and file contents. Useful for
1631 // reclaiming memory once filesystem operations are done (e.g. model is loaded),
1632 // or for resetting the filesystem to a consistent state.
FlushCaches()1633 void GcsFileSystem::FlushCaches() {
1634 tf_shared_lock l(block_cache_lock_);
1635 file_block_cache_->Flush();
1636 stat_cache_->Clear();
1637 matching_paths_cache_->Clear();
1638 bucket_location_cache_->Clear();
1639 }
1640
SetStats(GcsStatsInterface * stats)1641 void GcsFileSystem::SetStats(GcsStatsInterface* stats) {
1642 CHECK(stats_ == nullptr) << "SetStats() has already been called.";
1643 CHECK(stats != nullptr);
1644 mutex_lock l(block_cache_lock_);
1645 stats_ = stats;
1646 stats_->Configure(this, &throttle_, file_block_cache_.get());
1647 }
1648
SetAuthProvider(std::unique_ptr<AuthProvider> auth_provider)1649 void GcsFileSystem::SetAuthProvider(
1650 std::unique_ptr<AuthProvider> auth_provider) {
1651 mutex_lock l(mu_);
1652 auth_provider_ = std::move(auth_provider);
1653 }
1654
1655 // Creates an HttpRequest and sets several parameters that are common to all
1656 // requests. All code (in GcsFileSystem) that creates an HttpRequest should
1657 // go through this method, rather than directly using http_request_factory_.
CreateHttpRequest(std::unique_ptr<HttpRequest> * request)1658 Status GcsFileSystem::CreateHttpRequest(std::unique_ptr<HttpRequest>* request) {
1659 std::unique_ptr<HttpRequest> new_request{http_request_factory_->Create()};
1660 if (dns_cache_) {
1661 dns_cache_->AnnotateRequest(new_request.get());
1662 }
1663
1664 string auth_token;
1665 {
1666 tf_shared_lock l(mu_);
1667 TF_RETURN_IF_ERROR(
1668 AuthProvider::GetToken(auth_provider_.get(), &auth_token));
1669 }
1670
1671 new_request->AddAuthBearerHeader(auth_token);
1672
1673 if (additional_header_) {
1674 new_request->AddHeader(additional_header_->first,
1675 additional_header_->second);
1676 }
1677
1678 if (stats_ != nullptr) {
1679 new_request->SetRequestStats(stats_->HttpStats());
1680 }
1681
1682 if (!throttle_.AdmitRequest()) {
1683 return errors::Unavailable("Request throttled");
1684 }
1685
1686 *request = std::move(new_request);
1687 return Status::OK();
1688 }
1689
1690 } // namespace tensorflow
1691
1692 // Initialize gcs_file_system
1693 REGISTER_FILE_SYSTEM("gs", ::tensorflow::RetryingGcsFileSystem);
1694