1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python dataset sparse tensor utility functions.""" 16from tensorflow.python.data.util import nest 17from tensorflow.python.framework import dtypes 18from tensorflow.python.framework import ops 19from tensorflow.python.framework import sparse_tensor 20from tensorflow.python.framework import tensor_shape 21from tensorflow.python.ops import sparse_ops 22 23 24def any_sparse(classes): 25 """Checks for sparse tensor. 26 27 Args: 28 classes: a structure of objects that identify the dataset item classes 29 30 Returns: 31 `True` if `classes` contains a sparse tensor type and `False` otherwise. 32 """ 33 return any(c is sparse_tensor.SparseTensor for c in nest.flatten(classes)) 34 35 36def as_dense_shapes(shapes, classes): 37 """Converts sparse tensor shapes to their physical shapes. 38 39 Args: 40 shapes: a structure of shapes to convert. 41 classes: a structure of objects that identify the dataset item classes 42 43 Returns: 44 a structure matching the nested structure of `shapes`, containing 45 `tensor_shape.unknown_shape()` at positions where `classes` contains 46 `tf.sparse.SparseTensor` and matching contents of `shapes` otherwise 47 """ 48 ret = nest.pack_sequence_as(shapes, [ 49 tensor_shape.unknown_shape() if c is sparse_tensor.SparseTensor else shape 50 for shape, c in zip(nest.flatten(shapes), nest.flatten(classes)) 51 ]) 52 return ret 53 54 55def as_dense_types(types, classes): 56 """Converts sparse tensor types to `dtypes.variant`. 57 58 Args: 59 types: a structure of types to convert. 60 classes: a structure of objects that identify the dataset item classes 61 62 Returns: 63 a structure matching the nested structure of `types`, containing 64 `dtypes.variant` at positions where `classes` contains 65 `tf.sparse.SparseTensor` and matching contents of `types` otherwise 66 """ 67 ret = nest.pack_sequence_as(types, [ 68 dtypes.variant if c is sparse_tensor.SparseTensor else ty 69 for ty, c in zip(nest.flatten(types), nest.flatten(classes)) 70 ]) 71 return ret 72 73 74def deserialize_sparse_tensors(tensors, types, shapes, classes): 75 """Deserializes sparse tensors. 76 77 Args: 78 tensors: a structure of tensors to deserialize. 79 types: a structure that holds information about types of `tensors` 80 shapes: a structure that holds information about shapes of `tensors` 81 classes: a structure of objects that identify the dataset item classes 82 83 Returns: 84 `tensors` with any serialized sparse tensors replaced by their deserialized 85 version. 86 """ 87 ret = nest.pack_sequence_as(types, [ 88 sparse_ops.deserialize_sparse(tensor, dtype=ty, rank=shape.ndims) 89 if c is sparse_tensor.SparseTensor else tensor 90 for (tensor, ty, shape, c) in zip( 91 nest.flatten(tensors), nest.flatten(types), nest.flatten(shapes), 92 nest.flatten(classes)) 93 ]) 94 return ret 95 96 97def get_classes(tensors): 98 """Gets classes for a structure of tensors. 99 100 Args: 101 tensors: the tensor structure to get classes for. 102 103 Returns: 104 a structure matching the nested structure of `tensors`, containing 105 `tf.sparse.SparseTensor` at positions where `tensors` contains a sparse 106 tensor and `tf.Tensor` otherwise. 107 """ 108 return nest.pack_sequence_as(tensors, [ 109 sparse_tensor.SparseTensor 110 if isinstance(tensor, sparse_tensor.SparseTensor) else ops.Tensor 111 for tensor in nest.flatten(tensors) 112 ]) 113 114 115def serialize_many_sparse_tensors(tensors): 116 """Serializes many sparse tensors into a batch. 117 118 Args: 119 tensors: a tensor structure to serialize. 120 121 Returns: 122 `tensors` with any sparse tensors replaced by the serialized batch. 123 """ 124 125 ret = nest.pack_sequence_as(tensors, [ 126 sparse_ops.serialize_many_sparse(tensor, out_type=dtypes.variant) 127 if sparse_tensor.is_sparse(tensor) else tensor 128 for tensor in nest.flatten(tensors) 129 ]) 130 return ret 131 132 133def serialize_sparse_tensors(tensors): 134 """Serializes sparse tensors. 135 136 Args: 137 tensors: a tensor structure to serialize. 138 139 Returns: 140 `tensors` with any sparse tensors replaced by their serialized version. 141 """ 142 143 ret = nest.pack_sequence_as(tensors, [ 144 sparse_ops.serialize_sparse(tensor, out_type=dtypes.variant) 145 if isinstance(tensor, sparse_tensor.SparseTensor) else tensor 146 for tensor in nest.flatten(tensors) 147 ]) 148 return ret 149