1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Sparse tensors.""" 16# pylint: disable=g-bad-name 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23import numpy as np 24 25from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import 26from tensorflow.python import _pywrap_utils 27from tensorflow.python import tf2 28from tensorflow.python.framework import composite_tensor 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import tensor_like 33from tensorflow.python.framework import tensor_shape 34from tensorflow.python.framework import tensor_spec 35from tensorflow.python.framework import tensor_util 36from tensorflow.python.framework import type_spec 37from tensorflow.python.ops import gen_sparse_ops 38from tensorflow.python.util.tf_export import tf_export 39 40# pylint: disable=protected-access 41_TensorLike = tensor_like._TensorLike 42_eval_using_default_session = ops._eval_using_default_session 43_override_helper = ops._override_helper 44# pylint: enable=protected-access 45 46 47@tf_export("sparse.SparseTensor", "SparseTensor") 48class SparseTensor(_TensorLike, composite_tensor.CompositeTensor): 49 """Represents a sparse tensor. 50 51 TensorFlow represents a sparse tensor as three separate dense tensors: 52 `indices`, `values`, and `dense_shape`. In Python, the three tensors are 53 collected into a `SparseTensor` class for ease of use. If you have separate 54 `indices`, `values`, and `dense_shape` tensors, wrap them in a `SparseTensor` 55 object before passing to the ops below. 56 57 Concretely, the sparse tensor `SparseTensor(indices, values, dense_shape)` 58 comprises the following components, where `N` and `ndims` are the number 59 of values and number of dimensions in the `SparseTensor`, respectively: 60 61 * `indices`: A 2-D int64 tensor of shape `[N, ndims]`, which specifies the 62 indices of the elements in the sparse tensor that contain nonzero values 63 (elements are zero-indexed). For example, `indices=[[1,3], [2,4]]` specifies 64 that the elements with indexes of [1,3] and [2,4] have nonzero values. 65 66 * `values`: A 1-D tensor of any type and shape `[N]`, which supplies the 67 values for each element in `indices`. For example, given `indices=[[1,3], 68 [2,4]]`, the parameter `values=[18, 3.6]` specifies that element [1,3] of 69 the sparse tensor has a value of 18, and element [2,4] of the tensor has a 70 value of 3.6. 71 72 * `dense_shape`: A 1-D int64 tensor of shape `[ndims]`, which specifies the 73 dense_shape of the sparse tensor. Takes a list indicating the number of 74 elements in each dimension. For example, `dense_shape=[3,6]` specifies a 75 two-dimensional 3x6 tensor, `dense_shape=[2,3,4]` specifies a 76 three-dimensional 2x3x4 tensor, and `dense_shape=[9]` specifies a 77 one-dimensional tensor with 9 elements. 78 79 The corresponding dense tensor satisfies: 80 81 ```python 82 dense.shape = dense_shape 83 dense[tuple(indices[i])] = values[i] 84 ``` 85 86 By convention, `indices` should be sorted in row-major order (or equivalently 87 lexicographic order on the tuples `indices[i]`). This is not enforced when 88 `SparseTensor` objects are constructed, but most ops assume correct ordering. 89 If the ordering of sparse tensor `st` is wrong, a fixed version can be 90 obtained by calling `tf.sparse.reorder(st)`. 91 92 Example: The sparse tensor 93 94 ```python 95 SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) 96 ``` 97 98 represents the dense tensor 99 100 ```python 101 [[1, 0, 0, 0] 102 [0, 0, 2, 0] 103 [0, 0, 0, 0]] 104 ``` 105 """ 106 107 @classmethod 108 def from_value(cls, sparse_tensor_value): 109 if not is_sparse(sparse_tensor_value): 110 raise TypeError("Neither a SparseTensor nor SparseTensorValue: %s." % 111 sparse_tensor_value) 112 return SparseTensor( 113 indices=sparse_tensor_value.indices, 114 values=sparse_tensor_value.values, 115 dense_shape=sparse_tensor_value.dense_shape) 116 117 def __init__(self, indices, values, dense_shape): 118 """Creates a `SparseTensor`. 119 120 Args: 121 indices: A 2-D int64 tensor of shape `[N, ndims]`. 122 values: A 1-D tensor of any type and shape `[N]`. 123 dense_shape: A 1-D int64 tensor of shape `[ndims]`. 124 """ 125 with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]): 126 indices = ops.convert_to_tensor( 127 indices, name="indices", dtype=dtypes.int64) 128 # TODO(touts): Consider adding mutable_values() when 'values' 129 # is a VariableOp and updating users of SparseTensor. 130 values = ops.convert_to_tensor(values, name="values") 131 dense_shape = ops.convert_to_tensor( 132 dense_shape, name="dense_shape", dtype=dtypes.int64) 133 self._indices = indices 134 self._values = values 135 self._dense_shape = dense_shape 136 137 indices_shape = indices.shape.with_rank(2) 138 values_shape = values.shape.with_rank(1) 139 dense_shape_shape = dense_shape.shape.with_rank(1) 140 141 # Assert number of rows in indices match the number of elements in values. 142 indices_shape.dims[0].merge_with(values_shape.dims[0]) 143 # Assert number of columns in indices matches the number of elements in 144 # dense_shape. 145 indices_shape.dims[1].merge_with(dense_shape_shape.dims[0]) 146 147 def get_shape(self): 148 """Get the `TensorShape` representing the shape of the dense tensor. 149 150 Returns: 151 A `TensorShape` object. 152 """ 153 return tensor_util.constant_value_as_shape(self._dense_shape) 154 155 @property 156 def indices(self): 157 """The indices of non-zero values in the represented dense tensor. 158 159 Returns: 160 A 2-D Tensor of int64 with dense_shape `[N, ndims]`, where `N` is the 161 number of non-zero values in the tensor, and `ndims` is the rank. 162 """ 163 return self._indices 164 165 @property 166 def values(self): 167 """The non-zero values in the represented dense tensor. 168 169 Returns: 170 A 1-D Tensor of any data type. 171 """ 172 return self._values 173 174 @property 175 def op(self): 176 """The `Operation` that produces `values` as an output.""" 177 return self._values.op 178 179 @property 180 def dtype(self): 181 """The `DType` of elements in this tensor.""" 182 return self._values.dtype 183 184 @property 185 def dense_shape(self): 186 """A 1-D Tensor of int64 representing the shape of the dense tensor.""" 187 return self._dense_shape 188 189 @property 190 def shape(self): 191 """Get the `TensorShape` representing the shape of the dense tensor. 192 193 Returns: 194 A `TensorShape` object. 195 """ 196 return tensor_util.constant_value_as_shape(self._dense_shape) 197 198 @property 199 def graph(self): 200 """The `Graph` that contains the index, value, and dense_shape tensors.""" 201 return self._indices.graph 202 203 def __str__(self): 204 return "SparseTensor(indices=%s, values=%s, dense_shape=%s)" % ( 205 self._indices, self._values, self._dense_shape) 206 207 def eval(self, feed_dict=None, session=None): 208 """Evaluates this sparse tensor in a `Session`. 209 210 Calling this method will execute all preceding operations that 211 produce the inputs needed for the operation that produces this 212 tensor. 213 214 *N.B.* Before invoking `SparseTensor.eval()`, its graph must have been 215 launched in a session, and either a default session must be 216 available, or `session` must be specified explicitly. 217 218 Args: 219 feed_dict: A dictionary that maps `Tensor` objects to feed values. See 220 `tf.Session.run` for a description of the valid feed values. 221 session: (Optional.) The `Session` to be used to evaluate this sparse 222 tensor. If none, the default session will be used. 223 224 Returns: 225 A `SparseTensorValue` object. 226 """ 227 indices, values, dense_shape = _eval_using_default_session( 228 [self.indices, self.values, self.dense_shape], feed_dict, self.graph, 229 session) 230 return SparseTensorValue(indices, values, dense_shape) 231 232 @staticmethod 233 def _override_operator(operator, func): 234 _override_helper(SparseTensor, operator, func) 235 236 @property 237 def _type_spec(self): 238 return SparseTensorSpec(self.shape, self.dtype) 239 240 def _shape_invariant_to_type_spec(self, shape): 241 # From the tf.while_loop docs: "If a loop variable is a SparseTensor, the 242 # shape invariant must be TensorShape([r]) where r is the rank of the dense 243 # tensor represented by the sparse tensor. It means the shapes of the three 244 # tensors of the SparseTensor are ([None], [None, r], [r]). NOTE: The shape 245 # invariant here is the shape of the SparseTensor.dense_shape property. It 246 # must be the shape of a vector. 247 if shape.ndims is not None and shape.ndims != 1: 248 raise ValueError("Expected a shape with 1 dimension") 249 rank = tensor_shape.dimension_value(shape[0]) 250 return SparseTensorSpec(tensor_shape.unknown_shape(rank), self.dtype) 251 252 def consumers(self): 253 return self._consumers() 254 255 256SparseTensorValue = collections.namedtuple("SparseTensorValue", 257 ["indices", "values", "dense_shape"]) 258tf_export(v1=["SparseTensorValue"])(SparseTensorValue) 259_pywrap_utils.RegisterType("SparseTensorValue", SparseTensorValue) 260 261 262@tf_export("SparseTensorSpec") 263class SparseTensorSpec(type_spec.BatchableTypeSpec): 264 """Type specification for a `tf.SparseTensor`.""" 265 266 __slots__ = ["_shape", "_dtype"] 267 268 value_type = property(lambda self: SparseTensor) 269 270 def __init__(self, shape=None, dtype=dtypes.float32): 271 """Constructs a type specification for a `tf.SparseTensor`. 272 273 Args: 274 shape: The dense shape of the `SparseTensor`, or `None` to allow 275 any dense shape. 276 dtype: `tf.DType` of values in the `SparseTensor`. 277 """ 278 self._shape = tensor_shape.as_shape(shape) 279 self._dtype = dtypes.as_dtype(dtype) 280 281 def _serialize(self): 282 return (self._shape, self._dtype) 283 284 @property 285 def dtype(self): 286 """The `tf.dtypes.DType` specified by this type for the SparseTensor.""" 287 return self._dtype 288 289 @property 290 def shape(self): 291 """The `tf.TensorShape` specified by this type for the SparseTensor.""" 292 return self._shape 293 294 @property 295 def _component_specs(self): 296 rank = self._shape.ndims 297 num_values = None 298 return [ 299 tensor_spec.TensorSpec([num_values, rank], dtypes.int64), 300 tensor_spec.TensorSpec([num_values], self._dtype), 301 tensor_spec.TensorSpec([rank], dtypes.int64)] 302 303 def _to_components(self, value): 304 if isinstance(value, SparseTensorValue): 305 value = SparseTensor.from_value(value) 306 return [value.indices, value.values, value.dense_shape] 307 308 def _from_components(self, tensor_list): 309 if (all(isinstance(t, np.ndarray) for t in tensor_list) and 310 not tf2.enabled()): 311 return SparseTensorValue(*tensor_list) 312 else: 313 return SparseTensor(*tensor_list) 314 315 # The SparseTensorSpec tensor_list encoding uses (de)serialize_sparse ops 316 # to (un)box the component tensors in a way that allows for batching & 317 # unbatching. 318 @property 319 def _flat_tensor_specs(self): 320 # NOTE(mrry): The default flat shape of a boxed `SparseTensor` is `(3,)`, 321 # but a `SparseTensorSpec` can also represent a batch of boxed 322 # `SparseTensor` objects with shape `(..., 3)` (and batches of batches, 323 # etc.), so the flat shape must be unknown. 324 return [tensor_spec.TensorSpec(None, dtypes.variant)] 325 326 def _to_tensor_list(self, value): 327 value = SparseTensor.from_value(value) 328 return [gen_sparse_ops.serialize_sparse( 329 value.indices, value.values, value.dense_shape, 330 out_type=dtypes.variant)] 331 332 def _to_batched_tensor_list(self, value): 333 dense_shape = tensor_util.constant_value_as_shape(value.dense_shape) 334 if self._shape.merge_with(dense_shape).ndims == 0: 335 raise ValueError( 336 "Unbatching a sparse tensor is only supported for rank >= 1") 337 return [gen_sparse_ops.serialize_many_sparse( 338 value.indices, value.values, value.dense_shape, 339 out_type=dtypes.variant)] 340 341 def _from_compatible_tensor_list(self, tensor_list): 342 tensor_list = gen_sparse_ops.deserialize_sparse(tensor_list[0], self._dtype) 343 indices, values, dense_shape = tensor_list 344 rank = self._shape.ndims 345 indices.set_shape([None, rank]) 346 # We restore the dense_shape from the SparseTypeSpec. This is necessary 347 # for shape inference when using placeholder SparseTensors in function 348 # tracing. 349 if self._shape.is_fully_defined(): 350 dense_shape = ops.convert_to_tensor( 351 self._shape, dtype=dtypes.int64, name="shape") 352 elif (self._shape.rank is not None and 353 any(dim.value is not None for dim in self._shape.dims)): 354 # array_ops imports sparse_tensor.py. Local import to avoid import cycle. 355 from tensorflow.python.ops import array_ops # pylint: disable=g-import-not-at-top 356 pieces = array_ops.unstack(dense_shape, num=self._shape.rank) 357 for i, dim in enumerate(self._shape.dims): 358 if dim.value is not None: 359 pieces[i] = constant_op.constant(dim.value, dense_shape.dtype) 360 dense_shape = array_ops.stack(pieces) 361 else: 362 dense_shape.set_shape([rank]) 363 364 return SparseTensor(indices, values, dense_shape) 365 366 def _batch(self, batch_size): 367 return SparseTensorSpec( 368 tensor_shape.TensorShape([batch_size]).concatenate(self._shape), 369 self._dtype) 370 371 def _unbatch(self): 372 if self._shape.ndims == 0: 373 raise ValueError("Unbatching a tensor is only supported for rank >= 1") 374 return SparseTensorSpec(self._shape[1:], self._dtype) 375 376 def _to_legacy_output_types(self): 377 return self._dtype 378 379 def _to_legacy_output_shapes(self): 380 return self._shape 381 382 def _to_legacy_output_classes(self): 383 return SparseTensor 384 385 @classmethod 386 def from_value(cls, value): 387 if isinstance(value, SparseTensor): 388 return cls(value.shape, value.dtype) 389 if isinstance(value, SparseTensorValue): 390 if isinstance(value.values, np.ndarray): 391 return cls(value.dense_shape, value.values.dtype) 392 else: 393 return cls.from_value(SparseTensor.from_value(value)) 394 else: 395 raise TypeError("Expected SparseTensor or SparseTensorValue") 396 397 398# TODO(b/133606651) Delete the SparseTensor registration when CompositeTensor 399# is updated to define a _type_spec field (since registration will be 400# automatic). Do *not* delete the SparseTensorValue registration. 401type_spec.register_type_spec_from_value_converter( 402 SparseTensor, SparseTensorSpec.from_value) 403type_spec.register_type_spec_from_value_converter( 404 SparseTensorValue, SparseTensorSpec.from_value) 405 406 407@tf_export(v1=["convert_to_tensor_or_sparse_tensor"]) 408def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None): 409 """Converts value to a `SparseTensor` or `Tensor`. 410 411 Args: 412 value: A `SparseTensor`, `SparseTensorValue`, or an object whose type has a 413 registered `Tensor` conversion function. 414 dtype: Optional element type for the returned tensor. If missing, the type 415 is inferred from the type of `value`. 416 name: Optional name to use if a new `Tensor` is created. 417 418 Returns: 419 A `SparseTensor` or `Tensor` based on `value`. 420 421 Raises: 422 RuntimeError: If result type is incompatible with `dtype`. 423 """ 424 if dtype is not None: 425 dtype = dtypes.as_dtype(dtype) 426 if isinstance(value, SparseTensorValue): 427 value = SparseTensor.from_value(value) 428 if isinstance(value, SparseTensor): 429 if dtype and not dtype.is_compatible_with(value.dtype): 430 raise RuntimeError("Sparse dtype: requested = %s, actual = %s" % 431 (dtype.name, value.dtype.name)) 432 return value 433 return ops.convert_to_tensor(value, dtype=dtype, name=name) 434 435 436def is_sparse(x): 437 """Check whether `x` is sparse. 438 439 Check whether an object is a `tf.SparseTensor` or 440 `tf.compat.v1.SparseTensorValue`. 441 442 Args: 443 x: A python object to check. 444 445 Returns: 446 `True` iff `x` is a `tf.SparseTensor` or `tf.compat.v1.SparseTensorValue`. 447 """ 448 return isinstance(x, (SparseTensor, SparseTensorValue)) 449