• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The Counter Dataset."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.contrib.data.python.ops import scan_ops
21
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25
26
27def Counter(start=0, step=1, dtype=dtypes.int64):
28  """Creates a `Dataset` of a `step`-separated count startin from `start`.
29
30  For example:
31
32  ```python
33  Dataset.count() == [0, 1, 2, ...)
34  Dataset.count(2) == [2, 3, ...)
35  Dataset.count(2, 5) == [2, 7, 12, ...)
36  Dataset.count(0, -1) == [0, -1, -2, ...)
37  Dataset.count(10, -1) == [10, 9, ...)
38  ```
39
40  Args:
41    start: starting value for count.
42    step: step size.
43    dtype: counter data type.
44
45  Returns:
46    A `Dataset` of scalar elements.
47  """
48  with ops.name_scope("counter"):
49    start = ops.convert_to_tensor(start, dtype=dtype, name="start")
50    step = ops.convert_to_tensor(step, dtype=dtype, name="step")
51    return dataset_ops.Dataset.from_tensors(0).repeat(None).apply(
52        scan_ops.scan(start, lambda state, _: (state + step, state)))
53