1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operations for embeddings.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from six.moves import xrange # pylint: disable=redefined-builtin 21 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import sparse_tensor 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import clip_ops 29# Imports gradient definitions. 30from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import 31from tensorflow.python.ops import data_flow_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import resource_variable_ops 34from tensorflow.python.ops import sparse_ops 35from tensorflow.python.ops import variables 36from tensorflow.python.ops.ragged import ragged_functional_ops 37from tensorflow.python.ops.ragged import ragged_tensor 38from tensorflow.python.util import dispatch 39from tensorflow.python.util.tf_export import tf_export 40 41 42def _clip(params, ids, max_norm): 43 """Helper function for _embedding_lookup_and_transform. 44 45 This function optionally clips embeddings to an l2-norm of max_norm. 46 47 Args: 48 params: A `Tensor` of embeddings retrieved by `gather`. 49 ids: The `ids` argument that was passed to `gather`. 50 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 51 than this value. 52 53 Returns: 54 A `Tensor` with the same type as `params`. 55 """ 56 57 def _rank(x): 58 """Helper function to retrieve the rank of a tensor. 59 60 Args: 61 x: Something convertible to `Tensor`. 62 63 Returns: 64 Either a pair `(rank, True)` where `rank` is an integer or a pair 65 `(rank, False)` where `rank` is an integer `Tensor`. In either case, 66 `rank` is the rank of `x`. 67 """ 68 rank = ops.convert_to_tensor(x).get_shape().ndims 69 if rank: 70 return rank, True 71 else: 72 return array_ops.rank(x), False 73 74 if max_norm is None: 75 return params 76 ids_rank, ids_static = _rank(ids) 77 params_rank, params_static = _rank(params) 78 return clip_ops.clip_by_norm( 79 params, 80 max_norm, 81 axes=(list(range(ids_rank, params_rank)) if ids_static and params_static 82 else math_ops.range(ids_rank, params_rank))) 83 84 85def _colocate_with(param): 86 if ops.inside_function() and hasattr(param, "handle"): 87 # The `ops.colocate_with` will hard-code a device string if `param.device` 88 # is known, which will then break serving. We capture it here so that it 89 # produces a tensor without a device. 90 return ops.colocate_with(ops.get_default_graph().capture(param.handle)) 91 else: 92 return ops.colocate_with(param) 93 94 95def _embedding_lookup_and_transform(params, 96 ids, 97 partition_strategy="mod", 98 name=None, 99 max_norm=None, 100 transform_fn=None): 101 """Helper function for embedding_lookup and _compute_sampled_logits. 102 103 This function is a generalization of embedding_lookup that optionally 104 applies a caller-specified transformation to each embedding. This is 105 done through the `transform_fn` argument. If provided, the function is 106 applied to each partitioned tensor of retrieved embeddings, colocated 107 with the embeddings. This function will be called with a single `Tensor` 108 argument of the same type as the `params` tensor and should return a 109 `Tensor`. The shape of the argument will be the same as `params` except 110 for the size of the first dimension. The first dimension of the result's 111 shape must be the same size as the argument's. 112 113 Args: 114 params: See embedding_lookup. 115 ids: See embedding_lookup. 116 partition_strategy: See embedding_lookup. 117 name: See embedding_lookup. 118 max_norm: See embedding_lookup. 119 transform_fn: An optional function to apply to each retrieved embedding. If 120 max_norm is provided, transform_fn is applied to the norm-limited 121 embeddings. 122 123 Returns: 124 See embedding_lookup for details. 125 Raises: 126 ValueError: If `params` is empty. 127 """ 128 if params is None: 129 raise ValueError("params must be specified") 130 if isinstance(params, (list, tuple)) and not params: 131 raise ValueError("Need at least one param") 132 if isinstance(params, variables.PartitionedVariable): 133 params = list(params) # Iterate to get the underlying Variables. 134 if not isinstance(params, list): 135 params = [params] 136 137 with ops.name_scope(name, "embedding_lookup", params + [ids]) as name: 138 np = len(params) # Number of partitions 139 # Preserve the resource variable status to avoid accidental dense reads. 140 if not any( 141 isinstance(p, resource_variable_ops.BaseResourceVariable) 142 for p in params): 143 params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") 144 ids = ops.convert_to_tensor(ids, name="ids") 145 if np == 1 and (not transform_fn or ids.get_shape().ndims == 1): 146 with _colocate_with(params[0]): 147 result = _clip( 148 array_ops.gather(params[0], ids, name=name), ids, max_norm) 149 if transform_fn: 150 result = transform_fn(result) 151 # Make sure the final result does not have colocation constraints on the 152 # params. Similar to the case np > 1 where parallel_dynamic_stitch is 153 # outside the scope of all with _colocate_with(params[p]). 154 return array_ops.identity(result) 155 else: 156 # Flatten the ids. There are two cases where we need to do this. 157 # - There is more than one params tensor. 158 # - There is a transform_fn and ids is not statically known to be 1-D. 159 # We must flatten in this case because transform_fn expects a flat 160 # tensor of embeddings. 161 flat_ids = array_ops.reshape(ids, [-1]) 162 original_indices = math_ops.range(array_ops.size(flat_ids)) 163 164 # Create p_assignments and set new_ids depending on the strategy. 165 if partition_strategy == "mod": 166 p_assignments = flat_ids % np 167 new_ids = flat_ids // np 168 elif partition_strategy == "div": 169 # Compute num_total_ids as the sum of dim-0 of params, then assign to 170 # partitions based on a constant number of ids per partition. Optimize 171 # if we already know the full shape statically. 172 dim_0_size = tensor_shape.Dimension( 173 tensor_shape.dimension_value(params[0].get_shape()[0])) 174 for p in xrange(1, np): 175 dim_0_size += tensor_shape.Dimension( 176 tensor_shape.dimension_value(params[p].get_shape()[0])) 177 if dim_0_size.value: 178 num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype) 179 else: 180 dim_0_sizes = [] 181 for p in xrange(np): 182 param_p_dim = tensor_shape.dimension_value(params[p].get_shape()[0]) 183 if param_p_dim is not None: 184 dim_0_sizes.append(param_p_dim) 185 else: 186 with _colocate_with(params[p]): 187 dim_0_sizes.append(array_ops.shape(params[p])[0]) 188 num_total_ids = math_ops.reduce_sum( 189 math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) 190 ids_per_partition = num_total_ids // np 191 extras = num_total_ids % np 192 193 p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1), 194 (flat_ids - extras) // 195 ids_per_partition) 196 197 # Emulate a conditional using a boolean indicator tensor 198 new_ids = array_ops.where(p_assignments < extras, 199 flat_ids % (ids_per_partition + 1), 200 (flat_ids - extras) % ids_per_partition) 201 else: 202 raise ValueError("Unrecognized partition strategy: " + 203 partition_strategy) 204 205 # Cast partition assignments to int32 for use in dynamic_partition. 206 # There really should not be more than 2^32 partitions. 207 p_assignments = math_ops.cast(p_assignments, dtypes.int32) 208 # Partition list of ids based on assignments into np separate lists 209 gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np) 210 # Similarly, partition the original indices. 211 pindices = data_flow_ops.dynamic_partition(original_indices, 212 p_assignments, np) 213 # Do np separate lookups, finding embeddings for plist[p] in params[p] 214 partitioned_result = [] 215 for p in xrange(np): 216 pids = gather_ids[p] 217 with ops.device_v2(None): 218 with _colocate_with(params[p]): 219 result = array_ops.gather(params[p], pids) 220 if transform_fn: 221 # If transform_fn is provided, the clip_by_norm precedes 222 # the transform and hence must be co-located. See below 223 # for the counterpart if transform_fn is not provided. 224 result = transform_fn(_clip(result, pids, max_norm)) 225 partitioned_result.append(result) 226 # Stitch these back together 227 ret = data_flow_ops.parallel_dynamic_stitch( 228 pindices, partitioned_result, name=name) 229 230 # Determine the static element shape. 231 if transform_fn is None: 232 element_shape_s = params[0].get_shape()[1:] 233 for p in params[1:]: 234 element_shape_s = element_shape_s.merge_with(p.get_shape()[1:]) 235 else: 236 element_shape_s = ret.get_shape()[1:] 237 238 # Compute the dynamic element shape. 239 if element_shape_s.is_fully_defined(): 240 element_shape_d = element_shape_s 241 elif transform_fn is None: 242 # It's important that we compute params[0].shape on the right device 243 # to avoid data motion. 244 with _colocate_with(params[0]): 245 params_shape = array_ops.shape(params[0]) 246 element_shape_d = params_shape[1:] 247 else: 248 element_shape_d = array_ops.shape(ret)[1:] 249 250 # Reshape to reverse the flattening of ids. 251 ret = array_ops.reshape( 252 ret, array_ops.concat([array_ops.shape(ids), element_shape_d], 0)) 253 254 # Normally the reshape is sufficient, but setting shape explicitly 255 # teaches shape inference that params[1:].get_shape() matters 256 # (in the case that transform_fn is None). 257 ret.set_shape(ids.get_shape().concatenate(element_shape_s)) 258 if not transform_fn: 259 # If transform_fn was provided, the clip_by_norm was done above. 260 ret = _clip(ret, ids, max_norm) 261 return ret 262 263 264@tf_export(v1=["nn.embedding_lookup"]) 265@dispatch.add_dispatch_support 266def embedding_lookup( 267 params, 268 ids, 269 partition_strategy="mod", 270 name=None, 271 validate_indices=True, # pylint: disable=unused-argument 272 max_norm=None): 273 """Looks up embeddings for the given `ids` from a list of tensors. 274 275 This function is used to perform parallel lookups on the list of tensors in 276 `params`. It is a generalization of `tf.gather`, where `params` is 277 interpreted as a partitioning of a large embedding tensor. `params` may be 278 a `PartitionedVariable` as returned by using `tf.compat.v1.get_variable()` 279 with a partitioner. 280 281 If `len(params) > 1`, each element `id` of `ids` is partitioned between 282 the elements of `params` according to the `partition_strategy`. 283 In all strategies, if the id space does not evenly divide the number of 284 partitions, each of the first `(max_id + 1) % len(params)` partitions will 285 be assigned one more id. 286 287 If `partition_strategy` is `"mod"`, we assign each id to partition 288 `p = id % len(params)`. For instance, 289 13 ids are split across 5 partitions as: 290 `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]` 291 292 If `partition_strategy` is `"div"`, we assign ids to partitions in a 293 contiguous manner. In this case, 13 ids are split across 5 partitions as: 294 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]` 295 296 If the input ids are ragged tensors, partition variables are not supported and 297 the partition strategy and the max_norm are ignored. 298 The results of the lookup are concatenated into a dense 299 tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. 300 301 Args: 302 params: A single tensor representing the complete embedding tensor, or a 303 list of P tensors all of same shape except for the first dimension, 304 representing sharded embedding tensors. Alternatively, a 305 `PartitionedVariable`, created by partitioning along dimension 0. Each 306 element must be appropriately sized for the given `partition_strategy`. 307 ids: A `Tensor` or a 'RaggedTensor' with type `int32` or `int64` containing 308 the ids to be looked up in `params`. 309 partition_strategy: A string specifying the partitioning strategy, relevant 310 if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default 311 is `"mod"`. 312 name: A name for the operation (optional). 313 validate_indices: DEPRECATED. If this operation is assigned to CPU, values 314 in `indices` are always validated to be within range. If assigned to GPU, 315 out-of-bound indices result in safe but unspecified behavior, which may 316 include raising an error. 317 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 318 than this value. 319 320 Returns: 321 A `Tensor` or a 'RaggedTensor', depending on the input, with the same type 322 as the tensors in `params`. 323 324 Raises: 325 ValueError: If `params` is empty. 326 """ 327 if isinstance(ids, ragged_tensor.RaggedTensor): 328 return embedding_lookup_ragged(params, ids, 329 partition_strategy=partition_strategy, 330 max_norm=max_norm, 331 name=name) 332 333 return _embedding_lookup_and_transform( 334 params=params, 335 ids=ids, 336 partition_strategy=partition_strategy, 337 name=name, 338 max_norm=max_norm, 339 transform_fn=None) 340 341 342@tf_export("nn.embedding_lookup", v1=[]) 343@dispatch.add_dispatch_support 344def embedding_lookup_v2(params, ids, max_norm=None, name=None): 345 """Looks up embeddings for the given `ids` from a list of tensors. 346 347 This function is used to perform parallel lookups on the list of tensors in 348 `params`. It is a generalization of `tf.gather`, where `params` is 349 interpreted as a partitioning of a large embedding tensor. 350 351 If `len(params) > 1`, each element `id` of `ids` is partitioned between the 352 elements of `params` according to the "div" partition strategy, which means we 353 assign ids to partitions in a contiguous manner. For instance, 13 ids are 354 split across 5 partitions as: 355 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. 356 357 If the id space does not evenly divide the number of partitions, each of the 358 first `(max_id + 1) % len(params)` partitions will be assigned one more id. 359 360 The results of the lookup are concatenated into a dense 361 tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. 362 363 Args: 364 params: A single tensor representing the complete embedding tensor, or a 365 list of tensors all of same shape except for the first dimension, 366 representing sharded embedding tensors following "div" partition strategy. 367 ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked 368 up in `params`. 369 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 370 than this value. 371 name: A name for the operation (optional). 372 373 Returns: 374 A `Tensor` with the same type as the tensors in `params`. 375 376 For instance, if `params` is a 5x2 matrix: 377 378 ```python 379 [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] 380 ``` 381 382 or a list of matrices: 383 384 ```python 385 params[0]: [[1, 2], [3, 4]] 386 params[1]: [[5, 6], [7, 8]] 387 params[2]: [[9, 10]] 388 ``` 389 390 and `ids` is: 391 392 ```python 393 [0, 3, 4] 394 ``` 395 396 The output will be a 3x2 matrix: 397 398 ```python 399 [[1, 2], [7, 8], [9, 10]] 400 ``` 401 402 Raises: 403 ValueError: If `params` is empty. 404 """ 405 return embedding_lookup(params, ids, "div", name, max_norm=max_norm) 406 407 408@tf_export(v1=["nn.embedding_lookup_sparse"]) 409@dispatch.add_dispatch_support 410def embedding_lookup_sparse(params, 411 sp_ids, 412 sp_weights, 413 partition_strategy="mod", 414 name=None, 415 combiner=None, 416 max_norm=None): 417 """Looks up embeddings for the given ids and weights from a list of tensors. 418 419 This op assumes that there is at least one id for each row in the dense tensor 420 represented by sp_ids (i.e. there are no rows with empty features), and that 421 all the indices of sp_ids are in canonical row-major order. 422 423 `sp_ids` and `sp_weights` (if not None) are `SparseTensor`s with rank of 2. 424 Embeddings are always aggregated along the last dimension. 425 426 It also assumes that all id values lie in the range [0, p0), where p0 427 is the sum of the size of params along dimension 0. 428 429 Args: 430 params: A single tensor representing the complete embedding tensor, or a 431 list tensors all of same shape except for the first dimension, 432 representing sharded embedding tensors. Alternatively, a 433 `PartitionedVariable`, created by partitioning along dimension 0. Each 434 element must be appropriately sized for the given `partition_strategy`. 435 sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size 436 and M is arbitrary. 437 sp_weights: either a `SparseTensor` of float / double weights, or `None` to 438 indicate all weights should be taken to be 1. If specified, `sp_weights` 439 must have exactly the same shape and indices as `sp_ids`. 440 partition_strategy: A string specifying the partitioning strategy, relevant 441 if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default 442 is `"mod"`. See `tf.nn.embedding_lookup` for more details. 443 name: Optional name for the op. 444 combiner: A string specifying the reduction op. Currently "mean", "sqrtn" 445 and "sum" are supported. "sum" computes the weighted sum of the embedding 446 results for each row. "mean" is the weighted sum divided by the total 447 weight. "sqrtn" is the weighted sum divided by the square root of the sum 448 of the squares of the weights. Defaults to `mean`. 449 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 450 than this value, before combining. 451 452 Returns: 453 A dense tensor representing the combined embeddings for the 454 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op 455 looks up the embeddings for all ids in that row, multiplies them by the 456 corresponding weight, and combines these embeddings as specified. 457 458 In other words, if 459 460 `shape(combined params) = [p0, p1, ..., pm]` 461 462 and 463 464 `shape(sp_ids) = shape(sp_weights) = [d0, d1]` 465 466 then 467 468 `shape(output) = [d0, p1, ..., pm]`. 469 470 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 471 472 ```python 473 [0, 0]: id 1, weight 2.0 474 [0, 1]: id 3, weight 0.5 475 [1, 0]: id 0, weight 1.0 476 [2, 3]: id 1, weight 3.0 477 ``` 478 479 with `combiner`="mean", then the output will be a 3x20 matrix where 480 481 ```python 482 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 483 output[1, :] = (params[0, :] * 1.0) / 1.0 484 output[2, :] = (params[1, :] * 3.0) / 3.0 485 ``` 486 487 Raises: 488 TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is 489 neither `None` nor `SparseTensor`. 490 ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}. 491 """ 492 if combiner is None: 493 combiner = "mean" 494 if combiner not in ("mean", "sqrtn", "sum"): 495 raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'") 496 if isinstance(params, variables.PartitionedVariable): 497 params = list(params) # Iterate to get the underlying Variables. 498 if not isinstance(params, list): 499 params = [params] 500 if not isinstance(sp_ids, sparse_tensor.SparseTensor): 501 raise TypeError("sp_ids must be SparseTensor") 502 ignore_weights = sp_weights is None 503 if not ignore_weights: 504 if not isinstance(sp_weights, sparse_tensor.SparseTensor): 505 raise TypeError("sp_weights must be either None or SparseTensor") 506 sp_ids.values.get_shape().assert_is_compatible_with( 507 sp_weights.values.get_shape()) 508 sp_ids.indices.get_shape().assert_is_compatible_with( 509 sp_weights.indices.get_shape()) 510 sp_ids.dense_shape.get_shape().assert_is_compatible_with( 511 sp_weights.dense_shape.get_shape()) 512 # TODO(yleon): Add enhanced node assertions to verify that sp_ids and 513 # sp_weights have equal indices and shapes. 514 515 with ops.name_scope(name, "embedding_lookup_sparse", 516 params + [sp_ids]) as name: 517 segment_ids = sp_ids.indices[:, 0] 518 519 ids = sp_ids.values 520 ids, idx = array_ops.unique(ids) 521 522 embeddings = embedding_lookup( 523 params, ids, partition_strategy=partition_strategy, max_norm=max_norm) 524 if not ignore_weights: 525 if segment_ids.dtype != dtypes.int32: 526 segment_ids = math_ops.cast(segment_ids, dtypes.int32) 527 528 weights = sp_weights.values 529 embeddings = array_ops.gather(embeddings, idx) 530 531 original_dtype = embeddings.dtype 532 if embeddings.dtype in (dtypes.float16, dtypes.bfloat16): 533 # Cast low-precision embeddings to float32 during the computation to 534 # avoid numerical issues. 535 embeddings = math_ops.cast(embeddings, dtypes.float32) 536 if weights.dtype != embeddings.dtype: 537 weights = math_ops.cast(weights, embeddings.dtype) 538 539 # Reshape weights to allow broadcast 540 ones_shape = array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0) 541 ones = array_ops.ones(ones_shape, dtype=dtypes.int32) 542 bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 543 0) 544 545 orig_weights_shape = weights.get_shape() 546 weights = array_ops.reshape(weights, bcast_weights_shape) 547 548 # Set the weight shape, since after reshaping to bcast_weights_shape, 549 # the shape becomes None. 550 if embeddings.get_shape().ndims is not None: 551 weights.set_shape( 552 orig_weights_shape.concatenate( 553 [1 for _ in range(embeddings.get_shape().ndims - 1)])) 554 555 embeddings *= weights 556 557 if combiner == "sum": 558 embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name) 559 elif combiner == "mean": 560 embeddings = math_ops.segment_sum(embeddings, segment_ids) 561 weight_sum = math_ops.segment_sum(weights, segment_ids) 562 embeddings = math_ops.div_no_nan(embeddings, weight_sum, name=name) 563 elif combiner == "sqrtn": 564 embeddings = math_ops.segment_sum(embeddings, segment_ids) 565 weights_squared = math_ops.pow(weights, 2) 566 weight_sum = math_ops.segment_sum(weights_squared, segment_ids) 567 weight_sum_sqrt = math_ops.sqrt(weight_sum) 568 embeddings = math_ops.div_no_nan(embeddings, weight_sum_sqrt, name=name) 569 else: 570 assert False, "Unrecognized combiner" 571 if embeddings.dtype != original_dtype: 572 embeddings = math_ops.cast(embeddings, original_dtype) 573 else: 574 if segment_ids.dtype not in (dtypes.int32, dtypes.int64): 575 segment_ids = math_ops.cast(segment_ids, dtypes.int32) 576 assert idx is not None 577 if combiner == "sum": 578 embeddings = math_ops.sparse_segment_sum( 579 embeddings, idx, segment_ids, name=name) 580 elif combiner == "mean": 581 embeddings = math_ops.sparse_segment_mean( 582 embeddings, idx, segment_ids, name=name) 583 elif combiner == "sqrtn": 584 embeddings = math_ops.sparse_segment_sqrt_n( 585 embeddings, idx, segment_ids, name=name) 586 else: 587 assert False, "Unrecognized combiner" 588 589 return embeddings 590 591 592@tf_export("nn.embedding_lookup_sparse", v1=[]) 593@dispatch.add_dispatch_support 594def embedding_lookup_sparse_v2(params, 595 sp_ids, 596 sp_weights, 597 combiner=None, 598 max_norm=None, 599 name=None): 600 """Looks up embeddings for the given ids and weights from a list of tensors. 601 602 This op assumes that there is at least one id for each row in the dense tensor 603 represented by sp_ids (i.e. there are no rows with empty features), and that 604 all the indices of sp_ids are in canonical row-major order. 605 606 `sp_ids` and `sp_weights` (if not None) are `SparseTensor`s with rank of 2. 607 Embeddings are always aggregated along the last dimension. 608 609 It also assumes that all id values lie in the range [0, p0), where p0 610 is the sum of the size of params along dimension 0. 611 612 If `len(params) > 1`, each element of `sp_ids` is partitioned between the 613 elements of `params` according to the "div" partition strategy, which means we 614 assign ids to partitions in a contiguous manner. For instance, 13 ids are 615 split across 5 partitions as: 616 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. 617 618 If the id space does not evenly divide the number of partitions, each of the 619 first `(max_id + 1) % len(params)` partitions will be assigned one more id. 620 621 Args: 622 params: A single tensor representing the complete embedding tensor, or a 623 list of tensors all of same shape except for the first dimension, 624 representing sharded embedding tensors following "div" partition strategy. 625 sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size 626 and M is arbitrary. 627 sp_weights: either a `SparseTensor` of float / double weights, or `None` to 628 indicate all weights should be taken to be 1. If specified, `sp_weights` 629 must have exactly the same shape and indices as `sp_ids`. 630 combiner: A string specifying the reduction op. Currently "mean", "sqrtn" 631 and "sum" are supported. "sum" computes the weighted sum of the embedding 632 results for each row. "mean" is the weighted sum divided by the total 633 weight. "sqrtn" is the weighted sum divided by the square root of the sum 634 of the squares of the weights. Defaults to `mean`. 635 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 636 than this value, before combining. 637 name: Optional name for the op. 638 639 Returns: 640 A dense tensor representing the combined embeddings for the 641 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op 642 looks up the embeddings for all ids in that row, multiplies them by the 643 corresponding weight, and combines these embeddings as specified. 644 645 In other words, if 646 647 `shape(combined params) = [p0, p1, ..., pm]` 648 649 and 650 651 `shape(sp_ids) = shape(sp_weights) = [d0, d1]` 652 653 then 654 655 `shape(output) = [d0, p1, ..., pm]`. 656 657 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 658 659 ```python 660 [0, 0]: id 1, weight 2.0 661 [0, 1]: id 3, weight 0.5 662 [1, 0]: id 0, weight 1.0 663 [2, 3]: id 1, weight 3.0 664 ``` 665 666 with `combiner`="mean", then the output will be a 3x20 matrix where 667 668 ```python 669 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 670 output[1, :] = (params[0, :] * 1.0) / 1.0 671 output[2, :] = (params[1, :] * 3.0) / 3.0 672 ``` 673 674 Raises: 675 TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is 676 neither `None` nor `SparseTensor`. 677 ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}. 678 """ 679 return embedding_lookup_sparse(params, sp_ids, sp_weights, "div", name, 680 combiner, max_norm) 681 682 683@tf_export("nn.safe_embedding_lookup_sparse", v1=[]) 684@dispatch.add_dispatch_support 685def safe_embedding_lookup_sparse_v2(embedding_weights, 686 sparse_ids, 687 sparse_weights=None, 688 combiner="mean", 689 default_id=None, 690 max_norm=None, 691 name=None): 692 """Lookup embedding results, accounting for invalid IDs and empty features. 693 694 The partitioned embedding in `embedding_weights` must all be the same shape 695 except for the first dimension. The first dimension is allowed to vary as the 696 vocabulary size is not necessarily a multiple of num of shards. 697 698 Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs 699 with non-positive weight. For an entry with no features, the embedding vector 700 for `default_id` is returned, or the 0-vector if `default_id` is not supplied. 701 702 The ids and weights may be multi-dimensional. Embeddings are always aggregated 703 along the last dimension. 704 705 If `len(embedding_weights) > 1`, each element `id` of `ids` is partitioned 706 between the elements of `embedding_weights` according to the "div" partition 707 strategy, which means we assign ids to partitions in a contiguous manner. For 708 instance, 13 ids are split across 5 partitions as: 709 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. 710 711 If the id space does not evenly divide the number of partitions, each of the 712 first `(max_id + 1) % len(embedding_weights)` partitions will be assigned one 713 more id. 714 715 Args: 716 embedding_weights: A single tensor representing the complete embedding 717 tensor, or a list of tensors all of same shape except for the first 718 dimension, representing sharded embedding tensors following "div" 719 partition strategy. 720 sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the 721 ids. `d_0` is typically batch size. 722 sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing 723 float weights corresponding to `sparse_ids`, or `None` if all weights are 724 be assumed to be 1.0. 725 combiner: A string specifying how to combine embedding results for each 726 entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the 727 default. 728 default_id: The id to use for an entry with no features. Defaults to 729 0-vector. 730 max_norm: If not `None`, all embeddings are l2-normalized to max_norm before 731 combining. 732 name: A name for this operation (optional). 733 734 Returns: 735 A dense tensor representing the combined embeddings for the 736 sparse ids. For each row in the dense tensor represented by `sparse_ids`, 737 the op looks up the embeddings for all ids in that row, multiplies them by 738 the corresponding weight, and combines these embeddings as specified. 739 740 In other words, if 741 742 `shape(combined embedding_weights) = [p0, p1, ..., pm]` 743 744 and 745 746 `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]` 747 748 then 749 750 `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`. 751 752 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 753 754 ```python 755 [0, 0]: id 1, weight 2.0 756 [0, 1]: id 3, weight 0.5 757 [1, 0]: id -1, weight 1.0 758 [2, 3]: id 1, weight 3.0 759 ``` 760 761 `default_id` is 0. 762 763 with `combiner`="mean", then the output will be a 3x20 matrix where 764 765 ```python 766 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 767 output[1, :] = (params[0, :] * 1.0) / 1.0 768 output[2, :] = (params[1, :] * 3.0) / 3.0 769 ``` 770 771 Raises: 772 ValueError: if `embedding_weights` is empty. 773 """ 774 return safe_embedding_lookup_sparse( 775 embedding_weights, 776 sparse_ids, 777 sparse_weights=sparse_weights, 778 combiner=combiner, 779 default_id=default_id, 780 name=name, 781 partition_strategy="div", 782 max_norm=max_norm) 783 784 785@tf_export(v1=["nn.safe_embedding_lookup_sparse"]) 786@dispatch.add_dispatch_support 787def safe_embedding_lookup_sparse(embedding_weights, 788 sparse_ids, 789 sparse_weights=None, 790 combiner="mean", 791 default_id=None, 792 name=None, 793 partition_strategy="div", 794 max_norm=None): 795 """Lookup embedding results, accounting for invalid IDs and empty features. 796 797 The partitioned embedding in `embedding_weights` must all be the same shape 798 except for the first dimension. The first dimension is allowed to vary as the 799 vocabulary size is not necessarily a multiple of `P`. `embedding_weights` 800 may be a `PartitionedVariable` as returned by using 801 `tf.compat.v1.get_variable()` with a 802 partitioner. 803 804 Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs 805 with non-positive weight. For an entry with no features, the embedding vector 806 for `default_id` is returned, or the 0-vector if `default_id` is not supplied. 807 808 The ids and weights may be multi-dimensional. Embeddings are always aggregated 809 along the last dimension. 810 811 Args: 812 embedding_weights: A single tensor representing the complete embedding 813 tensor, or a list tensors all of same shape except for the first 814 dimension, representing sharded embedding tensors. Alternatively, a 815 `PartitionedVariable`, created by partitioning along dimension 0. Each 816 element must be appropriately sized for the given `partition_strategy`. 817 sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the 818 ids. `d_0` is typically batch size. 819 sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing 820 float weights corresponding to `sparse_ids`, or `None` if all weights are 821 be assumed to be 1.0. 822 combiner: A string specifying how to combine embedding results for each 823 entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the 824 default. 825 default_id: The id to use for an entry with no features. 826 name: A name for this operation (optional). 827 partition_strategy: A string specifying the partitioning strategy. Currently 828 `"div"` and `"mod"` are supported. Default is `"div"`. 829 max_norm: If not `None`, all embeddings are l2-normalized to max_norm before 830 combining. 831 832 Returns: 833 A dense tensor representing the combined embeddings for the 834 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op 835 looks up the embeddings for all ids in that row, multiplies them by the 836 corresponding weight, and combines these embeddings as specified. 837 838 In other words, if 839 840 `shape(combined embedding_weights) = [p0, p1, ..., pm]` 841 842 and 843 844 `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]` 845 846 then 847 848 `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`. 849 850 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 851 852 ```python 853 [0, 0]: id 1, weight 2.0 854 [0, 1]: id 3, weight 0.5 855 [1, 0]: id -1, weight 1.0 856 [2, 3]: id 1, weight 3.0 857 ``` 858 859 `default_id` is 0. 860 861 with `combiner`="mean", then the output will be a 3x20 matrix where 862 863 ```python 864 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 865 output[1, :] = (params[0, :] * 1.0) / 1.0 866 output[2, :] = (params[1, :] * 3.0) / 3.0 867 ``` 868 869 Raises: 870 ValueError: if `embedding_weights` is empty. 871 """ 872 if embedding_weights is None: 873 raise ValueError("Missing embedding_weights %s." % embedding_weights) 874 if isinstance(embedding_weights, variables.PartitionedVariable): 875 embedding_weights = list(embedding_weights) # get underlying Variables. 876 if not isinstance(embedding_weights, list): 877 embedding_weights = [embedding_weights] 878 if len(embedding_weights) < 1: 879 raise ValueError("Missing embedding_weights %s." % embedding_weights) 880 881 dtype = sparse_weights.dtype if sparse_weights is not None else None 882 embedding_weights = [ 883 w if (isinstance(w, resource_variable_ops.ResourceVariable) 884 and dtype in (None, w.dtype)) 885 else ops.convert_to_tensor(w, dtype=dtype) 886 for w in embedding_weights 887 ] 888 889 with ops.name_scope(name, "embedding_lookup", embedding_weights + 890 [sparse_ids, sparse_weights]) as scope: 891 # Reshape higher-rank sparse ids and weights to linear segment ids. 892 original_shape = sparse_ids.dense_shape 893 original_rank_dim = tensor_shape.dimension_value( 894 sparse_ids.dense_shape.get_shape()[0]) 895 original_rank = ( 896 array_ops.size(original_shape) 897 if original_rank_dim is None else original_rank_dim) 898 sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [ 899 math_ops.reduce_prod( 900 array_ops.slice(original_shape, [0], [original_rank - 1])), 901 array_ops.gather(original_shape, original_rank - 1) 902 ]) 903 if sparse_weights is not None: 904 sparse_weights = sparse_tensor.SparseTensor(sparse_ids.indices, 905 sparse_weights.values, 906 sparse_ids.dense_shape) 907 908 # Prune invalid ids and weights. 909 sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights) 910 if combiner != "sum": 911 sparse_ids, sparse_weights = _prune_invalid_weights( 912 sparse_ids, sparse_weights) 913 914 # Fill in dummy values for empty features, if necessary. 915 sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows( 916 sparse_ids, default_id or 0) 917 if sparse_weights is not None: 918 sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0) 919 920 result = embedding_lookup_sparse( 921 embedding_weights, 922 sparse_ids, 923 sparse_weights, 924 combiner=combiner, 925 partition_strategy=partition_strategy, 926 name=None if default_id is None else scope, 927 max_norm=max_norm) 928 929 if default_id is None: 930 # Broadcast is_row_empty to the same shape as embedding_lookup_result, 931 # for use in Select. 932 is_row_empty = array_ops.tile( 933 array_ops.reshape(is_row_empty, [-1, 1]), 934 array_ops.stack([1, array_ops.shape(result)[1]])) 935 936 result = array_ops.where( 937 is_row_empty, array_ops.zeros_like(result), result, name=scope) 938 939 # Reshape back from linear ids back into higher-dimensional dense result. 940 final_result = array_ops.reshape( 941 result, 942 array_ops.concat([ 943 array_ops.slice( 944 math_ops.cast(original_shape, dtypes.int32), [0], 945 [original_rank - 1]), 946 array_ops.slice(array_ops.shape(result), [1], [-1]) 947 ], 0)) 948 final_result.set_shape( 949 tensor_shape.unknown_shape( 950 (tensor_shape.Dimension(original_rank_dim) - 1).value).concatenate( 951 result.get_shape()[1:])) 952 return final_result 953 954 955def embedding_lookup_ragged(embedding_weights, 956 ragged_ids, 957 partition_strategy="mod", 958 max_norm=None, 959 name=None): 960 """Look up the ragged ids in a list of embedding tensors. 961 962 Args: 963 embedding_weights: A tensor representing the complete embedding tensor 964 having the shape [e1, ...eM] 965 ragged_ids: A 'RaggedTensor' with type 'int32' or 'int64' containing the ids 966 to be looked up in 'embedding_weights' of shape [r0, ..rN]. Values must be 967 in the range '[0, embedding_weights.shape[0]]'. 968 partition_strategy: A string specifying the partitioning strategy. 969 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 970 than this value. 971 name: A name for the operation (optional) 972 973 Returns: 974 A ragged tensor of shape [r0, r1, ...rN, e1, ...eM]. 975 976 Raises: 977 ValueError: whether the embedding_weights is empty or the ragged_ids is 978 not a RaggedTensor. 979 """ 980 if embedding_weights is None: 981 raise ValueError("The embedding weights must be specified.") 982 if isinstance(embedding_weights, (list, tuple)) and not embedding_weights: 983 raise ValueError("The embedding weights should not be empty.") 984 if ragged_ids.dtype != dtypes.int32 and ragged_ids.dtype != dtypes.int64: 985 raise ValueError("The values contained by the inputs have type " + 986 str(ragged_ids.dtype) + 987 " and cannot be processed. All values" 988 " should be indices, either of type `in32` or `int64`.") 989 990 with ops.name_scope(name, "embedding_lookup_ragged") as name: 991 looked_up_ragged = ragged_functional_ops.map_flat_values( 992 embedding_lookup, 993 params=embedding_weights, 994 ids=ragged_ids, 995 partition_strategy=partition_strategy, 996 max_norm=max_norm) 997 998 return looked_up_ragged 999 1000 1001def _prune_invalid_ids(sparse_ids, sparse_weights): 1002 """Prune invalid IDs (< 0) from the input ids and weights.""" 1003 is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) 1004 if sparse_weights is not None: 1005 is_id_valid = math_ops.logical_and( 1006 is_id_valid, 1007 array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool)) 1008 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) 1009 if sparse_weights is not None: 1010 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) 1011 return sparse_ids, sparse_weights 1012 1013 1014def _prune_invalid_weights(sparse_ids, sparse_weights): 1015 """Prune invalid weights (< 0) from the input ids and weights.""" 1016 if sparse_weights is not None: 1017 is_weights_valid = math_ops.greater(sparse_weights.values, 0) 1018 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid) 1019 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid) 1020 return sparse_ids, sparse_weights 1021