• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Experimental shuffle ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.data.util import nest
22from tensorflow.python.data.util import sparse
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import random_seed
27from tensorflow.python.ops import gen_dataset_ops
28
29
30class _ShuffleAndRepeatDataset(dataset_ops.Dataset):
31  """A `Dataset` that fuses `shuffle` and `repeat`."""
32
33  def __init__(self,
34               input_dataset,
35               buffer_size,
36               count=None,
37               seed=None):
38    """See `Dataset.map()` for details."""
39    super(_ShuffleAndRepeatDataset, self).__init__()
40    self._input_dataset = input_dataset
41    self._buffer_size = ops.convert_to_tensor(
42        buffer_size, dtype=dtypes.int64, name="buffer_size")
43    if count is None:
44      self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
45    else:
46      self._count = ops.convert_to_tensor(
47          count, dtype=dtypes.int64, name="count")
48
49    seed, seed2 = random_seed.get_seed(seed)
50    if seed is None:
51      self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed")
52    else:
53      self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed")
54    if seed2 is None:
55      self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2")
56    else:
57      self._seed2 = ops.convert_to_tensor(
58          seed2, dtype=dtypes.int64, name="seed2")
59
60  def _as_variant_tensor(self):
61    # pylint: disable=protected-access
62    input_resource = self._input_dataset._as_variant_tensor()
63    return gen_dataset_ops.shuffle_and_repeat_dataset(
64        input_resource,
65        buffer_size=self._buffer_size,
66        count=self._count,
67        seed=self._seed,
68        seed2=self._seed2,
69        output_types=nest.flatten(
70            sparse.as_dense_types(self.output_types, self.output_classes)),
71        output_shapes=nest.flatten(
72            sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
73    # pylint: enable=protected-access
74
75  @property
76  def output_classes(self):
77    return self._input_dataset.output_classes
78
79  @property
80  def output_shapes(self):
81    return self._input_dataset.output_shapes
82
83  @property
84  def output_types(self):
85    return self._input_dataset.output_types
86
87
88def shuffle_and_repeat(buffer_size, count=None, seed=None):
89  """Shuffles and repeats a Dataset returning a new permutation for each epoch.
90
91  `dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size, count))`
92
93  is equivalent to
94
95  `dataset.shuffle(buffer_size, reshuffle_each_iteration=True).repeat(count)`
96
97  The difference is that the latter dataset is not serializable. So,
98  if you need to checkpoint an input pipeline with reshuffling you must use
99  this implementation.
100
101  Args:
102    buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
103      maximum number elements that will be buffered when prefetching.
104    count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
105      number of times the dataset should be repeated. The default behavior
106      (if `count` is `None` or `-1`) is for the dataset be repeated
107      indefinitely.
108    seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
109      random seed that will be used to create the distribution. See
110      @{tf.set_random_seed} for behavior.
111
112  Returns:
113    A `Dataset` transformation function, which can be passed to
114    @{tf.data.Dataset.apply}.
115  """
116
117  def _apply_fn(dataset):  # pylint: disable=missing-docstring
118    return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed)
119
120  return _apply_fn
121