1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/map_dataset_op.h"
16
17 #include "tensorflow/core/common_runtime/function.h"
18 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
19 #include "tensorflow/core/framework/partial_tensor_shape.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/kernels/data/dataset_utils.h"
22 #include "tensorflow/core/kernels/data/name_utils.h"
23 #include "tensorflow/core/lib/random/random.h"
24
25 namespace tensorflow {
26 namespace data {
27
28 // See documentation in ../../ops/dataset_ops.cc for a high-level
29 // description of the following op.
30
31 /* static */ constexpr const char* const MapDatasetOp::kDatasetType;
32 /* static */ constexpr const char* const MapDatasetOp::kInputDataset;
33 /* static */ constexpr const char* const MapDatasetOp::kOtherArguments;
34 /* static */ constexpr const char* const MapDatasetOp::kFunc;
35 /* static */ constexpr const char* const MapDatasetOp::kTarguments;
36 /* static */ constexpr const char* const MapDatasetOp::kOutputTypes;
37 /* static */ constexpr const char* const MapDatasetOp::kOutputShapes;
38 /* static */ constexpr const char* const MapDatasetOp::kUseInterOpParallelism;
39 /* static */ constexpr const char* const MapDatasetOp::kPreserveCardinality;
40
41 class MapDatasetOp::Dataset : public DatasetBase {
42 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,std::unique_ptr<CapturedFunction> captured_func,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,bool preserve_cardinality)43 Dataset(OpKernelContext* ctx, const DatasetBase* input,
44 std::unique_ptr<CapturedFunction> captured_func,
45 const DataTypeVector& output_types,
46 const std::vector<PartialTensorShape>& output_shapes,
47 bool preserve_cardinality)
48 : DatasetBase(DatasetContext(ctx)),
49 input_(input),
50 preserve_cardinality_(preserve_cardinality),
51 captured_func_(std::move(captured_func)),
52 output_types_(output_types),
53 output_shapes_(output_shapes) {
54 input_->Ref();
55 }
56
~Dataset()57 ~Dataset() override { input_->Unref(); }
58
MakeIteratorInternal(const string & prefix) const59 std::unique_ptr<IteratorBase> MakeIteratorInternal(
60 const string& prefix) const override {
61 return absl::make_unique<Iterator>(Iterator::Params{
62 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
63 }
64
output_dtypes() const65 const DataTypeVector& output_dtypes() const override { return output_types_; }
66
output_shapes() const67 const std::vector<PartialTensorShape>& output_shapes() const override {
68 return output_shapes_;
69 }
70
DebugString() const71 string DebugString() const override {
72 return name_utils::DatasetDebugString(kDatasetType);
73 }
74
Cardinality() const75 int64 Cardinality() const override {
76 if (preserve_cardinality_) {
77 return input_->Cardinality();
78 } else {
79 return kUnknownCardinality;
80 }
81 }
82
InputDatasets(std::vector<const DatasetBase * > * inputs) const83 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
84 inputs->push_back(input_);
85 return Status::OK();
86 }
87
CheckExternalState() const88 Status CheckExternalState() const override {
89 TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
90 return input_->CheckExternalState();
91 }
92
93 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const94 Status AsGraphDefInternal(SerializationContext* ctx,
95 DatasetGraphDefBuilder* b,
96 Node** output) const override {
97 Node* input_graph_node = nullptr;
98 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
99
100 std::vector<Node*> other_arguments;
101 DataTypeVector other_arguments_types;
102 TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
103 &other_arguments_types));
104
105 // Attr: f
106 AttrValue f_attr;
107 b->BuildAttrValue(captured_func_->func(), &f_attr);
108
109 // Attr: Targuments
110 AttrValue other_arguments_types_attr;
111 b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
112
113 // Attr: use_inter_op_parallelism
114 AttrValue use_inter_op_parallelism_attr;
115 b->BuildAttrValue(captured_func_->use_inter_op_parallelism(),
116 &use_inter_op_parallelism_attr);
117
118 // Attr: preserve_cardinality
119 AttrValue preserve_cardinality_attr;
120 b->BuildAttrValue(preserve_cardinality_, &preserve_cardinality_attr);
121
122 TF_RETURN_IF_ERROR(b->AddDataset(
123 this, {std::make_pair(0, input_graph_node)}, // Single tensor inputs.
124 {std::make_pair(1, other_arguments)}, // Tensor list inputs.
125 {std::make_pair(kFunc, f_attr),
126 std::make_pair(kTarguments, other_arguments_types_attr),
127 std::make_pair(kUseInterOpParallelism, use_inter_op_parallelism_attr),
128 std::make_pair(kPreserveCardinality,
129 preserve_cardinality_attr)}, // Attrs
130 output));
131 return Status::OK();
132 }
133
134 private:
135 class Iterator : public DatasetIterator<Dataset> {
136 public:
Iterator(const Params & params)137 explicit Iterator(const Params& params)
138 : DatasetIterator<Dataset>(params) {}
139
Initialize(IteratorContext * ctx)140 Status Initialize(IteratorContext* ctx) override {
141 TF_RETURN_IF_ERROR(
142 dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
143 return dataset()->captured_func_->Instantiate(
144 ctx, &instantiated_captured_func_);
145 }
146
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)147 Status GetNextInternal(IteratorContext* ctx,
148 std::vector<Tensor>* out_tensors,
149 bool* end_of_sequence) override {
150 // NOTE(mrry): This method is thread-safe as long as
151 // `input_impl_` and `f` are thread-safe. However, if multiple
152 // threads enter this method, outputs may be observed in a
153 // non-deterministic order.
154
155 std::vector<Tensor> args;
156 TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &args, end_of_sequence));
157 if (*end_of_sequence) {
158 return Status::OK();
159 }
160
161 Status s = instantiated_captured_func_->Run(ctx, std::move(args),
162 out_tensors, model_node());
163 if (errors::IsOutOfRange(s)) {
164 if (dataset()->preserve_cardinality_) {
165 // To guarantee that the transformation preserves the cardinality of
166 // the dataset, we convert `OutOfRange` to `InvalidArgument` as the
167 // former may be interpreted by a caller as the end of sequence.
168 return errors::InvalidArgument(
169 "Function invocation produced OutOfRangeError: ",
170 s.error_message());
171 } else {
172 // `f` may deliberately raise `errors::OutOfRange` to indicate
173 // that we should terminate the iteration early.
174 *end_of_sequence = true;
175 return Status::OK();
176 }
177 } else {
178 return s;
179 }
180 }
181
182 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const183 std::shared_ptr<model::Node> CreateNode(
184 IteratorContext* ctx, model::Node::Args args) const override {
185 return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1);
186 }
187
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)188 Status SaveInternal(SerializationContext* ctx,
189 IteratorStateWriter* writer) override {
190 TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
191 dataset()->captured_func_->CheckExternalState()));
192 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
193 return Status::OK();
194 }
195
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)196 Status RestoreInternal(IteratorContext* ctx,
197 IteratorStateReader* reader) override {
198 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
199 return Status::OK();
200 }
201
202 private:
203 std::unique_ptr<IteratorBase> input_impl_;
204 std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
205 };
206
207 const DatasetBase* const input_;
208 const bool preserve_cardinality_;
209 const std::unique_ptr<CapturedFunction> captured_func_;
210 const DataTypeVector output_types_;
211 const std::vector<PartialTensorShape> output_shapes_;
212 };
213
MapDatasetOp(OpKernelConstruction * ctx)214 MapDatasetOp::MapDatasetOp(OpKernelConstruction* ctx)
215 : UnaryDatasetOpKernel(ctx) {
216 FunctionMetadata::Params params;
217 OP_REQUIRES_OK(ctx, ctx->GetAttr(kUseInterOpParallelism,
218 ¶ms.use_inter_op_parallelism));
219 OP_REQUIRES_OK(ctx,
220 FunctionMetadata::Create(ctx, kFunc, params, &func_metadata_));
221 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
222 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
223 OP_REQUIRES_OK(ctx,
224 ctx->GetAttr(kPreserveCardinality, &preserve_cardinality_));
225 }
226
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)227 void MapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
228 DatasetBase** output) {
229 std::unique_ptr<CapturedFunction> captured_func;
230 OP_REQUIRES_OK(ctx,
231 CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
232 &captured_func));
233
234 *output = new Dataset(ctx, input, std::move(captured_func), output_types_,
235 output_shapes_, preserve_cardinality_);
236 }
237
238 namespace {
239
240 REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp);
241 REGISTER_KERNEL_BUILDER(Name("ExperimentalMapDataset")
242 .Device(DEVICE_GPU)
243 .HostMemory("input_dataset")
244 .HostMemory("handle"),
245 MapDatasetOp);
246 REGISTER_INPUT_COLOCATION_EXEMPTION("MapDataset");
247
248 } // namespace
249 } // namespace data
250 } // namespace tensorflow
251