• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for indexed datasets."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.data.util import nest
24from tensorflow.python.data.util import sparse
25from tensorflow.python.data.util import structure
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
29
30
31class MaterializedIndexedDataset(object):
32  """MaterializedIndexedDataset is highly experimental!
33  """
34
35  def __init__(self, materialized_resource, materializer, output_classes,
36               output_types, output_shapes):
37    self._materialized_resource = materialized_resource
38    self._materializer = materializer
39    self._output_classes = output_classes
40    self._output_types = output_types
41    self._output_shapes = output_shapes
42
43  @property
44  def initializer(self):
45    if self._materializer is not None:
46      return self._materializer
47    raise ValueError("MaterializedDataset does not have a materializer")
48
49  def get(self, index):
50    """Get retrieves a value (or set of values) from the IndexedDataset.
51
52    Args:
53      index: A uint64 scalar or vector tensor with the indices to retrieve.
54
55    Returns:
56      A tensor containing the values corresponding to `index`.
57    """
58    # TODO(saeta): nest.pack_sequence_as(...)
59    return ged_ops.experimental_indexed_dataset_get(
60        self._materialized_resource,
61        index,
62        output_types=nest.flatten(
63            sparse.as_dense_types(self._output_types, self._output_classes)),
64        output_shapes=nest.flatten(
65            sparse.as_dense_types(self._output_shapes, self._output_classes)))
66
67
68# TODO(saeta): Add a `DatasetV1` wrapper if this is exposed via the public API.
69class IndexedDataset(dataset_ops.Dataset):
70  """IndexedDataset is highly experimental!
71  """
72
73  def __init__(self):
74    pass
75
76  def materialize(self, shared_name=None, container=None):
77    """Materialize creates a MaterializedIndexedDataset.
78
79    IndexedDatasets can be combined through operations such as TBD. Therefore,
80    they are only materialized when absolutely required.
81
82    Args:
83      shared_name: a string for the shared name to use for the resource.
84      container: a string for the container to store the resource.
85
86    Returns:
87      A MaterializedIndexedDataset.
88    """
89    if container is None:
90      container = ""
91    if shared_name is None:
92      shared_name = ""
93    materialized_resource = (
94        ged_ops.experimental_materialized_index_dataset_handle(
95            container=container,
96            shared_name=shared_name,
97            **dataset_ops.flat_structure(self)))
98
99    with ops.colocate_with(materialized_resource):
100      materializer = ged_ops.experimental_indexed_dataset_materialize(
101          self._as_variant_tensor(), materialized_resource)
102    return MaterializedIndexedDataset(materialized_resource, materializer,
103                                      self.output_classes, self.output_types,
104                                      self.output_shapes)
105
106  @abc.abstractmethod
107  def _as_variant_tensor(self):
108    """Creates a `tf.variant` `tf.Tensor` representing this IndexedDataset.
109
110    Returns:
111      A scalar `tf.Tensor` of `tf.variant` type, which represents this
112      IndexedDataset.
113    """
114    raise NotImplementedError("IndexedDataset._as_variant_tensor")
115
116
117# TODO(saeta): Add a `DatasetV1` wrapper if this is exposed via the public API.
118class IdentityIndexedDataset(IndexedDataset):
119  """IdentityIndexedDataset is a trivial indexed dataset used for testing.
120  """
121
122  def __init__(self, size):
123    super(IdentityIndexedDataset, self).__init__()
124    # TODO(saeta): Verify _size is a scalar!
125    self._size = ops.convert_to_tensor(size, dtype=dtypes.uint64, name="size")
126
127  @property
128  def _element_structure(self):
129    return structure.TensorStructure(dtypes.uint64, [])
130
131  def _as_variant_tensor(self):
132    return ged_ops.experimental_identity_indexed_dataset(self._size)
133
134  def _inputs(self):
135    return []
136