1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/take_dataset_op.h"
16
17 #include "tensorflow/core/framework/partial_tensor_shape.h"
18 #include "tensorflow/core/framework/tensor.h"
19 #include "tensorflow/core/kernels/data/name_utils.h"
20
21 namespace tensorflow {
22 namespace data {
23
24 /* static */ constexpr const char* const TakeDatasetOp::kDatasetType;
25 /* static */ constexpr const char* const TakeDatasetOp::kInputDataset;
26 /* static */ constexpr const char* const TakeDatasetOp::kCount;
27 /* static */ constexpr const char* const TakeDatasetOp::kOutputTypes;
28 /* static */ constexpr const char* const TakeDatasetOp::kOutputShapes;
29
30 constexpr char kCurIndex[] = "i";
31 constexpr char kInputImplEmpty[] = "input_impl_empty";
32 constexpr char kEmptyTake[] = "EmptyTake";
33 constexpr char kFiniteTake[] = "FiniteTake";
34
TakeDataset(OpKernelContext * ctx,int64 count,const DatasetBase * input)35 TakeDataset::TakeDataset(OpKernelContext* ctx, int64 count,
36 const DatasetBase* input)
37 : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
38 input_->Ref();
39 }
40
TakeDataset(DatasetContext::Params params,int64 count,const DatasetBase * input)41 TakeDataset::TakeDataset(DatasetContext::Params params, int64 count,
42 const DatasetBase* input)
43 : DatasetBase(DatasetContext(std::move(params))),
44 count_(count),
45 input_(input) {
46 input_->Ref();
47 }
48
~TakeDataset()49 TakeDataset::~TakeDataset() { input_->Unref(); }
50
output_dtypes() const51 const DataTypeVector& TakeDataset::output_dtypes() const {
52 return input_->output_dtypes();
53 }
54
output_shapes() const55 const std::vector<PartialTensorShape>& TakeDataset::output_shapes() const {
56 return input_->output_shapes();
57 }
58
DebugString() const59 string TakeDataset::DebugString() const {
60 return name_utils::DatasetDebugString(TakeDatasetOp::kDatasetType);
61 }
62
Cardinality() const63 int64 TakeDataset::Cardinality() const {
64 int64 n = input_->Cardinality();
65 if (n == kUnknownCardinality) {
66 return kUnknownCardinality;
67 }
68 if (n == kInfiniteCardinality) {
69 return count_;
70 } else if (count_ == kInfiniteCardinality) {
71 return n;
72 }
73
74 return std::min(n, count_);
75 }
76
InputDatasets(std::vector<const DatasetBase * > * inputs) const77 Status TakeDataset::InputDatasets(
78 std::vector<const DatasetBase*>* inputs) const {
79 inputs->push_back(input_);
80 return Status::OK();
81 }
82
CheckExternalState() const83 Status TakeDataset::CheckExternalState() const {
84 return input_->CheckExternalState();
85 }
86
87 class TakeDataset::EmptyIterator : public DatasetIterator<TakeDataset> {
88 public:
EmptyIterator(const Params & params)89 explicit EmptyIterator(const Params& params)
90 : DatasetIterator<TakeDataset>(params) {}
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)91 Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
92 bool* end_of_sequence) override {
93 *end_of_sequence = true;
94 return Status::OK();
95 }
96
97 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const98 std::shared_ptr<model::Node> CreateNode(
99 IteratorContext* ctx, model::Node::Args args) const override {
100 return model::MakeKnownRatioNode(std::move(args),
101 /*ratio=*/1);
102 }
103
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)104 Status SaveInternal(SerializationContext* ctx,
105 IteratorStateWriter* writer) override {
106 return Status::OK();
107 }
108
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)109 Status RestoreInternal(IteratorContext* ctx,
110 IteratorStateReader* reader) override {
111 return Status::OK();
112 }
113 };
114
115 class TakeDataset::FiniteIterator : public DatasetIterator<TakeDataset> {
116 public:
FiniteIterator(const Params & params)117 explicit FiniteIterator(const Params& params)
118 : DatasetIterator<TakeDataset>(params), i_(0) {}
119
Initialize(IteratorContext * ctx)120 Status Initialize(IteratorContext* ctx) override {
121 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
122 }
123
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)124 Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
125 bool* end_of_sequence) override {
126 mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
127 if (!input_impl_) {
128 *end_of_sequence = true;
129 return Status::OK();
130 }
131 while (dataset()->count_ < 0 || i_ < dataset()->count_) {
132 TF_RETURN_IF_ERROR(
133 input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
134 if (!*end_of_sequence) {
135 ++i_;
136 return Status::OK();
137 }
138 break;
139 }
140 *end_of_sequence = true;
141 input_impl_.reset();
142 return Status::OK();
143 }
144
145 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const146 std::shared_ptr<model::Node> CreateNode(
147 IteratorContext* ctx, model::Node::Args args) const override {
148 return model::MakeKnownRatioNode(std::move(args),
149 /*ratio=*/1);
150 }
151
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)152 Status SaveInternal(SerializationContext* ctx,
153 IteratorStateWriter* writer) override {
154 mutex_lock l(mu_);
155 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_));
156 if (input_impl_) {
157 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
158 } else {
159 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
160 }
161 return Status::OK();
162 }
163
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)164 Status RestoreInternal(IteratorContext* ctx,
165 IteratorStateReader* reader) override {
166 mutex_lock l(mu_);
167 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIndex), &i_));
168 if (!reader->Contains(full_name(kInputImplEmpty))) {
169 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
170 } else {
171 input_impl_.reset();
172 }
173 return Status::OK();
174 }
175
176 private:
177 mutex mu_;
178 int64 i_ TF_GUARDED_BY(mu_);
179 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
180 };
181
182 // See documentation in ../../ops/dataset_ops.cc for a high-level
183 // description of the following op.
MakeIteratorInternal(const string & prefix) const184 std::unique_ptr<IteratorBase> TakeDataset::MakeIteratorInternal(
185 const string& prefix) const {
186 if (count_ == 0) {
187 return absl::make_unique<EmptyIterator>(EmptyIterator::Params{
188 this, name_utils::IteratorPrefix(kEmptyTake, prefix)});
189 } else {
190 return absl::make_unique<FiniteIterator>(FiniteIterator::Params{
191 this, name_utils::IteratorPrefix(kFiniteTake, prefix)});
192 }
193 }
194
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const195 Status TakeDataset::AsGraphDefInternal(SerializationContext* ctx,
196 DatasetGraphDefBuilder* b,
197 Node** output) const {
198 Node* input_graph_node = nullptr;
199 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
200 Node* count = nullptr;
201 TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
202 TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, count}, output));
203 return Status::OK();
204 }
205
TakeDatasetOp(OpKernelConstruction * ctx)206 TakeDatasetOp::TakeDatasetOp(OpKernelConstruction* ctx)
207 : UnaryDatasetOpKernel(ctx) {}
208
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)209 void TakeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
210 DatasetBase** output) {
211 // Create a new TakeDatasetOp::Dataset, and return it as the output.
212 int64 count;
213 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kCount, &count));
214 *output = new TakeDataset(ctx, count, input);
215 }
216
217 namespace {
218 REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp);
219 } // namespace
220 } // namespace data
221 } // namespace tensorflow
222