• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unique element dataset transformations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.data.util import nest
22from tensorflow.python.data.util import sparse
23from tensorflow.python.framework import dtypes
24from tensorflow.python.ops import gen_dataset_ops
25
26
27def unique():
28  """Creates a `Dataset` from another `Dataset`, discarding duplicates.
29
30  Use this transformation to produce a dataset that contains one instance of
31  each unique element in the input. For example:
32
33  ```python
34  dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1])
35
36  # Using `unique()` will drop the duplicate elements.
37  dataset = dataset.apply(tf.contrib.data.unique())  # ==> { 1, 37, 2 }
38  ```
39
40  Returns:
41    A `Dataset` transformation function, which can be passed to
42    @{tf.data.Dataset.apply}.
43  """
44
45  def _apply_fn(dataset):
46    return UniqueDataset(dataset)
47
48  return _apply_fn
49
50
51class UniqueDataset(dataset_ops.Dataset):
52  """A `Dataset` contains the unique elements from its input."""
53
54  def __init__(self, input_dataset):
55    """See `unique()` for details."""
56    super(UniqueDataset, self).__init__()
57    self._input_dataset = input_dataset
58    if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
59                                          dtypes.string):
60      raise TypeError(
61          "`tf.contrib.data.unique()` only supports inputs with a single "
62          "`tf.int32`, `tf.int64`, or `tf.string` component.")
63
64  def _as_variant_tensor(self):
65    return gen_dataset_ops.unique_dataset(
66        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
67        output_shapes=nest.flatten(
68            sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
69        output_types=nest.flatten(
70            sparse.as_dense_types(self.output_types, self.output_classes)))
71
72  @property
73  def output_classes(self):
74    return self._input_dataset.output_classes
75
76  @property
77  def output_shapes(self):
78    return self._input_dataset.output_shapes
79
80  @property
81  def output_types(self):
82    return self._input_dataset.output_types
83