1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrappers for indexed datasets.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21 22from tensorflow.python.data.ops import dataset_ops 23from tensorflow.python.data.util import nest 24from tensorflow.python.data.util import sparse 25from tensorflow.python.data.util import structure 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 29 30 31class MaterializedIndexedDataset(object): 32 """MaterializedIndexedDataset is highly experimental! 33 """ 34 35 def __init__(self, materialized_resource, materializer, output_classes, 36 output_types, output_shapes): 37 self._materialized_resource = materialized_resource 38 self._materializer = materializer 39 self._output_classes = output_classes 40 self._output_types = output_types 41 self._output_shapes = output_shapes 42 43 @property 44 def initializer(self): 45 if self._materializer is not None: 46 return self._materializer 47 raise ValueError("MaterializedDataset does not have a materializer") 48 49 def get(self, index): 50 """Get retrieves a value (or set of values) from the IndexedDataset. 51 52 Args: 53 index: A uint64 scalar or vector tensor with the indices to retrieve. 54 55 Returns: 56 A tensor containing the values corresponding to `index`. 57 """ 58 # TODO(saeta): nest.pack_sequence_as(...) 59 return ged_ops.experimental_indexed_dataset_get( 60 self._materialized_resource, 61 index, 62 output_types=nest.flatten( 63 sparse.as_dense_types(self._output_types, self._output_classes)), 64 output_shapes=nest.flatten( 65 sparse.as_dense_types(self._output_shapes, self._output_classes))) 66 67 68# TODO(saeta): Add a `DatasetV1` wrapper if this is exposed via the public API. 69class IndexedDataset(dataset_ops.Dataset): 70 """IndexedDataset is highly experimental! 71 """ 72 73 def __init__(self): 74 pass 75 76 def materialize(self, shared_name=None, container=None): 77 """Materialize creates a MaterializedIndexedDataset. 78 79 IndexedDatasets can be combined through operations such as TBD. Therefore, 80 they are only materialized when absolutely required. 81 82 Args: 83 shared_name: a string for the shared name to use for the resource. 84 container: a string for the container to store the resource. 85 86 Returns: 87 A MaterializedIndexedDataset. 88 """ 89 if container is None: 90 container = "" 91 if shared_name is None: 92 shared_name = "" 93 materialized_resource = ( 94 ged_ops.experimental_materialized_index_dataset_handle( 95 container=container, 96 shared_name=shared_name, 97 **dataset_ops.flat_structure(self))) 98 99 with ops.colocate_with(materialized_resource): 100 materializer = ged_ops.experimental_indexed_dataset_materialize( 101 self._as_variant_tensor(), materialized_resource) 102 return MaterializedIndexedDataset(materialized_resource, materializer, 103 self.output_classes, self.output_types, 104 self.output_shapes) 105 106 @abc.abstractmethod 107 def _as_variant_tensor(self): 108 """Creates a `tf.variant` `tf.Tensor` representing this IndexedDataset. 109 110 Returns: 111 A scalar `tf.Tensor` of `tf.variant` type, which represents this 112 IndexedDataset. 113 """ 114 raise NotImplementedError("IndexedDataset._as_variant_tensor") 115 116 117# TODO(saeta): Add a `DatasetV1` wrapper if this is exposed via the public API. 118class IdentityIndexedDataset(IndexedDataset): 119 """IdentityIndexedDataset is a trivial indexed dataset used for testing. 120 """ 121 122 def __init__(self, size): 123 super(IdentityIndexedDataset, self).__init__() 124 # TODO(saeta): Verify _size is a scalar! 125 self._size = ops.convert_to_tensor(size, dtype=dtypes.uint64, name="size") 126 127 @property 128 def _element_structure(self): 129 return structure.TensorStructure(dtypes.uint64, []) 130 131 def _as_variant_tensor(self): 132 return ged_ops.experimental_identity_indexed_dataset(self._size) 133 134 def _inputs(self): 135 return [] 136