1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Indexed slices.""" 16 17# pylint: disable=g-bad-name 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import warnings 24 25import numpy as np 26 27from tensorflow.python import tf2 28from tensorflow.python.eager import context 29from tensorflow.python.framework import composite_tensor 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import tensor_conversion_registry 32from tensorflow.python.framework import tensor_shape 33from tensorflow.python.framework import type_spec 34from tensorflow.python.types import internal 35from tensorflow.python.util.compat import collections_abc 36from tensorflow.python.util.lazy_loader import LazyLoader 37from tensorflow.python.util.tf_export import tf_export 38 39 40# Use LazyLoader to avoid circular dependencies. 41# 42# Note: these can all be changed to regular imports once all code has been 43# updated to refer the symbols defined in this module directly, rather than 44# using the backwards-compatible aliases in ops.py. (E.g., 45# "indexed_slices.IndexedSlices" rather than "ops.IndexedSlices".) 46math_ops = LazyLoader( 47 "math_ops", globals(), 48 "tensorflow.python.ops.math_ops") 49ops = LazyLoader( 50 "ops", globals(), "tensorflow.python.framework.ops") 51tensor_spec = LazyLoader( 52 "tensor_spec", globals(), 53 "tensorflow.python.framework.tensor_spec") 54tensor_util = LazyLoader( 55 "tensor_util", globals(), 56 "tensorflow.python.framework.tensor_util") 57 58 59# TODO(mdan): Should IndexedSlices be a "tensor"? 60@tf_export("IndexedSlices") 61class IndexedSlices(internal.NativeObject, composite_tensor.CompositeTensor): 62 """A sparse representation of a set of tensor slices at given indices. 63 64 This class is a simple wrapper for a pair of `Tensor` objects: 65 66 * `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`. 67 * `indices`: A 1-D integer `Tensor` with shape `[D0]`. 68 69 An `IndexedSlices` is typically used to represent a subset of a larger 70 tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`. 71 The values in `indices` are the indices in the first dimension of 72 the slices that have been extracted from the larger tensor. 73 74 The dense tensor `dense` represented by an `IndexedSlices` `slices` has 75 76 ```python 77 dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...] 78 ``` 79 80 The `IndexedSlices` class is used principally in the definition of 81 gradients for operations that have sparse gradients 82 (e.g. `tf.gather`). 83 84 >>> v = tf.Variable([[0.,1, 2], [2, 3, 4], [4, 5, 6], [6, 7, 8]]) 85 >>> with tf.GradientTape() as tape: 86 ... r = tf.gather(v, [1,3]) 87 >>> index_slices = tape.gradient(r,v) 88 >>> index_slices 89 <...IndexedSlices object ...> 90 >>> index_slices.indices.numpy() 91 array([1, 3], dtype=int32) 92 >>> index_slices.values.numpy() 93 array([[1., 1., 1.], 94 [1., 1., 1.]], dtype=float32) 95 96 Contrast this representation with 97 `tf.sparse.SparseTensor`, 98 which uses multi-dimensional indices and scalar values. 99 """ 100 101 def __init__(self, values, indices, dense_shape=None): 102 """Creates an `IndexedSlices`.""" 103 self._values = values 104 self._indices = indices 105 self._dense_shape = dense_shape 106 107 @property 108 def values(self): 109 """A `Tensor` containing the values of the slices.""" 110 return self._values 111 112 @property 113 def indices(self): 114 """A 1-D `Tensor` containing the indices of the slices.""" 115 return self._indices 116 117 @property 118 def dense_shape(self): 119 """A 1-D `Tensor` containing the shape of the corresponding dense tensor.""" 120 return self._dense_shape 121 122 @property 123 def shape(self): 124 """Gets the `tf.TensorShape` representing the shape of the dense tensor. 125 126 Returns: 127 A `tf.TensorShape` object. 128 """ 129 if self._dense_shape is None: 130 return tensor_shape.TensorShape(None) 131 132 return tensor_util.constant_value_as_shape(self._dense_shape) 133 134 @property 135 def name(self): 136 """The name of this `IndexedSlices`.""" 137 return self.values.name 138 139 @property 140 def device(self): 141 """The name of the device on which `values` will be produced, or `None`.""" 142 return self.values.device 143 144 @property 145 def op(self): 146 """The `Operation` that produces `values` as an output.""" 147 return self.values.op 148 149 @property 150 def dtype(self): 151 """The `DType` of elements in this tensor.""" 152 return self.values.dtype 153 154 @property 155 def graph(self): 156 """The `Graph` that contains the values, indices, and shape tensors.""" 157 return self._values.graph 158 159 def __str__(self): 160 return "IndexedSlices(indices=%s, values=%s%s)" % ( 161 self._indices, self._values, 162 (", dense_shape=%s" % 163 (self._dense_shape,)) if self._dense_shape is not None else "") 164 165 def __neg__(self): 166 return IndexedSlices(-self.values, self.indices, self.dense_shape) 167 168 @property 169 def _type_spec(self): 170 indices_shape = self._indices.shape.merge_with(self._values.shape[:1]) 171 dense_shape = tensor_shape.TensorShape([None]).concatenate( 172 self._values.shape[1:]) 173 if self._dense_shape is not None: 174 dense_shape_dtype = self._dense_shape.dtype 175 dense_shape = dense_shape.merge_with( 176 tensor_util.constant_value_as_shape(self._dense_shape)) 177 else: 178 dense_shape_dtype = None 179 return IndexedSlicesSpec(dense_shape, self.dtype, self._indices.dtype, 180 dense_shape_dtype, indices_shape) 181 182 def _shape_invariant_to_type_spec(self, shape): 183 # From tf.while_loop docs: "If a loop variable is an IndexedSlices, the 184 # shape invariant must be a shape invariant of the values tensor of the 185 # IndexedSlices. It means the shapes of the three tensors of the 186 # IndexedSlices are (shape, [shape[0]], [shape.ndims])." 187 indices_shape = shape[:1] 188 dense_shape = tensor_shape.TensorShape([None]).concatenate(shape[1:]) 189 if self._dense_shape is None: 190 dense_shape_dtype = None 191 else: 192 dense_shape_dtype = self._dense_shape.dtype 193 return IndexedSlicesSpec(dense_shape, self.dtype, self._indices.dtype, 194 dense_shape_dtype, indices_shape) 195 196 def consumers(self): 197 return self._consumers() 198 199 200IndexedSlicesValue = collections.namedtuple( 201 "IndexedSlicesValue", ["values", "indices", "dense_shape"]) 202 203 204@tf_export("IndexedSlicesSpec") 205class IndexedSlicesSpec(type_spec.TypeSpec): 206 """Type specification for a `tf.IndexedSlices`.""" 207 208 __slots__ = ["_shape", "_values_dtype", "_indices_dtype", 209 "_dense_shape_dtype", "_indices_shape"] 210 211 value_type = property(lambda self: IndexedSlices) 212 213 def __init__(self, shape=None, dtype=dtypes.float32, 214 indices_dtype=dtypes.int64, dense_shape_dtype=None, 215 indices_shape=None): 216 """Constructs a type specification for a `tf.IndexedSlices`. 217 218 Args: 219 shape: The dense shape of the `IndexedSlices`, or `None` to allow any 220 dense shape. 221 dtype: `tf.DType` of values in the `IndexedSlices`. 222 indices_dtype: `tf.DType` of the `indices` in the `IndexedSlices`. One 223 of `tf.int32` or `tf.int64`. 224 dense_shape_dtype: `tf.DType` of the `dense_shape` in the `IndexedSlices`. 225 One of `tf.int32`, `tf.int64`, or `None` (if the `IndexedSlices` has 226 no `dense_shape` tensor). 227 indices_shape: The shape of the `indices` component, which indicates 228 how many slices are in the `IndexedSlices`. 229 """ 230 self._shape = tensor_shape.as_shape(shape) 231 self._values_dtype = dtypes.as_dtype(dtype) 232 self._indices_dtype = dtypes.as_dtype(indices_dtype) 233 if dense_shape_dtype is None: 234 self._dense_shape_dtype = None 235 else: 236 self._dense_shape_dtype = dtypes.as_dtype(dense_shape_dtype) 237 self._indices_shape = tensor_shape.as_shape(indices_shape).with_rank(1) 238 239 def _serialize(self): 240 return (self._shape, self._values_dtype, self._indices_dtype, 241 self._dense_shape_dtype, self._indices_shape) 242 243 @property 244 def _component_specs(self): 245 value_shape = self._indices_shape.concatenate(self._shape[1:]) 246 specs = [ 247 tensor_spec.TensorSpec(value_shape, self._values_dtype), 248 tensor_spec.TensorSpec(self._indices_shape, self._indices_dtype)] 249 if self._dense_shape_dtype is not None: 250 specs.append( 251 tensor_spec.TensorSpec([self._shape.ndims], self._dense_shape_dtype)) 252 return tuple(specs) 253 254 def _to_components(self, value): 255 if value.dense_shape is None: 256 return (value.values, value.indices) 257 else: 258 return (value.values, value.indices, value.dense_shape) 259 260 def _from_components(self, tensor_list): 261 if (all(isinstance(t, np.ndarray) for t in tensor_list) and 262 not tf2.enabled()): 263 if len(tensor_list) == 2: 264 return IndexedSlicesValue(tensor_list[0], tensor_list[1], None) 265 else: 266 return IndexedSlicesValue(*tensor_list) 267 else: 268 return IndexedSlices(*tensor_list) 269 270 271@tf_export(v1=["convert_to_tensor_or_indexed_slices"]) 272def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None): 273 """Converts the given object to a `Tensor` or an `IndexedSlices`. 274 275 If `value` is an `IndexedSlices` or `SparseTensor` it is returned 276 unmodified. Otherwise, it is converted to a `Tensor` using 277 `convert_to_tensor()`. 278 279 Args: 280 value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed 281 by `convert_to_tensor()`. 282 dtype: (Optional.) The required `DType` of the returned `Tensor` or 283 `IndexedSlices`. 284 name: (Optional.) A name to use if a new `Tensor` is created. 285 286 Returns: 287 A `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`. 288 289 Raises: 290 ValueError: If `dtype` does not match the element type of `value`. 291 """ 292 return internal_convert_to_tensor_or_indexed_slices( 293 value=value, dtype=dtype, name=name, as_ref=False) 294 295 296def internal_convert_to_tensor_or_indexed_slices(value, 297 dtype=None, 298 name=None, 299 as_ref=False): 300 """Converts the given object to a `Tensor` or an `IndexedSlices`. 301 302 If `value` is an `IndexedSlices` or `SparseTensor` it is returned 303 unmodified. Otherwise, it is converted to a `Tensor` using 304 `convert_to_tensor()`. 305 306 Args: 307 value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed 308 by `convert_to_tensor()`. 309 dtype: (Optional.) The required `DType` of the returned `Tensor` or 310 `IndexedSlices`. 311 name: (Optional.) A name to use if a new `Tensor` is created. 312 as_ref: True if the caller wants the results as ref tensors. 313 314 Returns: 315 A `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`. 316 317 Raises: 318 ValueError: If `dtype` does not match the element type of `value`. 319 """ 320 if isinstance(value, ops.EagerTensor) and not context.executing_eagerly(): 321 return ops.convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref) 322 # TODO(mdan): Name says tensor_or_indexed_slices. So do explicitly just that? 323 elif isinstance(value, internal.NativeObject): 324 if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype): 325 raise ValueError( 326 "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % 327 (dtypes.as_dtype(dtype).name, value.dtype.name, str(value))) 328 return value 329 else: 330 return ops.convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref) 331 332 333def internal_convert_n_to_tensor_or_indexed_slices(values, 334 dtype=None, 335 name=None, 336 as_ref=False): 337 """Converts `values` to a list of `Tensor` or `IndexedSlices` objects. 338 339 Any `IndexedSlices` or `SparseTensor` objects in `values` are returned 340 unmodified. 341 342 Args: 343 values: An iterable of `None`, `IndexedSlices`, `SparseTensor`, or objects 344 that can be consumed by `convert_to_tensor()`. 345 dtype: (Optional.) The required `DType` of the returned `Tensor` or 346 `IndexedSlices`. 347 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 348 which case element `i` will be given the name `name + '_' + i`. 349 as_ref: True if the caller wants the results as ref tensors. 350 351 Returns: 352 A list of `Tensor`, `IndexedSlices`, `SparseTensor` and/or `None` objects. 353 354 Raises: 355 TypeError: If no conversion function is registered for an element in 356 `values`. 357 RuntimeError: If a registered conversion function returns an invalid 358 value. 359 """ 360 if not isinstance(values, collections_abc.Iterable): 361 raise TypeError("values must be iterable.") 362 ret = [] 363 for i, value in enumerate(values): 364 if value is None: 365 ret.append(value) 366 else: 367 n = None if name is None else "%s_%d" % (name, i) 368 ret.append( 369 internal_convert_to_tensor_or_indexed_slices( 370 value, dtype=dtype, name=n, as_ref=as_ref)) 371 return ret 372 373 374def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None): 375 """Converts `values` to a list of `Output` or `IndexedSlices` objects. 376 377 Any `IndexedSlices` or `SparseTensor` objects in `values` are returned 378 unmodified. 379 380 Args: 381 values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that 382 can be consumed by `convert_to_tensor()`. 383 dtype: (Optional.) The required `DType` of the returned `Tensor` 384 `IndexedSlices`. 385 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 386 which case element `i` will be given the name `name + '_' + i`. 387 388 Returns: 389 A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects. 390 391 Raises: 392 TypeError: If no conversion function is registered for an element in 393 `values`. 394 RuntimeError: If a registered conversion function returns an invalid 395 value. 396 """ 397 return internal_convert_n_to_tensor_or_indexed_slices( 398 values=values, dtype=dtype, name=name, as_ref=False) 399 400 401# Warn the user if we convert a sparse representation to dense with at 402# least this number of elements. 403_LARGE_SPARSE_NUM_ELEMENTS = 100000000 404 405 406def _indexed_slices_to_tensor(value, dtype=None, name=None, as_ref=False): 407 """Converts an IndexedSlices object `value` to a Tensor. 408 409 NOTE(mrry): This function is potentially expensive. 410 411 Args: 412 value: An ops.IndexedSlices object. 413 dtype: The dtype of the Tensor to be returned. 414 name: Optional name to use for the returned Tensor. 415 as_ref: True if a ref is requested. 416 417 Returns: 418 A dense Tensor representing the values in the given IndexedSlices. 419 420 Raises: 421 ValueError: If the IndexedSlices does not have the same dtype. 422 """ 423 _ = as_ref 424 if dtype and not dtype.is_compatible_with(value.dtype): 425 raise ValueError( 426 "Tensor conversion requested dtype %s for IndexedSlices with dtype %s" % 427 (dtype.name, value.dtype.name)) 428 if value.dense_shape is None: 429 raise ValueError( 430 "Tensor conversion requested for IndexedSlices without dense_shape: %s" 431 % str(value)) 432 # TODO(mrry): Consider adding static shape information to 433 # IndexedSlices, to avoid using numpy here. 434 if not context.executing_eagerly(): 435 dense_shape_value = tensor_util.constant_value(value.dense_shape) 436 if dense_shape_value is not None: 437 num_elements = np.prod(dense_shape_value) 438 if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS: 439 warnings.warn( 440 "Converting sparse IndexedSlices to a dense Tensor with %d " 441 "elements. This may consume a large amount of memory." % 442 num_elements) 443 else: 444 if value.dense_shape.op.type != "VariableShape": 445 # VariableShape may hide static shapes behind a resource handle 446 # producing a warning that isn't that useful to users. 447 warnings.warn( 448 "Converting sparse IndexedSlices(%s) to a dense Tensor of unknown " 449 "shape. This may consume a large amount of memory." % value) 450 return math_ops.unsorted_segment_sum( 451 value.values, value.indices, value.dense_shape[0], name=name) 452 453 454tensor_conversion_registry.register_tensor_conversion_function( 455 IndexedSlices, _indexed_slices_to_tensor) 456