• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/flat_map_dataset_op.h"
16 
17 #include <string>
18 #include <utility>
19 
20 #include "tensorflow/core/common_runtime/function.h"
21 #include "tensorflow/core/common_runtime/graph_constructor.h"
22 #include "tensorflow/core/common_runtime/graph_runner.h"
23 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
24 #include "tensorflow/core/data/dataset_utils.h"
25 #include "tensorflow/core/data/name_utils.h"
26 #include "tensorflow/core/data/serialization_utils.h"
27 #include "tensorflow/core/framework/dataset.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/partial_tensor_shape.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/lib/random/random.h"
32 
33 namespace tensorflow {
34 namespace data {
35 
36 // See documentation in ../../ops/dataset_ops.cc for a high-level
37 // description of the following op.
38 
39 /* static */ constexpr const char* const FlatMapDatasetOp::kDatasetType;
40 /* static */ constexpr const char* const FlatMapDatasetOp::kInputDataset;
41 /* static */ constexpr const char* const FlatMapDatasetOp::kOtherArguments;
42 /* static */ constexpr const char* const FlatMapDatasetOp::kFunc;
43 /* static */ constexpr const char* const FlatMapDatasetOp::kTarguments;
44 /* static */ constexpr const char* const FlatMapDatasetOp::kOutputTypes;
45 /* static */ constexpr const char* const FlatMapDatasetOp::kOutputShapes;
46 
47 constexpr char kElementIndex[] = "element_index";
48 constexpr char kInputsSize[] = "inputs_size";
49 constexpr char kInputs[] = "inputs";
50 constexpr char kCurrentElementIteratorUninitialized[] =
51     "current_element_iterator_uninitialized";
52 constexpr char kExhausted[] = "exhausted";
53 
54 class FlatMapDatasetOp::Dataset : public DatasetBase {
55  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,std::unique_ptr<CapturedFunction> captured_func,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)56   Dataset(OpKernelContext* ctx, const DatasetBase* input,
57           std::unique_ptr<CapturedFunction> captured_func,
58           const DataTypeVector& output_types,
59           const std::vector<PartialTensorShape>& output_shapes)
60       : DatasetBase(DatasetContext(ctx)),
61         input_(input),
62         captured_func_(std::move(captured_func)),
63         output_types_(output_types),
64         output_shapes_(output_shapes) {
65     input_->Ref();
66   }
67 
~Dataset()68   ~Dataset() override { input_->Unref(); }
69 
MakeIteratorInternal(const string & prefix) const70   std::unique_ptr<IteratorBase> MakeIteratorInternal(
71       const string& prefix) const override {
72     return absl::make_unique<Iterator>(Iterator::Params{
73         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
74   }
75 
output_dtypes() const76   const DataTypeVector& output_dtypes() const override { return output_types_; }
77 
output_shapes() const78   const std::vector<PartialTensorShape>& output_shapes() const override {
79     return output_shapes_;
80   }
81 
DebugString() const82   string DebugString() const override {
83     return name_utils::DatasetDebugString(kDatasetType);
84   }
85 
InputDatasets(std::vector<const DatasetBase * > * inputs) const86   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
87     inputs->push_back(input_);
88     return Status::OK();
89   }
90 
CheckExternalState() const91   Status CheckExternalState() const override {
92     TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
93     return input_->CheckExternalState();
94   }
95 
96  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const97   Status AsGraphDefInternal(SerializationContext* ctx,
98                             DatasetGraphDefBuilder* b,
99                             Node** output) const override {
100     Node* input_graph_node = nullptr;
101     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
102     std::vector<Node*> other_arguments;
103     DataTypeVector other_arguments_types;
104     TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
105                                                   &other_arguments_types));
106     AttrValue f;
107     b->BuildAttrValue(captured_func_->func(), &f);
108     AttrValue other_arguments_types_attr;
109     b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
110 
111     TF_RETURN_IF_ERROR(b->AddDataset(
112         this, {std::make_pair(0, input_graph_node)},  // Single tensor inputs.
113         {std::make_pair(1, other_arguments)},         // Tensor list inputs.
114         {std::make_pair(kFunc, f),
115          std::make_pair(kTarguments, other_arguments_types_attr)},  // Attrs
116         output));
117     return Status::OK();
118   }
119 
120  private:
121   class Iterator : public DatasetIterator<Dataset> {
122    public:
Iterator(const Params & params)123     explicit Iterator(const Params& params)
124         : DatasetIterator<Dataset>(params) {}
125 
Initialize(IteratorContext * ctx)126     Status Initialize(IteratorContext* ctx) override {
127       TF_RETURN_IF_ERROR(
128           dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
129       return dataset()->captured_func_->Instantiate(
130           ctx, &instantiated_captured_func_);
131     }
132 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)133     Status GetNextInternal(IteratorContext* ctx,
134                            std::vector<Tensor>* out_tensors,
135                            bool* end_of_sequence) override {
136       mutex_lock l(mu_);
137       do {
138         if (!input_impl_) {
139           *end_of_sequence = true;
140           return Status::OK();
141         }
142         if (current_element_iterator_) {
143           // We are currently processing a mapped element, so try to get the
144           // next subelement.
145           bool end_of_element;
146           TF_RETURN_IF_ERROR(current_element_iterator_->GetNext(
147               ctx, out_tensors, &end_of_element));
148           if (!end_of_element) {
149             // Produce the subelement as output.
150             *end_of_sequence = false;
151             return Status::OK();
152           }
153 
154           // We have reached the end of the current element, so maybe move on
155           // to the next element.
156           current_element_iterator_.reset();
157         }
158 
159         // Get the next element from the input dataset.
160         inputs_.clear();
161         TF_RETURN_IF_ERROR(
162             input_impl_->GetNext(ctx, &inputs_, end_of_sequence));
163         if (*end_of_sequence) {
164           input_impl_.reset();
165           return Status::OK();
166         }
167 
168         TF_RETURN_IF_ERROR(
169             BuildCurrentElementIteratorLocked(ctx, /*is_get_next=*/true));
170       } while (true);
171     }
172 
SkipInternal(IteratorContext * ctx,int num_to_skip,bool * end_of_sequence,int * num_skipped)173     Status SkipInternal(IteratorContext* ctx, int num_to_skip,
174                         bool* end_of_sequence, int* num_skipped) override {
175       mutex_lock l(mu_);
176       *num_skipped = 0;
177       while (*num_skipped < num_to_skip) {
178         if (!input_impl_) {
179           *end_of_sequence = true;
180           return Status::OK();
181         }
182         if (!current_element_iterator_) {
183           // Get the next element from the input dataset.
184           inputs_.clear();
185           TF_RETURN_IF_ERROR(
186               input_impl_->GetNext(ctx, &inputs_, end_of_sequence));
187           if (*end_of_sequence) {
188             input_impl_.reset();
189             *end_of_sequence = true;
190             return Status::OK();
191           }
192           TF_RETURN_IF_ERROR(
193               BuildCurrentElementIteratorLocked(ctx, /*is_get_next=*/false));
194         }
195         bool end_of_element;
196         int last_num_skipped;
197         TF_RETURN_IF_ERROR(current_element_iterator_->Skip(
198             ctx, num_to_skip - *num_skipped, &end_of_element,
199             &last_num_skipped));
200         *num_skipped += last_num_skipped;
201         if (end_of_element) {
202           // We have reached the end of the current element, so maybe move on
203           // to the next element.
204           current_element_iterator_.reset();
205         }
206       }
207       *end_of_sequence = false;
208       return Status::OK();
209     }
210 
211    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const212     std::shared_ptr<model::Node> CreateNode(
213         IteratorContext* ctx, model::Node::Args args) const override {
214       return model::MakeInterleaveManyNode(std::move(args));
215     }
216 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)217     Status SaveInternal(SerializationContext* ctx,
218                         IteratorStateWriter* writer) override {
219       TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
220           dataset()->captured_func_->CheckExternalState()));
221       mutex_lock l(mu_);
222       if (input_impl_) {
223         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
224         TF_RETURN_IF_ERROR(
225             writer->WriteScalar(full_name(kElementIndex), element_index_));
226         if (current_element_iterator_) {
227           TF_RETURN_IF_ERROR(
228               writer->WriteScalar(full_name(kInputsSize), inputs_.size()));
229           for (int i = 0; i < inputs_.size(); i++) {
230             TF_RETURN_IF_ERROR(writer->WriteTensor(
231                 full_name(strings::StrCat(kInputs, "[", i, "]")), inputs_[i]));
232           }
233           TF_RETURN_IF_ERROR(SaveInput(ctx, writer, current_element_iterator_));
234         } else {
235           TF_RETURN_IF_ERROR(writer->WriteScalar(
236               full_name(kCurrentElementIteratorUninitialized), ""));
237         }
238       } else {
239         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kExhausted), ""));
240       }
241       return Status::OK();
242     }
243 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)244     Status RestoreInternal(IteratorContext* ctx,
245                            IteratorStateReader* reader) override {
246       mutex_lock l(mu_);
247       input_impl_.reset();
248       element_index_ = 0;
249       current_element_iterator_.reset();
250       inputs_.clear();
251       if (!reader->Contains(full_name(kExhausted))) {
252         TF_RETURN_IF_ERROR(
253             dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
254         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
255         {
256           int64_t temp;
257           TF_RETURN_IF_ERROR(
258               reader->ReadScalar(full_name(kElementIndex), &temp));
259           element_index_ = temp;
260         }
261         if (!reader->Contains(
262                 full_name(kCurrentElementIteratorUninitialized))) {
263           size_t inputs_size;
264           {
265             int64_t temp;
266             TF_RETURN_IF_ERROR(
267                 reader->ReadScalar(full_name(kInputsSize), &temp));
268             inputs_size = static_cast<size_t>(temp);
269           }
270           inputs_.reserve(inputs_size);
271           for (int i = 0; i < inputs_size; i++) {
272             inputs_.emplace_back();
273             TF_RETURN_IF_ERROR(reader->ReadTensor(
274                 ctx->flr(), full_name(strings::StrCat(kInputs, "[", i, "]")),
275                 &inputs_.back()));
276           }
277 
278           element_index_--;
279           TF_RETURN_IF_ERROR(
280               BuildCurrentElementIteratorLocked(ctx, /*is_get_next=*/false));
281           TF_RETURN_IF_ERROR(
282               RestoreInput(ctx, reader, current_element_iterator_));
283         }
284       }
285       return Status::OK();
286     }
287 
288    private:
BuildCurrentElementIteratorLocked(IteratorContext * ctx,bool is_get_next)289     Status BuildCurrentElementIteratorLocked(IteratorContext* ctx,
290                                              bool is_get_next)
291         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
292       if (is_get_next) {
293         return MakeIteratorFromInputElement(
294             ctx, this, inputs_, element_index_++, *instantiated_captured_func_,
295             prefix(), &current_element_iterator_, model_node());
296       } else {
297         // NOTE: We intentionally ignore resource modeling outside GetNext().
298         return MakeIteratorFromInputElement(
299             ctx, this, inputs_, element_index_++, *instantiated_captured_func_,
300             prefix(), &current_element_iterator_,
301             /*node=*/nullptr);
302       }
303     }
304 
305     mutex mu_;
306     size_t element_index_ TF_GUARDED_BY(mu_) = 0;
307     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
308     std::unique_ptr<IteratorBase> current_element_iterator_ TF_GUARDED_BY(mu_);
309     std::vector<Tensor> inputs_ TF_GUARDED_BY(mu_);
310     std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
311   };
312 
313   const DatasetBase* const input_;
314   const std::unique_ptr<CapturedFunction> captured_func_;
315   const DataTypeVector output_types_;
316   const std::vector<PartialTensorShape> output_shapes_;
317 };
318 
FlatMapDatasetOp(OpKernelConstruction * ctx)319 FlatMapDatasetOp::FlatMapDatasetOp(OpKernelConstruction* ctx)
320     : UnaryDatasetOpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
321   OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kFunc, /*params=*/{},
322                                                &func_metadata_));
323   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
324   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
325 }
326 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)327 void FlatMapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
328                                    DatasetBase** output) {
329   std::unique_ptr<CapturedFunction> captured_func;
330   OP_REQUIRES_OK(ctx,
331                  CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
332                                           &captured_func));
333   *output = new Dataset(ctx, input, std::move(captured_func), output_types_,
334                         output_shapes_);
335 }
336 
337 namespace {
338 
339 REGISTER_KERNEL_BUILDER(Name("FlatMapDataset").Device(DEVICE_CPU),
340                         FlatMapDatasetOp);
341 REGISTER_INPUT_COLOCATION_EXEMPTION("FlatMapDataset");
342 
343 }  // namespace
344 }  // namespace data
345 }  // namespace tensorflow
346