1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/split_utils.h"
16
17 #include <functional>
18 #include <string>
19 #include <utility>
20
21 #include "tensorflow/core/platform/errors.h"
22
23 namespace tensorflow {
24 namespace data {
25 namespace {
26 constexpr char kNumToSkip[] = "num_to_skip";
27 constexpr char kSplitProvider[] = "split_provider";
28 constexpr char kSlash[] = "/";
29 constexpr char kIndex[] = "index";
30 } // namespace
31
IndexSplitProvider(int64_t n)32 IndexSplitProvider::IndexSplitProvider(int64_t n) : i_(0), n_(n) {}
33
GetNext(Tensor * split,bool * end_of_splits)34 Status IndexSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
35 mutex_lock l(mu_);
36 if (i_ >= n_) {
37 *end_of_splits = true;
38 return Status::OK();
39 }
40 *end_of_splits = false;
41 *split = Tensor(DT_INT64, TensorShape{});
42 split->scalar<int64>()() = i_++;
43 return Status::OK();
44 }
45
Reset()46 Status IndexSplitProvider::Reset() {
47 mutex_lock l(mu_);
48 i_ = 0;
49 return Status::OK();
50 }
51
Save(std::function<std::string (std::string)> full_name,IteratorStateWriter * writer)52 Status IndexSplitProvider::Save(
53 std::function<std::string(std::string)> full_name,
54 IteratorStateWriter* writer) {
55 mutex_lock l(mu_);
56 return writer->WriteScalar(full_name(kIndex), i_);
57 }
58
Restore(std::function<std::string (std::string)> full_name,IteratorStateReader * reader)59 Status IndexSplitProvider::Restore(
60 std::function<std::string(std::string)> full_name,
61 IteratorStateReader* reader) {
62 mutex_lock l(mu_);
63 return reader->ReadScalar(full_name(kIndex), &i_);
64 }
65
ShardingSplitProvider(int64_t num_shards,int64_t shard_index,std::shared_ptr<SplitProvider> split_provider)66 ShardingSplitProvider::ShardingSplitProvider(
67 int64_t num_shards, int64_t shard_index,
68 std::shared_ptr<SplitProvider> split_provider)
69 : num_shards_(num_shards),
70 shard_index_(shard_index),
71 split_provider_(split_provider),
72 num_to_skip_(shard_index_) {}
73
GetNext(Tensor * split,bool * end_of_splits)74 Status ShardingSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
75 mutex_lock l(mu_);
76 while (num_to_skip_ > 0) {
77 TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits));
78 if (*end_of_splits) {
79 return Status::OK();
80 }
81 num_to_skip_--;
82 }
83 num_to_skip_ = num_shards_ - 1;
84 TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits));
85 return Status::OK();
86 }
87
Reset()88 Status ShardingSplitProvider::Reset() {
89 mutex_lock l(mu_);
90 TF_RETURN_IF_ERROR(split_provider_->Reset());
91 num_to_skip_ = shard_index_;
92 return Status::OK();
93 }
94
Save(std::function<std::string (std::string)> full_name,IteratorStateWriter * writer)95 Status ShardingSplitProvider::Save(
96 std::function<std::string(std::string)> full_name,
97 IteratorStateWriter* writer) {
98 mutex_lock l(mu_);
99 TF_RETURN_IF_ERROR(split_provider_->Save(
100 [&](const std::string& key) {
101 return full_name(absl::StrCat(kSplitProvider, kSlash, key));
102 },
103 writer));
104 return writer->WriteScalar(full_name(kNumToSkip), num_to_skip_);
105 }
106
Restore(std::function<std::string (std::string)> full_name,IteratorStateReader * reader)107 Status ShardingSplitProvider::Restore(
108 std::function<std::string(std::string)> full_name,
109 IteratorStateReader* reader) {
110 mutex_lock l(mu_);
111 TF_RETURN_IF_ERROR(split_provider_->Restore(
112 [&](const std::string& key) {
113 return full_name(absl::StrCat(kSplitProvider, kSlash, key));
114 },
115 reader));
116 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumToSkip), &num_to_skip_));
117 return Status::OK();
118 }
119
GetSingleSplitProvider(IteratorContext * ctx,const DatasetBase * dataset)120 StatusOr<std::shared_ptr<SplitProvider>> GetSingleSplitProvider(
121 IteratorContext* ctx, const DatasetBase* dataset) {
122 if (ctx->split_providers().size() != 1) {
123 return errors::FailedPrecondition(
124 "Failed to get single split provider for dataset ",
125 dataset->DebugString(), ". Found ", ctx->split_providers().size(),
126 " split providers");
127 }
128 return ctx->split_providers()[0];
129 }
130
GetSplitProviders(const DatasetBase * dataset)131 StatusOr<std::vector<std::unique_ptr<SplitProvider>>> GetSplitProviders(
132 const DatasetBase* dataset) {
133 std::vector<std::unique_ptr<SplitProvider>> result;
134 std::vector<const DatasetBase*> inputs;
135 TF_RETURN_IF_ERROR(dataset->InputDatasets(&inputs));
136 for (const auto& input : inputs) {
137 std::vector<std::unique_ptr<SplitProvider>> providers;
138 TF_RETURN_IF_ERROR(input->MakeSplitProviders(&providers));
139 for (auto& provider : providers) {
140 result.push_back(std::move(provider));
141 }
142 }
143 return result;
144 }
145
CreateInputIteratorContexts(IteratorContext * ctx,const DatasetBase * dataset)146 StatusOr<std::vector<IteratorContext>> CreateInputIteratorContexts(
147 IteratorContext* ctx, const DatasetBase* dataset) {
148 std::vector<const DatasetBase*> inputs;
149 TF_RETURN_IF_ERROR(dataset->InputDatasets(&inputs));
150 std::vector<IteratorContext> result;
151 if (ctx->split_providers().empty()) {
152 for (int i = 0; i < inputs.size(); ++i) {
153 result.emplace_back(ctx);
154 }
155 return result;
156 }
157 int64_t split_provider_index = 0;
158 for (size_t i = 0; i < inputs.size(); ++i) {
159 IteratorContext::Params params(ctx);
160 if (inputs[i]->num_sources() < 0) {
161 return errors::FailedPrecondition(
162 "Failed to determine the number of sources for dataset of type ",
163 inputs[i]->type_string());
164 }
165 params.split_providers.clear();
166 for (int j = 0; j < inputs[i]->num_sources(); ++j) {
167 params.split_providers.push_back(
168 ctx->split_providers()[split_provider_index + j]);
169 }
170 split_provider_index += inputs[i]->num_sources();
171 result.emplace_back(std::move(params));
172 }
173 if (split_provider_index != ctx->split_providers().size()) {
174 return errors::FailedPrecondition("Attempted to feed ",
175 ctx->split_providers().size(),
176 " split providers into a dataset with ",
177 split_provider_index, " sources");
178 }
179 return result;
180 }
181
182 } // namespace data
183 } // namespace tensorflow
184