1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/range_dataset_op.h"
16
17 #include "absl/memory/memory.h"
18 #include "tensorflow/core/framework/dataset.h"
19 #include "tensorflow/core/framework/partial_tensor_shape.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/kernels/data/name_utils.h"
22
23 namespace tensorflow {
24 namespace data {
25
26 // See documentation in ../../ops/dataset_ops.cc for a high-level
27 // description of the following op.
28
29 /* static */ constexpr const char* const RangeDatasetOp::kDatasetType;
30 /* static */ constexpr const char* const RangeDatasetOp::kStart;
31 /* static */ constexpr const char* const RangeDatasetOp::kStop;
32 /* static */ constexpr const char* const RangeDatasetOp::kStep;
33 /* static */ constexpr const char* const RangeDatasetOp::kOutputTypes;
34 /* static */ constexpr const char* const RangeDatasetOp::kOutputShapes;
35
36 namespace {
37 constexpr char kNext[] = "next";
38 constexpr char kHasSplitProvider[] = "has_split_provider";
39 constexpr char kSlash[] = "/";
40 constexpr char kSplitProvider[] = "split_provider";
41
42 // Class which produces the elements of `range(start, stop, step)`. Threadsafe.
43 class RangeCounter {
44 public:
RangeCounter(int64 start,int64 stop,int64 step)45 RangeCounter(int64 start, int64 stop, int64 step)
46 : start_(start), stop_(stop), step_(step), next_(start) {}
47
48 // Returns the next value for the counter. Sets `*end_of_counter` to indicate
49 // whether the end of the counter was reached.
GetNext(bool * end_of_counter)50 int64 GetNext(bool* end_of_counter) {
51 mutex_lock l(mu_);
52 if ((step_ > 0 && next_ >= stop_) || (step_ < 0 && next_ <= stop_)) {
53 *end_of_counter = true;
54 return -1;
55 }
56 *end_of_counter = false;
57 int result = next_;
58 next_ += step_;
59 return result;
60 }
61
Peek() const62 int64 Peek() const {
63 mutex_lock l(mu_);
64 return next_;
65 }
66
Reset()67 void Reset() {
68 mutex_lock l(mu_);
69 next_ = start_;
70 }
71
SetNext(int64 value)72 void SetNext(int64 value) {
73 mutex_lock l(mu_);
74 next_ = value;
75 }
76
77 private:
78 const int64 start_;
79 const int64 stop_;
80 const int64 step_;
81 mutable mutex mu_;
82 int64 next_ TF_GUARDED_BY(mu_);
83 };
84 } // namespace
85
86 // Split provider where splits are individual outputs from RangeDataset.
87 // For example, the "splits" of range(0, 10, 2) will be {0, 2, 4, 6, 8}.
88 // The split tensors are scalars of type DT_INT64.
89 class RangeDatasetOp::RangeSplitProvider : public SplitProvider {
90 public:
RangeSplitProvider(int64 start,int64 stop,int64 step)91 RangeSplitProvider(int64 start, int64 stop, int64 step)
92 : counter_(start, stop, step) {}
93
GetNext(Tensor * split,bool * end_of_splits)94 Status GetNext(Tensor* split, bool* end_of_splits) override {
95 int64 next = counter_.GetNext(end_of_splits);
96 if (*end_of_splits) {
97 return Status::OK();
98 }
99 *split = Tensor(DT_INT64, TensorShape{});
100 split->scalar<int64>()() = next;
101 return Status::OK();
102 }
103
Reset()104 Status Reset() override {
105 counter_.Reset();
106 return Status::OK();
107 }
108
Save(std::function<std::string (std::string)> key_name_fn,IteratorStateWriter * writer)109 Status Save(std::function<std::string(std::string)> key_name_fn,
110 IteratorStateWriter* writer) override {
111 TF_RETURN_IF_ERROR(
112 writer->WriteScalar(key_name_fn(kNext), counter_.Peek()));
113 return Status::OK();
114 }
115
Restore(std::function<std::string (std::string)> key_name_fn,IteratorStateReader * reader)116 Status Restore(std::function<std::string(std::string)> key_name_fn,
117 IteratorStateReader* reader) override {
118 int64 next;
119 TF_RETURN_IF_ERROR(reader->ReadScalar(key_name_fn(kNext), &next));
120 counter_.SetNext(next);
121 return Status::OK();
122 }
123
124 private:
125 RangeCounter counter_;
126 };
127
128 class RangeDatasetOp::Dataset : public DatasetBase {
129 public:
Dataset(OpKernelContext * ctx,int64 start,int64 stop,int64 step,DataTypeVector output_dtypes)130 Dataset(OpKernelContext* ctx, int64 start, int64 stop, int64 step,
131 DataTypeVector output_dtypes)
132 : DatasetBase(DatasetContext(ctx)),
133 start_(start),
134 stop_(stop),
135 step_(step),
136 output_dtypes_(output_dtypes) {}
137
MakeIteratorInternal(const string & prefix) const138 std::unique_ptr<IteratorBase> MakeIteratorInternal(
139 const string& prefix) const override {
140 return absl::make_unique<Iterator>(Iterator::Params{
141 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
142 }
143
output_dtypes() const144 const DataTypeVector& output_dtypes() const override {
145 return output_dtypes_;
146 }
147
output_shapes() const148 const std::vector<PartialTensorShape>& output_shapes() const override {
149 static std::vector<PartialTensorShape>* shapes =
150 new std::vector<PartialTensorShape>({PartialTensorShape({})});
151 return *shapes;
152 }
153
DebugString() const154 string DebugString() const override {
155 name_utils::DatasetDebugStringParams params;
156 params.set_args(start_, stop_, step_);
157 return name_utils::DatasetDebugString(kDatasetType, params);
158 }
159
Cardinality() const160 int64 Cardinality() const override {
161 if (step_ > 0) {
162 return std::max(int64{0}, (stop_ - start_ - 1) / step_ + 1);
163 } else {
164 return std::max(int64{0}, (start_ - stop_ - 1) / -step_ + 1);
165 }
166 }
167
MakeSplitProvider(std::unique_ptr<SplitProvider> * split_provider) const168 Status MakeSplitProvider(
169 std::unique_ptr<SplitProvider>* split_provider) const override {
170 *split_provider =
171 absl::make_unique<RangeSplitProvider>(start_, stop_, step_);
172 return Status::OK();
173 }
174
InputDatasets(std::vector<const DatasetBase * > * inputs) const175 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
176 inputs->clear();
177 return Status::OK();
178 }
179
CheckExternalState() const180 Status CheckExternalState() const override { return Status::OK(); }
181
182 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const183 Status AsGraphDefInternal(SerializationContext* ctx,
184 DatasetGraphDefBuilder* b,
185 Node** output) const override {
186 Node* start = nullptr;
187 Node* stop = nullptr;
188 Node* step = nullptr;
189 TF_RETURN_IF_ERROR(b->AddScalar(start_, &start));
190 TF_RETURN_IF_ERROR(b->AddScalar(stop_, &stop));
191 TF_RETURN_IF_ERROR(b->AddScalar(step_, &step));
192 TF_RETURN_IF_ERROR(b->AddDataset(this, {start, stop, step}, output));
193 return Status::OK();
194 }
195
196 private:
197 class Iterator : public DatasetIterator<Dataset> {
198 public:
Iterator(const Params & params)199 explicit Iterator(const Params& params)
200 : DatasetIterator<Dataset>(params) {}
201
Initialize(IteratorContext * ctx)202 Status Initialize(IteratorContext* ctx) override {
203 split_provider_ = ctx->split_provider();
204 if (!split_provider_) {
205 counter_ = absl::make_unique<RangeCounter>(
206 dataset()->start_, dataset()->stop_, dataset()->step_);
207 }
208 return Status::OK();
209 }
210
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)211 Status GetNextInternal(IteratorContext* ctx,
212 std::vector<Tensor>* out_tensors,
213 bool* end_of_sequence) override {
214 int64 value;
215 if (split_provider_ != nullptr) {
216 Tensor split;
217 TF_RETURN_IF_ERROR(split_provider_->GetNext(&split, end_of_sequence));
218 if (*end_of_sequence) {
219 return Status::OK();
220 }
221 value = split.scalar<int64>()();
222 } else {
223 value = counter_->GetNext(end_of_sequence);
224 if (*end_of_sequence) {
225 return Status::OK();
226 }
227 }
228 out_tensors->reserve(1);
229 switch (dataset()->output_dtypes()[0]) {
230 #define HANDLE_TYPE(type) \
231 case DataTypeToEnum<type>::value: { \
232 out_tensors->emplace_back(static_cast<type>(value)); \
233 break; \
234 }
235 TF_CALL_NUMBER_TYPES(HANDLE_TYPE);
236 #undef HANDLE_TYPE
237 default:
238 return errors::InvalidArgument(
239 "Unsupported data type: ",
240 DataTypeString(dataset()->output_dtypes()[0]));
241 }
242 return Status::OK();
243 }
244
245 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const246 std::shared_ptr<model::Node> CreateNode(
247 IteratorContext* ctx, model::Node::Args args) const override {
248 return model::MakeSourceNode(std::move(args));
249 }
250
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)251 Status SaveInternal(SerializationContext* ctx,
252 IteratorStateWriter* writer) override {
253 if (split_provider_) {
254 TF_RETURN_IF_ERROR(
255 writer->WriteScalar(full_name(kHasSplitProvider), true));
256 TF_RETURN_IF_ERROR(split_provider_->Save(
257 [this](const std::string& key) {
258 return SplitProviderKeyNameFn(key);
259 },
260 writer));
261 } else {
262 TF_RETURN_IF_ERROR(
263 writer->WriteScalar(full_name(kNext), counter_->Peek()));
264 }
265 return Status::OK();
266 }
267
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)268 Status RestoreInternal(IteratorContext* ctx,
269 IteratorStateReader* reader) override {
270 if (reader->Contains(full_name(kHasSplitProvider))) {
271 TF_RETURN_IF_ERROR(split_provider_->Restore(
272 [this](const std::string& key) {
273 return SplitProviderKeyNameFn(key);
274 },
275 reader));
276 } else {
277 int64 next;
278 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNext), &next));
279 counter_->SetNext(next);
280 }
281 return Status::OK();
282 }
283
SplitProviderKeyNameFn(const std::string & key)284 std::string SplitProviderKeyNameFn(const std::string& key) {
285 return full_name(absl::StrCat(kSplitProvider, kSlash, key));
286 }
287
288 private:
289 std::unique_ptr<RangeCounter> counter_;
290 std::shared_ptr<SplitProvider> split_provider_;
291 };
292
293 const int64 start_;
294 const int64 stop_;
295 const int64 step_;
296 const DataTypeVector output_dtypes_;
297 };
298
RangeDatasetOp(OpKernelConstruction * ctx)299 RangeDatasetOp::RangeDatasetOp(OpKernelConstruction* ctx)
300 : DatasetOpKernel(ctx) {
301 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
302 }
303
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)304 void RangeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) {
305 int64 start;
306 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kStart, &start));
307
308 int64 stop;
309 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kStop, &stop));
310
311 int64 step;
312 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kStep, &step));
313 OP_REQUIRES(ctx, step != 0,
314 errors::InvalidArgument("step must be a non-zero integer."));
315
316 *output = new Dataset(ctx, start, stop, step, output_types_);
317 }
318
319 namespace {
320 REGISTER_KERNEL_BUILDER(Name("RangeDataset").Device(DEVICE_CPU),
321 RangeDatasetOp);
322 } // namespace
323
324 } // namespace data
325 } // namespace tensorflow
326