1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/common_runtime/function.h" 16 #include "tensorflow/core/framework/partial_tensor_shape.h" 17 #include "tensorflow/core/framework/tensor.h" 18 #include "tensorflow/core/kernels/data/captured_function.h" 19 #include "tensorflow/core/kernels/data/dataset.h" 20 #include "tensorflow/core/lib/random/random.h" 21 22 namespace tensorflow { 23 24 namespace { 25 26 // See documentation in ../ops/dataset_ops.cc for a high-level 27 // description of the following op. 28 29 class MapDatasetOp : public UnaryDatasetOpKernel { 30 public: MapDatasetOp(OpKernelConstruction * ctx)31 explicit MapDatasetOp(OpKernelConstruction* ctx) 32 : UnaryDatasetOpKernel(ctx), 33 graph_def_version_(ctx->graph_def_version()) { 34 OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); 35 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); 36 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); 37 } 38 MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)39 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 40 DatasetBase** output) override { 41 OpInputList inputs; 42 OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); 43 std::vector<Tensor> other_arguments; 44 other_arguments.reserve(inputs.size()); 45 for (const Tensor& t : inputs) { 46 other_arguments.push_back(t); 47 } 48 49 std::unique_ptr<CapturedFunction> captured_func; 50 OP_REQUIRES_OK(ctx, CapturedFunction::Create( 51 func_, std::move(other_arguments), &captured_func)); 52 53 *output = new Dataset(ctx, input, func_, std::move(captured_func), 54 output_types_, output_shapes_); 55 } 56 57 private: 58 class Dataset : public GraphDatasetBase { 59 public: Dataset(OpKernelContext * ctx,const DatasetBase * input,const NameAttrList & func,std::unique_ptr<CapturedFunction> captured_func,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)60 Dataset(OpKernelContext* ctx, const DatasetBase* input, 61 const NameAttrList& func, 62 std::unique_ptr<CapturedFunction> captured_func, 63 const DataTypeVector& output_types, 64 const std::vector<PartialTensorShape>& output_shapes) 65 : GraphDatasetBase(ctx), 66 input_(input), 67 func_(func), 68 captured_func_(std::move(captured_func)), 69 output_types_(output_types), 70 output_shapes_(output_shapes) { 71 input_->Ref(); 72 } 73 ~Dataset()74 ~Dataset() override { input_->Unref(); } 75 MakeIterator(const string & prefix) const76 std::unique_ptr<IteratorBase> MakeIterator( 77 const string& prefix) const override { 78 return std::unique_ptr<IteratorBase>( 79 new Iterator({this, strings::StrCat(prefix, "::Map")})); 80 } 81 output_dtypes() const82 const DataTypeVector& output_dtypes() const override { 83 return output_types_; 84 } output_shapes() const85 const std::vector<PartialTensorShape>& output_shapes() const override { 86 return output_shapes_; 87 } 88 DebugString()89 string DebugString() override { return "MapDatasetOp::Dataset"; } 90 91 protected: AsGraphDefInternal(OpKernelContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const92 Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, 93 Node** output) const override { 94 TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); 95 Node* input_graph_node = nullptr; 96 TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); 97 98 DataTypeVector other_arguments_types; 99 other_arguments_types.reserve(captured_func_->captured_inputs().size()); 100 std::vector<Node*> other_arguments; 101 other_arguments.reserve(captured_func_->captured_inputs().size()); 102 for (const Tensor& t : captured_func_->captured_inputs()) { 103 Node* node; 104 TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); 105 other_arguments.emplace_back(node); 106 other_arguments_types.emplace_back(t.dtype()); 107 } 108 AttrValue f; 109 b->BuildAttrValue(func_, &f); 110 AttrValue other_arguments_types_attr; 111 b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); 112 113 TF_RETURN_IF_ERROR(b->AddDataset( 114 this, {std::make_pair(0, input_graph_node)}, // Single tensor inputs. 115 {std::make_pair(1, other_arguments)}, // Tensor list inputs. 116 {std::make_pair("f", f), 117 std::make_pair("Targuments", other_arguments_types_attr)}, // Attrs 118 output)); 119 return Status::OK(); 120 } 121 122 private: 123 class Iterator : public DatasetIterator<Dataset> { 124 public: Iterator(const Params & params)125 explicit Iterator(const Params& params) 126 : DatasetIterator<Dataset>(params), 127 input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} 128 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)129 Status GetNextInternal(IteratorContext* ctx, 130 std::vector<Tensor>* out_tensors, 131 bool* end_of_sequence) override { 132 // NOTE(mrry): This method is thread-safe as long as 133 // `input_impl_` and `f` are thread-safe. However, if multiple 134 // threads enter this method, outputs may be observed in a 135 // non-deterministic order. 136 137 std::vector<Tensor> args; 138 TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &args, end_of_sequence)); 139 if (*end_of_sequence) { 140 return Status::OK(); 141 } 142 143 // TODO(mrry): Avoid blocking a threadpool thread. We will need to 144 // stack-rip the iterators and use async kernels. 145 Status s = 146 dataset()->captured_func_->Run(ctx, std::move(args), out_tensors); 147 if (errors::IsOutOfRange(s)) { 148 // `f` may deliberately raise `errors::OutOfRange` to indicate 149 // that we should terminate the iteration early. 150 *end_of_sequence = true; 151 return Status::OK(); 152 } else { 153 return s; 154 } 155 } 156 157 protected: SaveInternal(IteratorStateWriter * writer)158 Status SaveInternal(IteratorStateWriter* writer) override { 159 TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); 160 return Status::OK(); 161 } 162 RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)163 Status RestoreInternal(IteratorContext* ctx, 164 IteratorStateReader* reader) override { 165 TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); 166 return Status::OK(); 167 } 168 169 private: 170 const std::unique_ptr<IteratorBase> input_impl_; 171 }; 172 173 const DatasetBase* const input_; 174 const NameAttrList func_; 175 const std::unique_ptr<CapturedFunction> captured_func_; 176 const DataTypeVector output_types_; 177 const std::vector<PartialTensorShape> output_shapes_; 178 }; 179 180 const int graph_def_version_; 181 DataTypeVector output_types_; 182 std::vector<PartialTensorShape> output_shapes_; 183 NameAttrList func_; 184 }; 185 186 REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp); 187 188 } // namespace 189 190 } // namespace tensorflow 191