• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/platform/cloud/gcs_file_system.h"
17 #include <stdio.h>
18 #include <unistd.h>
19 #include <algorithm>
20 #include <cstdio>
21 #include <cstdlib>
22 #include <cstring>
23 #include <fstream>
24 #include <vector>
25 #ifdef _WIN32
26 #include <io.h>  // for _mktemp
27 #endif
28 #include "absl/base/macros.h"
29 #include "include/json/json.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/map_util.h"
32 #include "tensorflow/core/lib/gtl/stl_util.h"
33 #include "tensorflow/core/lib/io/path.h"
34 #include "tensorflow/core/lib/strings/numbers.h"
35 #include "tensorflow/core/lib/strings/str_util.h"
36 #include "tensorflow/core/lib/strings/stringprintf.h"
37 #include "tensorflow/core/platform/cloud/curl_http_request.h"
38 #include "tensorflow/core/platform/cloud/file_block_cache.h"
39 #include "tensorflow/core/platform/cloud/google_auth_provider.h"
40 #include "tensorflow/core/platform/cloud/ram_file_block_cache.h"
41 #include "tensorflow/core/platform/cloud/retrying_utils.h"
42 #include "tensorflow/core/platform/cloud/time_util.h"
43 #include "tensorflow/core/platform/env.h"
44 #include "tensorflow/core/platform/mutex.h"
45 #include "tensorflow/core/platform/protobuf.h"
46 #include "tensorflow/core/platform/thread_annotations.h"
47 
48 #ifdef _WIN32
49 #ifdef DeleteFile
50 #undef DeleteFile
51 #endif
52 #endif
53 
54 namespace tensorflow {
55 namespace {
56 
57 constexpr char kGcsUriBase[] = "https://www.googleapis.com/storage/v1/";
58 constexpr char kGcsUploadUriBase[] =
59     "https://www.googleapis.com/upload/storage/v1/";
60 constexpr char kStorageHost[] = "storage.googleapis.com";
61 constexpr char kBucketMetadataLocationKey[] = "location";
62 constexpr size_t kReadAppendableFileBufferSize = 1024 * 1024;  // In bytes.
63 constexpr int kGetChildrenDefaultPageSize = 1000;
64 // The HTTP response code "308 Resume Incomplete".
65 constexpr uint64 HTTP_CODE_RESUME_INCOMPLETE = 308;
66 // The environment variable that overrides the size of the readahead buffer.
67 ABSL_DEPRECATED("Use GCS_READ_CACHE_BLOCK_SIZE_MB instead.")
68 constexpr char kReadaheadBufferSize[] = "GCS_READAHEAD_BUFFER_SIZE_BYTES";
69 // The environment variable that disables the GCS block cache for reads.
70 // This is the explicit alternative to setting BLOCK_SIZE or MAX_SIZE to 0, and
71 // takes precedence over either of those environment variables.
72 constexpr char kReadCacheDisabled[] = "GCS_READ_CACHE_DISABLED";
73 // The environment variable that overrides the block size for aligned reads from
74 // GCS. Specified in MB (e.g. "16" = 16 x 1024 x 1024 = 16777216 bytes).
75 constexpr char kBlockSize[] = "GCS_READ_CACHE_BLOCK_SIZE_MB";
76 constexpr size_t kDefaultBlockSize = 16 * 1024 * 1024;
77 // The environment variable that overrides the max size of the LRU cache of
78 // blocks read from GCS. Specified in MB.
79 constexpr char kMaxCacheSize[] = "GCS_READ_CACHE_MAX_SIZE_MB";
80 constexpr size_t kDefaultMaxCacheSize = kDefaultBlockSize;
81 // The environment variable that overrides the maximum staleness of cached file
82 // contents. Once any block of a file reaches this staleness, all cached blocks
83 // will be evicted on the next read.
84 constexpr char kMaxStaleness[] = "GCS_READ_CACHE_MAX_STALENESS";
85 constexpr uint64 kDefaultMaxStaleness = 0;
86 // The environment variable that overrides the maximum age of entries in the
87 // Stat cache. A value of 0 (the default) means nothing is cached.
88 constexpr char kStatCacheMaxAge[] = "GCS_STAT_CACHE_MAX_AGE";
89 constexpr uint64 kStatCacheDefaultMaxAge = 5;
90 // The environment variable that overrides the maximum number of entries in the
91 // Stat cache.
92 constexpr char kStatCacheMaxEntries[] = "GCS_STAT_CACHE_MAX_ENTRIES";
93 constexpr size_t kStatCacheDefaultMaxEntries = 1024;
94 // The environment variable that overrides the maximum age of entries in the
95 // GetMatchingPaths cache. A value of 0 (the default) means nothing is cached.
96 constexpr char kMatchingPathsCacheMaxAge[] = "GCS_MATCHING_PATHS_CACHE_MAX_AGE";
97 constexpr uint64 kMatchingPathsCacheDefaultMaxAge = 0;
98 // The environment variable that overrides the maximum number of entries in the
99 // GetMatchingPaths cache.
100 constexpr char kMatchingPathsCacheMaxEntries[] =
101     "GCS_MATCHING_PATHS_CACHE_MAX_ENTRIES";
102 constexpr size_t kMatchingPathsCacheDefaultMaxEntries = 1024;
103 // Number of bucket locations cached, most workloads wont touch more than one
104 // bucket so this limit is set fairly low
105 constexpr size_t kBucketLocationCacheMaxEntries = 10;
106 // ExpiringLRUCache doesnt support any "cache forever" option
107 constexpr size_t kCacheNeverExpire = std::numeric_limits<uint64>::max();
108 // The file statistics returned by Stat() for directories.
109 const FileStatistics DIRECTORY_STAT(0, 0, true);
110 // Some environments exhibit unreliable DNS resolution. Set this environment
111 // variable to a positive integer describing the frequency used to refresh the
112 // userspace DNS cache.
113 constexpr char kResolveCacheSecs[] = "GCS_RESOLVE_REFRESH_SECS";
114 // The environment variable to configure the http request's connection timeout.
115 constexpr char kRequestConnectionTimeout[] =
116     "GCS_REQUEST_CONNECTION_TIMEOUT_SECS";
117 // The environment variable to configure the http request's idle timeout.
118 constexpr char kRequestIdleTimeout[] = "GCS_REQUEST_IDLE_TIMEOUT_SECS";
119 // The environment variable to configure the overall request timeout for
120 // metadata requests.
121 constexpr char kMetadataRequestTimeout[] = "GCS_METADATA_REQUEST_TIMEOUT_SECS";
122 // The environment variable to configure the overall request timeout for
123 // block reads requests.
124 constexpr char kReadRequestTimeout[] = "GCS_READ_REQUEST_TIMEOUT_SECS";
125 // The environment variable to configure the overall request timeout for
126 // upload requests.
127 constexpr char kWriteRequestTimeout[] = "GCS_WRITE_REQUEST_TIMEOUT_SECS";
128 // The environment variable to configure an additional header to send with
129 // all requests to GCS (format HEADERNAME:HEADERCONTENT)
130 constexpr char kAdditionalRequestHeader[] = "GCS_ADDITIONAL_REQUEST_HEADER";
131 // The environment variable to configure the throttle (format: <int64>)
132 constexpr char kThrottleRate[] = "GCS_THROTTLE_TOKEN_RATE";
133 // The environment variable to configure the token bucket size (format: <int64>)
134 constexpr char kThrottleBucket[] = "GCS_THROTTLE_BUCKET_SIZE";
135 // The environment variable that controls the number of tokens per request.
136 // (format: <int64>)
137 constexpr char kTokensPerRequest[] = "GCS_TOKENS_PER_REQUEST";
138 // The environment variable to configure the initial tokens (format: <int64>)
139 constexpr char kInitialTokens[] = "GCS_INITIAL_TOKENS";
140 
141 // The environment variable to customize which GCS bucket locations are allowed,
142 // if the list is empty defaults to using the region of the zone (format, comma
143 // delimited list). Requires 'storage.buckets.get' permission.
144 constexpr char kAllowedBucketLocations[] = "GCS_ALLOWED_BUCKET_LOCATIONS";
145 // When this value is passed as an allowed location detects the zone tensorflow
146 // is running in and restricts to buckets in that region.
147 constexpr char kDetectZoneSentinalValue[] = "auto";
148 
149 // TODO: DO NOT use a hardcoded path
GetTmpFilename(string * filename)150 Status GetTmpFilename(string* filename) {
151 #ifndef _WIN32
152   char buffer[] = "/tmp/gcs_filesystem_XXXXXX";
153   int fd = mkstemp(buffer);
154   if (fd < 0) {
155     return errors::Internal("Failed to create a temporary file.");
156   }
157   close(fd);
158 #else
159   char buffer[] = "/tmp/gcs_filesystem_XXXXXX";
160   char* ret = _mktemp(buffer);
161   if (ret == nullptr) {
162     return errors::Internal("Failed to create a temporary file.");
163   }
164 #endif
165   *filename = buffer;
166   return Status::OK();
167 }
168 
169 /// \brief Splits a GCS path to a bucket and an object.
170 ///
171 /// For example, "gs://bucket-name/path/to/file.txt" gets split into
172 /// "bucket-name" and "path/to/file.txt".
173 /// If fname only contains the bucket and empty_object_ok = true, the returned
174 /// object is empty.
ParseGcsPath(StringPiece fname,bool empty_object_ok,string * bucket,string * object)175 Status ParseGcsPath(StringPiece fname, bool empty_object_ok, string* bucket,
176                     string* object) {
177   StringPiece scheme, bucketp, objectp;
178   io::ParseURI(fname, &scheme, &bucketp, &objectp);
179   if (scheme != "gs") {
180     return errors::InvalidArgument("GCS path doesn't start with 'gs://': ",
181                                    fname);
182   }
183   *bucket = string(bucketp);
184   if (bucket->empty() || *bucket == ".") {
185     return errors::InvalidArgument("GCS path doesn't contain a bucket name: ",
186                                    fname);
187   }
188   str_util::ConsumePrefix(&objectp, "/");
189   *object = string(objectp);
190   if (!empty_object_ok && object->empty()) {
191     return errors::InvalidArgument("GCS path doesn't contain an object name: ",
192                                    fname);
193   }
194   return Status::OK();
195 }
196 
197 /// Appends a trailing slash if the name doesn't already have one.
MaybeAppendSlash(const string & name)198 string MaybeAppendSlash(const string& name) {
199   if (name.empty()) {
200     return "/";
201   }
202   if (name.back() != '/') {
203     return strings::StrCat(name, "/");
204   }
205   return name;
206 }
207 
208 // io::JoinPath() doesn't work in cases when we want an empty subpath
209 // to result in an appended slash in order for directory markers
210 // to be processed correctly: "gs://a/b" + "" should give "gs://a/b/".
JoinGcsPath(const string & path,const string & subpath)211 string JoinGcsPath(const string& path, const string& subpath) {
212   return strings::StrCat(MaybeAppendSlash(path), subpath);
213 }
214 
215 /// \brief Returns the given paths appending all their subfolders.
216 ///
217 /// For every path X in the list, every subfolder in X is added to the
218 /// resulting list.
219 /// For example:
220 ///  - for 'a/b/c/d' it will append 'a', 'a/b' and 'a/b/c'
221 ///  - for 'a/b/c/' it will append 'a', 'a/b' and 'a/b/c'
AddAllSubpaths(const std::vector<string> & paths)222 std::set<string> AddAllSubpaths(const std::vector<string>& paths) {
223   std::set<string> result;
224   result.insert(paths.begin(), paths.end());
225   for (const string& path : paths) {
226     StringPiece subpath = io::Dirname(path);
227     while (!subpath.empty()) {
228       result.emplace(string(subpath));
229       subpath = io::Dirname(subpath);
230     }
231   }
232   return result;
233 }
234 
ParseJson(StringPiece json,Json::Value * result)235 Status ParseJson(StringPiece json, Json::Value* result) {
236   Json::Reader reader;
237   if (!reader.parse(json.data(), json.data() + json.size(), *result)) {
238     return errors::Internal("Couldn't parse JSON response from GCS.");
239   }
240   return Status::OK();
241 }
242 
ParseJson(const std::vector<char> & json,Json::Value * result)243 Status ParseJson(const std::vector<char>& json, Json::Value* result) {
244   return ParseJson(StringPiece{json.data(), json.size()}, result);
245 }
246 
247 /// Reads a JSON value with the given name from a parent JSON value.
GetValue(const Json::Value & parent,const char * name,Json::Value * result)248 Status GetValue(const Json::Value& parent, const char* name,
249                 Json::Value* result) {
250   *result = parent.get(name, Json::Value::null);
251   if (result->isNull()) {
252     return errors::Internal("The field '", name,
253                             "' was expected in the JSON response.");
254   }
255   return Status::OK();
256 }
257 
258 /// Reads a string JSON value with the given name from a parent JSON value.
GetStringValue(const Json::Value & parent,const char * name,string * result)259 Status GetStringValue(const Json::Value& parent, const char* name,
260                       string* result) {
261   Json::Value result_value;
262   TF_RETURN_IF_ERROR(GetValue(parent, name, &result_value));
263   if (!result_value.isString()) {
264     return errors::Internal(
265         "The field '", name,
266         "' in the JSON response was expected to be a string.");
267   }
268   *result = result_value.asString();
269   return Status::OK();
270 }
271 
272 /// Reads a long JSON value with the given name from a parent JSON value.
GetInt64Value(const Json::Value & parent,const char * name,int64 * result)273 Status GetInt64Value(const Json::Value& parent, const char* name,
274                      int64* result) {
275   Json::Value result_value;
276   TF_RETURN_IF_ERROR(GetValue(parent, name, &result_value));
277   if (result_value.isNumeric()) {
278     *result = result_value.asInt64();
279     return Status::OK();
280   }
281   if (result_value.isString() &&
282       strings::safe_strto64(result_value.asCString(), result)) {
283     return Status::OK();
284   }
285   return errors::Internal(
286       "The field '", name,
287       "' in the JSON response was expected to be a number.");
288 }
289 
290 /// Reads a boolean JSON value with the given name from a parent JSON value.
GetBoolValue(const Json::Value & parent,const char * name,bool * result)291 Status GetBoolValue(const Json::Value& parent, const char* name, bool* result) {
292   Json::Value result_value;
293   TF_RETURN_IF_ERROR(GetValue(parent, name, &result_value));
294   if (!result_value.isBool()) {
295     return errors::Internal(
296         "The field '", name,
297         "' in the JSON response was expected to be a boolean.");
298   }
299   *result = result_value.asBool();
300   return Status::OK();
301 }
302 
303 /// A GCS-based implementation of a random access file with an LRU block cache.
304 class GcsRandomAccessFile : public RandomAccessFile {
305  public:
306   using ReadFn =
307       std::function<Status(const string& filename, uint64 offset, size_t n,
308                            StringPiece* result, char* scratch)>;
309 
GcsRandomAccessFile(const string & filename,ReadFn read_fn)310   GcsRandomAccessFile(const string& filename, ReadFn read_fn)
311       : filename_(filename), read_fn_(std::move(read_fn)) {}
312 
Name(StringPiece * result) const313   Status Name(StringPiece* result) const override {
314     *result = filename_;
315     return Status::OK();
316   }
317 
318   /// The implementation of reads with an LRU block cache. Thread safe.
Read(uint64 offset,size_t n,StringPiece * result,char * scratch) const319   Status Read(uint64 offset, size_t n, StringPiece* result,
320               char* scratch) const override {
321     return read_fn_(filename_, offset, n, result, scratch);
322   }
323 
324  private:
325   /// The filename of this file.
326   const string filename_;
327   /// The implementation of the read operation (provided by the GCSFileSystem).
328   const ReadFn read_fn_;
329 };
330 
331 /// \brief GCS-based implementation of a writeable file.
332 ///
333 /// Since GCS objects are immutable, this implementation writes to a local
334 /// tmp file and copies it to GCS on flush/close.
335 class GcsWritableFile : public WritableFile {
336  public:
GcsWritableFile(const string & bucket,const string & object,GcsFileSystem * filesystem,GcsFileSystem::TimeoutConfig * timeouts,std::function<void ()> file_cache_erase,RetryConfig retry_config)337   GcsWritableFile(const string& bucket, const string& object,
338                   GcsFileSystem* filesystem,
339                   GcsFileSystem::TimeoutConfig* timeouts,
340                   std::function<void()> file_cache_erase,
341                   RetryConfig retry_config)
342       : bucket_(bucket),
343         object_(object),
344         filesystem_(filesystem),
345         timeouts_(timeouts),
346         file_cache_erase_(std::move(file_cache_erase)),
347         sync_needed_(true),
348         retry_config_(retry_config) {
349     // TODO: to make it safer, outfile_ should be constructed from an FD
350     if (GetTmpFilename(&tmp_content_filename_).ok()) {
351       outfile_.open(tmp_content_filename_,
352                     std::ofstream::binary | std::ofstream::app);
353     }
354   }
355 
356   /// \brief Constructs the writable file in append mode.
357   ///
358   /// tmp_content_filename should contain a path of an existing temporary file
359   /// with the content to be appended. The class takes onwnership of the
360   /// specified tmp file and deletes it on close.
GcsWritableFile(const string & bucket,const string & object,GcsFileSystem * filesystem,const string & tmp_content_filename,GcsFileSystem::TimeoutConfig * timeouts,std::function<void ()> file_cache_erase,RetryConfig retry_config)361   GcsWritableFile(const string& bucket, const string& object,
362                   GcsFileSystem* filesystem, const string& tmp_content_filename,
363                   GcsFileSystem::TimeoutConfig* timeouts,
364                   std::function<void()> file_cache_erase,
365                   RetryConfig retry_config)
366       : bucket_(bucket),
367         object_(object),
368         filesystem_(filesystem),
369         timeouts_(timeouts),
370         file_cache_erase_(std::move(file_cache_erase)),
371         sync_needed_(true),
372         retry_config_(retry_config) {
373     tmp_content_filename_ = tmp_content_filename;
374     outfile_.open(tmp_content_filename_,
375                   std::ofstream::binary | std::ofstream::app);
376   }
377 
~GcsWritableFile()378   ~GcsWritableFile() override { Close().IgnoreError(); }
379 
Append(StringPiece data)380   Status Append(StringPiece data) override {
381     TF_RETURN_IF_ERROR(CheckWritable());
382     sync_needed_ = true;
383     outfile_ << data;
384     if (!outfile_.good()) {
385       return errors::Internal(
386           "Could not append to the internal temporary file.");
387     }
388     return Status::OK();
389   }
390 
Close()391   Status Close() override {
392     if (outfile_.is_open()) {
393       TF_RETURN_IF_ERROR(Sync());
394       outfile_.close();
395       std::remove(tmp_content_filename_.c_str());
396     }
397     return Status::OK();
398   }
399 
Flush()400   Status Flush() override { return Sync(); }
401 
Name(StringPiece * result) const402   Status Name(StringPiece* result) const override {
403     return errors::Unimplemented("GCSWritableFile does not support Name()");
404   }
405 
Sync()406   Status Sync() override {
407     TF_RETURN_IF_ERROR(CheckWritable());
408     if (!sync_needed_) {
409       return Status::OK();
410     }
411     Status status = SyncImpl();
412     if (status.ok()) {
413       sync_needed_ = false;
414     }
415     return status;
416   }
417 
Tell(int64 * position)418   Status Tell(int64* position) override {
419     *position = outfile_.tellp();
420     if (*position == -1) {
421       return errors::Internal("tellp on the internal temporary file failed");
422     }
423     return Status::OK();
424   }
425 
426  private:
427   /// Copies the current version of the file to GCS.
428   ///
429   /// This SyncImpl() uploads the object to GCS.
430   /// In case of a failure, it resumes failed uploads as recommended by the GCS
431   /// resumable API documentation. When the whole upload needs to be
432   /// restarted, Sync() returns UNAVAILABLE and relies on RetryingFileSystem.
SyncImpl()433   Status SyncImpl() {
434     outfile_.flush();
435     if (!outfile_.good()) {
436       return errors::Internal(
437           "Could not write to the internal temporary file.");
438     }
439     string session_uri;
440     TF_RETURN_IF_ERROR(CreateNewUploadSession(&session_uri));
441     uint64 already_uploaded = 0;
442     bool first_attempt = true;
443     const Status upload_status = RetryingUtils::CallWithRetries(
444         [&first_attempt, &already_uploaded, &session_uri, this]() {
445           if (!first_attempt) {
446             bool completed;
447             TF_RETURN_IF_ERROR(RequestUploadSessionStatus(
448                 session_uri, &completed, &already_uploaded));
449             if (completed) {
450               // Erase the file from the file cache on every successful write.
451               file_cache_erase_();
452               // It's unclear why UploadToSession didn't return OK in the
453               // previous attempt, but GCS reports that the file is fully
454               // uploaded, so succeed.
455               return Status::OK();
456             }
457           }
458           first_attempt = false;
459           return UploadToSession(session_uri, already_uploaded);
460         },
461         retry_config_);
462     if (upload_status.code() == errors::Code::NOT_FOUND) {
463       // GCS docs recommend retrying the whole upload. We're relying on the
464       // RetryingFileSystem to retry the Sync() call.
465       return errors::Unavailable(
466           strings::StrCat("Upload to gs://", bucket_, "/", object_,
467                           " failed, caused by: ", upload_status.ToString()));
468     }
469     return upload_status;
470   }
471 
CheckWritable() const472   Status CheckWritable() const {
473     if (!outfile_.is_open()) {
474       return errors::FailedPrecondition(
475           "The internal temporary file is not writable.");
476     }
477     return Status::OK();
478   }
479 
GetCurrentFileSize(uint64 * size)480   Status GetCurrentFileSize(uint64* size) {
481     const auto tellp = outfile_.tellp();
482     if (tellp == static_cast<std::streampos>(-1)) {
483       return errors::Internal(
484           "Could not get the size of the internal temporary file.");
485     }
486     *size = tellp;
487     return Status::OK();
488   }
489 
490   /// Initiates a new resumable upload session.
CreateNewUploadSession(string * session_uri)491   Status CreateNewUploadSession(string* session_uri) {
492     uint64 file_size;
493     TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size));
494 
495     std::vector<char> output_buffer;
496     std::unique_ptr<HttpRequest> request;
497     TF_RETURN_IF_ERROR(filesystem_->CreateHttpRequest(&request));
498 
499     request->SetUri(strings::StrCat(
500         kGcsUploadUriBase, "b/", bucket_,
501         "/o?uploadType=resumable&name=", request->EscapeString(object_)));
502     request->AddHeader("X-Upload-Content-Length", std::to_string(file_size));
503     request->SetPostEmptyBody();
504     request->SetResultBuffer(&output_buffer);
505     request->SetTimeouts(timeouts_->connect, timeouts_->idle,
506                          timeouts_->metadata);
507     TF_RETURN_WITH_CONTEXT_IF_ERROR(
508         request->Send(), " when initiating an upload to ", GetGcsPath());
509     *session_uri = request->GetResponseHeader("Location");
510     if (session_uri->empty()) {
511       return errors::Internal("Unexpected response from GCS when writing to ",
512                               GetGcsPath(),
513                               ": 'Location' header not returned.");
514     }
515     return Status::OK();
516   }
517 
518   /// \brief Requests status of a previously initiated upload session.
519   ///
520   /// If the upload has already succeeded, sets 'completed' to true.
521   /// Otherwise sets 'completed' to false and 'uploaded' to the currently
522   /// uploaded size in bytes.
RequestUploadSessionStatus(const string & session_uri,bool * completed,uint64 * uploaded)523   Status RequestUploadSessionStatus(const string& session_uri, bool* completed,
524                                     uint64* uploaded) {
525     uint64 file_size;
526     TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size));
527 
528     std::unique_ptr<HttpRequest> request;
529     TF_RETURN_IF_ERROR(filesystem_->CreateHttpRequest(&request));
530     request->SetUri(session_uri);
531     request->SetTimeouts(timeouts_->connect, timeouts_->idle,
532                          timeouts_->metadata);
533     request->AddHeader("Content-Range", strings::StrCat("bytes */", file_size));
534     request->SetPutEmptyBody();
535     const Status& status = request->Send();
536     if (status.ok()) {
537       *completed = true;
538       return Status::OK();
539     }
540     *completed = false;
541     if (request->GetResponseCode() != HTTP_CODE_RESUME_INCOMPLETE) {
542       TF_RETURN_WITH_CONTEXT_IF_ERROR(status, " when resuming upload ",
543                                       GetGcsPath());
544     }
545     const string& received_range = request->GetResponseHeader("Range");
546     if (received_range.empty()) {
547       // This means GCS doesn't have any bytes of the file yet.
548       *uploaded = 0;
549     } else {
550       StringPiece range_piece(received_range);
551       str_util::ConsumePrefix(&range_piece,
552                               "bytes=");  // May or may not be present.
553       std::vector<int64> range_parts;
554       if (!str_util::SplitAndParseAsInts(range_piece, '-', &range_parts) ||
555           range_parts.size() != 2) {
556         return errors::Internal("Unexpected response from GCS when writing ",
557                                 GetGcsPath(), ": Range header '",
558                                 received_range, "' could not be parsed.");
559       }
560       if (range_parts[0] != 0) {
561         return errors::Internal("Unexpected response from GCS when writing to ",
562                                 GetGcsPath(), ": the returned range '",
563                                 received_range, "' does not start at zero.");
564       }
565       // If GCS returned "Range: 0-10", this means 11 bytes were uploaded.
566       *uploaded = range_parts[1] + 1;
567     }
568     return Status::OK();
569   }
570 
UploadToSession(const string & session_uri,uint64 start_offset)571   Status UploadToSession(const string& session_uri, uint64 start_offset) {
572     uint64 file_size;
573     TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size));
574 
575     std::unique_ptr<HttpRequest> request;
576     TF_RETURN_IF_ERROR(filesystem_->CreateHttpRequest(&request));
577     request->SetUri(session_uri);
578     if (file_size > 0) {
579       request->AddHeader("Content-Range",
580                          strings::StrCat("bytes ", start_offset, "-",
581                                          file_size - 1, "/", file_size));
582     }
583     request->SetTimeouts(timeouts_->connect, timeouts_->idle, timeouts_->write);
584 
585     TF_RETURN_IF_ERROR(
586         request->SetPutFromFile(tmp_content_filename_, start_offset));
587     TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when uploading ",
588                                     GetGcsPath());
589     // Erase the file from the file cache on every successful write.
590     file_cache_erase_();
591     return Status::OK();
592   }
593 
GetGcsPath() const594   string GetGcsPath() const {
595     return strings::StrCat("gs://", bucket_, "/", object_);
596   }
597 
598   string bucket_;
599   string object_;
600   GcsFileSystem* const filesystem_;  // Not owned.
601   string tmp_content_filename_;
602   std::ofstream outfile_;
603   GcsFileSystem::TimeoutConfig* timeouts_;
604   std::function<void()> file_cache_erase_;
605   bool sync_needed_;  // whether there is buffered data that needs to be synced
606   RetryConfig retry_config_;
607 };
608 
609 class GcsReadOnlyMemoryRegion : public ReadOnlyMemoryRegion {
610  public:
GcsReadOnlyMemoryRegion(std::unique_ptr<char[]> data,uint64 length)611   GcsReadOnlyMemoryRegion(std::unique_ptr<char[]> data, uint64 length)
612       : data_(std::move(data)), length_(length) {}
data()613   const void* data() override { return reinterpret_cast<void*>(data_.get()); }
length()614   uint64 length() override { return length_; }
615 
616  private:
617   std::unique_ptr<char[]> data_;
618   uint64 length_;
619 };
620 
621 // Helper function to extract an environment variable and convert it into a
622 // value of type T.
623 template <typename T>
GetEnvVar(const char * varname,bool (* convert)(StringPiece,T *),T * value)624 bool GetEnvVar(const char* varname, bool (*convert)(StringPiece, T*),
625                T* value) {
626   const char* env_value = std::getenv(varname);
627   if (!env_value) {
628     return false;
629   }
630   return convert(env_value, value);
631 }
632 
StringPieceIdentity(StringPiece str,StringPiece * value)633 bool StringPieceIdentity(StringPiece str, StringPiece* value) {
634   *value = str;
635   return true;
636 }
637 
638 /// \brief Utility function to split a comma delimited list of strings to an
639 /// unordered set, lowercasing all values.
SplitByCommaToLowercaseSet(StringPiece list,std::unordered_set<string> * set)640 bool SplitByCommaToLowercaseSet(StringPiece list,
641                                 std::unordered_set<string>* set) {
642   std::vector<string> vector =
643       str_util::Split(tensorflow::str_util::Lowercase(list), ",");
644   *set = std::unordered_set<string>(vector.begin(), vector.end());
645   return true;
646 }
647 
648 // \brief Convert Compute Engine zone to region
ZoneToRegion(string * zone)649 string ZoneToRegion(string* zone) {
650   return zone->substr(0, zone->find_last_of('-'));
651 }
652 
653 }  // namespace
654 
GcsFileSystem()655 GcsFileSystem::GcsFileSystem() {
656   uint64 value;
657   size_t block_size = kDefaultBlockSize;
658   size_t max_bytes = kDefaultMaxCacheSize;
659   uint64 max_staleness = kDefaultMaxStaleness;
660 
661   http_request_factory_ = std::make_shared<CurlHttpRequest::Factory>();
662   compute_engine_metadata_client_ =
663       std::make_shared<ComputeEngineMetadataClient>(http_request_factory_);
664   auth_provider_ = std::unique_ptr<AuthProvider>(
665       new GoogleAuthProvider(compute_engine_metadata_client_));
666   zone_provider_ = std::unique_ptr<ZoneProvider>(
667       new ComputeEngineZoneProvider(compute_engine_metadata_client_));
668 
669   // Apply the sys env override for the readahead buffer size if it's provided.
670   if (GetEnvVar(kReadaheadBufferSize, strings::safe_strtou64, &value)) {
671     block_size = value;
672   }
673   // Apply the overrides for the block size (MB), max bytes (MB), and max
674   // staleness (seconds) if provided.
675   if (GetEnvVar(kBlockSize, strings::safe_strtou64, &value)) {
676     block_size = value * 1024 * 1024;
677   }
678   if (GetEnvVar(kMaxCacheSize, strings::safe_strtou64, &value)) {
679     max_bytes = value * 1024 * 1024;
680   }
681   if (GetEnvVar(kMaxStaleness, strings::safe_strtou64, &value)) {
682     max_staleness = value;
683   }
684   if (std::getenv(kReadCacheDisabled)) {
685     // Setting either to 0 disables the cache; set both for good measure.
686     block_size = max_bytes = 0;
687   }
688   VLOG(1) << "GCS cache max size = " << max_bytes << " ; "
689           << "block size = " << block_size << " ; "
690           << "max staleness = " << max_staleness;
691   file_block_cache_ = MakeFileBlockCache(block_size, max_bytes, max_staleness);
692   // Apply overrides for the stat cache max age and max entries, if provided.
693   uint64 stat_cache_max_age = kStatCacheDefaultMaxAge;
694   size_t stat_cache_max_entries = kStatCacheDefaultMaxEntries;
695   if (GetEnvVar(kStatCacheMaxAge, strings::safe_strtou64, &value)) {
696     stat_cache_max_age = value;
697   }
698   if (GetEnvVar(kStatCacheMaxEntries, strings::safe_strtou64, &value)) {
699     stat_cache_max_entries = value;
700   }
701   stat_cache_.reset(new ExpiringLRUCache<GcsFileStat>(stat_cache_max_age,
702                                                       stat_cache_max_entries));
703   // Apply overrides for the matching paths cache max age and max entries, if
704   // provided.
705   uint64 matching_paths_cache_max_age = kMatchingPathsCacheDefaultMaxAge;
706   size_t matching_paths_cache_max_entries =
707       kMatchingPathsCacheDefaultMaxEntries;
708   if (GetEnvVar(kMatchingPathsCacheMaxAge, strings::safe_strtou64, &value)) {
709     matching_paths_cache_max_age = value;
710   }
711   if (GetEnvVar(kMatchingPathsCacheMaxEntries, strings::safe_strtou64,
712                 &value)) {
713     matching_paths_cache_max_entries = value;
714   }
715   matching_paths_cache_.reset(new ExpiringLRUCache<std::vector<string>>(
716       matching_paths_cache_max_age, matching_paths_cache_max_entries));
717 
718   bucket_location_cache_.reset(new ExpiringLRUCache<string>(
719       kCacheNeverExpire, kBucketLocationCacheMaxEntries));
720 
721   int64 resolve_frequency_secs;
722   if (GetEnvVar(kResolveCacheSecs, strings::safe_strto64,
723                 &resolve_frequency_secs)) {
724     dns_cache_.reset(new GcsDnsCache(resolve_frequency_secs));
725     VLOG(1) << "GCS DNS cache is enabled.  " << kResolveCacheSecs << " = "
726             << resolve_frequency_secs;
727   } else {
728     VLOG(1) << "GCS DNS cache is disabled, because " << kResolveCacheSecs
729             << " = 0 (or is not set)";
730   }
731 
732   // Get the additional header
733   StringPiece add_header_contents;
734   if (GetEnvVar(kAdditionalRequestHeader, StringPieceIdentity,
735                 &add_header_contents)) {
736     size_t split = add_header_contents.find(':', 0);
737 
738     if (split != StringPiece::npos) {
739       StringPiece header_name = add_header_contents.substr(0, split);
740       StringPiece header_value = add_header_contents.substr(split + 1);
741 
742       if (!header_name.empty() && !header_value.empty()) {
743         additional_header_.reset(new std::pair<const string, const string>(
744             string(header_name), string(header_value)));
745 
746         VLOG(1) << "GCS additional header ENABLED. "
747                 << "Name: " << additional_header_->first << ", "
748                 << "Value: " << additional_header_->second;
749       } else {
750         LOG(ERROR) << "GCS additional header DISABLED. Invalid contents: "
751                    << add_header_contents;
752       }
753     } else {
754       LOG(ERROR) << "GCS additional header DISABLED. Invalid contents: "
755                  << add_header_contents;
756     }
757   } else {
758     VLOG(1) << "GCS additional header DISABLED. No environment variable set.";
759   }
760 
761   // Apply the overrides for request timeouts
762   uint32 timeout_value;
763   if (GetEnvVar(kRequestConnectionTimeout, strings::safe_strtou32,
764                 &timeout_value)) {
765     timeouts_.connect = timeout_value;
766   }
767   if (GetEnvVar(kRequestIdleTimeout, strings::safe_strtou32, &timeout_value)) {
768     timeouts_.idle = timeout_value;
769   }
770   if (GetEnvVar(kMetadataRequestTimeout, strings::safe_strtou32,
771                 &timeout_value)) {
772     timeouts_.metadata = timeout_value;
773   }
774   if (GetEnvVar(kReadRequestTimeout, strings::safe_strtou32, &timeout_value)) {
775     timeouts_.read = timeout_value;
776   }
777   if (GetEnvVar(kWriteRequestTimeout, strings::safe_strtou32, &timeout_value)) {
778     timeouts_.write = timeout_value;
779   }
780 
781   int64 token_value;
782   if (GetEnvVar(kThrottleRate, strings::safe_strto64, &token_value)) {
783     GcsThrottleConfig config;
784     config.enabled = true;
785     config.token_rate = token_value;
786 
787     if (GetEnvVar(kThrottleBucket, strings::safe_strto64, &token_value)) {
788       config.bucket_size = token_value;
789     }
790 
791     if (GetEnvVar(kTokensPerRequest, strings::safe_strto64, &token_value)) {
792       config.tokens_per_request = token_value;
793     }
794 
795     if (GetEnvVar(kInitialTokens, strings::safe_strto64, &token_value)) {
796       config.initial_tokens = token_value;
797     }
798     throttle_.SetConfig(config);
799   }
800 
801   GetEnvVar(kAllowedBucketLocations, SplitByCommaToLowercaseSet,
802             &allowed_locations_);
803 }
804 
GcsFileSystem(std::unique_ptr<AuthProvider> auth_provider,std::unique_ptr<HttpRequest::Factory> http_request_factory,std::unique_ptr<ZoneProvider> zone_provider,size_t block_size,size_t max_bytes,uint64 max_staleness,uint64 stat_cache_max_age,size_t stat_cache_max_entries,uint64 matching_paths_cache_max_age,size_t matching_paths_cache_max_entries,RetryConfig retry_config,TimeoutConfig timeouts,const std::unordered_set<string> & allowed_locations,std::pair<const string,const string> * additional_header)805 GcsFileSystem::GcsFileSystem(
806     std::unique_ptr<AuthProvider> auth_provider,
807     std::unique_ptr<HttpRequest::Factory> http_request_factory,
808     std::unique_ptr<ZoneProvider> zone_provider, size_t block_size,
809     size_t max_bytes, uint64 max_staleness, uint64 stat_cache_max_age,
810     size_t stat_cache_max_entries, uint64 matching_paths_cache_max_age,
811     size_t matching_paths_cache_max_entries, RetryConfig retry_config,
812     TimeoutConfig timeouts, const std::unordered_set<string>& allowed_locations,
813     std::pair<const string, const string>* additional_header)
814     : auth_provider_(std::move(auth_provider)),
815       http_request_factory_(std::move(http_request_factory)),
816       zone_provider_(std::move(zone_provider)),
817       file_block_cache_(
818           MakeFileBlockCache(block_size, max_bytes, max_staleness)),
819       stat_cache_(new StatCache(stat_cache_max_age, stat_cache_max_entries)),
820       matching_paths_cache_(new MatchingPathsCache(
821           matching_paths_cache_max_age, matching_paths_cache_max_entries)),
822       bucket_location_cache_(new BucketLocationCache(
823           kCacheNeverExpire, kBucketLocationCacheMaxEntries)),
824       allowed_locations_(allowed_locations),
825       timeouts_(timeouts),
826       retry_config_(retry_config),
827       additional_header_(additional_header) {}
828 
NewRandomAccessFile(const string & fname,std::unique_ptr<RandomAccessFile> * result)829 Status GcsFileSystem::NewRandomAccessFile(
830     const string& fname, std::unique_ptr<RandomAccessFile>* result) {
831   string bucket, object;
832   TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
833   TF_RETURN_IF_ERROR(CheckBucketLocationConstraint(bucket));
834   result->reset(new GcsRandomAccessFile(fname, [this, bucket, object](
835                                                    const string& fname,
836                                                    uint64 offset, size_t n,
837                                                    StringPiece* result,
838                                                    char* scratch) {
839     tf_shared_lock l(block_cache_lock_);
840     if (file_block_cache_->IsCacheEnabled()) {
841       GcsFileStat stat;
842       TF_RETURN_IF_ERROR(stat_cache_->LookupOrCompute(
843           fname, &stat,
844           [this, bucket, object](const string& fname, GcsFileStat* stat) {
845             return UncachedStatForObject(fname, bucket, object, stat);
846           }));
847       if (!file_block_cache_->ValidateAndUpdateFileSignature(
848               fname, stat.generation_number)) {
849         VLOG(1)
850             << "File signature has been changed. Refreshing the cache. Path: "
851             << fname;
852       }
853     }
854     *result = StringPiece();
855     size_t bytes_transferred;
856     TF_RETURN_IF_ERROR(
857         file_block_cache_->Read(fname, offset, n, scratch, &bytes_transferred));
858     *result = StringPiece(scratch, bytes_transferred);
859     if (bytes_transferred < n) {
860       return errors::OutOfRange("EOF reached, ", result->size(),
861                                 " bytes were read out of ", n,
862                                 " bytes requested.");
863     }
864     return Status::OK();
865   }));
866   return Status::OK();
867 }
868 
ResetFileBlockCache(size_t block_size_bytes,size_t max_bytes,uint64 max_staleness_secs)869 void GcsFileSystem::ResetFileBlockCache(size_t block_size_bytes,
870                                         size_t max_bytes,
871                                         uint64 max_staleness_secs) {
872   mutex_lock l(block_cache_lock_);
873   file_block_cache_ =
874       MakeFileBlockCache(block_size_bytes, max_bytes, max_staleness_secs);
875   if (stats_ != nullptr) {
876     stats_->Configure(this, &throttle_, file_block_cache_.get());
877   }
878 }
879 
880 // A helper function to build a FileBlockCache for GcsFileSystem.
MakeFileBlockCache(size_t block_size,size_t max_bytes,uint64 max_staleness)881 std::unique_ptr<FileBlockCache> GcsFileSystem::MakeFileBlockCache(
882     size_t block_size, size_t max_bytes, uint64 max_staleness) {
883   std::unique_ptr<FileBlockCache> file_block_cache(new RamFileBlockCache(
884       block_size, max_bytes, max_staleness,
885       [this](const string& filename, size_t offset, size_t n, char* buffer,
886              size_t* bytes_transferred) {
887         return LoadBufferFromGCS(filename, offset, n, buffer,
888                                  bytes_transferred);
889       }));
890   return file_block_cache;
891 }
892 
893 // A helper function to actually read the data from GCS.
LoadBufferFromGCS(const string & filename,size_t offset,size_t n,char * buffer,size_t * bytes_transferred)894 Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
895                                         size_t n, char* buffer,
896                                         size_t* bytes_transferred) {
897   *bytes_transferred = 0;
898 
899   string bucket, object;
900   TF_RETURN_IF_ERROR(ParseGcsPath(filename, false, &bucket, &object));
901 
902   std::unique_ptr<HttpRequest> request;
903   TF_RETURN_WITH_CONTEXT_IF_ERROR(CreateHttpRequest(&request),
904                                   "when reading gs://", bucket, "/", object);
905 
906   request->SetUri(strings::StrCat("https://", kStorageHost, "/", bucket, "/",
907                                   request->EscapeString(object)));
908   request->SetRange(offset, offset + n - 1);
909   request->SetResultBufferDirect(buffer, n);
910   request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.read);
911 
912   if (stats_ != nullptr) {
913     stats_->RecordBlockLoadRequest(filename, offset);
914   }
915 
916   TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading gs://",
917                                   bucket, "/", object);
918 
919   size_t bytes_read = request->GetResultBufferDirectBytesTransferred();
920   *bytes_transferred = bytes_read;
921   VLOG(1) << "Successful read of gs://" << bucket << "/" << object << " @ "
922           << offset << " of size: " << bytes_read;
923 
924   if (stats_ != nullptr) {
925     stats_->RecordBlockRetrieved(filename, offset, bytes_read);
926   }
927 
928   throttle_.RecordResponse(bytes_read);
929 
930   if (bytes_read < n) {
931     // Check stat cache to see if we encountered an interrupted read.
932     GcsFileStat stat;
933     if (stat_cache_->Lookup(filename, &stat)) {
934       if (offset + bytes_read < stat.base.length) {
935         return errors::Internal(strings::Printf(
936             "File contents are inconsistent for file: %s @ %lu.",
937             filename.c_str(), offset));
938       }
939       VLOG(2) << "Successful integrity check for: gs://" << bucket << "/"
940               << object << " @ " << offset;
941     }
942   }
943 
944   return Status::OK();
945 }
946 
ClearFileCaches(const string & fname)947 void GcsFileSystem::ClearFileCaches(const string& fname) {
948   tf_shared_lock l(block_cache_lock_);
949   file_block_cache_->RemoveFile(fname);
950   stat_cache_->Delete(fname);
951   // TODO(rxsang): Remove the patterns that matche the file in
952   // MatchingPathsCache as well.
953 }
954 
NewWritableFile(const string & fname,std::unique_ptr<WritableFile> * result)955 Status GcsFileSystem::NewWritableFile(const string& fname,
956                                       std::unique_ptr<WritableFile>* result) {
957   string bucket, object;
958   TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
959   result->reset(new GcsWritableFile(bucket, object, this, &timeouts_,
960                                     [this, fname]() { ClearFileCaches(fname); },
961                                     retry_config_));
962   return Status::OK();
963 }
964 
965 // Reads the file from GCS in chunks and stores it in a tmp file,
966 // which is then passed to GcsWritableFile.
NewAppendableFile(const string & fname,std::unique_ptr<WritableFile> * result)967 Status GcsFileSystem::NewAppendableFile(const string& fname,
968                                         std::unique_ptr<WritableFile>* result) {
969   std::unique_ptr<RandomAccessFile> reader;
970   TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, &reader));
971   std::unique_ptr<char[]> buffer(new char[kReadAppendableFileBufferSize]);
972   Status status;
973   uint64 offset = 0;
974   StringPiece read_chunk;
975 
976   // Read the file from GCS in chunks and save it to a tmp file.
977   string old_content_filename;
978   TF_RETURN_IF_ERROR(GetTmpFilename(&old_content_filename));
979   std::ofstream old_content(old_content_filename, std::ofstream::binary);
980   while (true) {
981     status = reader->Read(offset, kReadAppendableFileBufferSize, &read_chunk,
982                           buffer.get());
983     if (status.ok()) {
984       old_content << read_chunk;
985       offset += kReadAppendableFileBufferSize;
986     } else if (status.code() == error::OUT_OF_RANGE) {
987       // Expected, this means we reached EOF.
988       old_content << read_chunk;
989       break;
990     } else {
991       return status;
992     }
993   }
994   old_content.close();
995 
996   // Create a writable file and pass the old content to it.
997   string bucket, object;
998   TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
999   result->reset(new GcsWritableFile(
1000       bucket, object, this, old_content_filename, &timeouts_,
1001       [this, fname]() { ClearFileCaches(fname); }, retry_config_));
1002   return Status::OK();
1003 }
1004 
NewReadOnlyMemoryRegionFromFile(const string & fname,std::unique_ptr<ReadOnlyMemoryRegion> * result)1005 Status GcsFileSystem::NewReadOnlyMemoryRegionFromFile(
1006     const string& fname, std::unique_ptr<ReadOnlyMemoryRegion>* result) {
1007   uint64 size;
1008   TF_RETURN_IF_ERROR(GetFileSize(fname, &size));
1009   std::unique_ptr<char[]> data(new char[size]);
1010 
1011   std::unique_ptr<RandomAccessFile> file;
1012   TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, &file));
1013 
1014   StringPiece piece;
1015   TF_RETURN_IF_ERROR(file->Read(0, size, &piece, data.get()));
1016 
1017   result->reset(new GcsReadOnlyMemoryRegion(std::move(data), size));
1018   return Status::OK();
1019 }
1020 
FileExists(const string & fname)1021 Status GcsFileSystem::FileExists(const string& fname) {
1022   string bucket, object;
1023   TF_RETURN_IF_ERROR(ParseGcsPath(fname, true, &bucket, &object));
1024   if (object.empty()) {
1025     bool result;
1026     TF_RETURN_IF_ERROR(BucketExists(bucket, &result));
1027     if (result) {
1028       return Status::OK();
1029     }
1030   }
1031 
1032   // Check if the object exists.
1033   GcsFileStat stat;
1034   const Status status = StatForObject(fname, bucket, object, &stat);
1035   if (status.code() != errors::Code::NOT_FOUND) {
1036     return status;
1037   }
1038 
1039   // Check if the folder exists.
1040   bool result;
1041   TF_RETURN_IF_ERROR(FolderExists(fname, &result));
1042   if (result) {
1043     return Status::OK();
1044   }
1045   return errors::NotFound("The specified path ", fname, " was not found.");
1046 }
1047 
ObjectExists(const string & fname,const string & bucket,const string & object,bool * result)1048 Status GcsFileSystem::ObjectExists(const string& fname, const string& bucket,
1049                                    const string& object, bool* result) {
1050   GcsFileStat stat;
1051   const Status status = StatForObject(fname, bucket, object, &stat);
1052   switch (status.code()) {
1053     case errors::Code::OK:
1054       *result = !stat.base.is_directory;
1055       return Status::OK();
1056     case errors::Code::NOT_FOUND:
1057       *result = false;
1058       return Status::OK();
1059     default:
1060       return status;
1061   }
1062 }
1063 
UncachedStatForObject(const string & fname,const string & bucket,const string & object,GcsFileStat * stat)1064 Status GcsFileSystem::UncachedStatForObject(const string& fname,
1065                                             const string& bucket,
1066                                             const string& object,
1067                                             GcsFileStat* stat) {
1068   std::vector<char> output_buffer;
1069   std::unique_ptr<HttpRequest> request;
1070   TF_RETURN_WITH_CONTEXT_IF_ERROR(CreateHttpRequest(&request),
1071                                   " when reading metadata of gs://", bucket,
1072                                   "/", object);
1073 
1074   request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket, "/o/",
1075                                   request->EscapeString(object),
1076                                   "?fields=size%2Cgeneration%2Cupdated"));
1077   request->SetResultBuffer(&output_buffer);
1078   request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata);
1079 
1080   if (stats_ != nullptr) {
1081     stats_->RecordStatObjectRequest();
1082   }
1083 
1084   TF_RETURN_WITH_CONTEXT_IF_ERROR(
1085       request->Send(), " when reading metadata of gs://", bucket, "/", object);
1086 
1087   Json::Value root;
1088   TF_RETURN_IF_ERROR(ParseJson(output_buffer, &root));
1089 
1090   // Parse file size.
1091   TF_RETURN_IF_ERROR(GetInt64Value(root, "size", &stat->base.length));
1092 
1093   // Parse generation number.
1094   TF_RETURN_IF_ERROR(
1095       GetInt64Value(root, "generation", &stat->generation_number));
1096 
1097   // Parse file modification time.
1098   string updated;
1099   TF_RETURN_IF_ERROR(GetStringValue(root, "updated", &updated));
1100   TF_RETURN_IF_ERROR(ParseRfc3339Time(updated, &(stat->base.mtime_nsec)));
1101 
1102   VLOG(1) << "Stat of: gs://" << bucket << "/" << object << " -- "
1103           << " length: " << stat->base.length
1104           << " generation: " << stat->generation_number
1105           << "; mtime_nsec: " << stat->base.mtime_nsec
1106           << "; updated: " << updated;
1107 
1108   if (str_util::EndsWith(fname, "/")) {
1109     // In GCS a path can be both a directory and a file, both it is uncommon for
1110     // other file systems. To avoid the ambiguity, if a path ends with "/" in
1111     // GCS, we always regard it as a directory mark or a virtual directory.
1112     stat->base.is_directory = true;
1113   } else {
1114     stat->base.is_directory = false;
1115   }
1116   return Status::OK();
1117 }
1118 
StatForObject(const string & fname,const string & bucket,const string & object,GcsFileStat * stat)1119 Status GcsFileSystem::StatForObject(const string& fname, const string& bucket,
1120                                     const string& object, GcsFileStat* stat) {
1121   if (object.empty()) {
1122     return errors::InvalidArgument(strings::Printf(
1123         "'object' must be a non-empty string. (File: %s)", fname.c_str()));
1124   }
1125 
1126   TF_RETURN_IF_ERROR(stat_cache_->LookupOrCompute(
1127       fname, stat,
1128       [this, &bucket, &object](const string& fname, GcsFileStat* stat) {
1129         return UncachedStatForObject(fname, bucket, object, stat);
1130       }));
1131   return Status::OK();
1132 }
1133 
BucketExists(const string & bucket,bool * result)1134 Status GcsFileSystem::BucketExists(const string& bucket, bool* result) {
1135   const Status status = GetBucketMetadata(bucket, nullptr);
1136   switch (status.code()) {
1137     case errors::Code::OK:
1138       *result = true;
1139       return Status::OK();
1140     case errors::Code::NOT_FOUND:
1141       *result = false;
1142       return Status::OK();
1143     default:
1144       return status;
1145   }
1146 }
1147 
CheckBucketLocationConstraint(const string & bucket)1148 Status GcsFileSystem::CheckBucketLocationConstraint(const string& bucket) {
1149   if (allowed_locations_.empty()) {
1150     return Status::OK();
1151   }
1152 
1153   // Avoid calling external API's in the constructor
1154   if (allowed_locations_.erase(kDetectZoneSentinalValue) == 1) {
1155     string zone;
1156     TF_RETURN_IF_ERROR(zone_provider_->GetZone(&zone));
1157     allowed_locations_.insert(ZoneToRegion(&zone));
1158   }
1159 
1160   string location;
1161   TF_RETURN_IF_ERROR(GetBucketLocation(bucket, &location));
1162   if (allowed_locations_.find(location) != allowed_locations_.end()) {
1163     return Status::OK();
1164   }
1165 
1166   return errors::FailedPrecondition(strings::Printf(
1167       "Bucket '%s' is in '%s' location, allowed locations are: (%s).",
1168       bucket.c_str(), location.c_str(),
1169       str_util::Join(allowed_locations_, ", ").c_str()));
1170 }
1171 
GetBucketLocation(const string & bucket,string * location)1172 Status GcsFileSystem::GetBucketLocation(const string& bucket,
1173                                         string* location) {
1174   auto compute_func = [this](const string& bucket, string* location) {
1175     std::vector<char> result_buffer;
1176     Status status = GetBucketMetadata(bucket, &result_buffer);
1177     Json::Value result;
1178     TF_RETURN_IF_ERROR(ParseJson(result_buffer, &result));
1179     string bucket_location;
1180     TF_RETURN_IF_ERROR(
1181         GetStringValue(result, kBucketMetadataLocationKey, &bucket_location));
1182     // Lowercase the GCS location to be case insensitive for allowed locations.
1183     *location = tensorflow::str_util::Lowercase(bucket_location);
1184     return Status::OK();
1185   };
1186 
1187   TF_RETURN_IF_ERROR(
1188       bucket_location_cache_->LookupOrCompute(bucket, location, compute_func));
1189 
1190   return Status::OK();
1191 }
1192 
GetBucketMetadata(const string & bucket,std::vector<char> * result_buffer)1193 Status GcsFileSystem::GetBucketMetadata(const string& bucket,
1194                                         std::vector<char>* result_buffer) {
1195   std::unique_ptr<HttpRequest> request;
1196   TF_RETURN_IF_ERROR(CreateHttpRequest(&request));
1197   request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket));
1198 
1199   if (result_buffer != nullptr) {
1200     request->SetResultBuffer(result_buffer);
1201   }
1202 
1203   request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata);
1204   return request->Send();
1205 }
1206 
FolderExists(const string & dirname,bool * result)1207 Status GcsFileSystem::FolderExists(const string& dirname, bool* result) {
1208   StatCache::ComputeFunc compute_func = [this](const string& dirname,
1209                                                GcsFileStat* stat) {
1210     std::vector<string> children;
1211     TF_RETURN_IF_ERROR(
1212         GetChildrenBounded(dirname, 1, &children, true /* recursively */,
1213                            true /* include_self_directory_marker */));
1214     if (!children.empty()) {
1215       stat->base = DIRECTORY_STAT;
1216       return Status::OK();
1217     } else {
1218       return errors::InvalidArgument("Not a directory!");
1219     }
1220   };
1221   GcsFileStat stat;
1222   Status s = stat_cache_->LookupOrCompute(MaybeAppendSlash(dirname), &stat,
1223                                           compute_func);
1224   if (s.ok()) {
1225     *result = stat.base.is_directory;
1226     return Status::OK();
1227   }
1228   if (errors::IsInvalidArgument(s)) {
1229     *result = false;
1230     return Status::OK();
1231   }
1232   return s;
1233 }
1234 
GetChildren(const string & dirname,std::vector<string> * result)1235 Status GcsFileSystem::GetChildren(const string& dirname,
1236                                   std::vector<string>* result) {
1237   return GetChildrenBounded(dirname, UINT64_MAX, result,
1238                             false /* recursively */,
1239                             false /* include_self_directory_marker */);
1240 }
1241 
GetMatchingPaths(const string & pattern,std::vector<string> * results)1242 Status GcsFileSystem::GetMatchingPaths(const string& pattern,
1243                                        std::vector<string>* results) {
1244   MatchingPathsCache::ComputeFunc compute_func =
1245       [this](const string& pattern, std::vector<string>* results) {
1246         results->clear();
1247         // Find the fixed prefix by looking for the first wildcard.
1248         const string& fixed_prefix =
1249             pattern.substr(0, pattern.find_first_of("*?[\\"));
1250         const string dir(io::Dirname(fixed_prefix));
1251         if (dir.empty()) {
1252           return errors::InvalidArgument(
1253               "A GCS pattern doesn't have a bucket name: ", pattern);
1254         }
1255         std::vector<string> all_files;
1256         TF_RETURN_IF_ERROR(GetChildrenBounded(
1257             dir, UINT64_MAX, &all_files, true /* recursively */,
1258             false /* include_self_directory_marker */));
1259 
1260         const auto& files_and_folders = AddAllSubpaths(all_files);
1261 
1262         // Match all obtained paths to the input pattern.
1263         for (const auto& path : files_and_folders) {
1264           const string& full_path = io::JoinPath(dir, path);
1265           if (Env::Default()->MatchPath(full_path, pattern)) {
1266             results->push_back(full_path);
1267           }
1268         }
1269         return Status::OK();
1270       };
1271   TF_RETURN_IF_ERROR(
1272       matching_paths_cache_->LookupOrCompute(pattern, results, compute_func));
1273   return Status::OK();
1274 }
1275 
GetChildrenBounded(const string & dirname,uint64 max_results,std::vector<string> * result,bool recursive,bool include_self_directory_marker)1276 Status GcsFileSystem::GetChildrenBounded(const string& dirname,
1277                                          uint64 max_results,
1278                                          std::vector<string>* result,
1279                                          bool recursive,
1280                                          bool include_self_directory_marker) {
1281   if (!result) {
1282     return errors::InvalidArgument("'result' cannot be null");
1283   }
1284   string bucket, object_prefix;
1285   TF_RETURN_IF_ERROR(
1286       ParseGcsPath(MaybeAppendSlash(dirname), true, &bucket, &object_prefix));
1287 
1288   string nextPageToken;
1289   uint64 retrieved_results = 0;
1290   while (true) {  // A loop over multiple result pages.
1291     std::vector<char> output_buffer;
1292     std::unique_ptr<HttpRequest> request;
1293     TF_RETURN_IF_ERROR(CreateHttpRequest(&request));
1294     auto uri = strings::StrCat(kGcsUriBase, "b/", bucket, "/o");
1295     if (recursive) {
1296       uri = strings::StrCat(uri, "?fields=items%2Fname%2CnextPageToken");
1297     } else {
1298       // Set "/" as a delimiter to ask GCS to treat subfolders as children
1299       // and return them in "prefixes".
1300       uri = strings::StrCat(uri,
1301                             "?fields=items%2Fname%2Cprefixes%2CnextPageToken");
1302       uri = strings::StrCat(uri, "&delimiter=%2F");
1303     }
1304     if (!object_prefix.empty()) {
1305       uri = strings::StrCat(uri,
1306                             "&prefix=", request->EscapeString(object_prefix));
1307     }
1308     if (!nextPageToken.empty()) {
1309       uri = strings::StrCat(
1310           uri, "&pageToken=", request->EscapeString(nextPageToken));
1311     }
1312     if (max_results - retrieved_results < kGetChildrenDefaultPageSize) {
1313       uri =
1314           strings::StrCat(uri, "&maxResults=", max_results - retrieved_results);
1315     }
1316     request->SetUri(uri);
1317     request->SetResultBuffer(&output_buffer);
1318     request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata);
1319 
1320     TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading ", dirname);
1321     Json::Value root;
1322     TF_RETURN_IF_ERROR(ParseJson(output_buffer, &root));
1323     const auto items = root.get("items", Json::Value::null);
1324     if (!items.isNull()) {
1325       if (!items.isArray()) {
1326         return errors::Internal(
1327             "Expected an array 'items' in the GCS response.");
1328       }
1329       for (size_t i = 0; i < items.size(); i++) {
1330         const auto item = items.get(i, Json::Value::null);
1331         if (!item.isObject()) {
1332           return errors::Internal(
1333               "Unexpected JSON format: 'items' should be a list of objects.");
1334         }
1335         string name;
1336         TF_RETURN_IF_ERROR(GetStringValue(item, "name", &name));
1337         // The names should be relative to the 'dirname'. That means the
1338         // 'object_prefix', which is part of 'dirname', should be removed from
1339         // the beginning of 'name'.
1340         StringPiece relative_path(name);
1341         if (!str_util::ConsumePrefix(&relative_path, object_prefix)) {
1342           return errors::Internal(strings::StrCat(
1343               "Unexpected response: the returned file name ", name,
1344               " doesn't match the prefix ", object_prefix));
1345         }
1346         if (!relative_path.empty() || include_self_directory_marker) {
1347           result->emplace_back(relative_path);
1348         }
1349         if (++retrieved_results >= max_results) {
1350           return Status::OK();
1351         }
1352       }
1353     }
1354     const auto prefixes = root.get("prefixes", Json::Value::null);
1355     if (!prefixes.isNull()) {
1356       // Subfolders are returned for the non-recursive mode.
1357       if (!prefixes.isArray()) {
1358         return errors::Internal(
1359             "'prefixes' was expected to be an array in the GCS response.");
1360       }
1361       for (size_t i = 0; i < prefixes.size(); i++) {
1362         const auto prefix = prefixes.get(i, Json::Value::null);
1363         if (prefix.isNull() || !prefix.isString()) {
1364           return errors::Internal(
1365               "'prefixes' was expected to be an array of strings in the GCS "
1366               "response.");
1367         }
1368         const string& prefix_str = prefix.asString();
1369         StringPiece relative_path(prefix_str);
1370         if (!str_util::ConsumePrefix(&relative_path, object_prefix)) {
1371           return errors::Internal(
1372               "Unexpected response: the returned folder name ", prefix_str,
1373               " doesn't match the prefix ", object_prefix);
1374         }
1375         result->emplace_back(relative_path);
1376         if (++retrieved_results >= max_results) {
1377           return Status::OK();
1378         }
1379       }
1380     }
1381     const auto token = root.get("nextPageToken", Json::Value::null);
1382     if (token.isNull()) {
1383       return Status::OK();
1384     }
1385     if (!token.isString()) {
1386       return errors::Internal(
1387           "Unexpected response: nextPageToken is not a string");
1388     }
1389     nextPageToken = token.asString();
1390   }
1391 }
1392 
Stat(const string & fname,FileStatistics * stat)1393 Status GcsFileSystem::Stat(const string& fname, FileStatistics* stat) {
1394   if (!stat) {
1395     return errors::Internal("'stat' cannot be nullptr.");
1396   }
1397   string bucket, object;
1398   TF_RETURN_IF_ERROR(ParseGcsPath(fname, true, &bucket, &object));
1399   if (object.empty()) {
1400     bool is_bucket;
1401     TF_RETURN_IF_ERROR(BucketExists(bucket, &is_bucket));
1402     if (is_bucket) {
1403       *stat = DIRECTORY_STAT;
1404       return Status::OK();
1405     }
1406     return errors::NotFound("The specified bucket ", fname, " was not found.");
1407   }
1408 
1409   GcsFileStat gcs_stat;
1410   const Status status = StatForObject(fname, bucket, object, &gcs_stat);
1411   if (status.ok()) {
1412     *stat = gcs_stat.base;
1413     return Status::OK();
1414   }
1415   if (status.code() != errors::Code::NOT_FOUND) {
1416     return status;
1417   }
1418   bool is_folder;
1419   TF_RETURN_IF_ERROR(FolderExists(fname, &is_folder));
1420   if (is_folder) {
1421     *stat = DIRECTORY_STAT;
1422     return Status::OK();
1423   }
1424   return errors::NotFound("The specified path ", fname, " was not found.");
1425 }
1426 
DeleteFile(const string & fname)1427 Status GcsFileSystem::DeleteFile(const string& fname) {
1428   string bucket, object;
1429   TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
1430 
1431   std::unique_ptr<HttpRequest> request;
1432   TF_RETURN_IF_ERROR(CreateHttpRequest(&request));
1433   request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket, "/o/",
1434                                   request->EscapeString(object)));
1435   request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata);
1436   request->SetDeleteRequest();
1437 
1438   TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when deleting ", fname);
1439   ClearFileCaches(fname);
1440   return Status::OK();
1441 }
1442 
CreateDir(const string & dirname)1443 Status GcsFileSystem::CreateDir(const string& dirname) {
1444   string bucket, object;
1445   TF_RETURN_IF_ERROR(ParseGcsPath(dirname, true, &bucket, &object));
1446   if (object.empty()) {
1447     bool is_bucket;
1448     TF_RETURN_IF_ERROR(BucketExists(bucket, &is_bucket));
1449     return is_bucket ? Status::OK()
1450                      : errors::NotFound("The specified bucket ", dirname,
1451                                         " was not found.");
1452   }
1453 
1454   const string dirname_with_slash = MaybeAppendSlash(dirname);
1455 
1456   if (FileExists(dirname_with_slash).ok()) {
1457     return errors::AlreadyExists(dirname);
1458   }
1459 
1460   // Create a zero-length directory marker object.
1461   std::unique_ptr<WritableFile> file;
1462   TF_RETURN_IF_ERROR(NewWritableFile(dirname_with_slash, &file));
1463   TF_RETURN_IF_ERROR(file->Close());
1464   return Status::OK();
1465 }
1466 
1467 // Checks that the directory is empty (i.e no objects with this prefix exist).
1468 // Deletes the GCS directory marker if it exists.
DeleteDir(const string & dirname)1469 Status GcsFileSystem::DeleteDir(const string& dirname) {
1470   std::vector<string> children;
1471   // A directory is considered empty either if there are no matching objects
1472   // with the corresponding name prefix or if there is exactly one matching
1473   // object and it is the directory marker. Therefore we need to retrieve
1474   // at most two children for the prefix to detect if a directory is empty.
1475   TF_RETURN_IF_ERROR(
1476       GetChildrenBounded(dirname, 2, &children, true /* recursively */,
1477                          true /* include_self_directory_marker */));
1478 
1479   if (children.size() > 1 || (children.size() == 1 && !children[0].empty())) {
1480     return errors::FailedPrecondition("Cannot delete a non-empty directory.");
1481   }
1482   if (children.size() == 1 && children[0].empty()) {
1483     // This is the directory marker object. Delete it.
1484     return DeleteFile(MaybeAppendSlash(dirname));
1485   }
1486   return Status::OK();
1487 }
1488 
GetFileSize(const string & fname,uint64 * file_size)1489 Status GcsFileSystem::GetFileSize(const string& fname, uint64* file_size) {
1490   if (!file_size) {
1491     return errors::Internal("'file_size' cannot be nullptr.");
1492   }
1493 
1494   // Only validate the name.
1495   string bucket, object;
1496   TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
1497 
1498   FileStatistics stat;
1499   TF_RETURN_IF_ERROR(Stat(fname, &stat));
1500   *file_size = stat.length;
1501   return Status::OK();
1502 }
1503 
RenameFile(const string & src,const string & target)1504 Status GcsFileSystem::RenameFile(const string& src, const string& target) {
1505   if (!IsDirectory(src).ok()) {
1506     return RenameObject(src, target);
1507   }
1508   // Rename all individual objects in the directory one by one.
1509   std::vector<string> children;
1510   TF_RETURN_IF_ERROR(
1511       GetChildrenBounded(src, UINT64_MAX, &children, true /* recursively */,
1512                          true /* include_self_directory_marker */));
1513   for (const string& subpath : children) {
1514     TF_RETURN_IF_ERROR(
1515         RenameObject(JoinGcsPath(src, subpath), JoinGcsPath(target, subpath)));
1516   }
1517   return Status::OK();
1518 }
1519 
1520 // Uses a GCS API command to copy the object and then deletes the old one.
RenameObject(const string & src,const string & target)1521 Status GcsFileSystem::RenameObject(const string& src, const string& target) {
1522   string src_bucket, src_object, target_bucket, target_object;
1523   TF_RETURN_IF_ERROR(ParseGcsPath(src, false, &src_bucket, &src_object));
1524   TF_RETURN_IF_ERROR(
1525       ParseGcsPath(target, false, &target_bucket, &target_object));
1526 
1527   std::unique_ptr<HttpRequest> request;
1528   TF_RETURN_IF_ERROR(CreateHttpRequest(&request));
1529   request->SetUri(strings::StrCat(kGcsUriBase, "b/", src_bucket, "/o/",
1530                                   request->EscapeString(src_object),
1531                                   "/rewriteTo/b/", target_bucket, "/o/",
1532                                   request->EscapeString(target_object)));
1533   request->SetPostEmptyBody();
1534   request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata);
1535   std::vector<char> output_buffer;
1536   request->SetResultBuffer(&output_buffer);
1537   TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when renaming ", src,
1538                                   " to ", target);
1539   // Flush the target from the caches.  The source will be flushed in the
1540   // DeleteFile call below.
1541   ClearFileCaches(target);
1542   Json::Value root;
1543   TF_RETURN_IF_ERROR(ParseJson(output_buffer, &root));
1544   bool done;
1545   TF_RETURN_IF_ERROR(GetBoolValue(root, "done", &done));
1546   if (!done) {
1547     // If GCS didn't complete rewrite in one call, this means that a large file
1548     // is being copied to a bucket with a different storage class or location,
1549     // which requires multiple rewrite calls.
1550     // TODO(surkov): implement multi-step rewrites.
1551     return errors::Unimplemented(
1552         "Couldn't rename ", src, " to ", target,
1553         ": moving large files between buckets with different "
1554         "locations or storage classes is not supported.");
1555   }
1556 
1557   // In case the delete API call failed, but the deletion actually happened
1558   // on the server side, we can't just retry the whole RenameFile operation
1559   // because the source object is already gone.
1560   return RetryingUtils::DeleteWithRetries(
1561       [this, &src]() { return DeleteFile(src); }, retry_config_);
1562 }
1563 
IsDirectory(const string & fname)1564 Status GcsFileSystem::IsDirectory(const string& fname) {
1565   string bucket, object;
1566   TF_RETURN_IF_ERROR(ParseGcsPath(fname, true, &bucket, &object));
1567   if (object.empty()) {
1568     bool is_bucket;
1569     TF_RETURN_IF_ERROR(BucketExists(bucket, &is_bucket));
1570     if (is_bucket) {
1571       return Status::OK();
1572     }
1573     return errors::NotFound("The specified bucket gs://", bucket,
1574                             " was not found.");
1575   }
1576   bool is_folder;
1577   TF_RETURN_IF_ERROR(FolderExists(fname, &is_folder));
1578   if (is_folder) {
1579     return Status::OK();
1580   }
1581   bool is_object;
1582   TF_RETURN_IF_ERROR(ObjectExists(fname, bucket, object, &is_object));
1583   if (is_object) {
1584     return errors::FailedPrecondition("The specified path ", fname,
1585                                       " is not a directory.");
1586   }
1587   return errors::NotFound("The specified path ", fname, " was not found.");
1588 }
1589 
DeleteRecursively(const string & dirname,int64 * undeleted_files,int64 * undeleted_dirs)1590 Status GcsFileSystem::DeleteRecursively(const string& dirname,
1591                                         int64* undeleted_files,
1592                                         int64* undeleted_dirs) {
1593   if (!undeleted_files || !undeleted_dirs) {
1594     return errors::Internal(
1595         "'undeleted_files' and 'undeleted_dirs' cannot be nullptr.");
1596   }
1597   *undeleted_files = 0;
1598   *undeleted_dirs = 0;
1599   if (!IsDirectory(dirname).ok()) {
1600     *undeleted_dirs = 1;
1601     return Status(
1602         error::NOT_FOUND,
1603         strings::StrCat(dirname, " doesn't exist or not a directory."));
1604   }
1605   std::vector<string> all_objects;
1606   // Get all children in the directory recursively.
1607   TF_RETURN_IF_ERROR(GetChildrenBounded(
1608       dirname, UINT64_MAX, &all_objects, true /* recursively */,
1609       true /* include_self_directory_marker */));
1610   for (const string& object : all_objects) {
1611     const string& full_path = JoinGcsPath(dirname, object);
1612     // Delete all objects including directory markers for subfolders.
1613     // Since DeleteRecursively returns OK if individual file deletions fail,
1614     // and therefore RetryingFileSystem won't pay attention to the failures,
1615     // we need to make sure these failures are properly retried.
1616     const auto& delete_file_status = RetryingUtils::DeleteWithRetries(
1617         [this, &full_path]() { return DeleteFile(full_path); }, retry_config_);
1618     if (!delete_file_status.ok()) {
1619       if (IsDirectory(full_path).ok()) {
1620         // The object is a directory marker.
1621         (*undeleted_dirs)++;
1622       } else {
1623         (*undeleted_files)++;
1624       }
1625     }
1626   }
1627   return Status::OK();
1628 }
1629 
1630 // Flushes all caches for filesystem metadata and file contents. Useful for
1631 // reclaiming memory once filesystem operations are done (e.g. model is loaded),
1632 // or for resetting the filesystem to a consistent state.
FlushCaches()1633 void GcsFileSystem::FlushCaches() {
1634   tf_shared_lock l(block_cache_lock_);
1635   file_block_cache_->Flush();
1636   stat_cache_->Clear();
1637   matching_paths_cache_->Clear();
1638   bucket_location_cache_->Clear();
1639 }
1640 
SetStats(GcsStatsInterface * stats)1641 void GcsFileSystem::SetStats(GcsStatsInterface* stats) {
1642   CHECK(stats_ == nullptr) << "SetStats() has already been called.";
1643   CHECK(stats != nullptr);
1644   mutex_lock l(block_cache_lock_);
1645   stats_ = stats;
1646   stats_->Configure(this, &throttle_, file_block_cache_.get());
1647 }
1648 
SetAuthProvider(std::unique_ptr<AuthProvider> auth_provider)1649 void GcsFileSystem::SetAuthProvider(
1650     std::unique_ptr<AuthProvider> auth_provider) {
1651   mutex_lock l(mu_);
1652   auth_provider_ = std::move(auth_provider);
1653 }
1654 
1655 // Creates an HttpRequest and sets several parameters that are common to all
1656 // requests.  All code (in GcsFileSystem) that creates an HttpRequest should
1657 // go through this method, rather than directly using http_request_factory_.
CreateHttpRequest(std::unique_ptr<HttpRequest> * request)1658 Status GcsFileSystem::CreateHttpRequest(std::unique_ptr<HttpRequest>* request) {
1659   std::unique_ptr<HttpRequest> new_request{http_request_factory_->Create()};
1660   if (dns_cache_) {
1661     dns_cache_->AnnotateRequest(new_request.get());
1662   }
1663 
1664   string auth_token;
1665   {
1666     tf_shared_lock l(mu_);
1667     TF_RETURN_IF_ERROR(
1668         AuthProvider::GetToken(auth_provider_.get(), &auth_token));
1669   }
1670 
1671   new_request->AddAuthBearerHeader(auth_token);
1672 
1673   if (additional_header_) {
1674     new_request->AddHeader(additional_header_->first,
1675                            additional_header_->second);
1676   }
1677 
1678   if (stats_ != nullptr) {
1679     new_request->SetRequestStats(stats_->HttpStats());
1680   }
1681 
1682   if (!throttle_.AdmitRequest()) {
1683     return errors::Unavailable("Request throttled");
1684   }
1685 
1686   *request = std::move(new_request);
1687   return Status::OK();
1688 }
1689 
1690 }  // namespace tensorflow
1691 
1692 // Initialize gcs_file_system
1693 REGISTER_FILE_SYSTEM("gs", ::tensorflow::RetryingGcsFileSystem);
1694