• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/platform/s3/s3_file_system.h"
16 #include "tensorflow/core/lib/io/path.h"
17 #include "tensorflow/core/lib/strings/str_util.h"
18 #include "tensorflow/core/platform/file_system_helper.h"
19 #include "tensorflow/core/platform/mutex.h"
20 #include "tensorflow/core/platform/s3/aws_crypto.h"
21 #include "tensorflow/core/platform/s3/aws_logging.h"
22 
23 #include <aws/core/Aws.h>
24 #include <aws/core/config/AWSProfileConfigLoader.h>
25 #include <aws/core/utils/FileSystemUtils.h>
26 #include <aws/core/utils/StringUtils.h>
27 #include <aws/core/utils/logging/AWSLogging.h>
28 #include <aws/core/utils/logging/LogSystemInterface.h>
29 #include <aws/s3/S3Client.h>
30 #include <aws/s3/S3Errors.h>
31 #include <aws/s3/model/CopyObjectRequest.h>
32 #include <aws/s3/model/DeleteObjectRequest.h>
33 #include <aws/s3/model/GetObjectRequest.h>
34 #include <aws/s3/model/HeadBucketRequest.h>
35 #include <aws/s3/model/HeadObjectRequest.h>
36 #include <aws/s3/model/ListObjectsRequest.h>
37 #include <aws/s3/model/PutObjectRequest.h>
38 
39 #include <cstdlib>
40 
41 namespace tensorflow {
42 
43 namespace {
44 static const char* kS3FileSystemAllocationTag = "S3FileSystemAllocation";
45 static const size_t kS3ReadAppendableFileBufferSize = 1024 * 1024;
46 static const int kS3GetChildrenMaxKeys = 100;
47 
GetDefaultClientConfig()48 Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
49   static mutex cfg_lock(LINKER_INITIALIZED);
50   static bool init(false);
51   static Aws::Client::ClientConfiguration cfg;
52 
53   std::lock_guard<mutex> lock(cfg_lock);
54 
55   if (!init) {
56     const char* endpoint = getenv("S3_ENDPOINT");
57     if (endpoint) {
58       cfg.endpointOverride = Aws::String(endpoint);
59     }
60     const char* region = getenv("AWS_REGION");
61     if (!region) {
62       // TODO (yongtang): `S3_REGION` should be deprecated after 2.0.
63       region = getenv("S3_REGION");
64     }
65     if (region) {
66       cfg.region = Aws::String(region);
67     } else {
68       // Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG
69       // is set with a truthy value.
70       const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG");
71       string load_config =
72           load_config_env ? str_util::Lowercase(load_config_env) : "";
73       if (load_config == "true" || load_config == "1") {
74         Aws::String config_file;
75         // If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config.
76         const char* config_file_env = getenv("AWS_CONFIG_FILE");
77         if (config_file_env) {
78           config_file = config_file_env;
79         } else {
80           const char* home_env = getenv("HOME");
81           if (home_env) {
82             config_file = home_env;
83             config_file += "/.aws/config";
84           }
85         }
86         Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file);
87         loader.Load();
88         auto profiles = loader.GetProfiles();
89         if (!profiles["default"].GetRegion().empty()) {
90           cfg.region = profiles["default"].GetRegion();
91         }
92       }
93     }
94     const char* use_https = getenv("S3_USE_HTTPS");
95     if (use_https) {
96       if (use_https[0] == '0') {
97         cfg.scheme = Aws::Http::Scheme::HTTP;
98       } else {
99         cfg.scheme = Aws::Http::Scheme::HTTPS;
100       }
101     }
102     const char* verify_ssl = getenv("S3_VERIFY_SSL");
103     if (verify_ssl) {
104       if (verify_ssl[0] == '0') {
105         cfg.verifySSL = false;
106       } else {
107         cfg.verifySSL = true;
108       }
109     }
110     const char* connect_timeout = getenv("S3_CONNECT_TIMEOUT_MSEC");
111     if (connect_timeout) {
112       int64 timeout;
113 
114       if (strings::safe_strto64(connect_timeout, &timeout)) {
115         cfg.connectTimeoutMs = timeout;
116       }
117     }
118     const char* request_timeout = getenv("S3_REQUEST_TIMEOUT_MSEC");
119     if (request_timeout) {
120       int64 timeout;
121 
122       if (strings::safe_strto64(request_timeout, &timeout)) {
123         cfg.requestTimeoutMs = timeout;
124       }
125     }
126 
127     init = true;
128   }
129 
130   return cfg;
131 };
132 
ShutdownClient(Aws::S3::S3Client * s3_client)133 void ShutdownClient(Aws::S3::S3Client* s3_client) {
134   if (s3_client != nullptr) {
135     delete s3_client;
136     Aws::SDKOptions options;
137     Aws::ShutdownAPI(options);
138     AWSLogSystem::ShutdownAWSLogging();
139   }
140 }
141 
ParseS3Path(const string & fname,bool empty_object_ok,string * bucket,string * object)142 Status ParseS3Path(const string& fname, bool empty_object_ok, string* bucket,
143                    string* object) {
144   if (!bucket || !object) {
145     return errors::Internal("bucket and object cannot be null.");
146   }
147   StringPiece scheme, bucketp, objectp;
148   io::ParseURI(fname, &scheme, &bucketp, &objectp);
149   if (scheme != "s3") {
150     return errors::InvalidArgument("S3 path doesn't start with 's3://': ",
151                                    fname);
152   }
153   *bucket = string(bucketp);
154   if (bucket->empty() || *bucket == ".") {
155     return errors::InvalidArgument("S3 path doesn't contain a bucket name: ",
156                                    fname);
157   }
158   str_util::ConsumePrefix(&objectp, "/");
159   *object = string(objectp);
160   if (!empty_object_ok && object->empty()) {
161     return errors::InvalidArgument("S3 path doesn't contain an object name: ",
162                                    fname);
163   }
164   return Status::OK();
165 }
166 
167 class S3RandomAccessFile : public RandomAccessFile {
168  public:
S3RandomAccessFile(const string & bucket,const string & object,std::shared_ptr<Aws::S3::S3Client> s3_client)169   S3RandomAccessFile(const string& bucket, const string& object,
170                      std::shared_ptr<Aws::S3::S3Client> s3_client)
171       : bucket_(bucket), object_(object), s3_client_(s3_client) {}
172 
Name(StringPiece * result) const173   Status Name(StringPiece* result) const override {
174     return errors::Unimplemented("S3RandomAccessFile does not support Name()");
175   }
176 
Read(uint64 offset,size_t n,StringPiece * result,char * scratch) const177   Status Read(uint64 offset, size_t n, StringPiece* result,
178               char* scratch) const override {
179     Aws::S3::Model::GetObjectRequest getObjectRequest;
180     getObjectRequest.WithBucket(bucket_.c_str()).WithKey(object_.c_str());
181     string bytes = strings::StrCat("bytes=", offset, "-", offset + n - 1);
182     getObjectRequest.SetRange(bytes.c_str());
183     getObjectRequest.SetResponseStreamFactory([]() {
184       return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag);
185     });
186     auto getObjectOutcome = this->s3_client_->GetObject(getObjectRequest);
187     if (!getObjectOutcome.IsSuccess()) {
188       n = 0;
189       *result = StringPiece(scratch, n);
190       return Status(error::OUT_OF_RANGE, "Read less bytes than requested");
191     }
192     n = getObjectOutcome.GetResult().GetContentLength();
193     getObjectOutcome.GetResult().GetBody().read(scratch, n);
194 
195     *result = StringPiece(scratch, n);
196     return Status::OK();
197   }
198 
199  private:
200   string bucket_;
201   string object_;
202   std::shared_ptr<Aws::S3::S3Client> s3_client_;
203 };
204 
205 class S3WritableFile : public WritableFile {
206  public:
S3WritableFile(const string & bucket,const string & object,std::shared_ptr<Aws::S3::S3Client> s3_client)207   S3WritableFile(const string& bucket, const string& object,
208                  std::shared_ptr<Aws::S3::S3Client> s3_client)
209       : bucket_(bucket),
210         object_(object),
211         s3_client_(s3_client),
212         sync_needed_(true),
213         outfile_(Aws::MakeShared<Aws::Utils::TempFile>(
214             kS3FileSystemAllocationTag, "/tmp/s3_filesystem_XXXXXX",
215             std::ios_base::binary | std::ios_base::trunc | std::ios_base::in |
216                 std::ios_base::out)) {}
217 
Append(StringPiece data)218   Status Append(StringPiece data) override {
219     if (!outfile_) {
220       return errors::FailedPrecondition(
221           "The internal temporary file is not writable.");
222     }
223     sync_needed_ = true;
224     outfile_->write(data.data(), data.size());
225     if (!outfile_->good()) {
226       return errors::Internal(
227           "Could not append to the internal temporary file.");
228     }
229     return Status::OK();
230   }
231 
Close()232   Status Close() override {
233     if (outfile_) {
234       TF_RETURN_IF_ERROR(Sync());
235       outfile_.reset();
236     }
237     return Status::OK();
238   }
239 
Flush()240   Status Flush() override { return Sync(); }
241 
Name(StringPiece * result) const242   Status Name(StringPiece* result) const override {
243     return errors::Unimplemented("S3WritableFile does not support Name()");
244   }
245 
Sync()246   Status Sync() override {
247     if (!outfile_) {
248       return errors::FailedPrecondition(
249           "The internal temporary file is not writable.");
250     }
251     if (!sync_needed_) {
252       return Status::OK();
253     }
254     Aws::S3::Model::PutObjectRequest putObjectRequest;
255     putObjectRequest.WithBucket(bucket_.c_str()).WithKey(object_.c_str());
256     long offset = outfile_->tellp();
257     outfile_->seekg(0);
258     putObjectRequest.SetBody(outfile_);
259     putObjectRequest.SetContentLength(offset);
260     auto putObjectOutcome = this->s3_client_->PutObject(putObjectRequest);
261     outfile_->clear();
262     outfile_->seekp(offset);
263     if (!putObjectOutcome.IsSuccess()) {
264       return errors::Unknown(putObjectOutcome.GetError().GetExceptionName(),
265                              ": ", putObjectOutcome.GetError().GetMessage());
266     }
267     return Status::OK();
268   }
269 
270  private:
271   string bucket_;
272   string object_;
273   std::shared_ptr<Aws::S3::S3Client> s3_client_;
274   bool sync_needed_;
275   std::shared_ptr<Aws::Utils::TempFile> outfile_;
276 };
277 
278 class S3ReadOnlyMemoryRegion : public ReadOnlyMemoryRegion {
279  public:
S3ReadOnlyMemoryRegion(std::unique_ptr<char[]> data,uint64 length)280   S3ReadOnlyMemoryRegion(std::unique_ptr<char[]> data, uint64 length)
281       : data_(std::move(data)), length_(length) {}
data()282   const void* data() override { return reinterpret_cast<void*>(data_.get()); }
length()283   uint64 length() override { return length_; }
284 
285  private:
286   std::unique_ptr<char[]> data_;
287   uint64 length_;
288 };
289 
290 }  // namespace
291 
S3FileSystem()292 S3FileSystem::S3FileSystem()
293     : s3_client_(nullptr, ShutdownClient), client_lock_() {}
294 
~S3FileSystem()295 S3FileSystem::~S3FileSystem() {}
296 
297 // Initializes s3_client_, if needed, and returns it.
GetS3Client()298 std::shared_ptr<Aws::S3::S3Client> S3FileSystem::GetS3Client() {
299   std::lock_guard<mutex> lock(this->client_lock_);
300 
301   if (this->s3_client_.get() == nullptr) {
302     AWSLogSystem::InitializeAWSLogging();
303 
304     Aws::SDKOptions options;
305     options.cryptoOptions.sha256Factory_create_fn = []() {
306       return Aws::MakeShared<AWSSHA256Factory>(AWSCryptoAllocationTag);
307     };
308     options.cryptoOptions.sha256HMACFactory_create_fn = []() {
309       return Aws::MakeShared<AWSSHA256HmacFactory>(AWSCryptoAllocationTag);
310     };
311     Aws::InitAPI(options);
312 
313     // The creation of S3Client disables virtual addressing:
314     //   S3Client(clientConfiguration, signPayloads, useVirtualAdressing = true)
315     // The purpose is to address the issue encountered when there is an `.`
316     // in the bucket name. Due to TLS hostname validation or DNS rules,
317     // the bucket may not be resolved. Disabling of virtual addressing
318     // should address the issue. See GitHub issue 16397 for details.
319     this->s3_client_ = std::shared_ptr<Aws::S3::S3Client>(new Aws::S3::S3Client(
320         GetDefaultClientConfig(),
321         Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, false));
322   }
323 
324   return this->s3_client_;
325 }
326 
NewRandomAccessFile(const string & fname,std::unique_ptr<RandomAccessFile> * result)327 Status S3FileSystem::NewRandomAccessFile(
328     const string& fname, std::unique_ptr<RandomAccessFile>* result) {
329   string bucket, object;
330   TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
331   result->reset(new S3RandomAccessFile(bucket, object, this->GetS3Client()));
332   return Status::OK();
333 }
334 
NewWritableFile(const string & fname,std::unique_ptr<WritableFile> * result)335 Status S3FileSystem::NewWritableFile(const string& fname,
336                                      std::unique_ptr<WritableFile>* result) {
337   string bucket, object;
338   TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
339   result->reset(new S3WritableFile(bucket, object, this->GetS3Client()));
340   return Status::OK();
341 }
342 
NewAppendableFile(const string & fname,std::unique_ptr<WritableFile> * result)343 Status S3FileSystem::NewAppendableFile(const string& fname,
344                                        std::unique_ptr<WritableFile>* result) {
345   std::unique_ptr<RandomAccessFile> reader;
346   TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, &reader));
347   std::unique_ptr<char[]> buffer(new char[kS3ReadAppendableFileBufferSize]);
348   Status status;
349   uint64 offset = 0;
350   StringPiece read_chunk;
351 
352   string bucket, object;
353   TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
354   result->reset(new S3WritableFile(bucket, object, this->GetS3Client()));
355 
356   while (true) {
357     status = reader->Read(offset, kS3ReadAppendableFileBufferSize, &read_chunk,
358                           buffer.get());
359     if (status.ok()) {
360       (*result)->Append(read_chunk);
361       offset += kS3ReadAppendableFileBufferSize;
362     } else if (status.code() == error::OUT_OF_RANGE) {
363       (*result)->Append(read_chunk);
364       break;
365     } else {
366       (*result).reset();
367       return status;
368     }
369   }
370 
371   return Status::OK();
372 }
373 
NewReadOnlyMemoryRegionFromFile(const string & fname,std::unique_ptr<ReadOnlyMemoryRegion> * result)374 Status S3FileSystem::NewReadOnlyMemoryRegionFromFile(
375     const string& fname, std::unique_ptr<ReadOnlyMemoryRegion>* result) {
376   uint64 size;
377   TF_RETURN_IF_ERROR(GetFileSize(fname, &size));
378   std::unique_ptr<char[]> data(new char[size]);
379 
380   std::unique_ptr<RandomAccessFile> file;
381   TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, &file));
382 
383   StringPiece piece;
384   TF_RETURN_IF_ERROR(file->Read(0, size, &piece, data.get()));
385 
386   result->reset(new S3ReadOnlyMemoryRegion(std::move(data), size));
387   return Status::OK();
388 }
389 
FileExists(const string & fname)390 Status S3FileSystem::FileExists(const string& fname) {
391   FileStatistics stats;
392   TF_RETURN_IF_ERROR(this->Stat(fname, &stats));
393   return Status::OK();
394 }
395 
GetChildren(const string & dir,std::vector<string> * result)396 Status S3FileSystem::GetChildren(const string& dir,
397                                  std::vector<string>* result) {
398   string bucket, prefix;
399   TF_RETURN_IF_ERROR(ParseS3Path(dir, false, &bucket, &prefix));
400 
401   if (prefix.back() != '/') {
402     prefix.push_back('/');
403   }
404 
405   Aws::S3::Model::ListObjectsRequest listObjectsRequest;
406   listObjectsRequest.WithBucket(bucket.c_str())
407       .WithPrefix(prefix.c_str())
408       .WithMaxKeys(kS3GetChildrenMaxKeys)
409       .WithDelimiter("/");
410   listObjectsRequest.SetResponseStreamFactory(
411       []() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
412 
413   Aws::S3::Model::ListObjectsResult listObjectsResult;
414   do {
415     auto listObjectsOutcome =
416         this->GetS3Client()->ListObjects(listObjectsRequest);
417     if (!listObjectsOutcome.IsSuccess()) {
418       return errors::Unknown(listObjectsOutcome.GetError().GetExceptionName(),
419                              ": ", listObjectsOutcome.GetError().GetMessage());
420     }
421 
422     listObjectsResult = listObjectsOutcome.GetResult();
423     for (const auto& object : listObjectsResult.GetCommonPrefixes()) {
424       Aws::String s = object.GetPrefix();
425       s.erase(s.length() - 1);
426       Aws::String entry = s.substr(strlen(prefix.c_str()));
427       if (entry.length() > 0) {
428         result->push_back(entry.c_str());
429       }
430     }
431     for (const auto& object : listObjectsResult.GetContents()) {
432       Aws::String s = object.GetKey();
433       Aws::String entry = s.substr(strlen(prefix.c_str()));
434       if (entry.length() > 0) {
435         result->push_back(entry.c_str());
436       }
437     }
438     listObjectsRequest.SetMarker(listObjectsResult.GetNextMarker());
439   } while (listObjectsResult.GetIsTruncated());
440 
441   return Status::OK();
442 }
443 
Stat(const string & fname,FileStatistics * stats)444 Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) {
445   string bucket, object;
446   TF_RETURN_IF_ERROR(ParseS3Path(fname, true, &bucket, &object));
447 
448   if (object.empty()) {
449     Aws::S3::Model::HeadBucketRequest headBucketRequest;
450     headBucketRequest.WithBucket(bucket.c_str());
451     auto headBucketOutcome = this->GetS3Client()->HeadBucket(headBucketRequest);
452     if (!headBucketOutcome.IsSuccess()) {
453       return errors::Unknown(headBucketOutcome.GetError().GetExceptionName(),
454                              ": ", headBucketOutcome.GetError().GetMessage());
455     }
456     stats->length = 0;
457     stats->is_directory = 1;
458     return Status::OK();
459   }
460 
461   bool found = false;
462 
463   Aws::S3::Model::HeadObjectRequest headObjectRequest;
464   headObjectRequest.WithBucket(bucket.c_str()).WithKey(object.c_str());
465   headObjectRequest.SetResponseStreamFactory(
466       []() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
467   auto headObjectOutcome = this->GetS3Client()->HeadObject(headObjectRequest);
468   if (headObjectOutcome.IsSuccess()) {
469     stats->length = headObjectOutcome.GetResult().GetContentLength();
470     stats->is_directory = 0;
471     stats->mtime_nsec =
472         headObjectOutcome.GetResult().GetLastModified().Millis() * 1e6;
473     found = true;
474   }
475   string prefix = object;
476   if (prefix.back() != '/') {
477     prefix.push_back('/');
478   }
479   Aws::S3::Model::ListObjectsRequest listObjectsRequest;
480   listObjectsRequest.WithBucket(bucket.c_str())
481       .WithPrefix(prefix.c_str())
482       .WithMaxKeys(1);
483   listObjectsRequest.SetResponseStreamFactory(
484       []() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
485   auto listObjectsOutcome =
486       this->GetS3Client()->ListObjects(listObjectsRequest);
487   if (listObjectsOutcome.IsSuccess()) {
488     if (listObjectsOutcome.GetResult().GetContents().size() > 0) {
489       stats->length = 0;
490       stats->is_directory = 1;
491       found = true;
492     }
493   }
494   if (!found) {
495     return errors::NotFound("Object ", fname, " does not exist");
496   }
497   return Status::OK();
498 }
499 
GetMatchingPaths(const string & pattern,std::vector<string> * results)500 Status S3FileSystem::GetMatchingPaths(const string& pattern,
501                                       std::vector<string>* results) {
502   return internal::GetMatchingPaths(this, Env::Default(), pattern, results);
503 }
504 
DeleteFile(const string & fname)505 Status S3FileSystem::DeleteFile(const string& fname) {
506   string bucket, object;
507   TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
508 
509   Aws::S3::Model::DeleteObjectRequest deleteObjectRequest;
510   deleteObjectRequest.WithBucket(bucket.c_str()).WithKey(object.c_str());
511 
512   auto deleteObjectOutcome =
513       this->GetS3Client()->DeleteObject(deleteObjectRequest);
514   if (!deleteObjectOutcome.IsSuccess()) {
515     return errors::Unknown(deleteObjectOutcome.GetError().GetExceptionName(),
516                            ": ", deleteObjectOutcome.GetError().GetMessage());
517   }
518   return Status::OK();
519 }
520 
CreateDir(const string & dirname)521 Status S3FileSystem::CreateDir(const string& dirname) {
522   string bucket, object;
523   TF_RETURN_IF_ERROR(ParseS3Path(dirname, true, &bucket, &object));
524 
525   if (object.empty()) {
526     Aws::S3::Model::HeadBucketRequest headBucketRequest;
527     headBucketRequest.WithBucket(bucket.c_str());
528     auto headBucketOutcome = this->GetS3Client()->HeadBucket(headBucketRequest);
529     if (!headBucketOutcome.IsSuccess()) {
530       return errors::NotFound("The bucket ", bucket, " was not found.");
531     }
532     return Status::OK();
533   }
534   string filename = dirname;
535   if (filename.back() != '/') {
536     filename.push_back('/');
537   }
538   std::unique_ptr<WritableFile> file;
539   TF_RETURN_IF_ERROR(NewWritableFile(filename, &file));
540   TF_RETURN_IF_ERROR(file->Close());
541   return Status::OK();
542 }
543 
DeleteDir(const string & dirname)544 Status S3FileSystem::DeleteDir(const string& dirname) {
545   string bucket, object;
546   TF_RETURN_IF_ERROR(ParseS3Path(dirname, false, &bucket, &object));
547 
548   string prefix = object;
549   if (prefix.back() != '/') {
550     prefix.push_back('/');
551   }
552   Aws::S3::Model::ListObjectsRequest listObjectsRequest;
553   listObjectsRequest.WithBucket(bucket.c_str())
554       .WithPrefix(prefix.c_str())
555       .WithMaxKeys(2);
556   listObjectsRequest.SetResponseStreamFactory(
557       []() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
558   auto listObjectsOutcome =
559       this->GetS3Client()->ListObjects(listObjectsRequest);
560   if (listObjectsOutcome.IsSuccess()) {
561     auto contents = listObjectsOutcome.GetResult().GetContents();
562     if (contents.size() > 1 ||
563         (contents.size() == 1 && contents[0].GetKey() != prefix.c_str())) {
564       return errors::FailedPrecondition("Cannot delete a non-empty directory.");
565     }
566     if (contents.size() == 1 && contents[0].GetKey() == prefix.c_str()) {
567       string filename = dirname;
568       if (filename.back() != '/') {
569         filename.push_back('/');
570       }
571       return DeleteFile(filename);
572     }
573   }
574   return Status::OK();
575 }
576 
GetFileSize(const string & fname,uint64 * file_size)577 Status S3FileSystem::GetFileSize(const string& fname, uint64* file_size) {
578   FileStatistics stats;
579   TF_RETURN_IF_ERROR(this->Stat(fname, &stats));
580   *file_size = stats.length;
581   return Status::OK();
582 }
583 
RenameFile(const string & src,const string & target)584 Status S3FileSystem::RenameFile(const string& src, const string& target) {
585   string src_bucket, src_object, target_bucket, target_object;
586   TF_RETURN_IF_ERROR(ParseS3Path(src, false, &src_bucket, &src_object));
587   TF_RETURN_IF_ERROR(
588       ParseS3Path(target, false, &target_bucket, &target_object));
589   if (src_object.back() == '/') {
590     if (target_object.back() != '/') {
591       target_object.push_back('/');
592     }
593   } else {
594     if (target_object.back() == '/') {
595       target_object.pop_back();
596     }
597   }
598 
599   Aws::S3::Model::CopyObjectRequest copyObjectRequest;
600   Aws::S3::Model::DeleteObjectRequest deleteObjectRequest;
601 
602   Aws::S3::Model::ListObjectsRequest listObjectsRequest;
603   listObjectsRequest.WithBucket(src_bucket.c_str())
604       .WithPrefix(src_object.c_str())
605       .WithMaxKeys(kS3GetChildrenMaxKeys);
606   listObjectsRequest.SetResponseStreamFactory(
607       []() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
608 
609   Aws::S3::Model::ListObjectsResult listObjectsResult;
610   do {
611     auto listObjectsOutcome =
612         this->GetS3Client()->ListObjects(listObjectsRequest);
613     if (!listObjectsOutcome.IsSuccess()) {
614       return errors::Unknown(listObjectsOutcome.GetError().GetExceptionName(),
615                              ": ", listObjectsOutcome.GetError().GetMessage());
616     }
617 
618     listObjectsResult = listObjectsOutcome.GetResult();
619     for (const auto& object : listObjectsResult.GetContents()) {
620       Aws::String src_key = object.GetKey();
621       Aws::String target_key = src_key;
622       target_key.replace(0, src_object.length(), target_object.c_str());
623       Aws::String source = Aws::String(src_bucket.c_str()) + "/" +
624                            Aws::Utils::StringUtils::URLEncode(src_key.c_str());
625 
626       copyObjectRequest.SetBucket(target_bucket.c_str());
627       copyObjectRequest.SetKey(target_key);
628       copyObjectRequest.SetCopySource(source);
629 
630       auto copyObjectOutcome =
631           this->GetS3Client()->CopyObject(copyObjectRequest);
632       if (!copyObjectOutcome.IsSuccess()) {
633         return errors::Unknown(copyObjectOutcome.GetError().GetExceptionName(),
634                                ": ", copyObjectOutcome.GetError().GetMessage());
635       }
636 
637       deleteObjectRequest.SetBucket(src_bucket.c_str());
638       deleteObjectRequest.SetKey(src_key.c_str());
639 
640       auto deleteObjectOutcome =
641           this->GetS3Client()->DeleteObject(deleteObjectRequest);
642       if (!deleteObjectOutcome.IsSuccess()) {
643         return errors::Unknown(
644             deleteObjectOutcome.GetError().GetExceptionName(), ": ",
645             deleteObjectOutcome.GetError().GetMessage());
646       }
647     }
648     listObjectsRequest.SetMarker(listObjectsResult.GetNextMarker());
649   } while (listObjectsResult.GetIsTruncated());
650 
651   return Status::OK();
652 }
653 
654 REGISTER_FILE_SYSTEM("s3", S3FileSystem);
655 
656 }  // namespace tensorflow
657