1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Embedding functions.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from six.moves import xrange # pylint: disable=redefined-builtin 21 22from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util 23from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op 24 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import clip_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import data_flow_ops 34from tensorflow.python.ops import embedding_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import resource_variable_ops 37from tensorflow.python.ops import sparse_ops 38from tensorflow.python.ops import variables 39from tensorflow.python.platform import tf_logging as logging 40 41__all__ = [ 42 "safe_embedding_lookup_sparse", "scattered_embedding_lookup", 43 "scattered_embedding_lookup_sparse", "embedding_lookup_unique", 44 "embedding_lookup_sparse_with_distributed_aggregation" 45] 46 47 48def safe_embedding_lookup_sparse(embedding_weights, 49 sparse_ids, 50 sparse_weights=None, 51 combiner=None, 52 default_id=None, 53 name=None, 54 partition_strategy="div", 55 max_norm=None): 56 """Lookup embedding results, accounting for invalid IDs and empty features. 57 58 The partitioned embedding in `embedding_weights` must all be the same shape 59 except for the first dimension. The first dimension is allowed to vary as the 60 vocabulary size is not necessarily a multiple of `P`. `embedding_weights` 61 may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a 62 partitioner. 63 64 Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs 65 with non-positive weight. For an entry with no features, the embedding vector 66 for `default_id` is returned, or the 0-vector if `default_id` is not supplied. 67 68 The ids and weights may be multi-dimensional. Embeddings are always aggregated 69 along the last dimension. 70 71 Args: 72 embedding_weights: A list of `P` float tensors or values representing 73 partitioned embedding tensors. Alternatively, a `PartitionedVariable`, 74 created by partitioning along dimension 0. The total unpartitioned 75 shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the 76 vocab size and `e_1, ..., e_m` are the embedding dimensions. 77 sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the 78 ids. `d_0` is typically batch size. 79 sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing 80 float weights corresponding to `sparse_ids`, or `None` if all weights 81 are be assumed to be 1.0. 82 combiner: A string specifying how to combine embedding results for each 83 entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" 84 the default. 85 default_id: The id to use for an entry with no features. 86 name: A name for this operation (optional). 87 partition_strategy: A string specifying the partitioning strategy. 88 Currently `"div"` and `"mod"` are supported. Default is `"div"`. 89 max_norm: If not None, all embeddings are l2-normalized to max_norm before 90 combining. 91 92 93 Returns: 94 Dense tensor of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`. 95 96 Raises: 97 ValueError: if `embedding_weights` is empty. 98 """ 99 if combiner is None: 100 logging.warn("The default value of combiner will change from \"mean\" " 101 "to \"sqrtn\" after 2016/11/01.") 102 combiner = "mean" 103 if embedding_weights is None: 104 raise ValueError("Missing embedding_weights %s." % embedding_weights) 105 if isinstance(embedding_weights, variables.PartitionedVariable): 106 embedding_weights = list(embedding_weights) # get underlying Variables. 107 if not isinstance(embedding_weights, list): 108 embedding_weights = [embedding_weights] 109 if len(embedding_weights) < 1: 110 raise ValueError("Missing embedding_weights %s." % embedding_weights) 111 112 dtype = sparse_weights.dtype if sparse_weights is not None else None 113 if isinstance(embedding_weights, variables.PartitionedVariable): 114 embedding_weights = list(embedding_weights) 115 embedding_weights = [ 116 ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights 117 ] 118 119 contrib_tensor_util.assert_same_float_dtype(embedding_weights + 120 [sparse_weights]) 121 122 with ops.name_scope(name, "embedding_lookup", 123 embedding_weights + [sparse_ids, 124 sparse_weights]) as scope: 125 # Reshape higher-rank sparse ids and weights to linear segment ids. 126 original_shape = sparse_ids.dense_shape 127 original_rank_dim = tensor_shape.Dimension(tensor_shape.dimension_value( 128 sparse_ids.dense_shape.get_shape()[0])) 129 original_rank = ( 130 array_ops.size(original_shape) 131 if original_rank_dim.value is None 132 else original_rank_dim.value) 133 sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [ 134 math_ops.reduce_prod( 135 array_ops.slice(original_shape, [0], [original_rank - 1])), 136 array_ops.gather(original_shape, original_rank - 1)]) 137 if sparse_weights is not None: 138 sparse_weights = sparse_tensor.SparseTensor( 139 sparse_ids.indices, 140 sparse_weights.values, sparse_ids.dense_shape) 141 142 # Prune invalid ids and weights. 143 sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights) 144 if combiner != "sum": 145 sparse_ids, sparse_weights = _prune_invalid_weights( 146 sparse_ids, sparse_weights) 147 148 # Fill in dummy values for empty features, if necessary. 149 sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids, 150 default_id or 151 0) 152 if sparse_weights is not None: 153 sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0) 154 155 result = embedding_ops.embedding_lookup_sparse( 156 embedding_weights, 157 sparse_ids, 158 sparse_weights, 159 combiner=combiner, 160 partition_strategy=partition_strategy, 161 name=None if default_id is None else scope, 162 max_norm=max_norm) 163 164 if default_id is None: 165 # Broadcast is_row_empty to the same shape as embedding_lookup_result, 166 # for use in Select. 167 is_row_empty = array_ops.tile( 168 array_ops.reshape(is_row_empty, [-1, 1]), 169 array_ops.stack([1, array_ops.shape(result)[1]])) 170 171 result = array_ops.where(is_row_empty, 172 array_ops.zeros_like(result), 173 result, 174 name=scope) 175 176 # Reshape back from linear ids back into higher-dimensional dense result. 177 final_result = array_ops.reshape( 178 result, 179 array_ops.concat([ 180 array_ops.slice( 181 math_ops.cast(original_shape, dtypes.int32), [0], 182 [original_rank - 1]), 183 array_ops.slice(array_ops.shape(result), [1], [-1]) 184 ], 0)) 185 final_result.set_shape(tensor_shape.unknown_shape( 186 (original_rank_dim - 1).value).concatenate(result.get_shape()[1:])) 187 return final_result 188 189 190def _prune_invalid_ids(sparse_ids, sparse_weights): 191 """Prune invalid IDs (< 0) from the input ids and weights.""" 192 is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) 193 if sparse_weights is not None: 194 is_id_valid = math_ops.logical_and( 195 is_id_valid, 196 array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool)) 197 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) 198 if sparse_weights is not None: 199 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) 200 return sparse_ids, sparse_weights 201 202 203def _prune_invalid_weights(sparse_ids, sparse_weights): 204 """Prune invalid weights (< 0) from the input ids and weights.""" 205 if sparse_weights is not None: 206 is_weights_valid = math_ops.greater(sparse_weights.values, 0) 207 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid) 208 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid) 209 return sparse_ids, sparse_weights 210 211 212def scattered_embedding_lookup(params, 213 values, 214 dimension, 215 name=None, 216 hash_key=None): 217 """Looks up embeddings using parameter hashing for each value in `values`. 218 219 The i-th embedding component of a value v in `values` is found by retrieving 220 the weight whose index is a fingerprint of the pair (v,i). 221 The concept is explored as "feature hashing" for model compression in this 222 paper: http://arxiv.org/pdf/1504.04788.pdf 223 224 Feature hashing has the pleasant effect of allowing us to compute an embedding 225 without needing a pre-determined vocabulary, relieving some amount of process 226 complexity. It also allows for us to maintain embeddings for possibly 227 trillions of features with a fixed amount of memory. 228 229 Note that this is superior to out-of-vocabulary shared "hash buckets" in that 230 the embedding is extremely likely to be unique for each token as opposed to 231 being shared across probably-colliding tokens. The price is that we must 232 compute a hash once for each scalar in the token's embedding as opposed to 233 once per token. 234 235 If `params` is a list, it represents a partition of the embedding parameters. 236 Each tensor in the list should have the same length, except for the first ones 237 which may have an additional element. For instance 10 parameters can be 238 partitioned in 4 tensors with length `[3, 3, 2, 2]`. 239 240 Args: 241 params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`. 242 Each tensor must be of rank 1 with fully-defined shape. 243 values: `Tensor` of values to be embedded with shape `[d0, ..., dn]`. 244 dimension: Embedding dimension. 245 name: An optional name for this op. 246 hash_key: Specify the hash_key that will be used by the `FingerprintCat64` 247 function to combine the crosses fingerprints on SparseFeatureCrossOp 248 (optional). 249 250 Returns: 251 A `Tensor` with shape `[d0, ..., dn, dimension]`. 252 253 Raises: 254 ValueError: if dimension is not positive or the partition size is invalid. 255 """ 256 if dimension is None: 257 raise ValueError("You must specify dimension.") 258 return _sampled_scattered_embedding_lookup( 259 params, values, dimension=dimension, sampled_candidates=None, 260 hash_key=hash_key, name=name) 261 262 263def _sampled_scattered_embedding_lookup( 264 params, values, dimension=None, sampled_candidates=None, hash_key=None, 265 name=None): 266 """Looks up embeddings using parameter hashing for each value in `values`. 267 268 This method looks up selected embedding dimensions if `sampled_candidates` is 269 given, otherwise looks up all dimensions. 270 271 The i-th embedding component of a value v in `values` is found by retrieving 272 the weight whose index is a fingerprint of the pair (v,i). 273 The concept is explored as "feature hashing" for model compression in this 274 paper: http://arxiv.org/pdf/1504.04788.pdf 275 276 Feature hashing has the pleasant effect of allowing us to compute an embedding 277 without needing a pre-determined vocabulary, relieving some amount of process 278 complexity. It also allows for us to maintain embeddings for possibly 279 trillions of features with a fixed amount of memory. 280 281 Note that this is superior to out-of-vocabulary shared "hash buckets" in that 282 the embedding is extremely likely to be unique for each token as opposed to 283 being shared across probably-colliding tokens. The price is that we must 284 compute a hash once for each scalar in the token's embedding as opposed to 285 once per token. 286 287 If `params` is a list, it represents a partition of the embedding parameters. 288 Each tensor in the list should have the same length, except for the first ones 289 which may have an additional element. For instance 10 parameters can be 290 partitioned in 4 tensors with length `[3, 3, 2, 2]`. 291 292 Args: 293 params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`. 294 Each tensor must be of rank 1 with fully-defined shape. 295 values: `Tensor` of values to be embedded with shape `[d0, ..., dn]`. 296 dimension: Embedding dimension. The user must specify either `dimension` or 297 `sampled_candidates`. 298 sampled_candidates: An optional `Tensor` of slice indices to keep along the 299 final dimension with shape `[d0, ..., dn, N]`. If given, `dimension` is 300 ignored. If `None`, looks up all candidates. 301 hash_key: Specify the hash_key that will be used by the `FingerprintCat64` 302 function to combine the crosses fingerprints on SparseFeatureCrossOp 303 (optional). 304 name: An optional name for this op. 305 306 Returns: 307 A `Tensor` with shape `[d0, ..., dn, dimension]`. 308 If `sampled_candidates` is given, the output shape is `[d0, ..., dn, N]` 309 310 Raises: 311 ValueError: if dimension is not positive or the partition size is invalid. 312 """ 313 if isinstance(params, variables.PartitionedVariable): 314 params = list(params) 315 if not isinstance(params, list): 316 params = [params] 317 318 with ops.name_scope(name, "scattered_embedding_lookup", 319 params + [dimension, values]): 320 # Flatten the values 321 values_shape = array_ops.shape(values) 322 values = array_ops.reshape(values, [-1, 1]) 323 324 if sampled_candidates is None: 325 if dimension is None: 326 raise ValueError( 327 "You must specify either dimension or sampled_candidates.") 328 if dimension <= 0: 329 raise ValueError("Dimension must be >0. Given is %d" % dimension) 330 sampled_candidates = array_ops.tile(array_ops.expand_dims( 331 math_ops.range(0, dimension), 0), array_ops.shape(values)) 332 else: 333 dimension = array_ops.shape(sampled_candidates)[ 334 math_ops.subtract(array_ops.rank(sampled_candidates), 1)] 335 sampled_candidates_shape = array_ops.shape(sampled_candidates) 336 dimension_tensor = array_ops.reshape(dimension, shape=[1,]) 337 expected_shape = array_ops.concat([values_shape, dimension_tensor], 0) 338 with ops.control_dependencies([control_flow_ops.Assert( 339 math_ops.reduce_all(math_ops.equal(sampled_candidates_shape, 340 expected_shape)), 341 ["The shape of sampled_candidates: ", sampled_candidates_shape, 342 " does not match the shape of values: ", values_shape])]): 343 # Flatten sampled_candidates, same way as values are flattened. 344 sampled_candidates = array_ops.reshape(sampled_candidates, 345 [-1, dimension]) 346 347 num_partitions = len(params) 348 partition_sizes = [] 349 for p in range(num_partitions): 350 shape = params[p].get_shape() 351 shape.assert_has_rank(1) 352 shape.assert_is_fully_defined() 353 partition_sizes.append(tensor_shape.dimension_value(shape[0])) 354 num_params = sum(partition_sizes) # Total number of parameters. 355 356 # Assert the size of each partition. 357 for p in range(num_partitions): 358 expected_size = (num_params - p - 1) // num_partitions + 1 359 if partition_sizes[p] != expected_size: 360 raise ValueError("Tensor %d in params has size %d, expected %d." % 361 (p, partition_sizes[p], expected_size)) 362 363 # With two values v1 and v2 and 3 dimensions, we will cross 364 # [[0, 1, 2], [0, 1, 2]] with [[v1], [v2]]. 365 tensors_to_cross = [sampled_candidates, values] 366 ids = sparse_feature_cross_op.sparse_feature_cross( 367 tensors_to_cross, hashed_output=True, num_buckets=num_params, 368 hash_key=hash_key) 369 ids = sparse_ops.sparse_tensor_to_dense(ids) 370 371 # No need to validate the indices since we have checked the params 372 # dimensions and we know the largest id. 373 result = embedding_ops.embedding_lookup( 374 params, ids, partition_strategy="div") 375 376 return array_ops.reshape(result, 377 array_ops.concat([values_shape, [dimension]], 0)) 378 379 380def scattered_embedding_lookup_sparse(params, 381 sparse_values, 382 dimension, 383 combiner=None, 384 default_value=None, 385 name=None, 386 hash_key=None): 387 """Looks up embeddings of a sparse feature using parameter hashing. 388 389 See `tf.contrib.layers.scattered_embedding_lookup` for embedding with hashing. 390 391 Args: 392 params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`. 393 Each tensor must be of rank 1 with fully-defined shape. 394 sparse_values: A 2-D `SparseTensor` containing the values to be embedded. 395 Some rows may be empty. 396 dimension: Embedding dimension 397 combiner: A string specifying how to combine embedding results for each 398 entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" 399 the default. 400 default_value: The value to use for an entry with no features. 401 name: An optional name for this op. 402 hash_key: Specify the hash_key that will be used by the `FingerprintCat64` 403 function to combine the crosses fingerprints on SparseFeatureCrossOp 404 (optional). 405 406 Returns: 407 Dense tensor with shape [N, dimension] with N the number of rows in 408 sparse_values. 409 410 Raises: 411 TypeError: If sparse_values is not a SparseTensor. 412 ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}. 413 """ 414 if combiner is None: 415 logging.warn("The default value of combiner will change from \"mean\" " 416 "to \"sqrtn\" after 2016/11/01.") 417 combiner = "mean" 418 if isinstance(params, variables.PartitionedVariable): 419 params = list(params) 420 if not isinstance(params, list): 421 params = [params] 422 if not isinstance(sparse_values, sparse_tensor.SparseTensor): 423 raise TypeError("sparse_values must be SparseTensor") 424 425 with ops.name_scope(name, "scattered_embedding_lookup_sparse", 426 params + [sparse_values]) as scope: 427 # Fill in the empty rows. 428 if default_value is None: 429 # Random default values to reduce the risk of collision. 430 if sparse_values.dtype == dtypes.string: 431 default_value = "6ZxWzWOHxZ" 432 else: 433 default_value = 1288896567 434 sparse_values, _ = sparse_ops.sparse_fill_empty_rows( 435 sparse_values, default_value) 436 437 segment_ids = sparse_values.indices[:, 0] 438 if segment_ids.dtype != dtypes.int32: 439 segment_ids = math_ops.cast(segment_ids, dtypes.int32) 440 441 values = sparse_values.values 442 values, idx = array_ops.unique(values) 443 444 embeddings = scattered_embedding_lookup( 445 params, values, dimension, hash_key=hash_key) 446 447 if combiner == "sum": 448 embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids, 449 name=scope) 450 elif combiner == "mean": 451 embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids, 452 name=scope) 453 elif combiner == "sqrtn": 454 embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx, segment_ids, 455 name=scope) 456 else: 457 raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.") 458 459 return embeddings 460 461 462def embedding_lookup_unique(params, ids, partition_strategy="mod", name=None): 463 """Version of embedding_lookup that avoids duplicate lookups. 464 465 This can save communication in the case of repeated ids. 466 Same interface as embedding_lookup. Except it supports multi-dimensional `ids` 467 which allows to not reshape input/output to fit gather. 468 469 Args: 470 params: A list of tensors with the same shape and type, or a 471 `PartitionedVariable`. Shape `[index, d1, d2, ...]`. 472 ids: A one-dimensional `Tensor` with type `int32` or `int64` containing 473 the ids to be looked up in `params`. Shape `[ids1, ids2, ...]`. 474 partition_strategy: A string specifying the partitioning strategy, relevant 475 if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default 476 is `"mod"`. 477 name: A name for this operation (optional). 478 479 Returns: 480 A `Tensor` with the same type as the tensors in `params` and dimension of 481 `[ids1, ids2, d1, d2, ...]`. 482 483 Raises: 484 ValueError: If `params` is empty. 485 """ 486 with ops.name_scope(name, "EmbeddingLookupUnique", [params, ids]): 487 ids = ops.convert_to_tensor(ids) 488 shape = array_ops.shape(ids) 489 ids_flat = array_ops.reshape( 490 ids, math_ops.reduce_prod(shape, keepdims=True)) 491 unique_ids, idx = array_ops.unique(ids_flat) 492 unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids, 493 partition_strategy) 494 embeds_flat = array_ops.gather(unique_embeddings, idx) 495 embed_shape = array_ops.concat( 496 [shape, array_ops.shape(unique_embeddings)[1:]], 0) 497 embeds = array_ops.reshape(embeds_flat, embed_shape) 498 embeds.set_shape(ids.get_shape().concatenate( 499 unique_embeddings.get_shape()[1:])) 500 return embeds 501 502 503def _sampled_scattered_embedding_lookup_sparse(params, 504 sp_values, 505 dimension=None, 506 sampled_candidates=None, 507 hash_key=None, 508 with_sign_hash=False, 509 name=None): 510 """Looks up embeddings using parameter hashing for sparse values. 511 512 This method looks up selected embedding dimensions if `sampled_candidates` is 513 given, otherwise looks up all dimensions. 514 515 The i-th embedding component of a value v in `values` is found by retrieving 516 the weight whose index is a fingerprint of the pair (v,i). 517 The concept is explored as "feature hashing" for model compression in this 518 paper: http://arxiv.org/pdf/1504.04788.pdf 519 520 This is logically equivalent to: 521 * Transforming `sp_values` (which has shape `[d0, d1]`) into a one-hot 522 `Tensor` of shape `[d0, N]`. 523 * Multiplying with a `Tensor` `h` of shape `[N, dimension]`, where 524 `h(i, j) = params[hash(i, j)]`. 525 526 Args: 527 params: A float `Tensor` with rank 1 and fully-defined shape. 528 sp_values: A 2D `SparseTensor` to be embedded with shape `[d0, d1]`. 529 dimension: An int `Tensor` of the final dimension. The user needs to provide 530 either `dimension` or `sampled_candidates`. 531 sampled_candidates: An optional `Tensor` of column indices to keep along 532 the final dimension with shape `[d0, N]`. If given, `dimension` is 533 ignored. If `None`, looks up all candidates. 534 hash_key: Specify the hash_key that will be used by the `FingerprintCat64` 535 function to combine the crosses fingerprints on SparseFeatureCrossOp 536 (optional). 537 with_sign_hash: A `bool` indicating whether `h(i, j)` should be multiplied 538 by `+1` or `-1`, where the value selected is determined by hashing 539 `(i, j)`. This is often necessary to remove bias resulting from hash 540 collisions. 541 name: An optional name for this op. 542 543 Returns: 544 A `Tensor` of shape `[d0, dimension]`. 545 If `sampled_candidates` is given, the output shape is `[d0, N]`. 546 547 Raises: 548 TypeError: If sp_values is not `SparseTensor`. 549 ValueError: If both `dimension` and `sampled_candidates` are `None`. 550 """ 551 if not isinstance(sp_values, sparse_tensor.SparseTensor): 552 raise TypeError("sp_values must be SparseTensor") 553 554 with ops.name_scope( 555 name=name, 556 default_name="sampled_scattered_embedding_lookup_sparse", 557 values=[sp_values, params, dimension, sampled_candidates]) as name_scope: 558 segment_ids = sp_values.indices[:, 0] 559 if sampled_candidates is not None: 560 # Tile sampled_candidates so there is one line corresponding to each 561 # element in sp_values.values 562 sampled_candidates = array_ops.gather(sampled_candidates, segment_ids) 563 564 embeddings = _sampled_scattered_embedding_lookup( 565 params, sp_values.values, dimension=dimension, 566 sampled_candidates=sampled_candidates, 567 hash_key=hash_key, name="values_lookup") 568 if with_sign_hash: 569 signs = _sampled_scattered_embedding_lookup( 570 array_ops.constant([-1., 1.]), sp_values.values, dimension=dimension, 571 sampled_candidates=sampled_candidates, hash_key=hash_key, 572 name="signs_lookup") 573 embeddings = math_ops.multiply(signs, embeddings, name="signs_hash") 574 575 if segment_ids.dtype != dtypes.int32: 576 segment_ids = math_ops.cast(segment_ids, dtypes.int32) 577 num_segments = array_ops.shape(sp_values)[0] 578 579 return math_ops.unsorted_segment_sum(embeddings, segment_ids, 580 num_segments=num_segments, 581 name=name_scope) 582 583 584def embedding_lookup_sparse_with_distributed_aggregation( 585 params, 586 sp_ids, 587 sp_weights, 588 partition_strategy="mod", 589 name=None, 590 combiner=None, 591 max_norm=None): 592 """Computes embeddings for the given ids and weights. 593 594 Embeddings belonging to same param are aggregated on that device first. This 595 op is intended to decrease data transmission and improve parallelism. See 596 `tf.nn.embedding_lookup_sparse` for the functionality and example of this op. 597 598 Args: 599 params: A single tensor representing the complete embedding tensor, 600 or a list of P tensors all of same shape except for the first dimension, 601 representing sharded embedding tensors. Alternatively, a 602 `PartitionedVariable`, created by partitioning along dimension 0. Each 603 element must be appropriately sized for the given `partition_strategy`. 604 sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId), 605 where N is typically batch size and M is arbitrary. 606 sp_weights: either a SparseTensor of float / double weights, or None to 607 indicate all weights should be taken to be 1. If specified, sp_weights 608 must have exactly the same shape and indices as sp_ids. 609 partition_strategy: A string specifying the partitioning strategy, relevant 610 if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default 611 is `"mod"`. See `tf.nn.embedding_lookup` for more details. 612 name: Optional name for the op. 613 combiner: A string specifying the reduction op. Currently "mean", "sqrtn" 614 and "sum" are supported. 615 "sum" computes the weighted sum of the embedding results for each row. 616 "mean" is the weighted sum divided by the total weight. 617 "sqrtn" is the weighted sum divided by the square root of the sum of the 618 squares of the weights. 619 max_norm: If not None, each embedding is normalized to have l2 norm equal 620 to max_norm before combining. 621 622 Returns: 623 A dense tensor representing the combined embeddings for the 624 sparse ids. For each row in the dense tensor represented by sp_ids, the op 625 looks up the embeddings for all ids in that row, multiplies them by the 626 corresponding weight, and combines these embeddings as specified. 627 628 Raises: 629 TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither 630 None nor SparseTensor. 631 ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}. 632 """ 633 if combiner is None: 634 logging.warn("The default value of combiner will change from \"mean\" " 635 "to \"sqrtn\" after 2016/11/01.") 636 combiner = "mean" 637 if combiner not in ("mean", "sqrtn", "sum"): 638 raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'") 639 if isinstance(params, variables.PartitionedVariable): 640 params = list(params) # Iterate to get the underlying Variables. 641 if not isinstance(params, list): 642 params = [params] 643 if not isinstance(sp_ids, sparse_tensor.SparseTensor): 644 raise TypeError("sp_ids must be SparseTensor") 645 ignore_weights = sp_weights is None 646 if not ignore_weights: 647 if not isinstance(sp_weights, sparse_tensor.SparseTensor): 648 raise TypeError("sp_weights must be either None or SparseTensor") 649 sp_ids.values.get_shape().assert_is_compatible_with( 650 sp_weights.values.get_shape()) 651 sp_ids.indices.get_shape().assert_is_compatible_with( 652 sp_weights.indices.get_shape()) 653 sp_ids.dense_shape.get_shape().assert_is_compatible_with( 654 sp_weights.dense_shape.get_shape()) 655 # TODO(yleon): Add enhanced node assertions to verify that sp_ids and 656 # sp_weights have equal indices and shapes. 657 658 with ops.name_scope(name, "embedding_lookup_sparse", 659 params + [sp_ids]) as name: 660 segment_ids = sp_ids.indices[:, 0] 661 if segment_ids.dtype != dtypes.int32: 662 segment_ids = math_ops.cast(segment_ids, dtypes.int32) 663 664 ids = sp_ids.values 665 if ignore_weights: 666 ids, idx = array_ops.unique(ids) 667 else: 668 idx = None 669 670 weights = None if ignore_weights else sp_weights.values 671 embeddings = _embedding_lookup_with_distributed_aggregation( 672 params, 673 ids, 674 partition_strategy=partition_strategy, 675 max_norm=max_norm, 676 weights=weights, 677 idx=idx, 678 segment_ids=segment_ids) 679 # Set weights to all one if ignore weights. 680 if ignore_weights: 681 weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1) 682 if weights.dtype != embeddings.dtype: 683 weights = math_ops.cast(weights, embeddings.dtype) 684 # Reshape weights. 685 ones = array_ops.fill( 686 array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1) 687 bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 0) 688 orig_weights_shape = weights.get_shape() 689 weights = array_ops.reshape(weights, bcast_weights_shape) 690 if embeddings.get_shape().ndims is not None: 691 weights.set_shape( 692 orig_weights_shape.concatenate( 693 [1 for _ in range(embeddings.get_shape().ndims - 1)])) 694 695 if combiner == "mean": 696 weight_sum = math_ops.segment_sum(weights, segment_ids) 697 embeddings = math_ops.div(embeddings, weight_sum) 698 elif combiner == "sqrtn": 699 weights_squared = math_ops.pow(weights, 2) 700 weight_sum = math_ops.segment_sum(weights_squared, segment_ids) 701 weight_sum_sqrt = math_ops.sqrt(weight_sum) 702 embeddings = math_ops.div(embeddings, weight_sum_sqrt) 703 elif combiner != "sum": 704 assert False, "Unrecognized combiner" 705 return embeddings 706 707 708def _do_gather(params, ids, name=None): 709 """Deals with doing gather differently for resource variables.""" 710 if isinstance(params, resource_variable_ops.ResourceVariable): 711 return params.sparse_read(ids, name=name) 712 return array_ops.gather(params, ids, name=name) 713 714 715def _embedding_lookup_with_distributed_aggregation(params, 716 ids, 717 partition_strategy="mod", 718 name=None, 719 max_norm=None, 720 weights=None, 721 idx=None, 722 segment_ids=None): 723 """Lookup helper for embedding_lookup_sparse_with_distributed_aggregation.""" 724 if params is None or params == []: # pylint: disable=g-explicit-bool-comparison 725 raise ValueError("Need at least one param") 726 if isinstance(params, variables.PartitionedVariable): 727 params = list(params) # Iterate to get the underlying Variables. 728 if not isinstance(params, list): 729 params = [params] 730 731 def maybe_normalize(x): 732 if max_norm is not None: 733 if x.get_shape().ndims is not None: 734 ndims = x.get_shape().ndims 735 else: 736 ndims = array_ops.size(array_ops.shape(x)) 737 return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims))) 738 return x 739 740 with ops.name_scope(name, "embedding_lookup_with_distributed_aggregation", 741 params + [ids]) as name: 742 np = len(params) # Number of partitions 743 # Preserve the resource variable status to avoid accidental dense reads. 744 if not any( 745 isinstance(p, resource_variable_ops.ResourceVariable) for p in params): 746 params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") 747 if np == 1: 748 with ops.colocate_with(params[0]): 749 ret = maybe_normalize(_do_gather(params[0], ids)) 750 ignore_weights = weights is None 751 if not ignore_weights: 752 if weights.dtype != ret.dtype: 753 weights = math_ops.cast(weights, ret.dtype) 754 # Reshape to allow broadcast 755 ones = array_ops.fill( 756 array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1) 757 bcast_weights_shape = array_ops.concat( 758 [array_ops.shape(weights), ones], 0) 759 orig_weights_shape = weights.get_shape() 760 weights = array_ops.reshape(weights, bcast_weights_shape) 761 # Set weights shape after reshape 762 if ret.get_shape().ndims is not None: 763 weights.set_shape( 764 orig_weights_shape.concatenate( 765 [1 for _ in range(ret.get_shape().ndims - 1)])) 766 ret *= weights 767 return math_ops.segment_sum(ret, segment_ids, name=name) 768 else: 769 return math_ops.sparse_segment_sum(ret, idx, segment_ids, name=name) 770 else: 771 ids = ops.convert_to_tensor(ids, name="ids") 772 flat_ids = array_ops.reshape(ids, [-1]) 773 original_indices = math_ops.range(array_ops.size(flat_ids)) 774 775 # Create p_assignments and set new_ids depending on the strategy. 776 if partition_strategy == "mod": 777 p_assignments = flat_ids % np 778 new_ids = flat_ids // np 779 elif partition_strategy == "div": 780 # Compute num_total_ids as the sum of dim-0 of params, then assign to 781 # partitions based on a constant number of ids per partition. Optimize 782 # if we already know the full shape statically. 783 dim_0_size = params[0].get_shape().dims[0] 784 for p in xrange(1, np): 785 dim_0_size += params[p].get_shape().dims[0] 786 if dim_0_size.value: 787 num_total_ids = constant_op.constant(dim_0_size, flat_ids.dtype) 788 else: 789 dim_0_sizes = [] 790 for p in xrange(np): 791 if params[p].get_shape().dims[0].value is not None: 792 dim_0_sizes.append(params[p].get_shape().dims[0].value) 793 else: 794 with ops.colocate_with(params[p]): 795 dim_0_sizes.append(array_ops.shape(params[p])[0]) 796 num_total_ids = math_ops.reduce_sum( 797 math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) 798 ids_per_partition = num_total_ids // np 799 extras = num_total_ids % np 800 801 p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1), ( 802 flat_ids - extras) // ids_per_partition) 803 804 # Emulate a conditional using a boolean indicator tensor 805 is_in_first_extras_partitions = math_ops.cast(p_assignments < extras, 806 flat_ids.dtype) 807 new_ids = (is_in_first_extras_partitions * (flat_ids % 808 (ids_per_partition + 1)) + 809 (1 - is_in_first_extras_partitions) * ( 810 (flat_ids - extras) % ids_per_partition)) 811 else: 812 raise ValueError("Unrecognized partition strategy: " + 813 partition_strategy) 814 815 # Cast partition assignments to int32 for use in dynamic_partition. 816 # There really should not be more than 2^32 partitions. 817 p_assignments = math_ops.cast(p_assignments, dtypes.int32) 818 # Partition list of ids based on assignments into np separate lists 819 gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np) 820 # Similarly, partition the original indices. 821 pindices = data_flow_ops.dynamic_partition(original_indices, 822 p_assignments, np) 823 # Do np separate lookups, finding embeddings for plist[p] in params[p] 824 partitioned_result = [] 825 for p in xrange(np): 826 with ops.colocate_with(params[p]): 827 partitioned_result.append(_do_gather(params[p], gather_ids[p])) 828 829 ignore_weights = weights is None 830 if not ignore_weights: 831 # Partition weights according to pindices. 832 partitioned_weight = [] 833 for p in xrange(np): 834 partitioned_weight.append(array_ops.gather(weights, pindices[p])) 835 # Reshape each partition result. 836 element_shape = params[0].get_shape()[1:] 837 for p in params[1:]: 838 element_shape = element_shape.merge_with(p.get_shape()[1:]) 839 if element_shape.is_fully_defined(): 840 for p in xrange(np): 841 with ops.colocate_with(params[p]): 842 partitioned_result[p] = array_ops.reshape( 843 partitioned_result[p], 844 array_ops.concat([array_ops.shape(pindices[p]), element_shape], 845 0)) 846 else: 847 with ops.colocate_with(params[0]): 848 params_shape = array_ops.shape(params[0]) 849 for p in xrange(np): 850 with ops.colocate_with(params[p]): 851 partitioned_result[p] = array_ops.reshape( 852 partitioned_result[p], 853 array_ops.concat([ 854 array_ops.shape(pindices[p]), array_ops.slice( 855 params_shape, [1], [-1]) 856 ], 0)) 857 # Normalize each partition result. 858 for p in xrange(np): 859 with ops.colocate_with(params[p]): 860 partitioned_result[p] = maybe_normalize(partitioned_result[p]) 861 if not ignore_weights: 862 # Multiply each partition result with partition weights. 863 for p in xrange(np): 864 with ops.colocate_with(params[p]): 865 if partitioned_weight[p].dtype != partitioned_result[p].dtype: 866 partitioned_weight[p] = math_ops.cast(partitioned_weight[p], 867 partitioned_result[p].dtype) 868 # Reshape partition weights. 869 ones = array_ops.fill( 870 array_ops.expand_dims( 871 array_ops.rank(partitioned_result[p]) - 1, 0), 1) 872 bcast_weights_shape = array_ops.concat( 873 [array_ops.shape(partitioned_weight[p]), ones], 0) 874 orig_weights_shape = partitioned_weight[p].get_shape() 875 partitioned_weight[p] = array_ops.reshape(partitioned_weight[p], 876 bcast_weights_shape) 877 if partitioned_result[p].get_shape().ndims is not None: 878 partitioned_weight[p].set_shape( 879 orig_weights_shape.concatenate([ 880 1 881 for _ in range(partitioned_result[p].get_shape().ndims - 882 1) 883 ])) 884 partitioned_result[p] *= partitioned_weight[p] 885 partitioned_segment_ids = [] 886 for p in xrange(np): 887 if not ignore_weights: 888 # Partition segment_ids according to pindices. 889 p_segment_ids = array_ops.gather(segment_ids, pindices[p]) 890 # Number the p_segment_ids to meet segment_sum's requirements. Note 891 # that unique_p_segment_ids contains unique segment ids of this 892 # partition and these ids' order is unchanged. 893 unique_p_segment_ids, unique_p_segment_idx = array_ops.unique( 894 p_segment_ids) 895 partitioned_segment_ids.append(unique_p_segment_ids) 896 # segment_sum this partition's result. 897 with ops.colocate_with(params[p]): 898 partitioned_result[p] = math_ops.segment_sum( 899 partitioned_result[p], unique_p_segment_idx) 900 else: 901 # When ignore weights, we need to get indexs of elements in idx and 902 # segment_ids. 903 _, exclude_idx = array_ops.setdiff1d(idx, pindices[p]) 904 all_idx = math_ops.range(array_ops.shape(idx)[0]) 905 _, include_idx = array_ops.setdiff1d(all_idx, exclude_idx) 906 # Gather segment_ids and idx according to indexs. 907 p_segment_ids = array_ops.gather(segment_ids, include_idx) 908 p_idx = array_ops.gather(idx, include_idx) 909 # Number the p_segment_ids, same as ignore_weights case above. 910 unique_p_segment_ids, unique_p_segment_idx = array_ops.unique( 911 p_segment_ids) 912 _, unique_p_idx_idx = array_ops.unique(p_idx) 913 partitioned_segment_ids.append(unique_p_segment_ids) 914 with ops.colocate_with(params[p]): 915 partitioned_result[p] = math_ops.sparse_segment_sum( 916 partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx) 917 # Concat each partition's segment_ids and result for final segment_sum. 918 concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0) 919 concat_partitioned_result = array_ops.concat(partitioned_result, 0) 920 return math_ops.unsorted_segment_sum( 921 concat_partitioned_result, 922 concat_segment_ids, 923 math_ops.reduce_max(concat_segment_ids) + 1, 924 name=name) 925