• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python dataset sparse tensor utility functions."""
16from tensorflow.python.data.util import nest
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.framework import sparse_tensor
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.ops import sparse_ops
22
23
24def any_sparse(classes):
25  """Checks for sparse tensor.
26
27  Args:
28    classes: a structure of objects that identify the dataset item classes
29
30  Returns:
31    `True` if `classes` contains a sparse tensor type and `False` otherwise.
32  """
33  return any(c is sparse_tensor.SparseTensor for c in nest.flatten(classes))
34
35
36def as_dense_shapes(shapes, classes):
37  """Converts sparse tensor shapes to their physical shapes.
38
39  Args:
40    shapes: a structure of shapes to convert.
41    classes: a structure of objects that identify the dataset item classes
42
43  Returns:
44    a structure matching the nested structure of `shapes`, containing
45    `tensor_shape.unknown_shape()` at positions where `classes` contains
46    `tf.sparse.SparseTensor` and matching contents of `shapes` otherwise
47  """
48  ret = nest.pack_sequence_as(shapes, [
49      tensor_shape.unknown_shape() if c is sparse_tensor.SparseTensor else shape
50      for shape, c in zip(nest.flatten(shapes), nest.flatten(classes))
51  ])
52  return ret
53
54
55def as_dense_types(types, classes):
56  """Converts sparse tensor types to `dtypes.variant`.
57
58  Args:
59    types: a structure of types to convert.
60    classes: a structure of objects that identify the dataset item classes
61
62  Returns:
63    a structure matching the nested structure of `types`, containing
64    `dtypes.variant` at positions where `classes` contains
65    `tf.sparse.SparseTensor` and matching contents of `types` otherwise
66  """
67  ret = nest.pack_sequence_as(types, [
68      dtypes.variant if c is sparse_tensor.SparseTensor else ty
69      for ty, c in zip(nest.flatten(types), nest.flatten(classes))
70  ])
71  return ret
72
73
74def deserialize_sparse_tensors(tensors, types, shapes, classes):
75  """Deserializes sparse tensors.
76
77  Args:
78    tensors: a structure of tensors to deserialize.
79    types: a structure that holds information about types of `tensors`
80    shapes: a structure that holds information about shapes of `tensors`
81    classes: a structure of objects that identify the dataset item classes
82
83  Returns:
84    `tensors` with any serialized sparse tensors replaced by their deserialized
85    version.
86  """
87  ret = nest.pack_sequence_as(types, [
88      sparse_ops.deserialize_sparse(tensor, dtype=ty, rank=shape.ndims)
89      if c is sparse_tensor.SparseTensor else tensor
90      for (tensor, ty, shape, c) in zip(
91          nest.flatten(tensors), nest.flatten(types), nest.flatten(shapes),
92          nest.flatten(classes))
93  ])
94  return ret
95
96
97def get_classes(tensors):
98  """Gets classes for a structure of tensors.
99
100  Args:
101    tensors: the tensor structure to get classes for.
102
103  Returns:
104    a structure matching the nested structure of `tensors`, containing
105    `tf.sparse.SparseTensor` at positions where `tensors` contains a sparse
106    tensor and `tf.Tensor` otherwise.
107  """
108  return nest.pack_sequence_as(tensors, [
109      sparse_tensor.SparseTensor
110      if isinstance(tensor, sparse_tensor.SparseTensor) else ops.Tensor
111      for tensor in nest.flatten(tensors)
112  ])
113
114
115def serialize_many_sparse_tensors(tensors):
116  """Serializes many sparse tensors into a batch.
117
118  Args:
119    tensors: a tensor structure to serialize.
120
121  Returns:
122    `tensors` with any sparse tensors replaced by the serialized batch.
123  """
124
125  ret = nest.pack_sequence_as(tensors, [
126      sparse_ops.serialize_many_sparse(tensor, out_type=dtypes.variant)
127      if sparse_tensor.is_sparse(tensor) else tensor
128      for tensor in nest.flatten(tensors)
129  ])
130  return ret
131
132
133def serialize_sparse_tensors(tensors):
134  """Serializes sparse tensors.
135
136  Args:
137    tensors: a tensor structure to serialize.
138
139  Returns:
140    `tensors` with any sparse tensors replaced by their serialized version.
141  """
142
143  ret = nest.pack_sequence_as(tensors, [
144      sparse_ops.serialize_sparse(tensor, out_type=dtypes.variant)
145      if isinstance(tensor, sparse_tensor.SparseTensor) else tensor
146      for tensor in nest.flatten(tensors)
147  ])
148  return ret
149