1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Indexed slices.""" 16 17# pylint: disable=g-bad-name 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import warnings 24 25import numpy as np 26 27from tensorflow.python import tf2 28from tensorflow.python.eager import context 29from tensorflow.python.framework import composite_tensor 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import tensor_conversion_registry 32from tensorflow.python.framework import tensor_like 33from tensorflow.python.framework import tensor_shape 34from tensorflow.python.framework import type_spec 35from tensorflow.python.util.lazy_loader import LazyLoader 36from tensorflow.python.util.tf_export import tf_export 37 38 39# Use LazyLoader to avoid circular dependencies. 40# 41# Note: these can all be changed to regular imports once all code has been 42# updated to refer the symbols defined in this module directly, rather than 43# using the backwards-compatible aliases in ops.py. (E.g., 44# "indexed_slices.IndexedSlices" rather than "ops.IndexedSlices".) 45math_ops = LazyLoader( 46 "math_ops", globals(), 47 "tensorflow.python.ops.math_ops") 48ops = LazyLoader( 49 "ops", globals(), "tensorflow.python.framework.ops") 50tensor_spec = LazyLoader( 51 "tensor_spec", globals(), 52 "tensorflow.python.framework.tensor_spec") 53tensor_util = LazyLoader( 54 "tensor_util", globals(), 55 "tensorflow.python.framework.tensor_util") 56 57# pylint: disable=protected-access 58_TensorLike = tensor_like._TensorLike 59# pylint: enable=protected-access 60 61 62@tf_export("IndexedSlices") 63class IndexedSlices(_TensorLike, composite_tensor.CompositeTensor): 64 """A sparse representation of a set of tensor slices at given indices. 65 66 This class is a simple wrapper for a pair of `Tensor` objects: 67 68 * `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`. 69 * `indices`: A 1-D integer `Tensor` with shape `[D0]`. 70 71 An `IndexedSlices` is typically used to represent a subset of a larger 72 tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`. 73 The values in `indices` are the indices in the first dimension of 74 the slices that have been extracted from the larger tensor. 75 76 The dense tensor `dense` represented by an `IndexedSlices` `slices` has 77 78 ```python 79 dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...] 80 ``` 81 82 The `IndexedSlices` class is used principally in the definition of 83 gradients for operations that have sparse gradients 84 (e.g. `tf.gather`). 85 86 Contrast this representation with 87 `tf.SparseTensor`, 88 which uses multi-dimensional indices and scalar values. 89 """ 90 91 def __init__(self, values, indices, dense_shape=None): 92 """Creates an `IndexedSlices`.""" 93 self._values = values 94 self._indices = indices 95 self._dense_shape = dense_shape 96 97 @property 98 def values(self): 99 """A `Tensor` containing the values of the slices.""" 100 return self._values 101 102 @property 103 def indices(self): 104 """A 1-D `Tensor` containing the indices of the slices.""" 105 return self._indices 106 107 @property 108 def dense_shape(self): 109 """A 1-D `Tensor` containing the shape of the corresponding dense tensor.""" 110 return self._dense_shape 111 112 @property 113 def shape(self): 114 """Gets the `tf.TensorShape` representing the shape of the dense tensor. 115 116 Returns: 117 A `tf.TensorShape` object. 118 """ 119 if self._dense_shape is None: 120 return tensor_shape.TensorShape(None) 121 122 return tensor_util.constant_value_as_shape(self._dense_shape) 123 124 @property 125 def name(self): 126 """The name of this `IndexedSlices`.""" 127 return self.values.name 128 129 @property 130 def device(self): 131 """The name of the device on which `values` will be produced, or `None`.""" 132 return self.values.device 133 134 @property 135 def op(self): 136 """The `Operation` that produces `values` as an output.""" 137 return self.values.op 138 139 @property 140 def dtype(self): 141 """The `DType` of elements in this tensor.""" 142 return self.values.dtype 143 144 @property 145 def graph(self): 146 """The `Graph` that contains the values, indices, and shape tensors.""" 147 return self._values.graph 148 149 def __str__(self): 150 return "IndexedSlices(indices=%s, values=%s%s)" % ( 151 self._indices, self._values, 152 (", dense_shape=%s" % 153 self._dense_shape) if self._dense_shape is not None else "") 154 155 def __neg__(self): 156 return IndexedSlices(-self.values, self.indices, self.dense_shape) 157 158 @property 159 def _type_spec(self): 160 indices_shape = self._indices.shape.merge_with(self._values.shape[:1]) 161 dense_shape = tensor_shape.TensorShape([None]).concatenate( 162 self._values.shape[1:]) 163 if self._dense_shape is not None: 164 dense_shape_dtype = self._dense_shape.dtype 165 dense_shape = dense_shape.merge_with( 166 tensor_util.constant_value_as_shape(self._dense_shape)) 167 else: 168 dense_shape_dtype = None 169 return IndexedSlicesSpec(dense_shape, self.dtype, self._indices.dtype, 170 dense_shape_dtype, indices_shape) 171 172 def _shape_invariant_to_type_spec(self, shape): 173 # From tf.while_loop docs: "If a loop variable is an IndexedSlices, the 174 # shape invariant must be a shape invariant of the values tensor of the 175 # IndexedSlices. It means the shapes of the three tensors of the 176 # IndexedSlices are (shape, [shape[0]], [shape.ndims])." 177 indices_shape = shape[:1] 178 dense_shape = tensor_shape.TensorShape([None]).concatenate(shape[1:]) 179 if self._dense_shape is None: 180 dense_shape_dtype = None 181 else: 182 dense_shape_dtype = self._dense_shape.dtype 183 return IndexedSlicesSpec(dense_shape, self.dtype, self._indices.dtype, 184 dense_shape_dtype, indices_shape) 185 186 def consumers(self): 187 return self._consumers() 188 189 190IndexedSlicesValue = collections.namedtuple( 191 "IndexedSlicesValue", ["values", "indices", "dense_shape"]) 192 193 194@tf_export("IndexedSlicesSpec") 195class IndexedSlicesSpec(type_spec.TypeSpec): 196 """Type specification for a `tf.IndexedSlices`.""" 197 198 __slots__ = ["_shape", "_values_dtype", "_indices_dtype", 199 "_dense_shape_dtype", "_indices_shape"] 200 201 value_type = property(lambda self: IndexedSlices) 202 203 def __init__(self, shape=None, dtype=dtypes.float32, 204 indices_dtype=dtypes.int64, dense_shape_dtype=None, 205 indices_shape=None): 206 """Constructs a type specification for a `tf.IndexedSlices`. 207 208 Args: 209 shape: The dense shape of the `IndexedSlices`, or `None` to allow any 210 dense shape. 211 dtype: `tf.DType` of values in the `IndexedSlices`. 212 indices_dtype: `tf.DType` of the `indices` in the `IndexedSlices`. One 213 of `tf.int32` or `tf.int64`. 214 dense_shape_dtype: `tf.DType` of the `dense_shape` in the `IndexedSlices`. 215 One of `tf.int32`, `tf.int64`, or `None` (if the `IndexedSlices` has 216 no `dense_shape` tensor). 217 indices_shape: The shape of the `indices` component, which indicates 218 how many slices are in the `IndexedSlices`. 219 """ 220 self._shape = tensor_shape.as_shape(shape) 221 self._values_dtype = dtypes.as_dtype(dtype) 222 self._indices_dtype = dtypes.as_dtype(indices_dtype) 223 if dense_shape_dtype is None: 224 self._dense_shape_dtype = None 225 else: 226 self._dense_shape_dtype = dtypes.as_dtype(dense_shape_dtype) 227 self._indices_shape = tensor_shape.as_shape(indices_shape).with_rank(1) 228 229 def _serialize(self): 230 return (self._shape, self._values_dtype, self._indices_dtype, 231 self._dense_shape_dtype, self._indices_shape) 232 233 @property 234 def _component_specs(self): 235 value_shape = self._indices_shape.concatenate(self._shape[1:]) 236 specs = [ 237 tensor_spec.TensorSpec(value_shape, self._values_dtype), 238 tensor_spec.TensorSpec(self._indices_shape, self._indices_dtype)] 239 if self._dense_shape_dtype is not None: 240 specs.append( 241 tensor_spec.TensorSpec([self._shape.ndims], self._dense_shape_dtype)) 242 return tuple(specs) 243 244 def _to_components(self, value): 245 if value.dense_shape is None: 246 return (value.values, value.indices) 247 else: 248 return (value.values, value.indices, value.dense_shape) 249 250 def _from_components(self, tensor_list): 251 if (all(isinstance(t, np.ndarray) for t in tensor_list) and 252 not tf2.enabled()): 253 if len(tensor_list) == 2: 254 return IndexedSlicesValue(tensor_list[0], tensor_list[1], None) 255 else: 256 return IndexedSlicesValue(*tensor_list) 257 else: 258 return IndexedSlices(*tensor_list) 259 260 261@tf_export(v1=["convert_to_tensor_or_indexed_slices"]) 262def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None): 263 """Converts the given object to a `Tensor` or an `IndexedSlices`. 264 265 If `value` is an `IndexedSlices` or `SparseTensor` it is returned 266 unmodified. Otherwise, it is converted to a `Tensor` using 267 `convert_to_tensor()`. 268 269 Args: 270 value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed 271 by `convert_to_tensor()`. 272 dtype: (Optional.) The required `DType` of the returned `Tensor` or 273 `IndexedSlices`. 274 name: (Optional.) A name to use if a new `Tensor` is created. 275 276 Returns: 277 A `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`. 278 279 Raises: 280 ValueError: If `dtype` does not match the element type of `value`. 281 """ 282 return internal_convert_to_tensor_or_indexed_slices( 283 value=value, dtype=dtype, name=name, as_ref=False) 284 285 286def internal_convert_to_tensor_or_indexed_slices(value, 287 dtype=None, 288 name=None, 289 as_ref=False): 290 """Converts the given object to a `Tensor` or an `IndexedSlices`. 291 292 If `value` is an `IndexedSlices` or `SparseTensor` it is returned 293 unmodified. Otherwise, it is converted to a `Tensor` using 294 `convert_to_tensor()`. 295 296 Args: 297 value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed 298 by `convert_to_tensor()`. 299 dtype: (Optional.) The required `DType` of the returned `Tensor` or 300 `IndexedSlices`. 301 name: (Optional.) A name to use if a new `Tensor` is created. 302 as_ref: True if the caller wants the results as ref tensors. 303 304 Returns: 305 A `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`. 306 307 Raises: 308 ValueError: If `dtype` does not match the element type of `value`. 309 """ 310 if isinstance(value, ops.EagerTensor) and not context.executing_eagerly(): 311 return ops.convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref) 312 elif isinstance(value, _TensorLike): 313 if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype): 314 raise ValueError( 315 "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % 316 (dtypes.as_dtype(dtype).name, value.dtype.name, str(value))) 317 return value 318 else: 319 return ops.convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref) 320 321 322def internal_convert_n_to_tensor_or_indexed_slices(values, 323 dtype=None, 324 name=None, 325 as_ref=False): 326 """Converts `values` to a list of `Tensor` or `IndexedSlices` objects. 327 328 Any `IndexedSlices` or `SparseTensor` objects in `values` are returned 329 unmodified. 330 331 Args: 332 values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that 333 can be consumed by `convert_to_tensor()`. 334 dtype: (Optional.) The required `DType` of the returned `Tensor` or 335 `IndexedSlices`. 336 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 337 which case element `i` will be given the name `name + '_' + i`. 338 as_ref: True if the caller wants the results as ref tensors. 339 340 Returns: 341 A list of `Tensor`, `IndexedSlices`, `SparseTensor` and/or `None` objects. 342 343 Raises: 344 TypeError: If no conversion function is registered for an element in 345 `values`. 346 RuntimeError: If a registered conversion function returns an invalid 347 value. 348 """ 349 if not isinstance(values, collections.Sequence): 350 raise TypeError("values must be a sequence.") 351 ret = [] 352 for i, value in enumerate(values): 353 if value is None: 354 ret.append(value) 355 else: 356 n = None if name is None else "%s_%d" % (name, i) 357 ret.append( 358 internal_convert_to_tensor_or_indexed_slices( 359 value, dtype=dtype, name=n, as_ref=as_ref)) 360 return ret 361 362 363def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None): 364 """Converts `values` to a list of `Output` or `IndexedSlices` objects. 365 366 Any `IndexedSlices` or `SparseTensor` objects in `values` are returned 367 unmodified. 368 369 Args: 370 values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that 371 can be consumed by `convert_to_tensor()`. 372 dtype: (Optional.) The required `DType` of the returned `Tensor` 373 `IndexedSlices`. 374 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 375 which case element `i` will be given the name `name + '_' + i`. 376 377 Returns: 378 A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects. 379 380 Raises: 381 TypeError: If no conversion function is registered for an element in 382 `values`. 383 RuntimeError: If a registered conversion function returns an invalid 384 value. 385 """ 386 return internal_convert_n_to_tensor_or_indexed_slices( 387 values=values, dtype=dtype, name=name, as_ref=False) 388 389 390# Warn the user if we convert a sparse representation to dense with at 391# least this number of elements. 392_LARGE_SPARSE_NUM_ELEMENTS = 100000000 393 394 395def _indexed_slices_to_tensor(value, dtype=None, name=None, as_ref=False): 396 """Converts an IndexedSlices object `value` to a Tensor. 397 398 NOTE(mrry): This function is potentially expensive. 399 400 Args: 401 value: An ops.IndexedSlices object. 402 dtype: The dtype of the Tensor to be returned. 403 name: Optional name to use for the returned Tensor. 404 as_ref: True if a ref is requested. 405 406 Returns: 407 A dense Tensor representing the values in the given IndexedSlices. 408 409 Raises: 410 ValueError: If the IndexedSlices does not have the same dtype. 411 """ 412 _ = as_ref 413 if dtype and not dtype.is_compatible_with(value.dtype): 414 raise ValueError( 415 "Tensor conversion requested dtype %s for IndexedSlices with dtype %s" % 416 (dtype.name, value.dtype.name)) 417 if value.dense_shape is None: 418 raise ValueError( 419 "Tensor conversion requested for IndexedSlices without dense_shape: %s" 420 % str(value)) 421 # TODO(mrry): Consider adding static shape information to 422 # IndexedSlices, to avoid using numpy here. 423 if not context.executing_eagerly(): 424 dense_shape_value = tensor_util.constant_value(value.dense_shape) 425 if dense_shape_value is not None: 426 num_elements = np.prod(dense_shape_value) 427 if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS: 428 warnings.warn( 429 "Converting sparse IndexedSlices to a dense Tensor with %d " 430 "elements. This may consume a large amount of memory." % 431 num_elements) 432 else: 433 warnings.warn( 434 "Converting sparse IndexedSlices to a dense Tensor of unknown shape. " 435 "This may consume a large amount of memory.") 436 return math_ops.unsorted_segment_sum( 437 value.values, value.indices, value.dense_shape[0], name=name) 438 439 440tensor_conversion_registry.register_tensor_conversion_function( 441 IndexedSlices, _indexed_slices_to_tensor) 442