• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/map_dataset_op.h"
16 
17 #include "tensorflow/core/common_runtime/function.h"
18 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
19 #include "tensorflow/core/framework/partial_tensor_shape.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/kernels/data/dataset_utils.h"
22 #include "tensorflow/core/kernels/data/name_utils.h"
23 #include "tensorflow/core/lib/random/random.h"
24 
25 namespace tensorflow {
26 namespace data {
27 
28 // See documentation in ../../ops/dataset_ops.cc for a high-level
29 // description of the following op.
30 
31 /* static */ constexpr const char* const MapDatasetOp::kDatasetType;
32 /* static */ constexpr const char* const MapDatasetOp::kInputDataset;
33 /* static */ constexpr const char* const MapDatasetOp::kOtherArguments;
34 /* static */ constexpr const char* const MapDatasetOp::kFunc;
35 /* static */ constexpr const char* const MapDatasetOp::kTarguments;
36 /* static */ constexpr const char* const MapDatasetOp::kOutputTypes;
37 /* static */ constexpr const char* const MapDatasetOp::kOutputShapes;
38 /* static */ constexpr const char* const MapDatasetOp::kUseInterOpParallelism;
39 /* static */ constexpr const char* const MapDatasetOp::kPreserveCardinality;
40 
41 class MapDatasetOp::Dataset : public DatasetBase {
42  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,std::unique_ptr<CapturedFunction> captured_func,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,bool preserve_cardinality)43   Dataset(OpKernelContext* ctx, const DatasetBase* input,
44           std::unique_ptr<CapturedFunction> captured_func,
45           const DataTypeVector& output_types,
46           const std::vector<PartialTensorShape>& output_shapes,
47           bool preserve_cardinality)
48       : DatasetBase(DatasetContext(ctx)),
49         input_(input),
50         preserve_cardinality_(preserve_cardinality),
51         captured_func_(std::move(captured_func)),
52         output_types_(output_types),
53         output_shapes_(output_shapes) {
54     input_->Ref();
55   }
56 
~Dataset()57   ~Dataset() override { input_->Unref(); }
58 
MakeIteratorInternal(const string & prefix) const59   std::unique_ptr<IteratorBase> MakeIteratorInternal(
60       const string& prefix) const override {
61     return absl::make_unique<Iterator>(Iterator::Params{
62         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
63   }
64 
output_dtypes() const65   const DataTypeVector& output_dtypes() const override { return output_types_; }
66 
output_shapes() const67   const std::vector<PartialTensorShape>& output_shapes() const override {
68     return output_shapes_;
69   }
70 
DebugString() const71   string DebugString() const override {
72     return name_utils::DatasetDebugString(kDatasetType);
73   }
74 
Cardinality() const75   int64 Cardinality() const override {
76     if (preserve_cardinality_) {
77       return input_->Cardinality();
78     } else {
79       return kUnknownCardinality;
80     }
81   }
82 
InputDatasets(std::vector<const DatasetBase * > * inputs) const83   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
84     inputs->push_back(input_);
85     return Status::OK();
86   }
87 
CheckExternalState() const88   Status CheckExternalState() const override {
89     TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
90     return input_->CheckExternalState();
91   }
92 
93  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const94   Status AsGraphDefInternal(SerializationContext* ctx,
95                             DatasetGraphDefBuilder* b,
96                             Node** output) const override {
97     Node* input_graph_node = nullptr;
98     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
99 
100     std::vector<Node*> other_arguments;
101     DataTypeVector other_arguments_types;
102     TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
103                                                   &other_arguments_types));
104 
105     // Attr: f
106     AttrValue f_attr;
107     b->BuildAttrValue(captured_func_->func(), &f_attr);
108 
109     // Attr: Targuments
110     AttrValue other_arguments_types_attr;
111     b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
112 
113     // Attr: use_inter_op_parallelism
114     AttrValue use_inter_op_parallelism_attr;
115     b->BuildAttrValue(captured_func_->use_inter_op_parallelism(),
116                       &use_inter_op_parallelism_attr);
117 
118     // Attr: preserve_cardinality
119     AttrValue preserve_cardinality_attr;
120     b->BuildAttrValue(preserve_cardinality_, &preserve_cardinality_attr);
121 
122     TF_RETURN_IF_ERROR(b->AddDataset(
123         this, {std::make_pair(0, input_graph_node)},  // Single tensor inputs.
124         {std::make_pair(1, other_arguments)},         // Tensor list inputs.
125         {std::make_pair(kFunc, f_attr),
126          std::make_pair(kTarguments, other_arguments_types_attr),
127          std::make_pair(kUseInterOpParallelism, use_inter_op_parallelism_attr),
128          std::make_pair(kPreserveCardinality,
129                         preserve_cardinality_attr)},  // Attrs
130         output));
131     return Status::OK();
132   }
133 
134  private:
135   class Iterator : public DatasetIterator<Dataset> {
136    public:
Iterator(const Params & params)137     explicit Iterator(const Params& params)
138         : DatasetIterator<Dataset>(params) {}
139 
Initialize(IteratorContext * ctx)140     Status Initialize(IteratorContext* ctx) override {
141       TF_RETURN_IF_ERROR(
142           dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
143       return dataset()->captured_func_->Instantiate(
144           ctx, &instantiated_captured_func_);
145     }
146 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)147     Status GetNextInternal(IteratorContext* ctx,
148                            std::vector<Tensor>* out_tensors,
149                            bool* end_of_sequence) override {
150       // NOTE(mrry): This method is thread-safe as long as
151       // `input_impl_` and `f` are thread-safe. However, if multiple
152       // threads enter this method, outputs may be observed in a
153       // non-deterministic order.
154 
155       std::vector<Tensor> args;
156       TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &args, end_of_sequence));
157       if (*end_of_sequence) {
158         return Status::OK();
159       }
160 
161       Status s = instantiated_captured_func_->Run(ctx, std::move(args),
162                                                   out_tensors, model_node());
163       if (errors::IsOutOfRange(s)) {
164         if (dataset()->preserve_cardinality_) {
165           // To guarantee that the transformation preserves the cardinality of
166           // the dataset, we convert `OutOfRange` to `InvalidArgument` as the
167           // former may be interpreted by a caller as the end of sequence.
168           return errors::InvalidArgument(
169               "Function invocation produced OutOfRangeError: ",
170               s.error_message());
171         } else {
172           // `f` may deliberately raise `errors::OutOfRange` to indicate
173           // that we should terminate the iteration early.
174           *end_of_sequence = true;
175           return Status::OK();
176         }
177       } else {
178         return s;
179       }
180     }
181 
182    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const183     std::shared_ptr<model::Node> CreateNode(
184         IteratorContext* ctx, model::Node::Args args) const override {
185       return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1);
186     }
187 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)188     Status SaveInternal(SerializationContext* ctx,
189                         IteratorStateWriter* writer) override {
190       TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
191           dataset()->captured_func_->CheckExternalState()));
192       TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
193       return Status::OK();
194     }
195 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)196     Status RestoreInternal(IteratorContext* ctx,
197                            IteratorStateReader* reader) override {
198       TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
199       return Status::OK();
200     }
201 
202    private:
203     std::unique_ptr<IteratorBase> input_impl_;
204     std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
205   };
206 
207   const DatasetBase* const input_;
208   const bool preserve_cardinality_;
209   const std::unique_ptr<CapturedFunction> captured_func_;
210   const DataTypeVector output_types_;
211   const std::vector<PartialTensorShape> output_shapes_;
212 };
213 
MapDatasetOp(OpKernelConstruction * ctx)214 MapDatasetOp::MapDatasetOp(OpKernelConstruction* ctx)
215     : UnaryDatasetOpKernel(ctx) {
216   FunctionMetadata::Params params;
217   OP_REQUIRES_OK(ctx, ctx->GetAttr(kUseInterOpParallelism,
218                                    &params.use_inter_op_parallelism));
219   OP_REQUIRES_OK(ctx,
220                  FunctionMetadata::Create(ctx, kFunc, params, &func_metadata_));
221   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
222   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
223   OP_REQUIRES_OK(ctx,
224                  ctx->GetAttr(kPreserveCardinality, &preserve_cardinality_));
225 }
226 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)227 void MapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
228                                DatasetBase** output) {
229   std::unique_ptr<CapturedFunction> captured_func;
230   OP_REQUIRES_OK(ctx,
231                  CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
232                                           &captured_func));
233 
234   *output = new Dataset(ctx, input, std::move(captured_func), output_types_,
235                         output_shapes_, preserve_cardinality_);
236 }
237 
238 namespace {
239 
240 REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp);
241 REGISTER_KERNEL_BUILDER(Name("ExperimentalMapDataset")
242                             .Device(DEVICE_GPU)
243                             .HostMemory("input_dataset")
244                             .HostMemory("handle"),
245                         MapDatasetOp);
246 REGISTER_INPUT_COLOCATION_EXEMPTION("MapDataset");
247 
248 }  // namespace
249 }  // namespace data
250 }  // namespace tensorflow
251