1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/split_utils.h"
16
17 namespace tensorflow {
18 namespace data {
19 namespace {
20 constexpr char kNumToSkip[] = "num_to_skip";
21 constexpr char kSplitProvider[] = "split_provider";
22 constexpr char kSlash[] = "/";
23 constexpr char kIndex[] = "index";
24 } // namespace
25
IndexSplitProvider(int64 n)26 IndexSplitProvider::IndexSplitProvider(int64 n) : i_(0), n_(n) {}
27
GetNext(Tensor * split,bool * end_of_splits)28 Status IndexSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
29 mutex_lock l(mu_);
30 if (i_ >= n_) {
31 *end_of_splits = true;
32 return Status::OK();
33 }
34 *end_of_splits = false;
35 *split = Tensor(DT_INT64, TensorShape{});
36 split->scalar<int64>()() = i_++;
37 return Status::OK();
38 }
39
Reset()40 Status IndexSplitProvider::Reset() {
41 mutex_lock l(mu_);
42 i_ = 0;
43 return Status::OK();
44 }
45
Save(std::function<std::string (std::string)> full_name,IteratorStateWriter * writer)46 Status IndexSplitProvider::Save(
47 std::function<std::string(std::string)> full_name,
48 IteratorStateWriter* writer) {
49 mutex_lock l(mu_);
50 return writer->WriteScalar(full_name(kIndex), i_);
51 }
52
Restore(std::function<std::string (std::string)> full_name,IteratorStateReader * reader)53 Status IndexSplitProvider::Restore(
54 std::function<std::string(std::string)> full_name,
55 IteratorStateReader* reader) {
56 mutex_lock l(mu_);
57 return reader->ReadScalar(full_name(kIndex), &i_);
58 }
59
ShardingSplitProvider(int64 num_shards,int64 shard_index,std::shared_ptr<SplitProvider> split_provider)60 ShardingSplitProvider::ShardingSplitProvider(
61 int64 num_shards, int64 shard_index,
62 std::shared_ptr<SplitProvider> split_provider)
63 : num_shards_(num_shards),
64 shard_index_(shard_index),
65 split_provider_(split_provider),
66 num_to_skip_(shard_index_) {}
67
GetNext(Tensor * split,bool * end_of_splits)68 Status ShardingSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
69 mutex_lock l(mu_);
70 while (num_to_skip_ > 0) {
71 TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits));
72 if (*end_of_splits) {
73 return Status::OK();
74 }
75 num_to_skip_--;
76 }
77 num_to_skip_ = num_shards_ - 1;
78 TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits));
79 return Status::OK();
80 }
81
Reset()82 Status ShardingSplitProvider::Reset() {
83 mutex_lock l(mu_);
84 TF_RETURN_IF_ERROR(split_provider_->Reset());
85 num_to_skip_ = shard_index_;
86 return Status::OK();
87 }
88
Save(std::function<std::string (std::string)> full_name,IteratorStateWriter * writer)89 Status ShardingSplitProvider::Save(
90 std::function<std::string(std::string)> full_name,
91 IteratorStateWriter* writer) {
92 mutex_lock l(mu_);
93 TF_RETURN_IF_ERROR(split_provider_->Save(
94 [&](const std::string& key) {
95 return full_name(absl::StrCat(kSplitProvider, kSlash, key));
96 },
97 writer));
98 return writer->WriteScalar(full_name(kNumToSkip), num_to_skip_);
99 }
100
Restore(std::function<std::string (std::string)> full_name,IteratorStateReader * reader)101 Status ShardingSplitProvider::Restore(
102 std::function<std::string(std::string)> full_name,
103 IteratorStateReader* reader) {
104 mutex_lock l(mu_);
105 TF_RETURN_IF_ERROR(split_provider_->Restore(
106 [&](const std::string& key) {
107 return full_name(absl::StrCat(kSplitProvider, kSlash, key));
108 },
109 reader));
110 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumToSkip), &num_to_skip_));
111 return Status::OK();
112 }
113
114 } // namespace data
115 } // namespace tensorflow
116