• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Sparse tensors."""
16# pylint: disable=g-bad-name
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23import numpy as np
24
25from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
26from tensorflow.python import tf2
27from tensorflow.python.framework import composite_tensor
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import tensor_spec
33from tensorflow.python.framework import tensor_util
34from tensorflow.python.framework import type_spec
35from tensorflow.python.ops import gen_sparse_ops
36from tensorflow.python.types import internal
37from tensorflow.python.util import _pywrap_utils
38from tensorflow.python.util.tf_export import tf_export
39
40# pylint: disable=protected-access
41_eval_using_default_session = ops._eval_using_default_session
42_override_helper = ops._override_helper
43# pylint: enable=protected-access
44
45
46@tf_export("sparse.SparseTensor", "SparseTensor")
47class SparseTensor(internal.NativeObject, composite_tensor.CompositeTensor):
48  """Represents a sparse tensor.
49
50  TensorFlow represents a sparse tensor as three separate dense tensors:
51  `indices`, `values`, and `dense_shape`.  In Python, the three tensors are
52  collected into a `SparseTensor` class for ease of use.  If you have separate
53  `indices`, `values`, and `dense_shape` tensors, wrap them in a `SparseTensor`
54  object before passing to the ops below.
55
56  Concretely, the sparse tensor `SparseTensor(indices, values, dense_shape)`
57  comprises the following components, where `N` and `ndims` are the number
58  of values and number of dimensions in the `SparseTensor`, respectively:
59
60  * `indices`: A 2-D int64 tensor of shape `[N, ndims]`, which specifies the
61    indices of the elements in the sparse tensor that contain nonzero values
62    (elements are zero-indexed). For example, `indices=[[1,3], [2,4]]` specifies
63    that the elements with indexes of [1,3] and [2,4] have nonzero values.
64
65  * `values`: A 1-D tensor of any type and shape `[N]`, which supplies the
66    values for each element in `indices`. For example, given `indices=[[1,3],
67    [2,4]]`, the parameter `values=[18, 3.6]` specifies that element [1,3] of
68    the sparse tensor has a value of 18, and element [2,4] of the tensor has a
69    value of 3.6.
70
71  * `dense_shape`: A 1-D int64 tensor of shape `[ndims]`, which specifies the
72    dense_shape of the sparse tensor. Takes a list indicating the number of
73    elements in each dimension. For example, `dense_shape=[3,6]` specifies a
74    two-dimensional 3x6 tensor, `dense_shape=[2,3,4]` specifies a
75    three-dimensional 2x3x4 tensor, and `dense_shape=[9]` specifies a
76    one-dimensional tensor with 9 elements.
77
78  The corresponding dense tensor satisfies:
79
80  ```python
81  dense.shape = dense_shape
82  dense[tuple(indices[i])] = values[i]
83  ```
84
85  By convention, `indices` should be sorted in row-major order (or equivalently
86  lexicographic order on the tuples `indices[i]`). This is not enforced when
87  `SparseTensor` objects are constructed, but most ops assume correct ordering.
88  If the ordering of sparse tensor `st` is wrong, a fixed version can be
89  obtained by calling `tf.sparse.reorder(st)`.
90
91  Example: The sparse tensor
92
93  ```python
94  SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
95  ```
96
97  represents the dense tensor
98
99  ```python
100  [[1, 0, 0, 0]
101   [0, 0, 2, 0]
102   [0, 0, 0, 0]]
103  ```
104  """
105
106  @classmethod
107  def from_value(cls, sparse_tensor_value):
108    if not is_sparse(sparse_tensor_value):
109      raise TypeError("Neither a SparseTensor nor SparseTensorValue: %s." %
110                      sparse_tensor_value)
111    return SparseTensor(
112        indices=sparse_tensor_value.indices,
113        values=sparse_tensor_value.values,
114        dense_shape=sparse_tensor_value.dense_shape)
115
116  def __init__(self, indices, values, dense_shape):
117    """Creates a `SparseTensor`.
118
119    Args:
120      indices: A 2-D int64 tensor of shape `[N, ndims]`.
121      values: A 1-D tensor of any type and shape `[N]`.
122      dense_shape: A 1-D int64 tensor of shape `[ndims]`.
123
124    Raises:
125      ValueError: When building an eager SparseTensor if `dense_shape` is
126        unknown or contains unknown elements (None or -1).
127    """
128    with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]):
129      indices = ops.convert_to_tensor(
130          indices, name="indices", dtype=dtypes.int64)
131      # TODO(touts): Consider adding mutable_values() when 'values'
132      # is a VariableOp and updating users of SparseTensor.
133      values = ops.convert_to_tensor(values, name="values")
134
135      dense_shape = ops.convert_to_tensor(
136          dense_shape, name="dense_shape", dtype=dtypes.int64)
137      dense_shape_default = tensor_util.constant_value_as_shape(dense_shape)
138
139    self._indices = indices
140    self._values = values
141    self._dense_shape = dense_shape
142    self._dense_shape_default = dense_shape_default
143
144    indices_shape = indices.shape.with_rank(2)
145    values_shape = values.shape.with_rank(1)
146    dense_shape_shape = dense_shape.shape.with_rank(1)
147
148    # Assert number of rows in indices match the number of elements in values.
149    indices_shape.dims[0].assert_is_compatible_with(values_shape.dims[0])
150    # Assert number of columns in indices matches the number of elements in
151    # dense_shape.
152    indices_shape.dims[1].assert_is_compatible_with(dense_shape_shape.dims[0])
153
154  def get_shape(self):
155    """Get the `TensorShape` representing the shape of the dense tensor.
156
157    Returns:
158      A `TensorShape` object.
159    """
160    return self._dense_shape_default
161
162  @property
163  def indices(self):
164    """The indices of non-zero values in the represented dense tensor.
165
166    Returns:
167      A 2-D Tensor of int64 with dense_shape `[N, ndims]`, where `N` is the
168        number of non-zero values in the tensor, and `ndims` is the rank.
169    """
170    return self._indices
171
172  @property
173  def values(self):
174    """The non-zero values in the represented dense tensor.
175
176    Returns:
177      A 1-D Tensor of any data type.
178    """
179    return self._values
180
181  def with_values(self, new_values):
182    """Returns a copy of `self` with `values` replaced by `new_values`.
183
184    This method produces a new `SparseTensor` that has the same nonzero
185    `indices` and same `dense_shape`, but updated values.
186
187    Args:
188      new_values: The values of the new `SparseTensor`. Needs to have the same
189        shape as the current `.values` `Tensor`. May have a different type than
190        the current `values`.
191
192    Returns:
193      A `SparseTensor` with identical indices and shape but updated values.
194
195    Example usage:
196
197    >>> st = tf.sparse.from_dense([[1, 0, 2, 0], [3, 0, 0, 4]])
198    >>> tf.sparse.to_dense(st.with_values([10, 20, 30, 40]))  # 4 nonzero values
199    <tf.Tensor: shape=(2, 4), dtype=int32, numpy=
200    array([[10,  0, 20,  0],
201           [30,  0,  0, 40]], dtype=int32)>
202
203    """
204    return SparseTensor(self._indices, new_values, self._dense_shape)
205
206  @property
207  def op(self):
208    """The `Operation` that produces `values` as an output."""
209    return self._values.op
210
211  @property
212  def dtype(self):
213    """The `DType` of elements in this tensor."""
214    return self._values.dtype
215
216  @property
217  def dense_shape(self):
218    """A 1-D Tensor of int64 representing the shape of the dense tensor."""
219    return self._dense_shape
220
221  @property
222  def shape(self):
223    """Get the `TensorShape` representing the shape of the dense tensor.
224
225    Returns:
226      A `TensorShape` object.
227    """
228    return self._dense_shape_default
229
230  @property
231  def graph(self):
232    """The `Graph` that contains the index, value, and dense_shape tensors."""
233    return self._indices.graph
234
235  def __str__(self):
236    return "SparseTensor(indices=%s, values=%s, dense_shape=%s)" % (
237        self._indices, self._values, self._dense_shape)
238
239  def eval(self, feed_dict=None, session=None):
240    """Evaluates this sparse tensor in a `Session`.
241
242    Calling this method will execute all preceding operations that
243    produce the inputs needed for the operation that produces this
244    tensor.
245
246    *N.B.* Before invoking `SparseTensor.eval()`, its graph must have been
247    launched in a session, and either a default session must be
248    available, or `session` must be specified explicitly.
249
250    Args:
251      feed_dict: A dictionary that maps `Tensor` objects to feed values. See
252        `tf.Session.run` for a description of the valid feed values.
253      session: (Optional.) The `Session` to be used to evaluate this sparse
254        tensor. If none, the default session will be used.
255
256    Returns:
257      A `SparseTensorValue` object.
258    """
259    indices, values, dense_shape = _eval_using_default_session(
260        [self.indices, self.values, self.dense_shape], feed_dict, self.graph,
261        session)
262    return SparseTensorValue(indices, values, dense_shape)
263
264  @staticmethod
265  def _override_operator(operator, func):
266    _override_helper(SparseTensor, operator, func)
267
268  @property
269  def _type_spec(self):
270    return SparseTensorSpec(self.shape, self.dtype)
271
272  def _shape_invariant_to_type_spec(self, shape):
273    # From the tf.while_loop docs: "If a loop variable is a SparseTensor, the
274    # shape invariant must be TensorShape([r]) where r is the rank of the dense
275    # tensor represented by the sparse tensor. It means the shapes of the three
276    # tensors of the SparseTensor are ([None], [None, r], [r]). NOTE: The shape
277    # invariant here is the shape of the SparseTensor.dense_shape property. It
278    # must be the shape of a vector.
279    if shape.ndims is not None and shape.ndims != 1:
280      raise ValueError("Expected a shape with 1 dimension")
281    rank = tensor_shape.dimension_value(shape[0])
282    return SparseTensorSpec(tensor_shape.unknown_shape(rank), self.dtype)
283
284  def consumers(self):
285    return self._consumers()
286
287
288SparseTensorValue = collections.namedtuple("SparseTensorValue",
289                                           ["indices", "values", "dense_shape"])
290tf_export(v1=["SparseTensorValue"])(SparseTensorValue)
291_pywrap_utils.RegisterType("SparseTensorValue", SparseTensorValue)
292
293
294@tf_export("SparseTensorSpec")
295@type_spec.register("tf.SparseTensorSpec")
296class SparseTensorSpec(type_spec.BatchableTypeSpec):
297  """Type specification for a `tf.sparse.SparseTensor`."""
298
299  __slots__ = ["_shape", "_dtype"]
300
301  value_type = property(lambda self: SparseTensor)
302
303  def __init__(self, shape=None, dtype=dtypes.float32):
304    """Constructs a type specification for a `tf.sparse.SparseTensor`.
305
306    Args:
307      shape: The dense shape of the `SparseTensor`, or `None` to allow any dense
308        shape.
309      dtype: `tf.DType` of values in the `SparseTensor`.
310    """
311    self._shape = tensor_shape.as_shape(shape)
312    self._dtype = dtypes.as_dtype(dtype)
313
314  def _serialize(self):
315    return (self._shape, self._dtype)
316
317  @property
318  def dtype(self):
319    """The `tf.dtypes.DType` specified by this type for the SparseTensor."""
320    return self._dtype
321
322  @property
323  def shape(self):
324    """The `tf.TensorShape` specified by this type for the SparseTensor."""
325    return self._shape
326
327  @property
328  def _component_specs(self):
329    rank = self._shape.ndims
330    num_values = None
331    return [
332        tensor_spec.TensorSpec([num_values, rank], dtypes.int64),
333        tensor_spec.TensorSpec([num_values], self._dtype),
334        tensor_spec.TensorSpec([rank], dtypes.int64)]
335
336  def _to_components(self, value):
337    if isinstance(value, SparseTensorValue):
338      value = SparseTensor.from_value(value)
339    return [value.indices, value.values, value.dense_shape]
340
341  def _from_components(self, tensor_list):
342    if (all(isinstance(t, np.ndarray) for t in tensor_list) and
343        not tf2.enabled()):
344      return SparseTensorValue(*tensor_list)
345    else:
346      return SparseTensor(*tensor_list)
347
348  # The SparseTensorSpec tensor_list encoding uses (de)serialize_sparse ops
349  # to (un)box the component tensors in a way that allows for batching &
350  # unbatching.
351  @property
352  def _flat_tensor_specs(self):
353    # NOTE(mrry): The default flat shape of a boxed `SparseTensor` is `(3,)`,
354    # but a `SparseTensorSpec` can also represent a batch of boxed
355    # `SparseTensor` objects with shape `(..., 3)` (and batches of batches,
356    # etc.), so the flat shape must be unknown.
357    return [tensor_spec.TensorSpec(None, dtypes.variant)]
358
359  def _to_tensor_list(self, value):
360    value = SparseTensor.from_value(value)
361    return [gen_sparse_ops.serialize_sparse(
362        value.indices, value.values, value.dense_shape,
363        out_type=dtypes.variant)]
364
365  def _to_batched_tensor_list(self, value):
366    dense_shape = tensor_util.constant_value_as_shape(value.dense_shape)
367    if self._shape.merge_with(dense_shape).ndims == 0:
368      raise ValueError(
369          "Unbatching a sparse tensor is only supported for rank >= 1")
370    return [gen_sparse_ops.serialize_many_sparse(
371        value.indices, value.values, value.dense_shape,
372        out_type=dtypes.variant)]
373
374  def _from_compatible_tensor_list(self, tensor_list):
375    tensor_list = gen_sparse_ops.deserialize_sparse(tensor_list[0], self._dtype)
376    indices, values, dense_shape = tensor_list
377    rank = self._shape.ndims
378    indices.set_shape([None, rank])
379    # We restore the dense_shape from the SparseTypeSpec. This is necessary
380    # for shape inference when using placeholder SparseTensors in function
381    # tracing.
382    if self._shape.is_fully_defined():
383      dense_shape = ops.convert_to_tensor(
384          self._shape, dtype=dtypes.int64, name="shape")
385    elif (self._shape.rank is not None and
386          any(dim.value is not None for dim in self._shape.dims)):
387      # array_ops imports sparse_tensor.py. Local import to avoid import cycle.
388      from tensorflow.python.ops import array_ops  # pylint: disable=g-import-not-at-top
389      pieces = array_ops.unstack(dense_shape, num=self._shape.rank)
390      for i, dim in enumerate(self._shape.dims):
391        if dim.value is not None:
392          pieces[i] = constant_op.constant(dim.value, dense_shape.dtype)
393      dense_shape = array_ops.stack(pieces)
394    else:
395      dense_shape.set_shape([rank])
396
397    return SparseTensor(indices, values, dense_shape)
398
399  def _batch(self, batch_size):
400    return SparseTensorSpec(
401        tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
402        self._dtype)
403
404  def _unbatch(self):
405    if self._shape.ndims == 0:
406      raise ValueError("Unbatching a tensor is only supported for rank >= 1")
407    return SparseTensorSpec(self._shape[1:], self._dtype)
408
409  def _to_legacy_output_types(self):
410    return self._dtype
411
412  def _to_legacy_output_shapes(self):
413    return self._shape
414
415  def _to_legacy_output_classes(self):
416    return SparseTensor
417
418  @classmethod
419  def from_value(cls, value):
420    if isinstance(value, SparseTensor):
421      return cls(value.shape, value.dtype)
422    if isinstance(value, SparseTensorValue):
423      if isinstance(value.values, np.ndarray):
424        return cls(value.dense_shape, value.values.dtype)
425      else:
426        return cls.from_value(SparseTensor.from_value(value))
427    else:
428      raise TypeError("Expected SparseTensor or SparseTensorValue")
429
430
431# TODO(b/133606651) Delete the SparseTensor registration when CompositeTensor
432# is updated to define a _type_spec field (since registration will be
433# automatic).  Do *not* delete the SparseTensorValue registration.
434type_spec.register_type_spec_from_value_converter(
435    SparseTensor, SparseTensorSpec.from_value)
436type_spec.register_type_spec_from_value_converter(
437    SparseTensorValue, SparseTensorSpec.from_value)
438
439
440@tf_export(v1=["convert_to_tensor_or_sparse_tensor"])
441def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None):
442  """Converts value to a `SparseTensor` or `Tensor`.
443
444  Args:
445    value: A `SparseTensor`, `SparseTensorValue`, or an object whose type has a
446      registered `Tensor` conversion function.
447    dtype: Optional element type for the returned tensor. If missing, the type
448      is inferred from the type of `value`.
449    name: Optional name to use if a new `Tensor` is created.
450
451  Returns:
452    A `SparseTensor` or `Tensor` based on `value`.
453
454  Raises:
455    RuntimeError: If result type is incompatible with `dtype`.
456  """
457  if dtype is not None:
458    dtype = dtypes.as_dtype(dtype)
459  if isinstance(value, SparseTensorValue):
460    value = SparseTensor.from_value(value)
461  if isinstance(value, SparseTensor):
462    if dtype and not dtype.is_compatible_with(value.dtype):
463      raise RuntimeError("Sparse dtype: requested = %s, actual = %s" %
464                         (dtype.name, value.dtype.name))
465    return value
466  return ops.convert_to_tensor(value, dtype=dtype, name=name)
467
468
469def is_sparse(x):
470  """Check whether `x` is sparse.
471
472  Check whether an object is a `tf.sparse.SparseTensor` or
473  `tf.compat.v1.SparseTensorValue`.
474
475  Args:
476    x: A python object to check.
477
478  Returns:
479    `True` iff `x` is a `tf.sparse.SparseTensor` or
480    `tf.compat.v1.SparseTensorValue`.
481  """
482  return isinstance(x, (SparseTensor, SparseTensorValue))
483