• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/split_utils.h"
16 
17 #include <functional>
18 #include <string>
19 #include <utility>
20 
21 #include "tensorflow/core/platform/errors.h"
22 
23 namespace tensorflow {
24 namespace data {
25 namespace {
26 constexpr char kNumToSkip[] = "num_to_skip";
27 constexpr char kSplitProvider[] = "split_provider";
28 constexpr char kSlash[] = "/";
29 constexpr char kIndex[] = "index";
30 }  // namespace
31 
IndexSplitProvider(int64_t n)32 IndexSplitProvider::IndexSplitProvider(int64_t n) : i_(0), n_(n) {}
33 
GetNext(Tensor * split,bool * end_of_splits)34 Status IndexSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
35   mutex_lock l(mu_);
36   if (i_ >= n_) {
37     *end_of_splits = true;
38     return Status::OK();
39   }
40   *end_of_splits = false;
41   *split = Tensor(DT_INT64, TensorShape{});
42   split->scalar<int64>()() = i_++;
43   return Status::OK();
44 }
45 
Reset()46 Status IndexSplitProvider::Reset() {
47   mutex_lock l(mu_);
48   i_ = 0;
49   return Status::OK();
50 }
51 
Save(std::function<std::string (std::string)> full_name,IteratorStateWriter * writer)52 Status IndexSplitProvider::Save(
53     std::function<std::string(std::string)> full_name,
54     IteratorStateWriter* writer) {
55   mutex_lock l(mu_);
56   return writer->WriteScalar(full_name(kIndex), i_);
57 }
58 
Restore(std::function<std::string (std::string)> full_name,IteratorStateReader * reader)59 Status IndexSplitProvider::Restore(
60     std::function<std::string(std::string)> full_name,
61     IteratorStateReader* reader) {
62   mutex_lock l(mu_);
63   return reader->ReadScalar(full_name(kIndex), &i_);
64 }
65 
ShardingSplitProvider(int64_t num_shards,int64_t shard_index,std::shared_ptr<SplitProvider> split_provider)66 ShardingSplitProvider::ShardingSplitProvider(
67     int64_t num_shards, int64_t shard_index,
68     std::shared_ptr<SplitProvider> split_provider)
69     : num_shards_(num_shards),
70       shard_index_(shard_index),
71       split_provider_(split_provider),
72       num_to_skip_(shard_index_) {}
73 
GetNext(Tensor * split,bool * end_of_splits)74 Status ShardingSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
75   mutex_lock l(mu_);
76   while (num_to_skip_ > 0) {
77     TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits));
78     if (*end_of_splits) {
79       return Status::OK();
80     }
81     num_to_skip_--;
82   }
83   num_to_skip_ = num_shards_ - 1;
84   TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits));
85   return Status::OK();
86 }
87 
Reset()88 Status ShardingSplitProvider::Reset() {
89   mutex_lock l(mu_);
90   TF_RETURN_IF_ERROR(split_provider_->Reset());
91   num_to_skip_ = shard_index_;
92   return Status::OK();
93 }
94 
Save(std::function<std::string (std::string)> full_name,IteratorStateWriter * writer)95 Status ShardingSplitProvider::Save(
96     std::function<std::string(std::string)> full_name,
97     IteratorStateWriter* writer) {
98   mutex_lock l(mu_);
99   TF_RETURN_IF_ERROR(split_provider_->Save(
100       [&](const std::string& key) {
101         return full_name(absl::StrCat(kSplitProvider, kSlash, key));
102       },
103       writer));
104   return writer->WriteScalar(full_name(kNumToSkip), num_to_skip_);
105 }
106 
Restore(std::function<std::string (std::string)> full_name,IteratorStateReader * reader)107 Status ShardingSplitProvider::Restore(
108     std::function<std::string(std::string)> full_name,
109     IteratorStateReader* reader) {
110   mutex_lock l(mu_);
111   TF_RETURN_IF_ERROR(split_provider_->Restore(
112       [&](const std::string& key) {
113         return full_name(absl::StrCat(kSplitProvider, kSlash, key));
114       },
115       reader));
116   TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumToSkip), &num_to_skip_));
117   return Status::OK();
118 }
119 
GetSingleSplitProvider(IteratorContext * ctx,const DatasetBase * dataset)120 StatusOr<std::shared_ptr<SplitProvider>> GetSingleSplitProvider(
121     IteratorContext* ctx, const DatasetBase* dataset) {
122   if (ctx->split_providers().size() != 1) {
123     return errors::FailedPrecondition(
124         "Failed to get single split provider for dataset ",
125         dataset->DebugString(), ". Found ", ctx->split_providers().size(),
126         " split providers");
127   }
128   return ctx->split_providers()[0];
129 }
130 
GetSplitProviders(const DatasetBase * dataset)131 StatusOr<std::vector<std::unique_ptr<SplitProvider>>> GetSplitProviders(
132     const DatasetBase* dataset) {
133   std::vector<std::unique_ptr<SplitProvider>> result;
134   std::vector<const DatasetBase*> inputs;
135   TF_RETURN_IF_ERROR(dataset->InputDatasets(&inputs));
136   for (const auto& input : inputs) {
137     std::vector<std::unique_ptr<SplitProvider>> providers;
138     TF_RETURN_IF_ERROR(input->MakeSplitProviders(&providers));
139     for (auto& provider : providers) {
140       result.push_back(std::move(provider));
141     }
142   }
143   return result;
144 }
145 
CreateInputIteratorContexts(IteratorContext * ctx,const DatasetBase * dataset)146 StatusOr<std::vector<IteratorContext>> CreateInputIteratorContexts(
147     IteratorContext* ctx, const DatasetBase* dataset) {
148   std::vector<const DatasetBase*> inputs;
149   TF_RETURN_IF_ERROR(dataset->InputDatasets(&inputs));
150   std::vector<IteratorContext> result;
151   if (ctx->split_providers().empty()) {
152     for (int i = 0; i < inputs.size(); ++i) {
153       result.emplace_back(ctx);
154     }
155     return result;
156   }
157   int64_t split_provider_index = 0;
158   for (size_t i = 0; i < inputs.size(); ++i) {
159     IteratorContext::Params params(ctx);
160     if (inputs[i]->num_sources() < 0) {
161       return errors::FailedPrecondition(
162           "Failed to determine the number of sources for dataset of type ",
163           inputs[i]->type_string());
164     }
165     params.split_providers.clear();
166     for (int j = 0; j < inputs[i]->num_sources(); ++j) {
167       params.split_providers.push_back(
168           ctx->split_providers()[split_provider_index + j]);
169     }
170     split_provider_index += inputs[i]->num_sources();
171     result.emplace_back(std::move(params));
172   }
173   if (split_provider_index != ctx->split_providers().size()) {
174     return errors::FailedPrecondition("Attempted to feed ",
175                                       ctx->split_providers().size(),
176                                       " split providers into a dataset with ",
177                                       split_provider_index, " sources");
178   }
179   return result;
180 }
181 
182 }  // namespace data
183 }  // namespace tensorflow
184