• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/function.h"
16 #include "tensorflow/core/framework/partial_tensor_shape.h"
17 #include "tensorflow/core/framework/tensor.h"
18 #include "tensorflow/core/kernels/data/captured_function.h"
19 #include "tensorflow/core/kernels/data/dataset.h"
20 #include "tensorflow/core/lib/random/random.h"
21 
22 namespace tensorflow {
23 
24 namespace {
25 
26 // See documentation in ../ops/dataset_ops.cc for a high-level
27 // description of the following op.
28 
29 class MapDatasetOp : public UnaryDatasetOpKernel {
30  public:
MapDatasetOp(OpKernelConstruction * ctx)31   explicit MapDatasetOp(OpKernelConstruction* ctx)
32       : UnaryDatasetOpKernel(ctx),
33         graph_def_version_(ctx->graph_def_version()) {
34     OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
35     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
36     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
37   }
38 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)39   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
40                    DatasetBase** output) override {
41     OpInputList inputs;
42     OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
43     std::vector<Tensor> other_arguments;
44     other_arguments.reserve(inputs.size());
45     for (const Tensor& t : inputs) {
46       other_arguments.push_back(t);
47     }
48 
49     std::unique_ptr<CapturedFunction> captured_func;
50     OP_REQUIRES_OK(ctx, CapturedFunction::Create(
51                             func_, std::move(other_arguments), &captured_func));
52 
53     *output = new Dataset(ctx, input, func_, std::move(captured_func),
54                           output_types_, output_shapes_);
55   }
56 
57  private:
58   class Dataset : public GraphDatasetBase {
59    public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,const NameAttrList & func,std::unique_ptr<CapturedFunction> captured_func,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)60     Dataset(OpKernelContext* ctx, const DatasetBase* input,
61             const NameAttrList& func,
62             std::unique_ptr<CapturedFunction> captured_func,
63             const DataTypeVector& output_types,
64             const std::vector<PartialTensorShape>& output_shapes)
65         : GraphDatasetBase(ctx),
66           input_(input),
67           func_(func),
68           captured_func_(std::move(captured_func)),
69           output_types_(output_types),
70           output_shapes_(output_shapes) {
71       input_->Ref();
72     }
73 
~Dataset()74     ~Dataset() override { input_->Unref(); }
75 
MakeIterator(const string & prefix) const76     std::unique_ptr<IteratorBase> MakeIterator(
77         const string& prefix) const override {
78       return std::unique_ptr<IteratorBase>(
79           new Iterator({this, strings::StrCat(prefix, "::Map")}));
80     }
81 
output_dtypes() const82     const DataTypeVector& output_dtypes() const override {
83       return output_types_;
84     }
output_shapes() const85     const std::vector<PartialTensorShape>& output_shapes() const override {
86       return output_shapes_;
87     }
88 
DebugString()89     string DebugString() override { return "MapDatasetOp::Dataset"; }
90 
91    protected:
AsGraphDefInternal(OpKernelContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const92     Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
93                               Node** output) const override {
94       TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
95       Node* input_graph_node = nullptr;
96       TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
97 
98       DataTypeVector other_arguments_types;
99       other_arguments_types.reserve(captured_func_->captured_inputs().size());
100       std::vector<Node*> other_arguments;
101       other_arguments.reserve(captured_func_->captured_inputs().size());
102       for (const Tensor& t : captured_func_->captured_inputs()) {
103         Node* node;
104         TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
105         other_arguments.emplace_back(node);
106         other_arguments_types.emplace_back(t.dtype());
107       }
108       AttrValue f;
109       b->BuildAttrValue(func_, &f);
110       AttrValue other_arguments_types_attr;
111       b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
112 
113       TF_RETURN_IF_ERROR(b->AddDataset(
114           this, {std::make_pair(0, input_graph_node)},  // Single tensor inputs.
115           {std::make_pair(1, other_arguments)},         // Tensor list inputs.
116           {std::make_pair("f", f),
117            std::make_pair("Targuments", other_arguments_types_attr)},  // Attrs
118           output));
119       return Status::OK();
120     }
121 
122    private:
123     class Iterator : public DatasetIterator<Dataset> {
124      public:
Iterator(const Params & params)125       explicit Iterator(const Params& params)
126           : DatasetIterator<Dataset>(params),
127             input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
128 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)129       Status GetNextInternal(IteratorContext* ctx,
130                              std::vector<Tensor>* out_tensors,
131                              bool* end_of_sequence) override {
132         // NOTE(mrry): This method is thread-safe as long as
133         // `input_impl_` and `f` are thread-safe. However, if multiple
134         // threads enter this method, outputs may be observed in a
135         // non-deterministic order.
136 
137         std::vector<Tensor> args;
138         TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &args, end_of_sequence));
139         if (*end_of_sequence) {
140           return Status::OK();
141         }
142 
143         // TODO(mrry): Avoid blocking a threadpool thread. We will need to
144         // stack-rip the iterators and use async kernels.
145         Status s =
146             dataset()->captured_func_->Run(ctx, std::move(args), out_tensors);
147         if (errors::IsOutOfRange(s)) {
148           // `f` may deliberately raise `errors::OutOfRange` to indicate
149           // that we should terminate the iteration early.
150           *end_of_sequence = true;
151           return Status::OK();
152         } else {
153           return s;
154         }
155       }
156 
157      protected:
SaveInternal(IteratorStateWriter * writer)158       Status SaveInternal(IteratorStateWriter* writer) override {
159         TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
160         return Status::OK();
161       }
162 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)163       Status RestoreInternal(IteratorContext* ctx,
164                              IteratorStateReader* reader) override {
165         TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
166         return Status::OK();
167       }
168 
169      private:
170       const std::unique_ptr<IteratorBase> input_impl_;
171     };
172 
173     const DatasetBase* const input_;
174     const NameAttrList func_;
175     const std::unique_ptr<CapturedFunction> captured_func_;
176     const DataTypeVector output_types_;
177     const std::vector<PartialTensorShape> output_shapes_;
178   };
179 
180   const int graph_def_version_;
181   DataTypeVector output_types_;
182   std::vector<PartialTensorShape> output_shapes_;
183   NameAttrList func_;
184 };
185 
186 REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp);
187 
188 }  // namespace
189 
190 }  // namespace tensorflow
191