1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operations for embeddings.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from six.moves import xrange # pylint: disable=redefined-builtin 21 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import sparse_tensor 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import clip_ops 29# Imports gradient definitions. 30from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import 31from tensorflow.python.ops import data_flow_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import resource_variable_ops 34from tensorflow.python.ops import sparse_ops 35from tensorflow.python.ops import variables 36from tensorflow.python.platform import tf_logging as logging 37from tensorflow.python.util.tf_export import tf_export 38 39 40def _clip(params, ids, max_norm): 41 """Helper function for _embedding_lookup_and_transform. 42 43 This function optionally clips embeddings to an l2-norm of max_norm. 44 45 Args: 46 params: A `Tensor` of embeddings retrieved by `gather`. 47 ids: The `ids` argument that was passed to `gather`. 48 max_norm: If not `None`, each embedding is clipped if its l2-norm is 49 larger than this value. 50 51 Returns: 52 A `Tensor` with the same type as `params`. 53 """ 54 55 def _rank(x): 56 """Helper function to retrieve the rank of a tensor. 57 58 Args: 59 x: Something convertible to `Tensor`. 60 61 Returns: 62 Either a pair `(rank, True)` where `rank` is an integer or a pair 63 `(rank, False)` where `rank` is an integer `Tensor`. In either case, 64 `rank` is the rank of `x`. 65 """ 66 rank = ops.convert_to_tensor(x).get_shape().ndims 67 if rank: 68 return rank, True 69 else: 70 return array_ops.rank(x), False 71 72 if max_norm is None: 73 return params 74 ids_rank, ids_static = _rank(ids) 75 params_rank, params_static = _rank(params) 76 return clip_ops.clip_by_norm( 77 params, 78 max_norm, 79 axes=(list(range(ids_rank, params_rank)) 80 if ids_static and params_static 81 else math_ops.range(ids_rank, params_rank))) 82 83 84def _embedding_lookup_and_transform(params, 85 ids, 86 partition_strategy="mod", 87 name=None, 88 max_norm=None, 89 transform_fn=None): 90 """Helper function for embedding_lookup and _compute_sampled_logits. 91 92 This function is a generalization of embedding_lookup that optionally 93 applies a caller-specified transformation to each embedding. This is 94 done through the `transform_fn` argument. If provided, the function is 95 applied to each partitioned tensor of retrieved embeddings, colocated 96 with the embeddings. This function will be called with a single `Tensor` 97 argument of the same type as the `params` tensor and should return a 98 `Tensor`. The shape of the argument will be the same as `params` except 99 for the size of the first dimension. The first dimension of the result's 100 shape must be the same size as the argument's. 101 102 Args: 103 params: See embedding_lookup. 104 ids: See embedding_lookup. 105 partition_strategy: See embedding_lookup. 106 name: See embedding_lookup. 107 max_norm: See embedding_lookup. 108 transform_fn: An optional function to apply to each retrieved embedding. 109 If max_norm is provided, transform_fn is applied to the norm-limited 110 embeddings. 111 112 Returns: 113 See embedding_lookup for details. 114 Raises: 115 ValueError: If `params` is empty. 116 """ 117 if params is None or params in ((), []): 118 raise ValueError("Need at least one param") 119 if isinstance(params, variables.PartitionedVariable): 120 params = list(params) # Iterate to get the underlying Variables. 121 if not isinstance(params, list): 122 params = [params] 123 124 with ops.name_scope(name, "embedding_lookup", params + [ids]) as name: 125 np = len(params) # Number of partitions 126 # Preserve the resource variable status to avoid accidental dense reads. 127 if not any( 128 isinstance(p, resource_variable_ops.ResourceVariable) for p in params): 129 params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") 130 ids = ops.convert_to_tensor(ids, name="ids") 131 if np == 1 and (not transform_fn or ids.get_shape().ndims == 1): 132 with ops.colocate_with(params[0]): 133 result = _clip(array_ops.gather(params[0], ids, name=name), 134 ids, max_norm) 135 if transform_fn: 136 result = transform_fn(result) 137 # Make sure the final result does not have colocation contraints on the 138 # params. Similar to the case np > 1 where parallel_dynamic_stitch is 139 # outside the scioe of all with ops.colocate_with(params[p]). 140 return array_ops.identity(result) 141 else: 142 # Flatten the ids. There are two cases where we need to do this. 143 # - There is more than one params tensor. 144 # - There is a transform_fn and ids is not statically known to be 1-D. 145 # We must flatten in this case because transform_fn expects a flat 146 # tensor of embeddings. 147 flat_ids = array_ops.reshape(ids, [-1]) 148 original_indices = math_ops.range(array_ops.size(flat_ids)) 149 150 # Create p_assignments and set new_ids depending on the strategy. 151 if partition_strategy == "mod": 152 p_assignments = flat_ids % np 153 new_ids = flat_ids // np 154 elif partition_strategy == "div": 155 # Compute num_total_ids as the sum of dim-0 of params, then assign to 156 # partitions based on a constant number of ids per partition. Optimize 157 # if we already know the full shape statically. 158 dim_0_size = tensor_shape.Dimension(tensor_shape.dimension_value( 159 params[0].get_shape()[0])) 160 for p in xrange(1, np): 161 dim_0_size += tensor_shape.Dimension(tensor_shape.dimension_value( 162 params[p].get_shape()[0])) 163 if dim_0_size.value: 164 num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype) 165 else: 166 dim_0_sizes = [] 167 for p in xrange(np): 168 param_p_dim = tensor_shape.dimension_value(params[p].get_shape()[0]) 169 if param_p_dim is not None: 170 dim_0_sizes.append(param_p_dim) 171 else: 172 with ops.colocate_with(params[p]): 173 dim_0_sizes.append(array_ops.shape(params[p])[0]) 174 num_total_ids = math_ops.reduce_sum( 175 math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) 176 ids_per_partition = num_total_ids // np 177 extras = num_total_ids % np 178 179 p_assignments = math_ops.maximum( 180 flat_ids // (ids_per_partition + 1), 181 (flat_ids - extras) // ids_per_partition) 182 183 # Emulate a conditional using a boolean indicator tensor 184 new_ids = array_ops.where(p_assignments < extras, 185 flat_ids % (ids_per_partition + 1), 186 (flat_ids - extras) % ids_per_partition) 187 else: 188 raise ValueError("Unrecognized partition strategy: " + 189 partition_strategy) 190 191 # Cast partition assignments to int32 for use in dynamic_partition. 192 # There really should not be more than 2^32 partitions. 193 p_assignments = math_ops.cast(p_assignments, dtypes.int32) 194 # Partition list of ids based on assignments into np separate lists 195 gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np) 196 # Similarly, partition the original indices. 197 pindices = data_flow_ops.dynamic_partition(original_indices, 198 p_assignments, np) 199 # Do np separate lookups, finding embeddings for plist[p] in params[p] 200 partitioned_result = [] 201 for p in xrange(np): 202 pids = gather_ids[p] 203 with ops.colocate_with(params[p]): 204 result = array_ops.gather(params[p], pids) 205 if transform_fn: 206 # If transform_fn is provided, the clip_by_norm precedes 207 # the transform and hence must be co-located. See below 208 # for the counterpart if transform_fn is not proveded. 209 result = transform_fn(_clip(result, pids, max_norm)) 210 partitioned_result.append(result) 211 # Stitch these back together 212 ret = data_flow_ops.parallel_dynamic_stitch( 213 pindices, partitioned_result, name=name) 214 215 # Determine the static element shape. 216 if transform_fn is None: 217 element_shape_s = params[0].get_shape()[1:] 218 for p in params[1:]: 219 element_shape_s = element_shape_s.merge_with(p.get_shape()[1:]) 220 else: 221 element_shape_s = ret.get_shape()[1:] 222 223 # Compute the dynamic element shape. 224 if element_shape_s.is_fully_defined(): 225 element_shape_d = element_shape_s 226 elif transform_fn is None: 227 # It's important that we compute params[0].shape on the right device 228 # to avoid data motion. 229 with ops.colocate_with(params[0]): 230 params_shape = array_ops.shape(params[0]) 231 element_shape_d = params_shape[1:] 232 else: 233 element_shape_d = array_ops.shape(ret)[1:] 234 235 # Reshape to reverse the flattening of ids. 236 ret = array_ops.reshape(ret, 237 array_ops.concat( 238 [array_ops.shape(ids), element_shape_d], 0)) 239 240 # Normally the reshape is sufficient, but setting shape explicitly 241 # teaches shape inference that params[1:].get_shape() matters 242 # (in the case that transform_fn is None). 243 ret.set_shape(ids.get_shape().concatenate(element_shape_s)) 244 if not transform_fn: 245 # If transform_fn was provided, the clip_by_norm was done above. 246 ret = _clip(ret, ids, max_norm) 247 return ret 248 249 250@tf_export(v1=["nn.embedding_lookup"]) 251def embedding_lookup( 252 params, 253 ids, 254 partition_strategy="mod", 255 name=None, 256 validate_indices=True, # pylint: disable=unused-argument 257 max_norm=None): 258 """Looks up `ids` in a list of embedding tensors. 259 260 This function is used to perform parallel lookups on the list of 261 tensors in `params`. It is a generalization of 262 `tf.gather`, where `params` is 263 interpreted as a partitioning of a large embedding tensor. `params` may be 264 a `PartitionedVariable` as returned by using `tf.get_variable()` with a 265 partitioner. 266 267 If `len(params) > 1`, each element `id` of `ids` is partitioned between 268 the elements of `params` according to the `partition_strategy`. 269 In all strategies, if the id space does not evenly divide the number of 270 partitions, each of the first `(max_id + 1) % len(params)` partitions will 271 be assigned one more id. 272 273 If `partition_strategy` is `"mod"`, we assign each id to partition 274 `p = id % len(params)`. For instance, 275 13 ids are split across 5 partitions as: 276 `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]` 277 278 If `partition_strategy` is `"div"`, we assign ids to partitions in a 279 contiguous manner. In this case, 13 ids are split across 5 partitions as: 280 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]` 281 282 The results of the lookup are concatenated into a dense 283 tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. 284 285 Args: 286 params: A single tensor representing the complete embedding tensor, 287 or a list of P tensors all of same shape except for the first dimension, 288 representing sharded embedding tensors. Alternatively, a 289 `PartitionedVariable`, created by partitioning along dimension 0. Each 290 element must be appropriately sized for the given `partition_strategy`. 291 ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked 292 up in `params`. 293 partition_strategy: A string specifying the partitioning strategy, relevant 294 if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default 295 is `"mod"`. 296 name: A name for the operation (optional). 297 validate_indices: DEPRECATED. If this operation is assigned to CPU, values 298 in `indices` are always validated to be within range. If assigned to GPU, 299 out-of-bound indices result in safe but unspecified behavior, which may 300 include raising an error. 301 max_norm: If not `None`, each embedding is clipped if its l2-norm is 302 larger than this value. 303 304 Returns: 305 A `Tensor` with the same type as the tensors in `params`. 306 307 Raises: 308 ValueError: If `params` is empty. 309 """ 310 return _embedding_lookup_and_transform( 311 params=params, 312 ids=ids, 313 partition_strategy=partition_strategy, 314 name=name, 315 max_norm=max_norm, 316 transform_fn=None) 317 318 319@tf_export("nn.embedding_lookup", v1=[]) 320def embedding_lookup_v2( 321 params, 322 ids, 323 max_norm=None, 324 name=None): 325 """Looks up `ids` in a list of embedding tensors. 326 327 This function is used to perform parallel lookups on the list of 328 tensors in `params`. It is a generalization of 329 `tf.gather`, where `params` is 330 interpreted as a partitioning of a large embedding tensor. `params` may be 331 a `PartitionedVariable` as returned by using `tf.get_variable()` with a 332 partitioner. 333 334 If `len(params) > 1`, each element `id` of `ids` is partitioned between 335 the elements of `params` according to the `partition_strategy`. 336 In all strategies, if the id space does not evenly divide the number of 337 partitions, each of the first `(max_id + 1) % len(params)` partitions will 338 be assigned one more id. 339 340 The `partition_strategy` is always `"div"` currently. This means that we 341 assign ids to partitions in a contiguous manner. For instance, 13 ids are 342 split across 5 partitions as: 343 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]` 344 345 The results of the lookup are concatenated into a dense 346 tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. 347 348 Args: 349 params: A single tensor representing the complete embedding tensor, 350 or a list of P tensors all of same shape except for the first dimension, 351 representing sharded embedding tensors. Alternatively, a 352 `PartitionedVariable`, created by partitioning along dimension 0. Each 353 element must be appropriately sized for the 'div' `partition_strategy`. 354 ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked 355 up in `params`. 356 max_norm: If not `None`, each embedding is clipped if its l2-norm is 357 larger than this value. 358 name: A name for the operation (optional). 359 360 Returns: 361 A `Tensor` with the same type as the tensors in `params`. 362 363 Raises: 364 ValueError: If `params` is empty. 365 """ 366 return embedding_lookup(params, ids, "div", name, 367 max_norm=max_norm) 368 369 370@tf_export(v1=["nn.embedding_lookup_sparse"]) 371def embedding_lookup_sparse(params, 372 sp_ids, 373 sp_weights, 374 partition_strategy="mod", 375 name=None, 376 combiner=None, 377 max_norm=None): 378 """Computes embeddings for the given ids and weights. 379 380 This op assumes that there is at least one id for each row in the dense tensor 381 represented by sp_ids (i.e. there are no rows with empty features), and that 382 all the indices of sp_ids are in canonical row-major order. 383 384 It also assumes that all id values lie in the range [0, p0), where p0 385 is the sum of the size of params along dimension 0. 386 387 Args: 388 params: A single tensor representing the complete embedding tensor, 389 or a list of P tensors all of same shape except for the first dimension, 390 representing sharded embedding tensors. Alternatively, a 391 `PartitionedVariable`, created by partitioning along dimension 0. Each 392 element must be appropriately sized for the given `partition_strategy`. 393 sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size 394 and M is arbitrary. 395 sp_weights: either a `SparseTensor` of float / double weights, or `None` to 396 indicate all weights should be taken to be 1. If specified, `sp_weights` 397 must have exactly the same shape and indices as `sp_ids`. 398 partition_strategy: A string specifying the partitioning strategy, relevant 399 if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default 400 is `"mod"`. See `tf.nn.embedding_lookup` for more details. 401 name: Optional name for the op. 402 combiner: A string specifying the reduction op. Currently "mean", "sqrtn" 403 and "sum" are supported. 404 "sum" computes the weighted sum of the embedding results for each row. 405 "mean" is the weighted sum divided by the total weight. 406 "sqrtn" is the weighted sum divided by the square root of the sum of the 407 squares of the weights. 408 max_norm: If not `None`, each embedding is clipped if its l2-norm is 409 larger than this value, before combining. 410 411 Returns: 412 A dense tensor representing the combined embeddings for the 413 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op 414 looks up the embeddings for all ids in that row, multiplies them by the 415 corresponding weight, and combines these embeddings as specified. 416 417 In other words, if 418 419 `shape(combined params) = [p0, p1, ..., pm]` 420 421 and 422 423 `shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]` 424 425 then 426 427 `shape(output) = [d0, d1, ..., dn-1, p1, ..., pm]`. 428 429 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 430 431 ```python 432 [0, 0]: id 1, weight 2.0 433 [0, 1]: id 3, weight 0.5 434 [1, 0]: id 0, weight 1.0 435 [2, 3]: id 1, weight 3.0 436 ``` 437 438 with `combiner`="mean", then the output will be a 3x20 matrix where 439 440 ```python 441 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 442 output[1, :] = (params[0, :] * 1.0) / 1.0 443 output[2, :] = (params[1, :] * 3.0) / 3.0 444 ``` 445 446 Raises: 447 TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is 448 neither `None` nor `SparseTensor`. 449 ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}. 450 """ 451 if combiner is None: 452 logging.warn("The default value of combiner will change from \"mean\" " 453 "to \"sqrtn\" after 2016/11/01.") 454 combiner = "mean" 455 if combiner not in ("mean", "sqrtn", "sum"): 456 raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'") 457 if isinstance(params, variables.PartitionedVariable): 458 params = list(params) # Iterate to get the underlying Variables. 459 if not isinstance(params, list): 460 params = [params] 461 if not isinstance(sp_ids, sparse_tensor.SparseTensor): 462 raise TypeError("sp_ids must be SparseTensor") 463 ignore_weights = sp_weights is None 464 if not ignore_weights: 465 if not isinstance(sp_weights, sparse_tensor.SparseTensor): 466 raise TypeError("sp_weights must be either None or SparseTensor") 467 sp_ids.values.get_shape().assert_is_compatible_with( 468 sp_weights.values.get_shape()) 469 sp_ids.indices.get_shape().assert_is_compatible_with( 470 sp_weights.indices.get_shape()) 471 sp_ids.dense_shape.get_shape().assert_is_compatible_with( 472 sp_weights.dense_shape.get_shape()) 473 # TODO(yleon): Add enhanced node assertions to verify that sp_ids and 474 # sp_weights have equal indices and shapes. 475 476 with ops.name_scope(name, "embedding_lookup_sparse", 477 params + [sp_ids]) as name: 478 segment_ids = sp_ids.indices[:, 0] 479 if segment_ids.dtype != dtypes.int32: 480 segment_ids = math_ops.cast(segment_ids, dtypes.int32) 481 482 ids = sp_ids.values 483 ids, idx = array_ops.unique(ids) 484 485 embeddings = embedding_lookup( 486 params, ids, partition_strategy=partition_strategy, max_norm=max_norm) 487 if embeddings.dtype in (dtypes.float16, dtypes.bfloat16): 488 embeddings = math_ops.cast(embeddings, dtypes.float32) 489 if not ignore_weights: 490 weights = sp_weights.values 491 if weights.dtype != embeddings.dtype: 492 weights = math_ops.cast(weights, embeddings.dtype) 493 494 embeddings = array_ops.gather(embeddings, idx) 495 496 # Reshape weights to allow broadcast 497 ones = array_ops.fill( 498 array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1) 499 bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 500 0) 501 502 orig_weights_shape = weights.get_shape() 503 weights = array_ops.reshape(weights, bcast_weights_shape) 504 505 # Set the weight shape, since after reshaping to bcast_weights_shape, 506 # the shape becomes None. 507 if embeddings.get_shape().ndims is not None: 508 weights.set_shape( 509 orig_weights_shape.concatenate( 510 [1 for _ in range(embeddings.get_shape().ndims - 1)])) 511 512 embeddings *= weights 513 514 if combiner == "sum": 515 embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name) 516 elif combiner == "mean": 517 embeddings = math_ops.segment_sum(embeddings, segment_ids) 518 weight_sum = math_ops.segment_sum(weights, segment_ids) 519 embeddings = math_ops.div(embeddings, weight_sum, name=name) 520 elif combiner == "sqrtn": 521 embeddings = math_ops.segment_sum(embeddings, segment_ids) 522 weights_squared = math_ops.pow(weights, 2) 523 weight_sum = math_ops.segment_sum(weights_squared, segment_ids) 524 weight_sum_sqrt = math_ops.sqrt(weight_sum) 525 embeddings = math_ops.div(embeddings, weight_sum_sqrt, name=name) 526 else: 527 assert False, "Unrecognized combiner" 528 else: 529 assert idx is not None 530 if combiner == "sum": 531 embeddings = math_ops.sparse_segment_sum( 532 embeddings, idx, segment_ids, name=name) 533 elif combiner == "mean": 534 embeddings = math_ops.sparse_segment_mean( 535 embeddings, idx, segment_ids, name=name) 536 elif combiner == "sqrtn": 537 embeddings = math_ops.sparse_segment_sqrt_n( 538 embeddings, idx, segment_ids, name=name) 539 else: 540 assert False, "Unrecognized combiner" 541 542 return embeddings 543 544 545@tf_export("nn.embedding_lookup_sparse", v1=[]) 546def embedding_lookup_sparse_v2(params, 547 sp_ids, 548 sp_weights, 549 combiner=None, 550 max_norm=None, 551 name=None): 552 """Computes embeddings for the given ids and weights. 553 554 This op assumes that there is at least one id for each row in the dense tensor 555 represented by sp_ids (i.e. there are no rows with empty features), and that 556 all the indices of sp_ids are in canonical row-major order. 557 558 It also assumes that all id values lie in the range [0, p0), where p0 559 is the sum of the size of params along dimension 0. 560 561 Args: 562 params: A single tensor representing the complete embedding tensor, 563 or a list of P tensors all of same shape except for the first dimension, 564 representing sharded embedding tensors. Alternatively, a 565 `PartitionedVariable`, created by partitioning along dimension 0. Each 566 element must be appropriately sized for ``"div"`` `partition_strategy`. 567 sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size 568 and M is arbitrary. 569 sp_weights: either a `SparseTensor` of float / double weights, or `None` to 570 indicate all weights should be taken to be 1. If specified, `sp_weights` 571 must have exactly the same shape and indices as `sp_ids`. 572 combiner: A string specifying the reduction op. Currently "mean", "sqrtn" 573 and "sum" are supported. 574 "sum" computes the weighted sum of the embedding results for each row. 575 "mean" is the weighted sum divided by the total weight. 576 "sqrtn" is the weighted sum divided by the square root of the sum of the 577 squares of the weights. 578 max_norm: If not `None`, each embedding is clipped if its l2-norm is 579 larger than this value, before combining. 580 name: Optional name for the op. 581 582 Returns: 583 A dense tensor representing the combined embeddings for the 584 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op 585 looks up the embeddings for all ids in that row, multiplies them by the 586 corresponding weight, and combines these embeddings as specified. 587 588 In other words, if 589 590 `shape(combined params) = [p0, p1, ..., pm]` 591 592 and 593 594 `shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]` 595 596 then 597 598 `shape(output) = [d0, d1, ..., dn-1, p1, ..., pm]`. 599 600 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 601 602 ```python 603 [0, 0]: id 1, weight 2.0 604 [0, 1]: id 3, weight 0.5 605 [1, 0]: id 0, weight 1.0 606 [2, 3]: id 1, weight 3.0 607 ``` 608 609 with `combiner`="mean", then the output will be a 3x20 matrix where 610 611 ```python 612 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 613 output[1, :] = (params[0, :] * 1.0) / 1.0 614 output[2, :] = (params[1, :] * 3.0) / 3.0 615 ``` 616 617 Raises: 618 TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is 619 neither `None` nor `SparseTensor`. 620 ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}. 621 """ 622 return embedding_lookup_sparse( 623 params, sp_ids, sp_weights, "div", name, combiner, max_norm) 624 625 626@tf_export("nn.safe_embedding_lookup_sparse", v1=[]) 627def safe_embedding_lookup_sparse_v2(embedding_weights, 628 sparse_ids, 629 sparse_weights=None, 630 combiner="mean", 631 default_id=None, 632 max_norm=None, 633 name=None): 634 """Lookup embedding results, accounting for invalid IDs and empty features. 635 636 The partitioned embedding in `embedding_weights` must all be the same shape 637 except for the first dimension. The first dimension is allowed to vary as the 638 vocabulary size is not necessarily a multiple of `P`. `embedding_weights` 639 may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a 640 partitioner. 641 642 Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs 643 with non-positive weight. For an entry with no features, the embedding vector 644 for `default_id` is returned, or the 0-vector if `default_id` is not supplied. 645 646 The ids and weights may be multi-dimensional. Embeddings are always aggregated 647 along the last dimension. 648 649 Note: when doing embedding lookup on `embedding_weights`, "div" partition 650 strategy will be used. Support for other partition strategy will be added 651 later. 652 653 Args: 654 embedding_weights: A list of `P` float `Tensor`s or values representing 655 partitioned embedding `Tensor`s. Alternatively, a `PartitionedVariable` 656 created by partitioning along dimension 0. The total unpartitioned shape 657 should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the vocab size 658 and `e_1, ..., e_m` are the embedding dimensions. 659 sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the 660 ids. `d_0` is typically batch size. 661 sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing 662 float weights corresponding to `sparse_ids`, or `None` if all weights are 663 be assumed to be 1.0. 664 combiner: A string specifying how to combine embedding results for each 665 entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the 666 default. 667 default_id: The id to use for an entry with no features. 668 max_norm: If not `None`, all embeddings are l2-normalized to max_norm before 669 combining. 670 name: A name for this operation (optional). 671 672 Returns: 673 Dense `Tensor` of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`. 674 675 Raises: 676 ValueError: if `embedding_weights` is empty. 677 """ 678 return safe_embedding_lookup_sparse( 679 embedding_weights, 680 sparse_ids, 681 sparse_weights=sparse_weights, 682 combiner=combiner, 683 default_id=default_id, 684 name=name, 685 partition_strategy="div", 686 max_norm=max_norm) 687 688 689@tf_export(v1=["nn.safe_embedding_lookup_sparse"]) 690def safe_embedding_lookup_sparse(embedding_weights, 691 sparse_ids, 692 sparse_weights=None, 693 combiner='mean', 694 default_id=None, 695 name=None, 696 partition_strategy='div', 697 max_norm=None): 698 """Lookup embedding results, accounting for invalid IDs and empty features. 699 700 The partitioned embedding in `embedding_weights` must all be the same shape 701 except for the first dimension. The first dimension is allowed to vary as the 702 vocabulary size is not necessarily a multiple of `P`. `embedding_weights` 703 may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a 704 partitioner. 705 706 Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs 707 with non-positive weight. For an entry with no features, the embedding vector 708 for `default_id` is returned, or the 0-vector if `default_id` is not supplied. 709 710 The ids and weights may be multi-dimensional. Embeddings are always aggregated 711 along the last dimension. 712 713 Args: 714 embedding_weights: A list of `P` float `Tensor`s or values representing 715 partitioned embedding `Tensor`s. Alternatively, a `PartitionedVariable` 716 created by partitioning along dimension 0. The total unpartitioned 717 shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the 718 vocab size and `e_1, ..., e_m` are the embedding dimensions. 719 sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the 720 ids. `d_0` is typically batch size. 721 sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing 722 float weights corresponding to `sparse_ids`, or `None` if all weights 723 are be assumed to be 1.0. 724 combiner: A string specifying how to combine embedding results for each 725 entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" 726 the default. 727 default_id: The id to use for an entry with no features. 728 name: A name for this operation (optional). 729 partition_strategy: A string specifying the partitioning strategy. 730 Currently `"div"` and `"mod"` are supported. Default is `"div"`. 731 max_norm: If not `None`, all embeddings are l2-normalized to max_norm before 732 combining. 733 734 735 Returns: 736 Dense `Tensor` of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`. 737 738 Raises: 739 ValueError: if `embedding_weights` is empty. 740 """ 741 if embedding_weights is None: 742 raise ValueError('Missing embedding_weights %s.' % embedding_weights) 743 if isinstance(embedding_weights, variables.PartitionedVariable): 744 embedding_weights = list(embedding_weights) # get underlying Variables. 745 if not isinstance(embedding_weights, list): 746 embedding_weights = [embedding_weights] 747 if len(embedding_weights) < 1: 748 raise ValueError('Missing embedding_weights %s.' % embedding_weights) 749 750 dtype = sparse_weights.dtype if sparse_weights is not None else None 751 embedding_weights = [ 752 w if (isinstance(w, resource_variable_ops.ResourceVariable) 753 and dtype in (None, w.dtype)) 754 else ops.convert_to_tensor(w, dtype=dtype) 755 for w in embedding_weights 756 ] 757 758 with ops.name_scope(name, 'embedding_lookup', 759 embedding_weights + [sparse_ids, 760 sparse_weights]) as scope: 761 # Reshape higher-rank sparse ids and weights to linear segment ids. 762 original_shape = sparse_ids.dense_shape 763 original_rank_dim = tensor_shape.dimension_value( 764 sparse_ids.dense_shape.get_shape()[0]) 765 original_rank = ( 766 array_ops.size(original_shape) 767 if original_rank_dim is None 768 else original_rank_dim) 769 sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [ 770 math_ops.reduce_prod( 771 array_ops.slice(original_shape, [0], [original_rank - 1])), 772 array_ops.gather(original_shape, original_rank - 1)]) 773 if sparse_weights is not None: 774 sparse_weights = sparse_tensor.SparseTensor( 775 sparse_ids.indices, 776 sparse_weights.values, sparse_ids.dense_shape) 777 778 # Prune invalid ids and weights. 779 sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights) 780 if combiner != 'sum': 781 sparse_ids, sparse_weights = _prune_invalid_weights( 782 sparse_ids, sparse_weights) 783 784 # Fill in dummy values for empty features, if necessary. 785 sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids, 786 default_id or 787 0) 788 if sparse_weights is not None: 789 sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0) 790 791 result = embedding_lookup_sparse( 792 embedding_weights, 793 sparse_ids, 794 sparse_weights, 795 combiner=combiner, 796 partition_strategy=partition_strategy, 797 name=None if default_id is None else scope, 798 max_norm=max_norm) 799 800 if default_id is None: 801 # Broadcast is_row_empty to the same shape as embedding_lookup_result, 802 # for use in Select. 803 is_row_empty = array_ops.tile( 804 array_ops.reshape(is_row_empty, [-1, 1]), 805 array_ops.stack([1, array_ops.shape(result)[1]])) 806 807 result = array_ops.where(is_row_empty, 808 array_ops.zeros_like(result), 809 result, 810 name=scope) 811 812 # Reshape back from linear ids back into higher-dimensional dense result. 813 final_result = array_ops.reshape( 814 result, 815 array_ops.concat([ 816 array_ops.slice( 817 math_ops.cast(original_shape, dtypes.int32), [0], 818 [original_rank - 1]), 819 array_ops.slice(array_ops.shape(result), [1], [-1]) 820 ], 0)) 821 final_result.set_shape(tensor_shape.unknown_shape( 822 (tensor_shape.Dimension(original_rank_dim) - 1).value).concatenate( 823 result.get_shape()[1:])) 824 return final_result 825 826 827def _prune_invalid_ids(sparse_ids, sparse_weights): 828 """Prune invalid IDs (< 0) from the input ids and weights.""" 829 is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) 830 if sparse_weights is not None: 831 is_id_valid = math_ops.logical_and( 832 is_id_valid, 833 array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool)) 834 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) 835 if sparse_weights is not None: 836 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) 837 return sparse_ids, sparse_weights 838 839 840def _prune_invalid_weights(sparse_ids, sparse_weights): 841 """Prune invalid weights (< 0) from the input ids and weights.""" 842 if sparse_weights is not None: 843 is_weights_valid = math_ops.greater(sparse_weights.values, 0) 844 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid) 845 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid) 846 return sparse_ids, sparse_weights 847