1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/split_utils.h"
16
17 #include <functional>
18 #include <string>
19 #include <utility>
20
21 #include "tensorflow/core/platform/errors.h"
22
23 namespace tensorflow {
24 namespace data {
25 namespace {
26 constexpr char kNumToSkip[] = "num_to_skip";
27 constexpr char kSplitProvider[] = "split_provider";
28 constexpr char kSlash[] = "/";
29 constexpr char kIndex[] = "index";
30 } // namespace
31
IndexSplitProvider(int64_t n)32 IndexSplitProvider::IndexSplitProvider(int64_t n) : i_(0), n_(n) {}
33
GetNext(Tensor * split,bool * end_of_splits)34 Status IndexSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
35 mutex_lock l(mu_);
36 if (i_ >= n_) {
37 *end_of_splits = true;
38 return OkStatus();
39 }
40 *end_of_splits = false;
41 *split = Tensor(DT_INT64, TensorShape{});
42 split->scalar<int64_t>()() = i_++;
43 return OkStatus();
44 }
45
Reset()46 Status IndexSplitProvider::Reset() {
47 mutex_lock l(mu_);
48 i_ = 0;
49 return OkStatus();
50 }
51
Save(std::function<std::string (std::string)> full_name,IteratorStateWriter * writer)52 Status IndexSplitProvider::Save(
53 std::function<std::string(std::string)> full_name,
54 IteratorStateWriter* writer) {
55 mutex_lock l(mu_);
56 return writer->WriteScalar(full_name(kIndex), i_);
57 }
58
Restore(std::function<std::string (std::string)> full_name,IteratorStateReader * reader)59 Status IndexSplitProvider::Restore(
60 std::function<std::string(std::string)> full_name,
61 IteratorStateReader* reader) {
62 mutex_lock l(mu_);
63 return reader->ReadScalar(full_name(kIndex), &i_);
64 }
65
ShardingSplitProvider(int64_t num_shards,int64_t shard_index,std::shared_ptr<SplitProvider> split_provider)66 ShardingSplitProvider::ShardingSplitProvider(
67 int64_t num_shards, int64_t shard_index,
68 std::shared_ptr<SplitProvider> split_provider)
69 : num_shards_(num_shards),
70 shard_index_(shard_index),
71 split_provider_(split_provider),
72 num_to_skip_(shard_index_) {}
73
GetNext(Tensor * split,bool * end_of_splits)74 Status ShardingSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
75 mutex_lock l(mu_);
76 while (num_to_skip_ > 0) {
77 TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits));
78 if (*end_of_splits) {
79 return OkStatus();
80 }
81 num_to_skip_--;
82 }
83 num_to_skip_ = num_shards_ - 1;
84 TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits));
85 return OkStatus();
86 }
87
Reset()88 Status ShardingSplitProvider::Reset() {
89 mutex_lock l(mu_);
90 TF_RETURN_IF_ERROR(split_provider_->Reset());
91 num_to_skip_ = shard_index_;
92 return OkStatus();
93 }
94
Save(std::function<std::string (std::string)> full_name,IteratorStateWriter * writer)95 Status ShardingSplitProvider::Save(
96 std::function<std::string(std::string)> full_name,
97 IteratorStateWriter* writer) {
98 mutex_lock l(mu_);
99 TF_RETURN_IF_ERROR(split_provider_->Save(
100 [&](const std::string& key) {
101 return full_name(absl::StrCat(kSplitProvider, kSlash, key));
102 },
103 writer));
104 return writer->WriteScalar(full_name(kNumToSkip), num_to_skip_);
105 }
106
Restore(std::function<std::string (std::string)> full_name,IteratorStateReader * reader)107 Status ShardingSplitProvider::Restore(
108 std::function<std::string(std::string)> full_name,
109 IteratorStateReader* reader) {
110 mutex_lock l(mu_);
111 TF_RETURN_IF_ERROR(split_provider_->Restore(
112 [&](const std::string& key) {
113 return full_name(absl::StrCat(kSplitProvider, kSlash, key));
114 },
115 reader));
116 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumToSkip), &num_to_skip_));
117 return OkStatus();
118 }
119
GetSingleSplitProvider(IteratorContext * ctx,const DatasetBase * dataset)120 StatusOr<std::shared_ptr<SplitProvider>> GetSingleSplitProvider(
121 IteratorContext* ctx, const DatasetBase* dataset) {
122 if (ctx->split_providers().size() != 1) {
123 return errors::FailedPrecondition(
124 "Failed to get single split provider for dataset ",
125 dataset->DebugString(), ". Found ", ctx->split_providers().size(),
126 " split providers");
127 }
128 return ctx->split_providers()[0];
129 }
130
GetSplitProviders(const DatasetBase * dataset)131 StatusOr<std::vector<std::unique_ptr<SplitProvider>>> GetSplitProviders(
132 const DatasetBase* dataset) {
133 std::vector<std::unique_ptr<SplitProvider>> result;
134 std::vector<const DatasetBase*> inputs;
135 TF_RETURN_IF_ERROR(dataset->InputDatasets(&inputs));
136 for (const auto& input : inputs) {
137 std::vector<std::unique_ptr<SplitProvider>> providers;
138 TF_RETURN_IF_ERROR(input->MakeSplitProviders(&providers));
139 for (auto& provider : providers) {
140 result.push_back(std::move(provider));
141 }
142 }
143 return result;
144 }
145
CreateInputIteratorContexts(IteratorContext * ctx,const DatasetBase * dataset)146 StatusOr<std::vector<IteratorContext>> CreateInputIteratorContexts(
147 IteratorContext* ctx, const DatasetBase* dataset) {
148 std::vector<const DatasetBase*> inputs;
149 TF_RETURN_IF_ERROR(dataset->InputDatasets(&inputs));
150 std::vector<IteratorContext> result;
151 if (ctx->split_providers().empty()) {
152 for (int i = 0; i < inputs.size(); ++i) {
153 result.emplace_back(ctx);
154 }
155 return result;
156 }
157 int64_t num_sources = 0;
158 for (size_t i = 0; i < inputs.size(); ++i) {
159 if (inputs[i]->num_sources() < 0) {
160 return errors::FailedPrecondition(
161 "Failed to determine the number of sources for dataset of type ",
162 inputs[i]->type_string());
163 }
164 num_sources += inputs[i]->num_sources();
165 }
166 if (num_sources != ctx->split_providers().size()) {
167 return errors::FailedPrecondition(
168 "Attempted to feed ", ctx->split_providers().size(),
169 " split providers into a dataset with ", num_sources, " sources");
170 }
171 int64_t split_provider_index = 0;
172 for (size_t i = 0; i < inputs.size(); ++i) {
173 IteratorContext::Params params(ctx);
174 params.split_providers.clear();
175 for (int j = 0; j < inputs[i]->num_sources(); ++j) {
176 params.split_providers.push_back(
177 ctx->split_providers()[split_provider_index + j]);
178 }
179 split_provider_index += inputs[i]->num_sources();
180 result.emplace_back(std::move(params));
181 }
182 return result;
183 }
184
GetSingleSplitProvider(IteratorContext * ctx,const DatasetBase * dataset)185 StatusOr<std::shared_ptr<SplitProvider>> GetSingleSplitProvider(
186 IteratorContext* ctx, const DatasetBase* dataset) {
187 if (ctx->split_providers().size() != 1) {
188 return errors::FailedPrecondition(
189 "Failed to get single split provider for dataset ",
190 dataset->DebugString(), ". Found ", ctx->split_providers().size(),
191 " split providers");
192 }
193 return ctx->split_providers()[0];
194 }
195
GetSplitProviders(const DatasetBase * dataset)196 StatusOr<std::vector<std::unique_ptr<SplitProvider>>> GetSplitProviders(
197 const DatasetBase* dataset) {
198 std::vector<std::unique_ptr<SplitProvider>> result;
199 std::vector<const DatasetBase*> inputs;
200 TF_RETURN_IF_ERROR(dataset->InputDatasets(&inputs));
201 for (const auto& input : inputs) {
202 std::vector<std::unique_ptr<SplitProvider>> providers;
203 TF_RETURN_IF_ERROR(input->MakeSplitProviders(&providers));
204 for (auto& provider : providers) {
205 result.push_back(std::move(provider));
206 }
207 }
208 return result;
209 }
210
CreateInputIteratorContexts(IteratorContext * ctx,const DatasetBase * dataset)211 StatusOr<std::vector<IteratorContext>> CreateInputIteratorContexts(
212 IteratorContext* ctx, const DatasetBase* dataset) {
213 std::vector<const DatasetBase*> inputs;
214 TF_RETURN_IF_ERROR(dataset->InputDatasets(&inputs));
215 std::vector<IteratorContext> result;
216 if (ctx->split_providers().empty()) {
217 for (int i = 0; i < inputs.size(); ++i) {
218 result.emplace_back(ctx);
219 }
220 return result;
221 }
222 int64_t split_provider_index = 0;
223 for (size_t i = 0; i < inputs.size(); ++i) {
224 IteratorContext::Params params(ctx);
225 if (inputs[i]->num_sources() < 0) {
226 return errors::FailedPrecondition(
227 "Failed to determine the number of sources for dataset of type ",
228 inputs[i]->type_string());
229 }
230 params.split_providers.clear();
231 for (int j = 0; j < inputs[i]->num_sources(); ++j) {
232 params.split_providers.push_back(
233 ctx->split_providers()[split_provider_index + j]);
234 }
235 split_provider_index += inputs[i]->num_sources();
236 result.emplace_back(std::move(params));
237 }
238 if (split_provider_index != ctx->split_providers().size()) {
239 return errors::FailedPrecondition("Attempted to feed ",
240 ctx->split_providers().size(),
241 " split providers into a dataset with ",
242 split_provider_index, " sources");
243 }
244 return result;
245 }
246
247 } // namespace data
248 } // namespace tensorflow
249