1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/platform/s3/s3_file_system.h"
16 #include "tensorflow/core/lib/io/path.h"
17 #include "tensorflow/core/lib/strings/str_util.h"
18 #include "tensorflow/core/platform/file_system_helper.h"
19 #include "tensorflow/core/platform/mutex.h"
20 #include "tensorflow/core/platform/s3/aws_crypto.h"
21 #include "tensorflow/core/platform/s3/aws_logging.h"
22
23 #include <aws/core/Aws.h>
24 #include <aws/core/config/AWSProfileConfigLoader.h>
25 #include <aws/core/utils/FileSystemUtils.h>
26 #include <aws/core/utils/StringUtils.h>
27 #include <aws/core/utils/logging/AWSLogging.h>
28 #include <aws/core/utils/logging/LogSystemInterface.h>
29 #include <aws/s3/S3Client.h>
30 #include <aws/s3/S3Errors.h>
31 #include <aws/s3/model/CopyObjectRequest.h>
32 #include <aws/s3/model/DeleteObjectRequest.h>
33 #include <aws/s3/model/GetObjectRequest.h>
34 #include <aws/s3/model/HeadBucketRequest.h>
35 #include <aws/s3/model/HeadObjectRequest.h>
36 #include <aws/s3/model/ListObjectsRequest.h>
37 #include <aws/s3/model/PutObjectRequest.h>
38
39 #include <cstdlib>
40
41 namespace tensorflow {
42
43 namespace {
44 static const char* kS3FileSystemAllocationTag = "S3FileSystemAllocation";
45 static const size_t kS3ReadAppendableFileBufferSize = 1024 * 1024;
46 static const int kS3GetChildrenMaxKeys = 100;
47
GetDefaultClientConfig()48 Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
49 static mutex cfg_lock(LINKER_INITIALIZED);
50 static bool init(false);
51 static Aws::Client::ClientConfiguration cfg;
52
53 std::lock_guard<mutex> lock(cfg_lock);
54
55 if (!init) {
56 const char* endpoint = getenv("S3_ENDPOINT");
57 if (endpoint) {
58 cfg.endpointOverride = Aws::String(endpoint);
59 }
60 const char* region = getenv("AWS_REGION");
61 if (!region) {
62 // TODO (yongtang): `S3_REGION` should be deprecated after 2.0.
63 region = getenv("S3_REGION");
64 }
65 if (region) {
66 cfg.region = Aws::String(region);
67 } else {
68 // Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG
69 // is set with a truthy value.
70 const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG");
71 string load_config =
72 load_config_env ? str_util::Lowercase(load_config_env) : "";
73 if (load_config == "true" || load_config == "1") {
74 Aws::String config_file;
75 // If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config.
76 const char* config_file_env = getenv("AWS_CONFIG_FILE");
77 if (config_file_env) {
78 config_file = config_file_env;
79 } else {
80 const char* home_env = getenv("HOME");
81 if (home_env) {
82 config_file = home_env;
83 config_file += "/.aws/config";
84 }
85 }
86 Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file);
87 loader.Load();
88 auto profiles = loader.GetProfiles();
89 if (!profiles["default"].GetRegion().empty()) {
90 cfg.region = profiles["default"].GetRegion();
91 }
92 }
93 }
94 const char* use_https = getenv("S3_USE_HTTPS");
95 if (use_https) {
96 if (use_https[0] == '0') {
97 cfg.scheme = Aws::Http::Scheme::HTTP;
98 } else {
99 cfg.scheme = Aws::Http::Scheme::HTTPS;
100 }
101 }
102 const char* verify_ssl = getenv("S3_VERIFY_SSL");
103 if (verify_ssl) {
104 if (verify_ssl[0] == '0') {
105 cfg.verifySSL = false;
106 } else {
107 cfg.verifySSL = true;
108 }
109 }
110 const char* connect_timeout = getenv("S3_CONNECT_TIMEOUT_MSEC");
111 if (connect_timeout) {
112 int64 timeout;
113
114 if (strings::safe_strto64(connect_timeout, &timeout)) {
115 cfg.connectTimeoutMs = timeout;
116 }
117 }
118 const char* request_timeout = getenv("S3_REQUEST_TIMEOUT_MSEC");
119 if (request_timeout) {
120 int64 timeout;
121
122 if (strings::safe_strto64(request_timeout, &timeout)) {
123 cfg.requestTimeoutMs = timeout;
124 }
125 }
126
127 init = true;
128 }
129
130 return cfg;
131 };
132
ShutdownClient(Aws::S3::S3Client * s3_client)133 void ShutdownClient(Aws::S3::S3Client* s3_client) {
134 if (s3_client != nullptr) {
135 delete s3_client;
136 Aws::SDKOptions options;
137 Aws::ShutdownAPI(options);
138 AWSLogSystem::ShutdownAWSLogging();
139 }
140 }
141
ParseS3Path(const string & fname,bool empty_object_ok,string * bucket,string * object)142 Status ParseS3Path(const string& fname, bool empty_object_ok, string* bucket,
143 string* object) {
144 if (!bucket || !object) {
145 return errors::Internal("bucket and object cannot be null.");
146 }
147 StringPiece scheme, bucketp, objectp;
148 io::ParseURI(fname, &scheme, &bucketp, &objectp);
149 if (scheme != "s3") {
150 return errors::InvalidArgument("S3 path doesn't start with 's3://': ",
151 fname);
152 }
153 *bucket = string(bucketp);
154 if (bucket->empty() || *bucket == ".") {
155 return errors::InvalidArgument("S3 path doesn't contain a bucket name: ",
156 fname);
157 }
158 str_util::ConsumePrefix(&objectp, "/");
159 *object = string(objectp);
160 if (!empty_object_ok && object->empty()) {
161 return errors::InvalidArgument("S3 path doesn't contain an object name: ",
162 fname);
163 }
164 return Status::OK();
165 }
166
167 class S3RandomAccessFile : public RandomAccessFile {
168 public:
S3RandomAccessFile(const string & bucket,const string & object,std::shared_ptr<Aws::S3::S3Client> s3_client)169 S3RandomAccessFile(const string& bucket, const string& object,
170 std::shared_ptr<Aws::S3::S3Client> s3_client)
171 : bucket_(bucket), object_(object), s3_client_(s3_client) {}
172
Name(StringPiece * result) const173 Status Name(StringPiece* result) const override {
174 return errors::Unimplemented("S3RandomAccessFile does not support Name()");
175 }
176
Read(uint64 offset,size_t n,StringPiece * result,char * scratch) const177 Status Read(uint64 offset, size_t n, StringPiece* result,
178 char* scratch) const override {
179 Aws::S3::Model::GetObjectRequest getObjectRequest;
180 getObjectRequest.WithBucket(bucket_.c_str()).WithKey(object_.c_str());
181 string bytes = strings::StrCat("bytes=", offset, "-", offset + n - 1);
182 getObjectRequest.SetRange(bytes.c_str());
183 getObjectRequest.SetResponseStreamFactory([]() {
184 return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag);
185 });
186 auto getObjectOutcome = this->s3_client_->GetObject(getObjectRequest);
187 if (!getObjectOutcome.IsSuccess()) {
188 n = 0;
189 *result = StringPiece(scratch, n);
190 return Status(error::OUT_OF_RANGE, "Read less bytes than requested");
191 }
192 n = getObjectOutcome.GetResult().GetContentLength();
193 getObjectOutcome.GetResult().GetBody().read(scratch, n);
194
195 *result = StringPiece(scratch, n);
196 return Status::OK();
197 }
198
199 private:
200 string bucket_;
201 string object_;
202 std::shared_ptr<Aws::S3::S3Client> s3_client_;
203 };
204
205 class S3WritableFile : public WritableFile {
206 public:
S3WritableFile(const string & bucket,const string & object,std::shared_ptr<Aws::S3::S3Client> s3_client)207 S3WritableFile(const string& bucket, const string& object,
208 std::shared_ptr<Aws::S3::S3Client> s3_client)
209 : bucket_(bucket),
210 object_(object),
211 s3_client_(s3_client),
212 sync_needed_(true),
213 outfile_(Aws::MakeShared<Aws::Utils::TempFile>(
214 kS3FileSystemAllocationTag, "/tmp/s3_filesystem_XXXXXX",
215 std::ios_base::binary | std::ios_base::trunc | std::ios_base::in |
216 std::ios_base::out)) {}
217
Append(StringPiece data)218 Status Append(StringPiece data) override {
219 if (!outfile_) {
220 return errors::FailedPrecondition(
221 "The internal temporary file is not writable.");
222 }
223 sync_needed_ = true;
224 outfile_->write(data.data(), data.size());
225 if (!outfile_->good()) {
226 return errors::Internal(
227 "Could not append to the internal temporary file.");
228 }
229 return Status::OK();
230 }
231
Close()232 Status Close() override {
233 if (outfile_) {
234 TF_RETURN_IF_ERROR(Sync());
235 outfile_.reset();
236 }
237 return Status::OK();
238 }
239
Flush()240 Status Flush() override { return Sync(); }
241
Name(StringPiece * result) const242 Status Name(StringPiece* result) const override {
243 return errors::Unimplemented("S3WritableFile does not support Name()");
244 }
245
Sync()246 Status Sync() override {
247 if (!outfile_) {
248 return errors::FailedPrecondition(
249 "The internal temporary file is not writable.");
250 }
251 if (!sync_needed_) {
252 return Status::OK();
253 }
254 Aws::S3::Model::PutObjectRequest putObjectRequest;
255 putObjectRequest.WithBucket(bucket_.c_str()).WithKey(object_.c_str());
256 long offset = outfile_->tellp();
257 outfile_->seekg(0);
258 putObjectRequest.SetBody(outfile_);
259 putObjectRequest.SetContentLength(offset);
260 auto putObjectOutcome = this->s3_client_->PutObject(putObjectRequest);
261 outfile_->clear();
262 outfile_->seekp(offset);
263 if (!putObjectOutcome.IsSuccess()) {
264 return errors::Unknown(putObjectOutcome.GetError().GetExceptionName(),
265 ": ", putObjectOutcome.GetError().GetMessage());
266 }
267 return Status::OK();
268 }
269
270 private:
271 string bucket_;
272 string object_;
273 std::shared_ptr<Aws::S3::S3Client> s3_client_;
274 bool sync_needed_;
275 std::shared_ptr<Aws::Utils::TempFile> outfile_;
276 };
277
278 class S3ReadOnlyMemoryRegion : public ReadOnlyMemoryRegion {
279 public:
S3ReadOnlyMemoryRegion(std::unique_ptr<char[]> data,uint64 length)280 S3ReadOnlyMemoryRegion(std::unique_ptr<char[]> data, uint64 length)
281 : data_(std::move(data)), length_(length) {}
data()282 const void* data() override { return reinterpret_cast<void*>(data_.get()); }
length()283 uint64 length() override { return length_; }
284
285 private:
286 std::unique_ptr<char[]> data_;
287 uint64 length_;
288 };
289
290 } // namespace
291
S3FileSystem()292 S3FileSystem::S3FileSystem()
293 : s3_client_(nullptr, ShutdownClient), client_lock_() {}
294
~S3FileSystem()295 S3FileSystem::~S3FileSystem() {}
296
297 // Initializes s3_client_, if needed, and returns it.
GetS3Client()298 std::shared_ptr<Aws::S3::S3Client> S3FileSystem::GetS3Client() {
299 std::lock_guard<mutex> lock(this->client_lock_);
300
301 if (this->s3_client_.get() == nullptr) {
302 AWSLogSystem::InitializeAWSLogging();
303
304 Aws::SDKOptions options;
305 options.cryptoOptions.sha256Factory_create_fn = []() {
306 return Aws::MakeShared<AWSSHA256Factory>(AWSCryptoAllocationTag);
307 };
308 options.cryptoOptions.sha256HMACFactory_create_fn = []() {
309 return Aws::MakeShared<AWSSHA256HmacFactory>(AWSCryptoAllocationTag);
310 };
311 Aws::InitAPI(options);
312
313 // The creation of S3Client disables virtual addressing:
314 // S3Client(clientConfiguration, signPayloads, useVirtualAdressing = true)
315 // The purpose is to address the issue encountered when there is an `.`
316 // in the bucket name. Due to TLS hostname validation or DNS rules,
317 // the bucket may not be resolved. Disabling of virtual addressing
318 // should address the issue. See GitHub issue 16397 for details.
319 this->s3_client_ = std::shared_ptr<Aws::S3::S3Client>(new Aws::S3::S3Client(
320 GetDefaultClientConfig(),
321 Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, false));
322 }
323
324 return this->s3_client_;
325 }
326
NewRandomAccessFile(const string & fname,std::unique_ptr<RandomAccessFile> * result)327 Status S3FileSystem::NewRandomAccessFile(
328 const string& fname, std::unique_ptr<RandomAccessFile>* result) {
329 string bucket, object;
330 TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
331 result->reset(new S3RandomAccessFile(bucket, object, this->GetS3Client()));
332 return Status::OK();
333 }
334
NewWritableFile(const string & fname,std::unique_ptr<WritableFile> * result)335 Status S3FileSystem::NewWritableFile(const string& fname,
336 std::unique_ptr<WritableFile>* result) {
337 string bucket, object;
338 TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
339 result->reset(new S3WritableFile(bucket, object, this->GetS3Client()));
340 return Status::OK();
341 }
342
NewAppendableFile(const string & fname,std::unique_ptr<WritableFile> * result)343 Status S3FileSystem::NewAppendableFile(const string& fname,
344 std::unique_ptr<WritableFile>* result) {
345 std::unique_ptr<RandomAccessFile> reader;
346 TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, &reader));
347 std::unique_ptr<char[]> buffer(new char[kS3ReadAppendableFileBufferSize]);
348 Status status;
349 uint64 offset = 0;
350 StringPiece read_chunk;
351
352 string bucket, object;
353 TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
354 result->reset(new S3WritableFile(bucket, object, this->GetS3Client()));
355
356 while (true) {
357 status = reader->Read(offset, kS3ReadAppendableFileBufferSize, &read_chunk,
358 buffer.get());
359 if (status.ok()) {
360 (*result)->Append(read_chunk);
361 offset += kS3ReadAppendableFileBufferSize;
362 } else if (status.code() == error::OUT_OF_RANGE) {
363 (*result)->Append(read_chunk);
364 break;
365 } else {
366 (*result).reset();
367 return status;
368 }
369 }
370
371 return Status::OK();
372 }
373
NewReadOnlyMemoryRegionFromFile(const string & fname,std::unique_ptr<ReadOnlyMemoryRegion> * result)374 Status S3FileSystem::NewReadOnlyMemoryRegionFromFile(
375 const string& fname, std::unique_ptr<ReadOnlyMemoryRegion>* result) {
376 uint64 size;
377 TF_RETURN_IF_ERROR(GetFileSize(fname, &size));
378 std::unique_ptr<char[]> data(new char[size]);
379
380 std::unique_ptr<RandomAccessFile> file;
381 TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, &file));
382
383 StringPiece piece;
384 TF_RETURN_IF_ERROR(file->Read(0, size, &piece, data.get()));
385
386 result->reset(new S3ReadOnlyMemoryRegion(std::move(data), size));
387 return Status::OK();
388 }
389
FileExists(const string & fname)390 Status S3FileSystem::FileExists(const string& fname) {
391 FileStatistics stats;
392 TF_RETURN_IF_ERROR(this->Stat(fname, &stats));
393 return Status::OK();
394 }
395
GetChildren(const string & dir,std::vector<string> * result)396 Status S3FileSystem::GetChildren(const string& dir,
397 std::vector<string>* result) {
398 string bucket, prefix;
399 TF_RETURN_IF_ERROR(ParseS3Path(dir, false, &bucket, &prefix));
400
401 if (prefix.back() != '/') {
402 prefix.push_back('/');
403 }
404
405 Aws::S3::Model::ListObjectsRequest listObjectsRequest;
406 listObjectsRequest.WithBucket(bucket.c_str())
407 .WithPrefix(prefix.c_str())
408 .WithMaxKeys(kS3GetChildrenMaxKeys)
409 .WithDelimiter("/");
410 listObjectsRequest.SetResponseStreamFactory(
411 []() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
412
413 Aws::S3::Model::ListObjectsResult listObjectsResult;
414 do {
415 auto listObjectsOutcome =
416 this->GetS3Client()->ListObjects(listObjectsRequest);
417 if (!listObjectsOutcome.IsSuccess()) {
418 return errors::Unknown(listObjectsOutcome.GetError().GetExceptionName(),
419 ": ", listObjectsOutcome.GetError().GetMessage());
420 }
421
422 listObjectsResult = listObjectsOutcome.GetResult();
423 for (const auto& object : listObjectsResult.GetCommonPrefixes()) {
424 Aws::String s = object.GetPrefix();
425 s.erase(s.length() - 1);
426 Aws::String entry = s.substr(strlen(prefix.c_str()));
427 if (entry.length() > 0) {
428 result->push_back(entry.c_str());
429 }
430 }
431 for (const auto& object : listObjectsResult.GetContents()) {
432 Aws::String s = object.GetKey();
433 Aws::String entry = s.substr(strlen(prefix.c_str()));
434 if (entry.length() > 0) {
435 result->push_back(entry.c_str());
436 }
437 }
438 listObjectsRequest.SetMarker(listObjectsResult.GetNextMarker());
439 } while (listObjectsResult.GetIsTruncated());
440
441 return Status::OK();
442 }
443
Stat(const string & fname,FileStatistics * stats)444 Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) {
445 string bucket, object;
446 TF_RETURN_IF_ERROR(ParseS3Path(fname, true, &bucket, &object));
447
448 if (object.empty()) {
449 Aws::S3::Model::HeadBucketRequest headBucketRequest;
450 headBucketRequest.WithBucket(bucket.c_str());
451 auto headBucketOutcome = this->GetS3Client()->HeadBucket(headBucketRequest);
452 if (!headBucketOutcome.IsSuccess()) {
453 return errors::Unknown(headBucketOutcome.GetError().GetExceptionName(),
454 ": ", headBucketOutcome.GetError().GetMessage());
455 }
456 stats->length = 0;
457 stats->is_directory = 1;
458 return Status::OK();
459 }
460
461 bool found = false;
462
463 Aws::S3::Model::HeadObjectRequest headObjectRequest;
464 headObjectRequest.WithBucket(bucket.c_str()).WithKey(object.c_str());
465 headObjectRequest.SetResponseStreamFactory(
466 []() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
467 auto headObjectOutcome = this->GetS3Client()->HeadObject(headObjectRequest);
468 if (headObjectOutcome.IsSuccess()) {
469 stats->length = headObjectOutcome.GetResult().GetContentLength();
470 stats->is_directory = 0;
471 stats->mtime_nsec =
472 headObjectOutcome.GetResult().GetLastModified().Millis() * 1e6;
473 found = true;
474 }
475 string prefix = object;
476 if (prefix.back() != '/') {
477 prefix.push_back('/');
478 }
479 Aws::S3::Model::ListObjectsRequest listObjectsRequest;
480 listObjectsRequest.WithBucket(bucket.c_str())
481 .WithPrefix(prefix.c_str())
482 .WithMaxKeys(1);
483 listObjectsRequest.SetResponseStreamFactory(
484 []() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
485 auto listObjectsOutcome =
486 this->GetS3Client()->ListObjects(listObjectsRequest);
487 if (listObjectsOutcome.IsSuccess()) {
488 if (listObjectsOutcome.GetResult().GetContents().size() > 0) {
489 stats->length = 0;
490 stats->is_directory = 1;
491 found = true;
492 }
493 }
494 if (!found) {
495 return errors::NotFound("Object ", fname, " does not exist");
496 }
497 return Status::OK();
498 }
499
GetMatchingPaths(const string & pattern,std::vector<string> * results)500 Status S3FileSystem::GetMatchingPaths(const string& pattern,
501 std::vector<string>* results) {
502 return internal::GetMatchingPaths(this, Env::Default(), pattern, results);
503 }
504
DeleteFile(const string & fname)505 Status S3FileSystem::DeleteFile(const string& fname) {
506 string bucket, object;
507 TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
508
509 Aws::S3::Model::DeleteObjectRequest deleteObjectRequest;
510 deleteObjectRequest.WithBucket(bucket.c_str()).WithKey(object.c_str());
511
512 auto deleteObjectOutcome =
513 this->GetS3Client()->DeleteObject(deleteObjectRequest);
514 if (!deleteObjectOutcome.IsSuccess()) {
515 return errors::Unknown(deleteObjectOutcome.GetError().GetExceptionName(),
516 ": ", deleteObjectOutcome.GetError().GetMessage());
517 }
518 return Status::OK();
519 }
520
CreateDir(const string & dirname)521 Status S3FileSystem::CreateDir(const string& dirname) {
522 string bucket, object;
523 TF_RETURN_IF_ERROR(ParseS3Path(dirname, true, &bucket, &object));
524
525 if (object.empty()) {
526 Aws::S3::Model::HeadBucketRequest headBucketRequest;
527 headBucketRequest.WithBucket(bucket.c_str());
528 auto headBucketOutcome = this->GetS3Client()->HeadBucket(headBucketRequest);
529 if (!headBucketOutcome.IsSuccess()) {
530 return errors::NotFound("The bucket ", bucket, " was not found.");
531 }
532 return Status::OK();
533 }
534 string filename = dirname;
535 if (filename.back() != '/') {
536 filename.push_back('/');
537 }
538 std::unique_ptr<WritableFile> file;
539 TF_RETURN_IF_ERROR(NewWritableFile(filename, &file));
540 TF_RETURN_IF_ERROR(file->Close());
541 return Status::OK();
542 }
543
DeleteDir(const string & dirname)544 Status S3FileSystem::DeleteDir(const string& dirname) {
545 string bucket, object;
546 TF_RETURN_IF_ERROR(ParseS3Path(dirname, false, &bucket, &object));
547
548 string prefix = object;
549 if (prefix.back() != '/') {
550 prefix.push_back('/');
551 }
552 Aws::S3::Model::ListObjectsRequest listObjectsRequest;
553 listObjectsRequest.WithBucket(bucket.c_str())
554 .WithPrefix(prefix.c_str())
555 .WithMaxKeys(2);
556 listObjectsRequest.SetResponseStreamFactory(
557 []() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
558 auto listObjectsOutcome =
559 this->GetS3Client()->ListObjects(listObjectsRequest);
560 if (listObjectsOutcome.IsSuccess()) {
561 auto contents = listObjectsOutcome.GetResult().GetContents();
562 if (contents.size() > 1 ||
563 (contents.size() == 1 && contents[0].GetKey() != prefix.c_str())) {
564 return errors::FailedPrecondition("Cannot delete a non-empty directory.");
565 }
566 if (contents.size() == 1 && contents[0].GetKey() == prefix.c_str()) {
567 string filename = dirname;
568 if (filename.back() != '/') {
569 filename.push_back('/');
570 }
571 return DeleteFile(filename);
572 }
573 }
574 return Status::OK();
575 }
576
GetFileSize(const string & fname,uint64 * file_size)577 Status S3FileSystem::GetFileSize(const string& fname, uint64* file_size) {
578 FileStatistics stats;
579 TF_RETURN_IF_ERROR(this->Stat(fname, &stats));
580 *file_size = stats.length;
581 return Status::OK();
582 }
583
RenameFile(const string & src,const string & target)584 Status S3FileSystem::RenameFile(const string& src, const string& target) {
585 string src_bucket, src_object, target_bucket, target_object;
586 TF_RETURN_IF_ERROR(ParseS3Path(src, false, &src_bucket, &src_object));
587 TF_RETURN_IF_ERROR(
588 ParseS3Path(target, false, &target_bucket, &target_object));
589 if (src_object.back() == '/') {
590 if (target_object.back() != '/') {
591 target_object.push_back('/');
592 }
593 } else {
594 if (target_object.back() == '/') {
595 target_object.pop_back();
596 }
597 }
598
599 Aws::S3::Model::CopyObjectRequest copyObjectRequest;
600 Aws::S3::Model::DeleteObjectRequest deleteObjectRequest;
601
602 Aws::S3::Model::ListObjectsRequest listObjectsRequest;
603 listObjectsRequest.WithBucket(src_bucket.c_str())
604 .WithPrefix(src_object.c_str())
605 .WithMaxKeys(kS3GetChildrenMaxKeys);
606 listObjectsRequest.SetResponseStreamFactory(
607 []() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
608
609 Aws::S3::Model::ListObjectsResult listObjectsResult;
610 do {
611 auto listObjectsOutcome =
612 this->GetS3Client()->ListObjects(listObjectsRequest);
613 if (!listObjectsOutcome.IsSuccess()) {
614 return errors::Unknown(listObjectsOutcome.GetError().GetExceptionName(),
615 ": ", listObjectsOutcome.GetError().GetMessage());
616 }
617
618 listObjectsResult = listObjectsOutcome.GetResult();
619 for (const auto& object : listObjectsResult.GetContents()) {
620 Aws::String src_key = object.GetKey();
621 Aws::String target_key = src_key;
622 target_key.replace(0, src_object.length(), target_object.c_str());
623 Aws::String source = Aws::String(src_bucket.c_str()) + "/" +
624 Aws::Utils::StringUtils::URLEncode(src_key.c_str());
625
626 copyObjectRequest.SetBucket(target_bucket.c_str());
627 copyObjectRequest.SetKey(target_key);
628 copyObjectRequest.SetCopySource(source);
629
630 auto copyObjectOutcome =
631 this->GetS3Client()->CopyObject(copyObjectRequest);
632 if (!copyObjectOutcome.IsSuccess()) {
633 return errors::Unknown(copyObjectOutcome.GetError().GetExceptionName(),
634 ": ", copyObjectOutcome.GetError().GetMessage());
635 }
636
637 deleteObjectRequest.SetBucket(src_bucket.c_str());
638 deleteObjectRequest.SetKey(src_key.c_str());
639
640 auto deleteObjectOutcome =
641 this->GetS3Client()->DeleteObject(deleteObjectRequest);
642 if (!deleteObjectOutcome.IsSuccess()) {
643 return errors::Unknown(
644 deleteObjectOutcome.GetError().GetExceptionName(), ": ",
645 deleteObjectOutcome.GetError().GetMessage());
646 }
647 }
648 listObjectsRequest.SetMarker(listObjectsResult.GetNextMarker());
649 } while (listObjectsResult.GetIsTruncated());
650
651 return Status::OK();
652 }
653
654 REGISTER_FILE_SYSTEM("s3", S3FileSystem);
655
656 } // namespace tensorflow
657