• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Indexed slices."""
16
17# pylint: disable=g-bad-name
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import warnings
24
25import numpy as np
26
27from tensorflow.python import tf2
28from tensorflow.python.eager import context
29from tensorflow.python.framework import composite_tensor
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import tensor_conversion_registry
32from tensorflow.python.framework import tensor_like
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.framework import type_spec
35from tensorflow.python.util.lazy_loader import LazyLoader
36from tensorflow.python.util.tf_export import tf_export
37
38
39# Use LazyLoader to avoid circular dependencies.
40#
41# Note: these can all be changed to regular imports once all code has been
42# updated to refer the symbols defined in this module directly, rather than
43# using the backwards-compatible aliases in ops.py.  (E.g.,
44# "indexed_slices.IndexedSlices" rather than "ops.IndexedSlices".)
45math_ops = LazyLoader(
46    "math_ops", globals(),
47    "tensorflow.python.ops.math_ops")
48ops = LazyLoader(
49    "ops", globals(), "tensorflow.python.framework.ops")
50tensor_spec = LazyLoader(
51    "tensor_spec", globals(),
52    "tensorflow.python.framework.tensor_spec")
53tensor_util = LazyLoader(
54    "tensor_util", globals(),
55    "tensorflow.python.framework.tensor_util")
56
57# pylint: disable=protected-access
58_TensorLike = tensor_like._TensorLike
59# pylint: enable=protected-access
60
61
62@tf_export("IndexedSlices")
63class IndexedSlices(_TensorLike, composite_tensor.CompositeTensor):
64  """A sparse representation of a set of tensor slices at given indices.
65
66  This class is a simple wrapper for a pair of `Tensor` objects:
67
68  * `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`.
69  * `indices`: A 1-D integer `Tensor` with shape `[D0]`.
70
71  An `IndexedSlices` is typically used to represent a subset of a larger
72  tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`.
73  The values in `indices` are the indices in the first dimension of
74  the slices that have been extracted from the larger tensor.
75
76  The dense tensor `dense` represented by an `IndexedSlices` `slices` has
77
78  ```python
79  dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]
80  ```
81
82  The `IndexedSlices` class is used principally in the definition of
83  gradients for operations that have sparse gradients
84  (e.g. `tf.gather`).
85
86  Contrast this representation with
87  `tf.SparseTensor`,
88  which uses multi-dimensional indices and scalar values.
89  """
90
91  def __init__(self, values, indices, dense_shape=None):
92    """Creates an `IndexedSlices`."""
93    self._values = values
94    self._indices = indices
95    self._dense_shape = dense_shape
96
97  @property
98  def values(self):
99    """A `Tensor` containing the values of the slices."""
100    return self._values
101
102  @property
103  def indices(self):
104    """A 1-D `Tensor` containing the indices of the slices."""
105    return self._indices
106
107  @property
108  def dense_shape(self):
109    """A 1-D `Tensor` containing the shape of the corresponding dense tensor."""
110    return self._dense_shape
111
112  @property
113  def shape(self):
114    """Gets the `tf.TensorShape` representing the shape of the dense tensor.
115
116    Returns:
117      A `tf.TensorShape` object.
118    """
119    if self._dense_shape is None:
120      return tensor_shape.TensorShape(None)
121
122    return tensor_util.constant_value_as_shape(self._dense_shape)
123
124  @property
125  def name(self):
126    """The name of this `IndexedSlices`."""
127    return self.values.name
128
129  @property
130  def device(self):
131    """The name of the device on which `values` will be produced, or `None`."""
132    return self.values.device
133
134  @property
135  def op(self):
136    """The `Operation` that produces `values` as an output."""
137    return self.values.op
138
139  @property
140  def dtype(self):
141    """The `DType` of elements in this tensor."""
142    return self.values.dtype
143
144  @property
145  def graph(self):
146    """The `Graph` that contains the values, indices, and shape tensors."""
147    return self._values.graph
148
149  def __str__(self):
150    return "IndexedSlices(indices=%s, values=%s%s)" % (
151        self._indices, self._values,
152        (", dense_shape=%s" %
153         self._dense_shape) if self._dense_shape is not None else "")
154
155  def __neg__(self):
156    return IndexedSlices(-self.values, self.indices, self.dense_shape)
157
158  @property
159  def _type_spec(self):
160    indices_shape = self._indices.shape.merge_with(self._values.shape[:1])
161    dense_shape = tensor_shape.TensorShape([None]).concatenate(
162        self._values.shape[1:])
163    if self._dense_shape is not None:
164      dense_shape_dtype = self._dense_shape.dtype
165      dense_shape = dense_shape.merge_with(
166          tensor_util.constant_value_as_shape(self._dense_shape))
167    else:
168      dense_shape_dtype = None
169    return IndexedSlicesSpec(dense_shape, self.dtype, self._indices.dtype,
170                             dense_shape_dtype, indices_shape)
171
172  def _shape_invariant_to_type_spec(self, shape):
173    # From tf.while_loop docs: "If a loop variable is an IndexedSlices, the
174    # shape invariant must be a shape invariant of the values tensor of the
175    # IndexedSlices. It means the shapes of the three tensors of the
176    # IndexedSlices are (shape, [shape[0]], [shape.ndims])."
177    indices_shape = shape[:1]
178    dense_shape = tensor_shape.TensorShape([None]).concatenate(shape[1:])
179    if self._dense_shape is None:
180      dense_shape_dtype = None
181    else:
182      dense_shape_dtype = self._dense_shape.dtype
183    return IndexedSlicesSpec(dense_shape, self.dtype, self._indices.dtype,
184                             dense_shape_dtype, indices_shape)
185
186  def consumers(self):
187    return self._consumers()
188
189
190IndexedSlicesValue = collections.namedtuple(
191    "IndexedSlicesValue", ["values", "indices", "dense_shape"])
192
193
194@tf_export("IndexedSlicesSpec")
195class IndexedSlicesSpec(type_spec.TypeSpec):
196  """Type specification for a `tf.IndexedSlices`."""
197
198  __slots__ = ["_shape", "_values_dtype", "_indices_dtype",
199               "_dense_shape_dtype", "_indices_shape"]
200
201  value_type = property(lambda self: IndexedSlices)
202
203  def __init__(self, shape=None, dtype=dtypes.float32,
204               indices_dtype=dtypes.int64, dense_shape_dtype=None,
205               indices_shape=None):
206    """Constructs a type specification for a `tf.IndexedSlices`.
207
208    Args:
209      shape: The dense shape of the `IndexedSlices`, or `None` to allow any
210        dense shape.
211      dtype: `tf.DType` of values in the `IndexedSlices`.
212      indices_dtype: `tf.DType` of the `indices` in the `IndexedSlices`.  One
213        of `tf.int32` or `tf.int64`.
214      dense_shape_dtype: `tf.DType` of the `dense_shape` in the `IndexedSlices`.
215        One of `tf.int32`, `tf.int64`, or `None` (if the `IndexedSlices` has
216        no `dense_shape` tensor).
217      indices_shape: The shape of the `indices` component, which indicates
218        how many slices are in the `IndexedSlices`.
219    """
220    self._shape = tensor_shape.as_shape(shape)
221    self._values_dtype = dtypes.as_dtype(dtype)
222    self._indices_dtype = dtypes.as_dtype(indices_dtype)
223    if dense_shape_dtype is None:
224      self._dense_shape_dtype = None
225    else:
226      self._dense_shape_dtype = dtypes.as_dtype(dense_shape_dtype)
227    self._indices_shape = tensor_shape.as_shape(indices_shape).with_rank(1)
228
229  def _serialize(self):
230    return (self._shape, self._values_dtype, self._indices_dtype,
231            self._dense_shape_dtype, self._indices_shape)
232
233  @property
234  def _component_specs(self):
235    value_shape = self._indices_shape.concatenate(self._shape[1:])
236    specs = [
237        tensor_spec.TensorSpec(value_shape, self._values_dtype),
238        tensor_spec.TensorSpec(self._indices_shape, self._indices_dtype)]
239    if self._dense_shape_dtype is not None:
240      specs.append(
241          tensor_spec.TensorSpec([self._shape.ndims], self._dense_shape_dtype))
242    return tuple(specs)
243
244  def _to_components(self, value):
245    if value.dense_shape is None:
246      return (value.values, value.indices)
247    else:
248      return (value.values, value.indices, value.dense_shape)
249
250  def _from_components(self, tensor_list):
251    if (all(isinstance(t, np.ndarray) for t in tensor_list) and
252        not tf2.enabled()):
253      if len(tensor_list) == 2:
254        return IndexedSlicesValue(tensor_list[0], tensor_list[1], None)
255      else:
256        return IndexedSlicesValue(*tensor_list)
257    else:
258      return IndexedSlices(*tensor_list)
259
260
261@tf_export(v1=["convert_to_tensor_or_indexed_slices"])
262def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
263  """Converts the given object to a `Tensor` or an `IndexedSlices`.
264
265  If `value` is an `IndexedSlices` or `SparseTensor` it is returned
266  unmodified. Otherwise, it is converted to a `Tensor` using
267  `convert_to_tensor()`.
268
269  Args:
270    value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed
271      by `convert_to_tensor()`.
272    dtype: (Optional.) The required `DType` of the returned `Tensor` or
273      `IndexedSlices`.
274    name: (Optional.) A name to use if a new `Tensor` is created.
275
276  Returns:
277    A `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`.
278
279  Raises:
280    ValueError: If `dtype` does not match the element type of `value`.
281  """
282  return internal_convert_to_tensor_or_indexed_slices(
283      value=value, dtype=dtype, name=name, as_ref=False)
284
285
286def internal_convert_to_tensor_or_indexed_slices(value,
287                                                 dtype=None,
288                                                 name=None,
289                                                 as_ref=False):
290  """Converts the given object to a `Tensor` or an `IndexedSlices`.
291
292  If `value` is an `IndexedSlices` or `SparseTensor` it is returned
293  unmodified. Otherwise, it is converted to a `Tensor` using
294  `convert_to_tensor()`.
295
296  Args:
297    value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed
298      by `convert_to_tensor()`.
299    dtype: (Optional.) The required `DType` of the returned `Tensor` or
300      `IndexedSlices`.
301    name: (Optional.) A name to use if a new `Tensor` is created.
302    as_ref: True if the caller wants the results as ref tensors.
303
304  Returns:
305    A `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`.
306
307  Raises:
308    ValueError: If `dtype` does not match the element type of `value`.
309  """
310  if isinstance(value, ops.EagerTensor) and not context.executing_eagerly():
311    return ops.convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref)
312  elif isinstance(value, _TensorLike):
313    if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype):
314      raise ValueError(
315          "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
316          (dtypes.as_dtype(dtype).name, value.dtype.name, str(value)))
317    return value
318  else:
319    return ops.convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref)
320
321
322def internal_convert_n_to_tensor_or_indexed_slices(values,
323                                                   dtype=None,
324                                                   name=None,
325                                                   as_ref=False):
326  """Converts `values` to a list of `Tensor` or `IndexedSlices` objects.
327
328  Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
329  unmodified.
330
331  Args:
332    values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
333      can be consumed by `convert_to_tensor()`.
334    dtype: (Optional.) The required `DType` of the returned `Tensor` or
335      `IndexedSlices`.
336    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
337      which case element `i` will be given the name `name + '_' + i`.
338    as_ref: True if the caller wants the results as ref tensors.
339
340  Returns:
341    A list of `Tensor`, `IndexedSlices`, `SparseTensor` and/or `None` objects.
342
343  Raises:
344    TypeError: If no conversion function is registered for an element in
345      `values`.
346    RuntimeError: If a registered conversion function returns an invalid
347      value.
348  """
349  if not isinstance(values, collections.Sequence):
350    raise TypeError("values must be a sequence.")
351  ret = []
352  for i, value in enumerate(values):
353    if value is None:
354      ret.append(value)
355    else:
356      n = None if name is None else "%s_%d" % (name, i)
357      ret.append(
358          internal_convert_to_tensor_or_indexed_slices(
359              value, dtype=dtype, name=n, as_ref=as_ref))
360  return ret
361
362
363def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None):
364  """Converts `values` to a list of `Output` or `IndexedSlices` objects.
365
366  Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
367  unmodified.
368
369  Args:
370    values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
371      can be consumed by `convert_to_tensor()`.
372    dtype: (Optional.) The required `DType` of the returned `Tensor`
373      `IndexedSlices`.
374    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
375      which case element `i` will be given the name `name + '_' + i`.
376
377  Returns:
378    A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects.
379
380  Raises:
381    TypeError: If no conversion function is registered for an element in
382      `values`.
383    RuntimeError: If a registered conversion function returns an invalid
384      value.
385  """
386  return internal_convert_n_to_tensor_or_indexed_slices(
387      values=values, dtype=dtype, name=name, as_ref=False)
388
389
390# Warn the user if we convert a sparse representation to dense with at
391# least this number of elements.
392_LARGE_SPARSE_NUM_ELEMENTS = 100000000
393
394
395def _indexed_slices_to_tensor(value, dtype=None, name=None, as_ref=False):
396  """Converts an IndexedSlices object `value` to a Tensor.
397
398  NOTE(mrry): This function is potentially expensive.
399
400  Args:
401    value: An ops.IndexedSlices object.
402    dtype: The dtype of the Tensor to be returned.
403    name: Optional name to use for the returned Tensor.
404    as_ref: True if a ref is requested.
405
406  Returns:
407    A dense Tensor representing the values in the given IndexedSlices.
408
409  Raises:
410    ValueError: If the IndexedSlices does not have the same dtype.
411  """
412  _ = as_ref
413  if dtype and not dtype.is_compatible_with(value.dtype):
414    raise ValueError(
415        "Tensor conversion requested dtype %s for IndexedSlices with dtype %s" %
416        (dtype.name, value.dtype.name))
417  if value.dense_shape is None:
418    raise ValueError(
419        "Tensor conversion requested for IndexedSlices without dense_shape: %s"
420        % str(value))
421  # TODO(mrry): Consider adding static shape information to
422  # IndexedSlices, to avoid using numpy here.
423  if not context.executing_eagerly():
424    dense_shape_value = tensor_util.constant_value(value.dense_shape)
425    if dense_shape_value is not None:
426      num_elements = np.prod(dense_shape_value)
427      if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS:
428        warnings.warn(
429            "Converting sparse IndexedSlices to a dense Tensor with %d "
430            "elements. This may consume a large amount of memory." %
431            num_elements)
432    else:
433      warnings.warn(
434          "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
435          "This may consume a large amount of memory.")
436  return math_ops.unsorted_segment_sum(
437      value.values, value.indices, value.dense_shape[0], name=name)
438
439
440tensor_conversion_registry.register_tensor_conversion_function(
441    IndexedSlices, _indexed_slices_to_tensor)
442