• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Sparse tensors."""
16# pylint: disable=g-bad-name
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23import numpy as np
24
25from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
26from tensorflow.python import _pywrap_utils
27from tensorflow.python import tf2
28from tensorflow.python.framework import composite_tensor
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_like
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.framework import tensor_spec
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.framework import type_spec
37from tensorflow.python.ops import gen_sparse_ops
38from tensorflow.python.util.tf_export import tf_export
39
40# pylint: disable=protected-access
41_TensorLike = tensor_like._TensorLike
42_eval_using_default_session = ops._eval_using_default_session
43_override_helper = ops._override_helper
44# pylint: enable=protected-access
45
46
47@tf_export("sparse.SparseTensor", "SparseTensor")
48class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
49  """Represents a sparse tensor.
50
51  TensorFlow represents a sparse tensor as three separate dense tensors:
52  `indices`, `values`, and `dense_shape`.  In Python, the three tensors are
53  collected into a `SparseTensor` class for ease of use.  If you have separate
54  `indices`, `values`, and `dense_shape` tensors, wrap them in a `SparseTensor`
55  object before passing to the ops below.
56
57  Concretely, the sparse tensor `SparseTensor(indices, values, dense_shape)`
58  comprises the following components, where `N` and `ndims` are the number
59  of values and number of dimensions in the `SparseTensor`, respectively:
60
61  * `indices`: A 2-D int64 tensor of shape `[N, ndims]`, which specifies the
62    indices of the elements in the sparse tensor that contain nonzero values
63    (elements are zero-indexed). For example, `indices=[[1,3], [2,4]]` specifies
64    that the elements with indexes of [1,3] and [2,4] have nonzero values.
65
66  * `values`: A 1-D tensor of any type and shape `[N]`, which supplies the
67    values for each element in `indices`. For example, given `indices=[[1,3],
68    [2,4]]`, the parameter `values=[18, 3.6]` specifies that element [1,3] of
69    the sparse tensor has a value of 18, and element [2,4] of the tensor has a
70    value of 3.6.
71
72  * `dense_shape`: A 1-D int64 tensor of shape `[ndims]`, which specifies the
73    dense_shape of the sparse tensor. Takes a list indicating the number of
74    elements in each dimension. For example, `dense_shape=[3,6]` specifies a
75    two-dimensional 3x6 tensor, `dense_shape=[2,3,4]` specifies a
76    three-dimensional 2x3x4 tensor, and `dense_shape=[9]` specifies a
77    one-dimensional tensor with 9 elements.
78
79  The corresponding dense tensor satisfies:
80
81  ```python
82  dense.shape = dense_shape
83  dense[tuple(indices[i])] = values[i]
84  ```
85
86  By convention, `indices` should be sorted in row-major order (or equivalently
87  lexicographic order on the tuples `indices[i]`). This is not enforced when
88  `SparseTensor` objects are constructed, but most ops assume correct ordering.
89  If the ordering of sparse tensor `st` is wrong, a fixed version can be
90  obtained by calling `tf.sparse.reorder(st)`.
91
92  Example: The sparse tensor
93
94  ```python
95  SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
96  ```
97
98  represents the dense tensor
99
100  ```python
101  [[1, 0, 0, 0]
102   [0, 0, 2, 0]
103   [0, 0, 0, 0]]
104  ```
105  """
106
107  @classmethod
108  def from_value(cls, sparse_tensor_value):
109    if not is_sparse(sparse_tensor_value):
110      raise TypeError("Neither a SparseTensor nor SparseTensorValue: %s." %
111                      sparse_tensor_value)
112    return SparseTensor(
113        indices=sparse_tensor_value.indices,
114        values=sparse_tensor_value.values,
115        dense_shape=sparse_tensor_value.dense_shape)
116
117  def __init__(self, indices, values, dense_shape):
118    """Creates a `SparseTensor`.
119
120    Args:
121      indices: A 2-D int64 tensor of shape `[N, ndims]`.
122      values: A 1-D tensor of any type and shape `[N]`.
123      dense_shape: A 1-D int64 tensor of shape `[ndims]`.
124    """
125    with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]):
126      indices = ops.convert_to_tensor(
127          indices, name="indices", dtype=dtypes.int64)
128      # TODO(touts): Consider adding mutable_values() when 'values'
129      # is a VariableOp and updating users of SparseTensor.
130      values = ops.convert_to_tensor(values, name="values")
131      dense_shape = ops.convert_to_tensor(
132          dense_shape, name="dense_shape", dtype=dtypes.int64)
133    self._indices = indices
134    self._values = values
135    self._dense_shape = dense_shape
136
137    indices_shape = indices.shape.with_rank(2)
138    values_shape = values.shape.with_rank(1)
139    dense_shape_shape = dense_shape.shape.with_rank(1)
140
141    # Assert number of rows in indices match the number of elements in values.
142    indices_shape.dims[0].merge_with(values_shape.dims[0])
143    # Assert number of columns in indices matches the number of elements in
144    # dense_shape.
145    indices_shape.dims[1].merge_with(dense_shape_shape.dims[0])
146
147  def get_shape(self):
148    """Get the `TensorShape` representing the shape of the dense tensor.
149
150    Returns:
151      A `TensorShape` object.
152    """
153    return tensor_util.constant_value_as_shape(self._dense_shape)
154
155  @property
156  def indices(self):
157    """The indices of non-zero values in the represented dense tensor.
158
159    Returns:
160      A 2-D Tensor of int64 with dense_shape `[N, ndims]`, where `N` is the
161        number of non-zero values in the tensor, and `ndims` is the rank.
162    """
163    return self._indices
164
165  @property
166  def values(self):
167    """The non-zero values in the represented dense tensor.
168
169    Returns:
170      A 1-D Tensor of any data type.
171    """
172    return self._values
173
174  @property
175  def op(self):
176    """The `Operation` that produces `values` as an output."""
177    return self._values.op
178
179  @property
180  def dtype(self):
181    """The `DType` of elements in this tensor."""
182    return self._values.dtype
183
184  @property
185  def dense_shape(self):
186    """A 1-D Tensor of int64 representing the shape of the dense tensor."""
187    return self._dense_shape
188
189  @property
190  def shape(self):
191    """Get the `TensorShape` representing the shape of the dense tensor.
192
193    Returns:
194      A `TensorShape` object.
195    """
196    return tensor_util.constant_value_as_shape(self._dense_shape)
197
198  @property
199  def graph(self):
200    """The `Graph` that contains the index, value, and dense_shape tensors."""
201    return self._indices.graph
202
203  def __str__(self):
204    return "SparseTensor(indices=%s, values=%s, dense_shape=%s)" % (
205        self._indices, self._values, self._dense_shape)
206
207  def eval(self, feed_dict=None, session=None):
208    """Evaluates this sparse tensor in a `Session`.
209
210    Calling this method will execute all preceding operations that
211    produce the inputs needed for the operation that produces this
212    tensor.
213
214    *N.B.* Before invoking `SparseTensor.eval()`, its graph must have been
215    launched in a session, and either a default session must be
216    available, or `session` must be specified explicitly.
217
218    Args:
219      feed_dict: A dictionary that maps `Tensor` objects to feed values. See
220        `tf.Session.run` for a description of the valid feed values.
221      session: (Optional.) The `Session` to be used to evaluate this sparse
222        tensor. If none, the default session will be used.
223
224    Returns:
225      A `SparseTensorValue` object.
226    """
227    indices, values, dense_shape = _eval_using_default_session(
228        [self.indices, self.values, self.dense_shape], feed_dict, self.graph,
229        session)
230    return SparseTensorValue(indices, values, dense_shape)
231
232  @staticmethod
233  def _override_operator(operator, func):
234    _override_helper(SparseTensor, operator, func)
235
236  @property
237  def _type_spec(self):
238    return SparseTensorSpec(self.shape, self.dtype)
239
240  def _shape_invariant_to_type_spec(self, shape):
241    # From the tf.while_loop docs: "If a loop variable is a SparseTensor, the
242    # shape invariant must be TensorShape([r]) where r is the rank of the dense
243    # tensor represented by the sparse tensor. It means the shapes of the three
244    # tensors of the SparseTensor are ([None], [None, r], [r]). NOTE: The shape
245    # invariant here is the shape of the SparseTensor.dense_shape property. It
246    # must be the shape of a vector.
247    if shape.ndims is not None and shape.ndims != 1:
248      raise ValueError("Expected a shape with 1 dimension")
249    rank = tensor_shape.dimension_value(shape[0])
250    return SparseTensorSpec(tensor_shape.unknown_shape(rank), self.dtype)
251
252  def consumers(self):
253    return self._consumers()
254
255
256SparseTensorValue = collections.namedtuple("SparseTensorValue",
257                                           ["indices", "values", "dense_shape"])
258tf_export(v1=["SparseTensorValue"])(SparseTensorValue)
259_pywrap_utils.RegisterType("SparseTensorValue", SparseTensorValue)
260
261
262@tf_export("SparseTensorSpec")
263class SparseTensorSpec(type_spec.BatchableTypeSpec):
264  """Type specification for a `tf.SparseTensor`."""
265
266  __slots__ = ["_shape", "_dtype"]
267
268  value_type = property(lambda self: SparseTensor)
269
270  def __init__(self, shape=None, dtype=dtypes.float32):
271    """Constructs a type specification for a `tf.SparseTensor`.
272
273    Args:
274      shape: The dense shape of the `SparseTensor`, or `None` to allow
275        any dense shape.
276      dtype: `tf.DType` of values in the `SparseTensor`.
277    """
278    self._shape = tensor_shape.as_shape(shape)
279    self._dtype = dtypes.as_dtype(dtype)
280
281  def _serialize(self):
282    return (self._shape, self._dtype)
283
284  @property
285  def dtype(self):
286    """The `tf.dtypes.DType` specified by this type for the SparseTensor."""
287    return self._dtype
288
289  @property
290  def shape(self):
291    """The `tf.TensorShape` specified by this type for the SparseTensor."""
292    return self._shape
293
294  @property
295  def _component_specs(self):
296    rank = self._shape.ndims
297    num_values = None
298    return [
299        tensor_spec.TensorSpec([num_values, rank], dtypes.int64),
300        tensor_spec.TensorSpec([num_values], self._dtype),
301        tensor_spec.TensorSpec([rank], dtypes.int64)]
302
303  def _to_components(self, value):
304    if isinstance(value, SparseTensorValue):
305      value = SparseTensor.from_value(value)
306    return [value.indices, value.values, value.dense_shape]
307
308  def _from_components(self, tensor_list):
309    if (all(isinstance(t, np.ndarray) for t in tensor_list) and
310        not tf2.enabled()):
311      return SparseTensorValue(*tensor_list)
312    else:
313      return SparseTensor(*tensor_list)
314
315  # The SparseTensorSpec tensor_list encoding uses (de)serialize_sparse ops
316  # to (un)box the component tensors in a way that allows for batching &
317  # unbatching.
318  @property
319  def _flat_tensor_specs(self):
320    # NOTE(mrry): The default flat shape of a boxed `SparseTensor` is `(3,)`,
321    # but a `SparseTensorSpec` can also represent a batch of boxed
322    # `SparseTensor` objects with shape `(..., 3)` (and batches of batches,
323    # etc.), so the flat shape must be unknown.
324    return [tensor_spec.TensorSpec(None, dtypes.variant)]
325
326  def _to_tensor_list(self, value):
327    value = SparseTensor.from_value(value)
328    return [gen_sparse_ops.serialize_sparse(
329        value.indices, value.values, value.dense_shape,
330        out_type=dtypes.variant)]
331
332  def _to_batched_tensor_list(self, value):
333    dense_shape = tensor_util.constant_value_as_shape(value.dense_shape)
334    if self._shape.merge_with(dense_shape).ndims == 0:
335      raise ValueError(
336          "Unbatching a sparse tensor is only supported for rank >= 1")
337    return [gen_sparse_ops.serialize_many_sparse(
338        value.indices, value.values, value.dense_shape,
339        out_type=dtypes.variant)]
340
341  def _from_compatible_tensor_list(self, tensor_list):
342    tensor_list = gen_sparse_ops.deserialize_sparse(tensor_list[0], self._dtype)
343    indices, values, dense_shape = tensor_list
344    rank = self._shape.ndims
345    indices.set_shape([None, rank])
346    # We restore the dense_shape from the SparseTypeSpec. This is necessary
347    # for shape inference when using placeholder SparseTensors in function
348    # tracing.
349    if self._shape.is_fully_defined():
350      dense_shape = ops.convert_to_tensor(
351          self._shape, dtype=dtypes.int64, name="shape")
352    elif (self._shape.rank is not None and
353          any(dim.value is not None for dim in self._shape.dims)):
354      # array_ops imports sparse_tensor.py. Local import to avoid import cycle.
355      from tensorflow.python.ops import array_ops  # pylint: disable=g-import-not-at-top
356      pieces = array_ops.unstack(dense_shape, num=self._shape.rank)
357      for i, dim in enumerate(self._shape.dims):
358        if dim.value is not None:
359          pieces[i] = constant_op.constant(dim.value, dense_shape.dtype)
360      dense_shape = array_ops.stack(pieces)
361    else:
362      dense_shape.set_shape([rank])
363
364    return SparseTensor(indices, values, dense_shape)
365
366  def _batch(self, batch_size):
367    return SparseTensorSpec(
368        tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
369        self._dtype)
370
371  def _unbatch(self):
372    if self._shape.ndims == 0:
373      raise ValueError("Unbatching a tensor is only supported for rank >= 1")
374    return SparseTensorSpec(self._shape[1:], self._dtype)
375
376  def _to_legacy_output_types(self):
377    return self._dtype
378
379  def _to_legacy_output_shapes(self):
380    return self._shape
381
382  def _to_legacy_output_classes(self):
383    return SparseTensor
384
385  @classmethod
386  def from_value(cls, value):
387    if isinstance(value, SparseTensor):
388      return cls(value.shape, value.dtype)
389    if isinstance(value, SparseTensorValue):
390      if isinstance(value.values, np.ndarray):
391        return cls(value.dense_shape, value.values.dtype)
392      else:
393        return cls.from_value(SparseTensor.from_value(value))
394    else:
395      raise TypeError("Expected SparseTensor or SparseTensorValue")
396
397
398# TODO(b/133606651) Delete the SparseTensor registration when CompositeTensor
399# is updated to define a _type_spec field (since registration will be
400# automatic).  Do *not* delete the SparseTensorValue registration.
401type_spec.register_type_spec_from_value_converter(
402    SparseTensor, SparseTensorSpec.from_value)
403type_spec.register_type_spec_from_value_converter(
404    SparseTensorValue, SparseTensorSpec.from_value)
405
406
407@tf_export(v1=["convert_to_tensor_or_sparse_tensor"])
408def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None):
409  """Converts value to a `SparseTensor` or `Tensor`.
410
411  Args:
412    value: A `SparseTensor`, `SparseTensorValue`, or an object whose type has a
413      registered `Tensor` conversion function.
414    dtype: Optional element type for the returned tensor. If missing, the type
415      is inferred from the type of `value`.
416    name: Optional name to use if a new `Tensor` is created.
417
418  Returns:
419    A `SparseTensor` or `Tensor` based on `value`.
420
421  Raises:
422    RuntimeError: If result type is incompatible with `dtype`.
423  """
424  if dtype is not None:
425    dtype = dtypes.as_dtype(dtype)
426  if isinstance(value, SparseTensorValue):
427    value = SparseTensor.from_value(value)
428  if isinstance(value, SparseTensor):
429    if dtype and not dtype.is_compatible_with(value.dtype):
430      raise RuntimeError("Sparse dtype: requested = %s, actual = %s" %
431                         (dtype.name, value.dtype.name))
432    return value
433  return ops.convert_to_tensor(value, dtype=dtype, name=name)
434
435
436def is_sparse(x):
437  """Check whether `x` is sparse.
438
439  Check whether an object is a `tf.SparseTensor` or
440  `tf.compat.v1.SparseTensorValue`.
441
442  Args:
443    x: A python object to check.
444
445  Returns:
446    `True` iff `x` is a `tf.SparseTensor` or `tf.compat.v1.SparseTensorValue`.
447  """
448  return isinstance(x, (SparseTensor, SparseTensorValue))
449