1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operations for embeddings.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from six.moves import xrange # pylint: disable=redefined-builtin 21 22from tensorflow.python.compat import compat 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import sparse_tensor 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import clip_ops 30# Imports gradient definitions. 31from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import 32from tensorflow.python.ops import data_flow_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import resource_variable_ops 35from tensorflow.python.ops import sparse_ops 36from tensorflow.python.ops import variables 37from tensorflow.python.ops.ragged import ragged_functional_ops 38from tensorflow.python.ops.ragged import ragged_tensor 39from tensorflow.python.util import dispatch 40from tensorflow.python.util.tf_export import tf_export 41 42 43def _clip(params, ids, max_norm): 44 """Helper function for _embedding_lookup_and_transform. 45 46 This function optionally clips embeddings to an l2-norm of max_norm. 47 48 Args: 49 params: A `Tensor` of embeddings retrieved by `gather`. 50 ids: The `ids` argument that was passed to `gather`. 51 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 52 than this value. 53 54 Returns: 55 A `Tensor` with the same type as `params`. 56 """ 57 58 def _rank(x): 59 """Helper function to retrieve the rank of a tensor. 60 61 Args: 62 x: Something convertible to `Tensor`. 63 64 Returns: 65 Either a pair `(rank, True)` where `rank` is an integer or a pair 66 `(rank, False)` where `rank` is an integer `Tensor`. In either case, 67 `rank` is the rank of `x`. 68 """ 69 rank = ops.convert_to_tensor(x).get_shape().ndims 70 if rank: 71 return rank, True 72 else: 73 return array_ops.rank(x), False 74 75 if max_norm is None: 76 return params 77 ids_rank, ids_static = _rank(ids) 78 params_rank, params_static = _rank(params) 79 return clip_ops.clip_by_norm( 80 params, 81 max_norm, 82 axes=(list(range(ids_rank, params_rank)) if ids_static and params_static 83 else math_ops.range(ids_rank, params_rank))) 84 85 86def _embedding_lookup_and_transform(params, 87 ids, 88 partition_strategy="mod", 89 name=None, 90 max_norm=None, 91 transform_fn=None): 92 """Helper function for embedding_lookup and _compute_sampled_logits. 93 94 This function is a generalization of embedding_lookup that optionally 95 applies a caller-specified transformation to each embedding. This is 96 done through the `transform_fn` argument. If provided, the function is 97 applied to each partitioned tensor of retrieved embeddings, colocated 98 with the embeddings. This function will be called with a single `Tensor` 99 argument of the same type as the `params` tensor and should return a 100 `Tensor`. The shape of the argument will be the same as `params` except 101 for the size of the first dimension. The first dimension of the result's 102 shape must be the same size as the argument's. 103 104 Args: 105 params: See embedding_lookup. 106 ids: See embedding_lookup. 107 partition_strategy: See embedding_lookup. 108 name: See embedding_lookup. 109 max_norm: See embedding_lookup. 110 transform_fn: An optional function to apply to each retrieved embedding. If 111 max_norm is provided, transform_fn is applied to the norm-limited 112 embeddings. 113 114 Returns: 115 See embedding_lookup for details. 116 Raises: 117 ValueError: If `params` is empty. 118 """ 119 if params is None: 120 raise ValueError("params must be specified") 121 if isinstance(params, (list, tuple)) and not params: 122 raise ValueError("Need at least one param") 123 if isinstance(params, variables.PartitionedVariable): 124 params = list(params) # Iterate to get the underlying Variables. 125 if not isinstance(params, list): 126 params = [params] 127 128 with ops.name_scope(name, "embedding_lookup", params + [ids]) as name: 129 np = len(params) # Number of partitions 130 # Preserve the resource variable status to avoid accidental dense reads. 131 if not any( 132 isinstance(p, resource_variable_ops.ResourceVariable) for p in params): 133 params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") 134 ids = ops.convert_to_tensor(ids, name="ids") 135 if np == 1 and (not transform_fn or ids.get_shape().ndims == 1): 136 with ops.colocate_with(params[0]): 137 result = _clip( 138 array_ops.gather(params[0], ids, name=name), ids, max_norm) 139 if transform_fn: 140 result = transform_fn(result) 141 # Make sure the final result does not have colocation constraints on the 142 # params. Similar to the case np > 1 where parallel_dynamic_stitch is 143 # outside the scioe of all with ops.colocate_with(params[p]). 144 return array_ops.identity(result) 145 else: 146 # Flatten the ids. There are two cases where we need to do this. 147 # - There is more than one params tensor. 148 # - There is a transform_fn and ids is not statically known to be 1-D. 149 # We must flatten in this case because transform_fn expects a flat 150 # tensor of embeddings. 151 flat_ids = array_ops.reshape(ids, [-1]) 152 original_indices = math_ops.range(array_ops.size(flat_ids)) 153 154 # Create p_assignments and set new_ids depending on the strategy. 155 if partition_strategy == "mod": 156 p_assignments = flat_ids % np 157 new_ids = flat_ids // np 158 elif partition_strategy == "div": 159 # Compute num_total_ids as the sum of dim-0 of params, then assign to 160 # partitions based on a constant number of ids per partition. Optimize 161 # if we already know the full shape statically. 162 dim_0_size = tensor_shape.Dimension( 163 tensor_shape.dimension_value(params[0].get_shape()[0])) 164 for p in xrange(1, np): 165 dim_0_size += tensor_shape.Dimension( 166 tensor_shape.dimension_value(params[p].get_shape()[0])) 167 if dim_0_size.value: 168 num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype) 169 else: 170 dim_0_sizes = [] 171 for p in xrange(np): 172 param_p_dim = tensor_shape.dimension_value(params[p].get_shape()[0]) 173 if param_p_dim is not None: 174 dim_0_sizes.append(param_p_dim) 175 else: 176 with ops.colocate_with(params[p]): 177 dim_0_sizes.append(array_ops.shape(params[p])[0]) 178 num_total_ids = math_ops.reduce_sum( 179 math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) 180 ids_per_partition = num_total_ids // np 181 extras = num_total_ids % np 182 183 p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1), 184 (flat_ids - extras) // 185 ids_per_partition) 186 187 # Emulate a conditional using a boolean indicator tensor 188 new_ids = array_ops.where(p_assignments < extras, 189 flat_ids % (ids_per_partition + 1), 190 (flat_ids - extras) % ids_per_partition) 191 else: 192 raise ValueError("Unrecognized partition strategy: " + 193 partition_strategy) 194 195 # Cast partition assignments to int32 for use in dynamic_partition. 196 # There really should not be more than 2^32 partitions. 197 p_assignments = math_ops.cast(p_assignments, dtypes.int32) 198 # Partition list of ids based on assignments into np separate lists 199 gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np) 200 # Similarly, partition the original indices. 201 pindices = data_flow_ops.dynamic_partition(original_indices, 202 p_assignments, np) 203 # Do np separate lookups, finding embeddings for plist[p] in params[p] 204 partitioned_result = [] 205 for p in xrange(np): 206 pids = gather_ids[p] 207 with ops.colocate_with(params[p]): 208 result = array_ops.gather(params[p], pids) 209 if transform_fn: 210 # If transform_fn is provided, the clip_by_norm precedes 211 # the transform and hence must be co-located. See below 212 # for the counterpart if transform_fn is not provided. 213 result = transform_fn(_clip(result, pids, max_norm)) 214 partitioned_result.append(result) 215 # Stitch these back together 216 ret = data_flow_ops.parallel_dynamic_stitch( 217 pindices, partitioned_result, name=name) 218 219 # Determine the static element shape. 220 if transform_fn is None: 221 element_shape_s = params[0].get_shape()[1:] 222 for p in params[1:]: 223 element_shape_s = element_shape_s.merge_with(p.get_shape()[1:]) 224 else: 225 element_shape_s = ret.get_shape()[1:] 226 227 # Compute the dynamic element shape. 228 if element_shape_s.is_fully_defined(): 229 element_shape_d = element_shape_s 230 elif transform_fn is None: 231 # It's important that we compute params[0].shape on the right device 232 # to avoid data motion. 233 with ops.colocate_with(params[0]): 234 params_shape = array_ops.shape(params[0]) 235 element_shape_d = params_shape[1:] 236 else: 237 element_shape_d = array_ops.shape(ret)[1:] 238 239 # Reshape to reverse the flattening of ids. 240 ret = array_ops.reshape( 241 ret, array_ops.concat([array_ops.shape(ids), element_shape_d], 0)) 242 243 # Normally the reshape is sufficient, but setting shape explicitly 244 # teaches shape inference that params[1:].get_shape() matters 245 # (in the case that transform_fn is None). 246 ret.set_shape(ids.get_shape().concatenate(element_shape_s)) 247 if not transform_fn: 248 # If transform_fn was provided, the clip_by_norm was done above. 249 ret = _clip(ret, ids, max_norm) 250 return ret 251 252 253@tf_export(v1=["nn.embedding_lookup"]) 254@dispatch.add_dispatch_support 255def embedding_lookup( 256 params, 257 ids, 258 partition_strategy="mod", 259 name=None, 260 validate_indices=True, # pylint: disable=unused-argument 261 max_norm=None): 262 """Looks up embeddings for the given `ids` from a list of tensors. 263 264 This function is used to perform parallel lookups on the list of tensors in 265 `params`. It is a generalization of `tf.gather`, where `params` is 266 interpreted as a partitioning of a large embedding tensor. `params` may be 267 a `PartitionedVariable` as returned by using `tf.compat.v1.get_variable()` 268 with a partitioner. 269 270 If `len(params) > 1`, each element `id` of `ids` is partitioned between 271 the elements of `params` according to the `partition_strategy`. 272 In all strategies, if the id space does not evenly divide the number of 273 partitions, each of the first `(max_id + 1) % len(params)` partitions will 274 be assigned one more id. 275 276 If `partition_strategy` is `"mod"`, we assign each id to partition 277 `p = id % len(params)`. For instance, 278 13 ids are split across 5 partitions as: 279 `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]` 280 281 If `partition_strategy` is `"div"`, we assign ids to partitions in a 282 contiguous manner. In this case, 13 ids are split across 5 partitions as: 283 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]` 284 285 If the input ids are ragged tensors, partition variables are not supported and 286 the partition strategy and the max_norm are ignored. 287 The results of the lookup are concatenated into a dense 288 tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. 289 290 Args: 291 params: A single tensor representing the complete embedding tensor, or a 292 list of P tensors all of same shape except for the first dimension, 293 representing sharded embedding tensors. Alternatively, a 294 `PartitionedVariable`, created by partitioning along dimension 0. Each 295 element must be appropriately sized for the given `partition_strategy`. 296 ids: A `Tensor` or a 'RaggedTensor' with type `int32` or `int64` containing 297 the ids to be looked up in `params`. 298 partition_strategy: A string specifying the partitioning strategy, relevant 299 if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default 300 is `"mod"`. 301 name: A name for the operation (optional). 302 validate_indices: DEPRECATED. If this operation is assigned to CPU, values 303 in `indices` are always validated to be within range. If assigned to GPU, 304 out-of-bound indices result in safe but unspecified behavior, which may 305 include raising an error. 306 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 307 than this value. 308 309 Returns: 310 A `Tensor` or a 'RaggedTensor', depending on the input, with the same type 311 as the tensors in `params`. 312 313 Raises: 314 ValueError: If `params` is empty. 315 """ 316 if isinstance(ids, ragged_tensor.RaggedTensor): 317 return embedding_lookup_ragged(params, ids, 318 partition_strategy=partition_strategy, 319 max_norm=max_norm, 320 name=name) 321 322 return _embedding_lookup_and_transform( 323 params=params, 324 ids=ids, 325 partition_strategy=partition_strategy, 326 name=name, 327 max_norm=max_norm, 328 transform_fn=None) 329 330 331@tf_export("nn.embedding_lookup", v1=[]) 332@dispatch.add_dispatch_support 333def embedding_lookup_v2(params, ids, max_norm=None, name=None): 334 """Looks up embeddings for the given `ids` from a list of tensors. 335 336 This function is used to perform parallel lookups on the list of tensors in 337 `params`. It is a generalization of `tf.gather`, where `params` is 338 interpreted as a partitioning of a large embedding tensor. 339 340 If `len(params) > 1`, each element `id` of `ids` is partitioned between the 341 elements of `params` according to the "div" partition strategy, which means we 342 assign ids to partitions in a contiguous manner. For instance, 13 ids are 343 split across 5 partitions as: 344 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. 345 346 If the id space does not evenly divide the number of partitions, each of the 347 first `(max_id + 1) % len(params)` partitions will be assigned one more id. 348 349 The results of the lookup are concatenated into a dense 350 tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. 351 352 Args: 353 params: A single tensor representing the complete embedding tensor, or a 354 list of tensors all of same shape except for the first dimension, 355 representing sharded embedding tensors following "div" partition strategy. 356 ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked 357 up in `params`. 358 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 359 than this value. 360 name: A name for the operation (optional). 361 362 Returns: 363 A `Tensor` with the same type as the tensors in `params`. 364 365 For instance, if `params` is a 5x2 matrix: 366 367 ```python 368 [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] 369 ``` 370 371 or a list of matrices: 372 373 ```python 374 params[0]: [[1, 2], [3, 4]] 375 params[1]: [[5, 6], [7, 8]] 376 params[2]: [[9, 10]] 377 ``` 378 379 and `ids` is: 380 381 ```python 382 [0, 3, 4] 383 ``` 384 385 The output will be a 3x2 matrix: 386 387 ```python 388 [[1, 2], [7, 8], [9, 10]] 389 ``` 390 391 Raises: 392 ValueError: If `params` is empty. 393 """ 394 return embedding_lookup(params, ids, "div", name, max_norm=max_norm) 395 396 397@tf_export(v1=["nn.embedding_lookup_sparse"]) 398@dispatch.add_dispatch_support 399def embedding_lookup_sparse(params, 400 sp_ids, 401 sp_weights, 402 partition_strategy="mod", 403 name=None, 404 combiner=None, 405 max_norm=None): 406 """Looks up embeddings for the given ids and weights from a list of tensors. 407 408 This op assumes that there is at least one id for each row in the dense tensor 409 represented by sp_ids (i.e. there are no rows with empty features), and that 410 all the indices of sp_ids are in canonical row-major order. 411 412 `sp_ids` and `sp_weights` (if not None) are `SparseTensor`s with rank of 2. 413 Embeddings are always aggregated along the last dimension. 414 415 It also assumes that all id values lie in the range [0, p0), where p0 416 is the sum of the size of params along dimension 0. 417 418 Args: 419 params: A single tensor representing the complete embedding tensor, or a 420 list tensors all of same shape except for the first dimension, 421 representing sharded embedding tensors. Alternatively, a 422 `PartitionedVariable`, created by partitioning along dimension 0. Each 423 element must be appropriately sized for the given `partition_strategy`. 424 sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size 425 and M is arbitrary. 426 sp_weights: either a `SparseTensor` of float / double weights, or `None` to 427 indicate all weights should be taken to be 1. If specified, `sp_weights` 428 must have exactly the same shape and indices as `sp_ids`. 429 partition_strategy: A string specifying the partitioning strategy, relevant 430 if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default 431 is `"mod"`. See `tf.nn.embedding_lookup` for more details. 432 name: Optional name for the op. 433 combiner: A string specifying the reduction op. Currently "mean", "sqrtn" 434 and "sum" are supported. "sum" computes the weighted sum of the embedding 435 results for each row. "mean" is the weighted sum divided by the total 436 weight. "sqrtn" is the weighted sum divided by the square root of the sum 437 of the squares of the weights. Defaults to `mean`. 438 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 439 than this value, before combining. 440 441 Returns: 442 A dense tensor representing the combined embeddings for the 443 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op 444 looks up the embeddings for all ids in that row, multiplies them by the 445 corresponding weight, and combines these embeddings as specified. 446 447 In other words, if 448 449 `shape(combined params) = [p0, p1, ..., pm]` 450 451 and 452 453 `shape(sp_ids) = shape(sp_weights) = [d0, d1]` 454 455 then 456 457 `shape(output) = [d0, p1, ..., pm]`. 458 459 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 460 461 ```python 462 [0, 0]: id 1, weight 2.0 463 [0, 1]: id 3, weight 0.5 464 [1, 0]: id 0, weight 1.0 465 [2, 3]: id 1, weight 3.0 466 ``` 467 468 with `combiner`="mean", then the output will be a 3x20 matrix where 469 470 ```python 471 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 472 output[1, :] = (params[0, :] * 1.0) / 1.0 473 output[2, :] = (params[1, :] * 3.0) / 3.0 474 ``` 475 476 Raises: 477 TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is 478 neither `None` nor `SparseTensor`. 479 ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}. 480 """ 481 if combiner is None: 482 combiner = "mean" 483 if combiner not in ("mean", "sqrtn", "sum"): 484 raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'") 485 if isinstance(params, variables.PartitionedVariable): 486 params = list(params) # Iterate to get the underlying Variables. 487 if not isinstance(params, list): 488 params = [params] 489 if not isinstance(sp_ids, sparse_tensor.SparseTensor): 490 raise TypeError("sp_ids must be SparseTensor") 491 ignore_weights = sp_weights is None 492 if not ignore_weights: 493 if not isinstance(sp_weights, sparse_tensor.SparseTensor): 494 raise TypeError("sp_weights must be either None or SparseTensor") 495 sp_ids.values.get_shape().assert_is_compatible_with( 496 sp_weights.values.get_shape()) 497 sp_ids.indices.get_shape().assert_is_compatible_with( 498 sp_weights.indices.get_shape()) 499 sp_ids.dense_shape.get_shape().assert_is_compatible_with( 500 sp_weights.dense_shape.get_shape()) 501 # TODO(yleon): Add enhanced node assertions to verify that sp_ids and 502 # sp_weights have equal indices and shapes. 503 504 with ops.name_scope(name, "embedding_lookup_sparse", 505 params + [sp_ids]) as name: 506 segment_ids = sp_ids.indices[:, 0] 507 508 ids = sp_ids.values 509 ids, idx = array_ops.unique(ids) 510 511 embeddings = embedding_lookup( 512 params, ids, partition_strategy=partition_strategy, max_norm=max_norm) 513 if embeddings.dtype in (dtypes.float16, dtypes.bfloat16): 514 embeddings = math_ops.cast(embeddings, dtypes.float32) 515 if not ignore_weights: 516 if segment_ids.dtype != dtypes.int32: 517 segment_ids = math_ops.cast(segment_ids, dtypes.int32) 518 519 weights = sp_weights.values 520 if weights.dtype != embeddings.dtype: 521 weights = math_ops.cast(weights, embeddings.dtype) 522 523 embeddings = array_ops.gather(embeddings, idx) 524 525 # Reshape weights to allow broadcast 526 ones = array_ops.fill( 527 array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1) 528 bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 529 0) 530 531 orig_weights_shape = weights.get_shape() 532 weights = array_ops.reshape(weights, bcast_weights_shape) 533 534 # Set the weight shape, since after reshaping to bcast_weights_shape, 535 # the shape becomes None. 536 if embeddings.get_shape().ndims is not None: 537 weights.set_shape( 538 orig_weights_shape.concatenate( 539 [1 for _ in range(embeddings.get_shape().ndims - 1)])) 540 541 embeddings *= weights 542 543 if combiner == "sum": 544 embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name) 545 elif combiner == "mean": 546 embeddings = math_ops.segment_sum(embeddings, segment_ids) 547 weight_sum = math_ops.segment_sum(weights, segment_ids) 548 embeddings = math_ops.divide(embeddings, weight_sum, name=name) 549 elif combiner == "sqrtn": 550 embeddings = math_ops.segment_sum(embeddings, segment_ids) 551 weights_squared = math_ops.pow(weights, 2) 552 weight_sum = math_ops.segment_sum(weights_squared, segment_ids) 553 weight_sum_sqrt = math_ops.sqrt(weight_sum) 554 embeddings = math_ops.divide(embeddings, weight_sum_sqrt, name=name) 555 else: 556 assert False, "Unrecognized combiner" 557 else: 558 if compat.forward_compatible(2020, 5, 14): 559 if segment_ids.dtype not in (dtypes.int32, dtypes.int64): 560 segment_ids = math_ops.cast(segment_ids, dtypes.int32) 561 else: 562 if segment_ids.dtype != dtypes.int32: 563 segment_ids = math_ops.cast(segment_ids, dtypes.int32) 564 assert idx is not None 565 if combiner == "sum": 566 embeddings = math_ops.sparse_segment_sum( 567 embeddings, idx, segment_ids, name=name) 568 elif combiner == "mean": 569 embeddings = math_ops.sparse_segment_mean( 570 embeddings, idx, segment_ids, name=name) 571 elif combiner == "sqrtn": 572 embeddings = math_ops.sparse_segment_sqrt_n( 573 embeddings, idx, segment_ids, name=name) 574 else: 575 assert False, "Unrecognized combiner" 576 577 return embeddings 578 579 580@tf_export("nn.embedding_lookup_sparse", v1=[]) 581@dispatch.add_dispatch_support 582def embedding_lookup_sparse_v2(params, 583 sp_ids, 584 sp_weights, 585 combiner=None, 586 max_norm=None, 587 name=None): 588 """Looks up embeddings for the given ids and weights from a list of tensors. 589 590 This op assumes that there is at least one id for each row in the dense tensor 591 represented by sp_ids (i.e. there are no rows with empty features), and that 592 all the indices of sp_ids are in canonical row-major order. 593 594 `sp_ids` and `sp_weights` (if not None) are `SparseTensor`s with rank of 2. 595 Embeddings are always aggregated along the last dimension. 596 597 It also assumes that all id values lie in the range [0, p0), where p0 598 is the sum of the size of params along dimension 0. 599 600 If `len(params) > 1`, each element of `sp_ids` is partitioned between the 601 elements of `params` according to the "div" partition strategy, which means we 602 assign ids to partitions in a contiguous manner. For instance, 13 ids are 603 split across 5 partitions as: 604 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. 605 606 If the id space does not evenly divide the number of partitions, each of the 607 first `(max_id + 1) % len(params)` partitions will be assigned one more id. 608 609 Args: 610 params: A single tensor representing the complete embedding tensor, or a 611 list of tensors all of same shape except for the first dimension, 612 representing sharded embedding tensors following "div" partition strategy. 613 sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size 614 and M is arbitrary. 615 sp_weights: either a `SparseTensor` of float / double weights, or `None` to 616 indicate all weights should be taken to be 1. If specified, `sp_weights` 617 must have exactly the same shape and indices as `sp_ids`. 618 combiner: A string specifying the reduction op. Currently "mean", "sqrtn" 619 and "sum" are supported. "sum" computes the weighted sum of the embedding 620 results for each row. "mean" is the weighted sum divided by the total 621 weight. "sqrtn" is the weighted sum divided by the square root of the sum 622 of the squares of the weights. Defaults to `mean`. 623 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 624 than this value, before combining. 625 name: Optional name for the op. 626 627 Returns: 628 A dense tensor representing the combined embeddings for the 629 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op 630 looks up the embeddings for all ids in that row, multiplies them by the 631 corresponding weight, and combines these embeddings as specified. 632 633 In other words, if 634 635 `shape(combined params) = [p0, p1, ..., pm]` 636 637 and 638 639 `shape(sp_ids) = shape(sp_weights) = [d0, d1]` 640 641 then 642 643 `shape(output) = [d0, p1, ..., pm]`. 644 645 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 646 647 ```python 648 [0, 0]: id 1, weight 2.0 649 [0, 1]: id 3, weight 0.5 650 [1, 0]: id 0, weight 1.0 651 [2, 3]: id 1, weight 3.0 652 ``` 653 654 with `combiner`="mean", then the output will be a 3x20 matrix where 655 656 ```python 657 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 658 output[1, :] = (params[0, :] * 1.0) / 1.0 659 output[2, :] = (params[1, :] * 3.0) / 3.0 660 ``` 661 662 Raises: 663 TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is 664 neither `None` nor `SparseTensor`. 665 ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}. 666 """ 667 return embedding_lookup_sparse(params, sp_ids, sp_weights, "div", name, 668 combiner, max_norm) 669 670 671@tf_export("nn.safe_embedding_lookup_sparse", v1=[]) 672@dispatch.add_dispatch_support 673def safe_embedding_lookup_sparse_v2(embedding_weights, 674 sparse_ids, 675 sparse_weights=None, 676 combiner="mean", 677 default_id=None, 678 max_norm=None, 679 name=None): 680 """Lookup embedding results, accounting for invalid IDs and empty features. 681 682 The partitioned embedding in `embedding_weights` must all be the same shape 683 except for the first dimension. The first dimension is allowed to vary as the 684 vocabulary size is not necessarily a multiple of num of shards. 685 686 Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs 687 with non-positive weight. For an entry with no features, the embedding vector 688 for `default_id` is returned, or the 0-vector if `default_id` is not supplied. 689 690 The ids and weights may be multi-dimensional. Embeddings are always aggregated 691 along the last dimension. 692 693 If `len(embedding_weights) > 1`, each element `id` of `ids` is partitioned 694 between the elements of `embedding_weights` according to the "div" partition 695 strategy, which means we assign ids to partitions in a contiguous manner. For 696 instance, 13 ids are split across 5 partitions as: 697 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. 698 699 If the id space does not evenly divide the number of partitions, each of the 700 first `(max_id + 1) % len(embedding_weights)` partitions will be assigned one 701 more id. 702 703 Args: 704 embedding_weights: A single tensor representing the complete embedding 705 tensor, or a list of tensors all of same shape except for the first 706 dimension, representing sharded embedding tensors following "div" 707 partition strategy. 708 sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the 709 ids. `d_0` is typically batch size. 710 sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing 711 float weights corresponding to `sparse_ids`, or `None` if all weights are 712 be assumed to be 1.0. 713 combiner: A string specifying how to combine embedding results for each 714 entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the 715 default. 716 default_id: The id to use for an entry with no features. Defaults to 717 0-vector. 718 max_norm: If not `None`, all embeddings are l2-normalized to max_norm before 719 combining. 720 name: A name for this operation (optional). 721 722 Returns: 723 A dense tensor representing the combined embeddings for the 724 sparse ids. For each row in the dense tensor represented by `sparse_ids`, 725 the op looks up the embeddings for all ids in that row, multiplies them by 726 the corresponding weight, and combines these embeddings as specified. 727 728 In other words, if 729 730 `shape(combined embedding_weights) = [p0, p1, ..., pm]` 731 732 and 733 734 `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]` 735 736 then 737 738 `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`. 739 740 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 741 742 ```python 743 [0, 0]: id 1, weight 2.0 744 [0, 1]: id 3, weight 0.5 745 [1, 0]: id -1, weight 1.0 746 [2, 3]: id 1, weight 3.0 747 ``` 748 749 `default_id` is 0. 750 751 with `combiner`="mean", then the output will be a 3x20 matrix where 752 753 ```python 754 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 755 output[1, :] = (params[0, :] * 1.0) / 1.0 756 output[2, :] = (params[1, :] * 3.0) / 3.0 757 ``` 758 759 Raises: 760 ValueError: if `embedding_weights` is empty. 761 """ 762 return safe_embedding_lookup_sparse( 763 embedding_weights, 764 sparse_ids, 765 sparse_weights=sparse_weights, 766 combiner=combiner, 767 default_id=default_id, 768 name=name, 769 partition_strategy="div", 770 max_norm=max_norm) 771 772 773@tf_export(v1=["nn.safe_embedding_lookup_sparse"]) 774@dispatch.add_dispatch_support 775def safe_embedding_lookup_sparse(embedding_weights, 776 sparse_ids, 777 sparse_weights=None, 778 combiner="mean", 779 default_id=None, 780 name=None, 781 partition_strategy="div", 782 max_norm=None): 783 """Lookup embedding results, accounting for invalid IDs and empty features. 784 785 The partitioned embedding in `embedding_weights` must all be the same shape 786 except for the first dimension. The first dimension is allowed to vary as the 787 vocabulary size is not necessarily a multiple of `P`. `embedding_weights` 788 may be a `PartitionedVariable` as returned by using 789 `tf.compat.v1.get_variable()` with a 790 partitioner. 791 792 Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs 793 with non-positive weight. For an entry with no features, the embedding vector 794 for `default_id` is returned, or the 0-vector if `default_id` is not supplied. 795 796 The ids and weights may be multi-dimensional. Embeddings are always aggregated 797 along the last dimension. 798 799 Args: 800 embedding_weights: A single tensor representing the complete embedding 801 tensor, or a list tensors all of same shape except for the first 802 dimension, representing sharded embedding tensors. Alternatively, a 803 `PartitionedVariable`, created by partitioning along dimension 0. Each 804 element must be appropriately sized for the given `partition_strategy`. 805 sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the 806 ids. `d_0` is typically batch size. 807 sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing 808 float weights corresponding to `sparse_ids`, or `None` if all weights are 809 be assumed to be 1.0. 810 combiner: A string specifying how to combine embedding results for each 811 entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the 812 default. 813 default_id: The id to use for an entry with no features. 814 name: A name for this operation (optional). 815 partition_strategy: A string specifying the partitioning strategy. Currently 816 `"div"` and `"mod"` are supported. Default is `"div"`. 817 max_norm: If not `None`, all embeddings are l2-normalized to max_norm before 818 combining. 819 820 Returns: 821 A dense tensor representing the combined embeddings for the 822 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op 823 looks up the embeddings for all ids in that row, multiplies them by the 824 corresponding weight, and combines these embeddings as specified. 825 826 In other words, if 827 828 `shape(combined embedding_weights) = [p0, p1, ..., pm]` 829 830 and 831 832 `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]` 833 834 then 835 836 `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`. 837 838 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 839 840 ```python 841 [0, 0]: id 1, weight 2.0 842 [0, 1]: id 3, weight 0.5 843 [1, 0]: id -1, weight 1.0 844 [2, 3]: id 1, weight 3.0 845 ``` 846 847 `default_id` is 0. 848 849 with `combiner`="mean", then the output will be a 3x20 matrix where 850 851 ```python 852 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 853 output[1, :] = (params[0, :] * 1.0) / 1.0 854 output[2, :] = (params[1, :] * 3.0) / 3.0 855 ``` 856 857 Raises: 858 ValueError: if `embedding_weights` is empty. 859 """ 860 if embedding_weights is None: 861 raise ValueError("Missing embedding_weights %s." % embedding_weights) 862 if isinstance(embedding_weights, variables.PartitionedVariable): 863 embedding_weights = list(embedding_weights) # get underlying Variables. 864 if not isinstance(embedding_weights, list): 865 embedding_weights = [embedding_weights] 866 if len(embedding_weights) < 1: 867 raise ValueError("Missing embedding_weights %s." % embedding_weights) 868 869 dtype = sparse_weights.dtype if sparse_weights is not None else None 870 embedding_weights = [ 871 w if (isinstance(w, resource_variable_ops.ResourceVariable) 872 and dtype in (None, w.dtype)) 873 else ops.convert_to_tensor(w, dtype=dtype) 874 for w in embedding_weights 875 ] 876 877 with ops.name_scope(name, "embedding_lookup", embedding_weights + 878 [sparse_ids, sparse_weights]) as scope: 879 # Reshape higher-rank sparse ids and weights to linear segment ids. 880 original_shape = sparse_ids.dense_shape 881 original_rank_dim = tensor_shape.dimension_value( 882 sparse_ids.dense_shape.get_shape()[0]) 883 original_rank = ( 884 array_ops.size(original_shape) 885 if original_rank_dim is None else original_rank_dim) 886 sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [ 887 math_ops.reduce_prod( 888 array_ops.slice(original_shape, [0], [original_rank - 1])), 889 array_ops.gather(original_shape, original_rank - 1) 890 ]) 891 if sparse_weights is not None: 892 sparse_weights = sparse_tensor.SparseTensor(sparse_ids.indices, 893 sparse_weights.values, 894 sparse_ids.dense_shape) 895 896 # Prune invalid ids and weights. 897 sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights) 898 if combiner != "sum": 899 sparse_ids, sparse_weights = _prune_invalid_weights( 900 sparse_ids, sparse_weights) 901 902 # Fill in dummy values for empty features, if necessary. 903 sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows( 904 sparse_ids, default_id or 0) 905 if sparse_weights is not None: 906 sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0) 907 908 result = embedding_lookup_sparse( 909 embedding_weights, 910 sparse_ids, 911 sparse_weights, 912 combiner=combiner, 913 partition_strategy=partition_strategy, 914 name=None if default_id is None else scope, 915 max_norm=max_norm) 916 917 if default_id is None: 918 # Broadcast is_row_empty to the same shape as embedding_lookup_result, 919 # for use in Select. 920 is_row_empty = array_ops.tile( 921 array_ops.reshape(is_row_empty, [-1, 1]), 922 array_ops.stack([1, array_ops.shape(result)[1]])) 923 924 result = array_ops.where( 925 is_row_empty, array_ops.zeros_like(result), result, name=scope) 926 927 # Reshape back from linear ids back into higher-dimensional dense result. 928 final_result = array_ops.reshape( 929 result, 930 array_ops.concat([ 931 array_ops.slice( 932 math_ops.cast(original_shape, dtypes.int32), [0], 933 [original_rank - 1]), 934 array_ops.slice(array_ops.shape(result), [1], [-1]) 935 ], 0)) 936 final_result.set_shape( 937 tensor_shape.unknown_shape( 938 (tensor_shape.Dimension(original_rank_dim) - 1).value).concatenate( 939 result.get_shape()[1:])) 940 return final_result 941 942 943def embedding_lookup_ragged(embedding_weights, 944 ragged_ids, 945 partition_strategy="mod", 946 max_norm=None, 947 name=None): 948 """Look up the ragged ids in a list of embedding tensors. 949 950 Args: 951 embedding_weights: A tensor representing the complete embedding tensor 952 having the shape [e1, ...eM] 953 ragged_ids: A 'RaggedTensor' with type 'int32' or 'int64' containing the ids 954 to be looked up in 'embedding_weights' of shape [r0, ..rN]. Values must be 955 in the range '[0, embedding_weights.shape[0]]'. 956 partition_strategy: A string specifying the partitioning strategy. 957 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 958 than this value. 959 name: A name for the operation (optional) 960 961 Returns: 962 A ragged tensor of shape [r0, r1, ...rN, e1, ...eM]. 963 964 Raises: 965 ValueError: whether the embedding_weights is empty or the ragged_ids is 966 not a RaggedTensor. 967 """ 968 if embedding_weights is None: 969 raise ValueError("The embedding weights must be specified.") 970 if isinstance(embedding_weights, (list, tuple)) and not embedding_weights: 971 raise ValueError("The embedding weights should not be empty.") 972 if ragged_ids.dtype != dtypes.int32 and ragged_ids.dtype != dtypes.int64: 973 raise ValueError("The values contained by the inputs have type " + 974 str(ragged_ids.dtype) + 975 " and cannot be processed. All values" 976 " should be indices, either of type `in32` or `int64`.") 977 978 with ops.name_scope(name, "embedding_lookup_ragged") as name: 979 looked_up_ragged = ragged_functional_ops.map_flat_values( 980 embedding_lookup, 981 params=embedding_weights, 982 ids=ragged_ids, 983 partition_strategy=partition_strategy, 984 max_norm=max_norm) 985 986 return looked_up_ragged 987 988 989def _prune_invalid_ids(sparse_ids, sparse_weights): 990 """Prune invalid IDs (< 0) from the input ids and weights.""" 991 is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) 992 if sparse_weights is not None: 993 is_id_valid = math_ops.logical_and( 994 is_id_valid, 995 array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool)) 996 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) 997 if sparse_weights is not None: 998 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) 999 return sparse_ids, sparse_weights 1000 1001 1002def _prune_invalid_weights(sparse_ids, sparse_weights): 1003 """Prune invalid weights (< 0) from the input ids and weights.""" 1004 if sparse_weights is not None: 1005 is_weights_valid = math_ops.greater(sparse_weights.values, 0) 1006 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid) 1007 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid) 1008 return sparse_ids, sparse_weights 1009