1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python dataset sparse tensor utility functitons.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.util import nest 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import sparse_tensor 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.ops import sparse_ops 26 27 28def any_sparse(classes): 29 """Checks for sparse tensor. 30 31 Args: 32 classes: a structure of objects that identify the dataset item classes 33 34 Returns: 35 `True` if `classes` contains a sparse tensor type and `False` otherwise. 36 """ 37 return any(c is sparse_tensor.SparseTensor for c in nest.flatten(classes)) 38 39 40def as_dense_shapes(shapes, classes): 41 """Converts sparse tensor shapes to their physical shapes. 42 43 Args: 44 shapes: a structure of shapes to convert. 45 classes: a structure of objects that identify the dataset item classes 46 47 Returns: 48 a structure matching the nested structure of `shapes`, containing 49 `tensor_shape.unknown_shape()` at positions where `classes` contains 50 `tf.SparseTensor` and matching contents of `shapes` otherwise 51 """ 52 ret = nest.pack_sequence_as(shapes, [ 53 tensor_shape.unknown_shape() if c is sparse_tensor.SparseTensor else shape 54 for shape, c in zip(nest.flatten(shapes), nest.flatten(classes)) 55 ]) 56 return ret 57 58 59def as_dense_types(types, classes): 60 """Converts sparse tensor types to `dtypes.variant`. 61 62 Args: 63 types: a structure of types to convert. 64 classes: a structure of objects that identify the dataset item classes 65 66 Returns: 67 a structure matching the nested structure of `types`, containing 68 `dtypes.variant` at positions where `classes` contains `tf.SparseTensor` and 69 matching contents of `types` otherwise 70 """ 71 ret = nest.pack_sequence_as(types, [ 72 dtypes.variant if c is sparse_tensor.SparseTensor else ty 73 for ty, c in zip(nest.flatten(types), nest.flatten(classes)) 74 ]) 75 return ret 76 77 78def deserialize_sparse_tensors(tensors, types, shapes, classes): 79 """Deserializes sparse tensors. 80 81 Args: 82 tensors: a structure of tensors to deserialize. 83 types: a structure that holds information about types of `tensors` 84 shapes: a structure that holds information about shapes of `tensors` 85 classes: a structure of objects that identify the dataset item classes 86 87 Returns: 88 `tensors` with any serialized sparse tensors replaced by their deserialized 89 version. 90 """ 91 ret = nest.pack_sequence_as(types, [ 92 sparse_ops.deserialize_sparse(tensor, dtype=ty, rank=shape.ndims) 93 if c is sparse_tensor.SparseTensor else tensor 94 for (tensor, ty, shape, c) in zip( 95 nest.flatten(tensors), nest.flatten(types), nest.flatten(shapes), 96 nest.flatten(classes)) 97 ]) 98 return ret 99 100 101def get_classes(tensors): 102 """Gets classes for a structure of tensors. 103 104 Args: 105 tensors: the tensor structure to get classes for. 106 107 Returns: 108 a structure matching the nested structure of `tensors`, containing 109 `tf.SparseTensor` at positions where `tensors` contains a sparse tensor and 110 `tf.Tensor` otherwise 111 """ 112 return nest.pack_sequence_as(tensors, [ 113 sparse_tensor.SparseTensor 114 if isinstance(tensor, sparse_tensor.SparseTensor) else ops.Tensor 115 for tensor in nest.flatten(tensors) 116 ]) 117 118 119def serialize_many_sparse_tensors(tensors): 120 """Serializes many sparse tensors into a batch. 121 122 Args: 123 tensors: a tensor structure to serialize. 124 125 Returns: 126 `tensors` with any sparse tensors replaced by the serialized batch. 127 """ 128 129 ret = nest.pack_sequence_as(tensors, [ 130 sparse_ops.serialize_many_sparse(tensor, out_type=dtypes.variant) 131 if sparse_tensor.is_sparse(tensor) else tensor 132 for tensor in nest.flatten(tensors) 133 ]) 134 return ret 135 136 137def serialize_sparse_tensors(tensors): 138 """Serializes sparse tensors. 139 140 Args: 141 tensors: a tensor structure to serialize. 142 143 Returns: 144 `tensors` with any sparse tensors replaced by their serialized version. 145 """ 146 147 ret = nest.pack_sequence_as(tensors, [ 148 sparse_ops.serialize_sparse(tensor, out_type=dtypes.variant) 149 if isinstance(tensor, sparse_tensor.SparseTensor) else tensor 150 for tensor in nest.flatten(tensors) 151 ]) 152 return ret 153