• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations for embeddings."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from six.moves import xrange  # pylint: disable=redefined-builtin
21
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import clip_ops
29# Imports gradient definitions.
30from tensorflow.python.ops import data_flow_grad  # pylint: disable=unused-import
31from tensorflow.python.ops import data_flow_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import resource_variable_ops
34from tensorflow.python.ops import sparse_ops
35from tensorflow.python.ops import variables
36from tensorflow.python.ops.ragged import ragged_functional_ops
37from tensorflow.python.ops.ragged import ragged_tensor
38from tensorflow.python.util import dispatch
39from tensorflow.python.util.tf_export import tf_export
40
41
42def _clip(params, ids, max_norm):
43  """Helper function for _embedding_lookup_and_transform.
44
45  This function optionally clips embeddings to an l2-norm of max_norm.
46
47  Args:
48    params: A `Tensor` of embeddings retrieved by `gather`.
49    ids: The `ids` argument that was passed to `gather`.
50    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
51      than this value.
52
53  Returns:
54    A `Tensor` with the same type as `params`.
55  """
56
57  def _rank(x):
58    """Helper function to retrieve the rank of a tensor.
59
60    Args:
61      x: Something convertible to `Tensor`.
62
63    Returns:
64      Either a pair `(rank, True)` where `rank` is an integer or a pair
65      `(rank, False)` where `rank` is an integer `Tensor`. In either case,
66      `rank` is the rank of `x`.
67    """
68    rank = ops.convert_to_tensor(x).get_shape().ndims
69    if rank:
70      return rank, True
71    else:
72      return array_ops.rank(x), False
73
74  if max_norm is None:
75    return params
76  ids_rank, ids_static = _rank(ids)
77  params_rank, params_static = _rank(params)
78  return clip_ops.clip_by_norm(
79      params,
80      max_norm,
81      axes=(list(range(ids_rank, params_rank)) if ids_static and params_static
82            else math_ops.range(ids_rank, params_rank)))
83
84
85def _colocate_with(param):
86  if ops.inside_function() and hasattr(param, "handle"):
87    # The `ops.colocate_with` will hard-code a device string if `param.device`
88    # is known, which will then break serving. We capture it here so that it
89    # produces a tensor without a device.
90    return ops.colocate_with(ops.get_default_graph().capture(param.handle))
91  else:
92    return ops.colocate_with(param)
93
94
95def _embedding_lookup_and_transform(params,
96                                    ids,
97                                    partition_strategy="mod",
98                                    name=None,
99                                    max_norm=None,
100                                    transform_fn=None):
101  """Helper function for embedding_lookup and _compute_sampled_logits.
102
103  This function is a generalization of embedding_lookup that optionally
104  applies a caller-specified transformation to each embedding. This is
105  done through the `transform_fn` argument. If provided, the function is
106  applied to each partitioned tensor of retrieved embeddings, colocated
107  with the embeddings. This function will be called with a single `Tensor`
108  argument of the same type as the `params` tensor and should return a
109  `Tensor`. The shape of the argument will be the same as `params` except
110  for the size of the first dimension. The first dimension of the result's
111  shape must be the same size as the argument's.
112
113  Args:
114    params: See embedding_lookup.
115    ids: See embedding_lookup.
116    partition_strategy: See embedding_lookup.
117    name: See embedding_lookup.
118    max_norm: See embedding_lookup.
119    transform_fn: An optional function to apply to each retrieved embedding. If
120      max_norm is provided, transform_fn is applied to the norm-limited
121      embeddings.
122
123  Returns:
124    See embedding_lookup for details.
125  Raises:
126    ValueError: If `params` is empty.
127  """
128  if params is None:
129    raise ValueError("params must be specified")
130  if isinstance(params, (list, tuple)) and not params:
131    raise ValueError("Need at least one param")
132  if isinstance(params, variables.PartitionedVariable):
133    params = list(params)  # Iterate to get the underlying Variables.
134  if not isinstance(params, list):
135    params = [params]
136
137  with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
138    np = len(params)  # Number of partitions
139    # Preserve the resource variable status to avoid accidental dense reads.
140    if not any(
141        isinstance(p, resource_variable_ops.BaseResourceVariable)
142        for p in params):
143      params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
144    ids = ops.convert_to_tensor(ids, name="ids")
145    if np == 1 and (not transform_fn or ids.get_shape().ndims == 1):
146      with _colocate_with(params[0]):
147        result = _clip(
148            array_ops.gather(params[0], ids, name=name), ids, max_norm)
149        if transform_fn:
150          result = transform_fn(result)
151      # Make sure the final result does not have colocation constraints on the
152      # params. Similar to the case np > 1 where parallel_dynamic_stitch is
153      # outside the scope of all with _colocate_with(params[p]).
154      return array_ops.identity(result)
155    else:
156      # Flatten the ids. There are two cases where we need to do this.
157      # - There is more than one params tensor.
158      # - There is a transform_fn and ids is not statically known to be 1-D.
159      #   We must flatten in this case because transform_fn expects a flat
160      #   tensor of embeddings.
161      flat_ids = array_ops.reshape(ids, [-1])
162      original_indices = math_ops.range(array_ops.size(flat_ids))
163
164      # Create p_assignments and set new_ids depending on the strategy.
165      if partition_strategy == "mod":
166        p_assignments = flat_ids % np
167        new_ids = flat_ids // np
168      elif partition_strategy == "div":
169        # Compute num_total_ids as the sum of dim-0 of params, then assign to
170        # partitions based on a constant number of ids per partition. Optimize
171        # if we already know the full shape statically.
172        dim_0_size = tensor_shape.Dimension(
173            tensor_shape.dimension_value(params[0].get_shape()[0]))
174        for p in xrange(1, np):
175          dim_0_size += tensor_shape.Dimension(
176              tensor_shape.dimension_value(params[p].get_shape()[0]))
177        if dim_0_size.value:
178          num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
179        else:
180          dim_0_sizes = []
181          for p in xrange(np):
182            param_p_dim = tensor_shape.dimension_value(params[p].get_shape()[0])
183            if param_p_dim is not None:
184              dim_0_sizes.append(param_p_dim)
185            else:
186              with _colocate_with(params[p]):
187                dim_0_sizes.append(array_ops.shape(params[p])[0])
188          num_total_ids = math_ops.reduce_sum(
189              math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
190        ids_per_partition = num_total_ids // np
191        extras = num_total_ids % np
192
193        p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1),
194                                         (flat_ids - extras) //
195                                         ids_per_partition)
196
197        # Emulate a conditional using a boolean indicator tensor
198        new_ids = array_ops.where(p_assignments < extras,
199                                  flat_ids % (ids_per_partition + 1),
200                                  (flat_ids - extras) % ids_per_partition)
201      else:
202        raise ValueError("Unrecognized partition strategy: " +
203                         partition_strategy)
204
205      # Cast partition assignments to int32 for use in dynamic_partition.
206      # There really should not be more than 2^32 partitions.
207      p_assignments = math_ops.cast(p_assignments, dtypes.int32)
208      # Partition list of ids based on assignments into np separate lists
209      gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
210      # Similarly, partition the original indices.
211      pindices = data_flow_ops.dynamic_partition(original_indices,
212                                                 p_assignments, np)
213      # Do np separate lookups, finding embeddings for plist[p] in params[p]
214      partitioned_result = []
215      for p in xrange(np):
216        pids = gather_ids[p]
217        with ops.device_v2(None):
218          with _colocate_with(params[p]):
219            result = array_ops.gather(params[p], pids)
220            if transform_fn:
221              # If transform_fn is provided, the clip_by_norm precedes
222              # the transform and hence must be co-located. See below
223              # for the counterpart if transform_fn is not provided.
224              result = transform_fn(_clip(result, pids, max_norm))
225        partitioned_result.append(result)
226      # Stitch these back together
227      ret = data_flow_ops.parallel_dynamic_stitch(
228          pindices, partitioned_result, name=name)
229
230      # Determine the static element shape.
231      if transform_fn is None:
232        element_shape_s = params[0].get_shape()[1:]
233        for p in params[1:]:
234          element_shape_s = element_shape_s.merge_with(p.get_shape()[1:])
235      else:
236        element_shape_s = ret.get_shape()[1:]
237
238      # Compute the dynamic element shape.
239      if element_shape_s.is_fully_defined():
240        element_shape_d = element_shape_s
241      elif transform_fn is None:
242        # It's important that we compute params[0].shape on the right device
243        # to avoid data motion.
244        with _colocate_with(params[0]):
245          params_shape = array_ops.shape(params[0])
246        element_shape_d = params_shape[1:]
247      else:
248        element_shape_d = array_ops.shape(ret)[1:]
249
250      # Reshape to reverse the flattening of ids.
251      ret = array_ops.reshape(
252          ret, array_ops.concat([array_ops.shape(ids), element_shape_d], 0))
253
254      # Normally the reshape is sufficient, but setting shape explicitly
255      # teaches shape inference that params[1:].get_shape() matters
256      # (in the case that transform_fn is None).
257      ret.set_shape(ids.get_shape().concatenate(element_shape_s))
258      if not transform_fn:
259        # If transform_fn was provided, the clip_by_norm was done above.
260        ret = _clip(ret, ids, max_norm)
261      return ret
262
263
264@tf_export(v1=["nn.embedding_lookup"])
265@dispatch.add_dispatch_support
266def embedding_lookup(
267    params,
268    ids,
269    partition_strategy="mod",
270    name=None,
271    validate_indices=True,  # pylint: disable=unused-argument
272    max_norm=None):
273  """Looks up embeddings for the given `ids` from a list of tensors.
274
275  This function is used to perform parallel lookups on the list of tensors in
276  `params`.  It is a generalization of `tf.gather`, where `params` is
277  interpreted as a partitioning of a large embedding tensor.  `params` may be
278  a `PartitionedVariable` as returned by using `tf.compat.v1.get_variable()`
279  with a partitioner.
280
281  If `len(params) > 1`, each element `id` of `ids` is partitioned between
282  the elements of `params` according to the `partition_strategy`.
283  In all strategies, if the id space does not evenly divide the number of
284  partitions, each of the first `(max_id + 1) % len(params)` partitions will
285  be assigned one more id.
286
287  If `partition_strategy` is `"mod"`, we assign each id to partition
288  `p = id % len(params)`. For instance,
289  13 ids are split across 5 partitions as:
290  `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
291
292  If `partition_strategy` is `"div"`, we assign ids to partitions in a
293  contiguous manner. In this case, 13 ids are split across 5 partitions as:
294  `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`
295
296  If the input ids are ragged tensors, partition variables are not supported and
297  the partition strategy and the max_norm are ignored.
298  The results of the lookup are concatenated into a dense
299  tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
300
301  Args:
302    params: A single tensor representing the complete embedding tensor, or a
303      list of P tensors all of same shape except for the first dimension,
304      representing sharded embedding tensors.  Alternatively, a
305      `PartitionedVariable`, created by partitioning along dimension 0. Each
306      element must be appropriately sized for the given `partition_strategy`.
307    ids: A `Tensor` or a 'RaggedTensor' with type `int32` or `int64` containing
308      the ids to be looked up in `params`.
309    partition_strategy: A string specifying the partitioning strategy, relevant
310      if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
311      is `"mod"`.
312    name: A name for the operation (optional).
313    validate_indices: DEPRECATED. If this operation is assigned to CPU, values
314      in `indices` are always validated to be within range.  If assigned to GPU,
315      out-of-bound indices result in safe but unspecified behavior, which may
316      include raising an error.
317    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
318      than this value.
319
320  Returns:
321    A `Tensor` or a 'RaggedTensor', depending on the input, with the same type
322    as the tensors in `params`.
323
324  Raises:
325    ValueError: If `params` is empty.
326  """
327  if isinstance(ids, ragged_tensor.RaggedTensor):
328    return embedding_lookup_ragged(params, ids,
329                                   partition_strategy=partition_strategy,
330                                   max_norm=max_norm,
331                                   name=name)
332
333  return _embedding_lookup_and_transform(
334      params=params,
335      ids=ids,
336      partition_strategy=partition_strategy,
337      name=name,
338      max_norm=max_norm,
339      transform_fn=None)
340
341
342@tf_export("nn.embedding_lookup", v1=[])
343@dispatch.add_dispatch_support
344def embedding_lookup_v2(params, ids, max_norm=None, name=None):
345  """Looks up embeddings for the given `ids` from a list of tensors.
346
347  This function is used to perform parallel lookups on the list of tensors in
348  `params`.  It is a generalization of `tf.gather`, where `params` is
349  interpreted as a partitioning of a large embedding tensor.
350
351  If `len(params) > 1`, each element `id` of `ids` is partitioned between the
352  elements of `params` according to the "div" partition strategy, which means we
353  assign ids to partitions in a contiguous manner. For instance, 13 ids are
354  split across 5 partitions as:
355  `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
356
357  If the id space does not evenly divide the number of partitions, each of the
358  first `(max_id + 1) % len(params)` partitions will be assigned one more id.
359
360  The results of the lookup are concatenated into a dense
361  tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
362
363  Args:
364    params: A single tensor representing the complete embedding tensor, or a
365      list of tensors all of same shape except for the first dimension,
366      representing sharded embedding tensors following "div" partition strategy.
367    ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked
368      up in `params`.
369    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
370      than this value.
371    name: A name for the operation (optional).
372
373  Returns:
374    A `Tensor` with the same type as the tensors in `params`.
375
376    For instance, if `params` is a 5x2 matrix:
377
378    ```python
379    [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
380    ```
381
382    or a list of matrices:
383
384    ```python
385    params[0]: [[1, 2], [3, 4]]
386    params[1]: [[5, 6], [7, 8]]
387    params[2]: [[9, 10]]
388    ```
389
390    and `ids` is:
391
392    ```python
393    [0, 3, 4]
394    ```
395
396    The output will be a 3x2 matrix:
397
398    ```python
399    [[1, 2], [7, 8], [9, 10]]
400    ```
401
402  Raises:
403    ValueError: If `params` is empty.
404  """
405  return embedding_lookup(params, ids, "div", name, max_norm=max_norm)
406
407
408@tf_export(v1=["nn.embedding_lookup_sparse"])
409@dispatch.add_dispatch_support
410def embedding_lookup_sparse(params,
411                            sp_ids,
412                            sp_weights,
413                            partition_strategy="mod",
414                            name=None,
415                            combiner=None,
416                            max_norm=None):
417  """Looks up embeddings for the given ids and weights from a list of tensors.
418
419  This op assumes that there is at least one id for each row in the dense tensor
420  represented by sp_ids (i.e. there are no rows with empty features), and that
421  all the indices of sp_ids are in canonical row-major order.
422
423  `sp_ids` and `sp_weights` (if not None) are `SparseTensor`s with rank of 2.
424  Embeddings are always aggregated along the last dimension.
425
426  It also assumes that all id values lie in the range [0, p0), where p0
427  is the sum of the size of params along dimension 0.
428
429  Args:
430    params: A single tensor representing the complete embedding tensor, or a
431      list tensors all of same shape except for the first dimension,
432      representing sharded embedding tensors. Alternatively, a
433      `PartitionedVariable`, created by partitioning along dimension 0. Each
434      element must be appropriately sized for the given `partition_strategy`.
435    sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size
436      and M is arbitrary.
437    sp_weights: either a `SparseTensor` of float / double weights, or `None` to
438      indicate all weights should be taken to be 1. If specified, `sp_weights`
439      must have exactly the same shape and indices as `sp_ids`.
440    partition_strategy: A string specifying the partitioning strategy, relevant
441      if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
442      is `"mod"`. See `tf.nn.embedding_lookup` for more details.
443    name: Optional name for the op.
444    combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
445      and "sum" are supported. "sum" computes the weighted sum of the embedding
446      results for each row. "mean" is the weighted sum divided by the total
447      weight. "sqrtn" is the weighted sum divided by the square root of the sum
448      of the squares of the weights. Defaults to `mean`.
449    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
450      than this value, before combining.
451
452  Returns:
453    A dense tensor representing the combined embeddings for the
454    sparse ids. For each row in the dense tensor represented by `sp_ids`, the op
455    looks up the embeddings for all ids in that row, multiplies them by the
456    corresponding weight, and combines these embeddings as specified.
457
458    In other words, if
459
460      `shape(combined params) = [p0, p1, ..., pm]`
461
462    and
463
464      `shape(sp_ids) = shape(sp_weights) = [d0, d1]`
465
466    then
467
468      `shape(output) = [d0, p1, ..., pm]`.
469
470    For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
471
472      ```python
473      [0, 0]: id 1, weight 2.0
474      [0, 1]: id 3, weight 0.5
475      [1, 0]: id 0, weight 1.0
476      [2, 3]: id 1, weight 3.0
477      ```
478
479    with `combiner`="mean", then the output will be a 3x20 matrix where
480
481      ```python
482      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
483      output[1, :] = (params[0, :] * 1.0) / 1.0
484      output[2, :] = (params[1, :] * 3.0) / 3.0
485      ```
486
487  Raises:
488    TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is
489      neither `None` nor `SparseTensor`.
490    ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}.
491  """
492  if combiner is None:
493    combiner = "mean"
494  if combiner not in ("mean", "sqrtn", "sum"):
495    raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'")
496  if isinstance(params, variables.PartitionedVariable):
497    params = list(params)  # Iterate to get the underlying Variables.
498  if not isinstance(params, list):
499    params = [params]
500  if not isinstance(sp_ids, sparse_tensor.SparseTensor):
501    raise TypeError("sp_ids must be SparseTensor")
502  ignore_weights = sp_weights is None
503  if not ignore_weights:
504    if not isinstance(sp_weights, sparse_tensor.SparseTensor):
505      raise TypeError("sp_weights must be either None or SparseTensor")
506    sp_ids.values.get_shape().assert_is_compatible_with(
507        sp_weights.values.get_shape())
508    sp_ids.indices.get_shape().assert_is_compatible_with(
509        sp_weights.indices.get_shape())
510    sp_ids.dense_shape.get_shape().assert_is_compatible_with(
511        sp_weights.dense_shape.get_shape())
512    # TODO(yleon): Add enhanced node assertions to verify that sp_ids and
513    # sp_weights have equal indices and shapes.
514
515  with ops.name_scope(name, "embedding_lookup_sparse",
516                      params + [sp_ids]) as name:
517    segment_ids = sp_ids.indices[:, 0]
518
519    ids = sp_ids.values
520    ids, idx = array_ops.unique(ids)
521
522    embeddings = embedding_lookup(
523        params, ids, partition_strategy=partition_strategy, max_norm=max_norm)
524    if not ignore_weights:
525      if segment_ids.dtype != dtypes.int32:
526        segment_ids = math_ops.cast(segment_ids, dtypes.int32)
527
528      weights = sp_weights.values
529      embeddings = array_ops.gather(embeddings, idx)
530
531      original_dtype = embeddings.dtype
532      if embeddings.dtype in (dtypes.float16, dtypes.bfloat16):
533        # Cast low-precision embeddings to float32 during the computation to
534        # avoid numerical issues.
535        embeddings = math_ops.cast(embeddings, dtypes.float32)
536      if weights.dtype != embeddings.dtype:
537        weights = math_ops.cast(weights, embeddings.dtype)
538
539      # Reshape weights to allow broadcast
540      ones_shape = array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0)
541      ones = array_ops.ones(ones_shape, dtype=dtypes.int32)
542      bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones],
543                                             0)
544
545      orig_weights_shape = weights.get_shape()
546      weights = array_ops.reshape(weights, bcast_weights_shape)
547
548      # Set the weight shape, since after reshaping to bcast_weights_shape,
549      # the shape becomes None.
550      if embeddings.get_shape().ndims is not None:
551        weights.set_shape(
552            orig_weights_shape.concatenate(
553                [1 for _ in range(embeddings.get_shape().ndims - 1)]))
554
555      embeddings *= weights
556
557      if combiner == "sum":
558        embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name)
559      elif combiner == "mean":
560        embeddings = math_ops.segment_sum(embeddings, segment_ids)
561        weight_sum = math_ops.segment_sum(weights, segment_ids)
562        embeddings = math_ops.div_no_nan(embeddings, weight_sum, name=name)
563      elif combiner == "sqrtn":
564        embeddings = math_ops.segment_sum(embeddings, segment_ids)
565        weights_squared = math_ops.pow(weights, 2)
566        weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
567        weight_sum_sqrt = math_ops.sqrt(weight_sum)
568        embeddings = math_ops.div_no_nan(embeddings, weight_sum_sqrt, name=name)
569      else:
570        assert False, "Unrecognized combiner"
571      if embeddings.dtype != original_dtype:
572        embeddings = math_ops.cast(embeddings, original_dtype)
573    else:
574      if segment_ids.dtype not in (dtypes.int32, dtypes.int64):
575        segment_ids = math_ops.cast(segment_ids, dtypes.int32)
576      assert idx is not None
577      if combiner == "sum":
578        embeddings = math_ops.sparse_segment_sum(
579            embeddings, idx, segment_ids, name=name)
580      elif combiner == "mean":
581        embeddings = math_ops.sparse_segment_mean(
582            embeddings, idx, segment_ids, name=name)
583      elif combiner == "sqrtn":
584        embeddings = math_ops.sparse_segment_sqrt_n(
585            embeddings, idx, segment_ids, name=name)
586      else:
587        assert False, "Unrecognized combiner"
588
589    return embeddings
590
591
592@tf_export("nn.embedding_lookup_sparse", v1=[])
593@dispatch.add_dispatch_support
594def embedding_lookup_sparse_v2(params,
595                               sp_ids,
596                               sp_weights,
597                               combiner=None,
598                               max_norm=None,
599                               name=None):
600  """Looks up embeddings for the given ids and weights from a list of tensors.
601
602  This op assumes that there is at least one id for each row in the dense tensor
603  represented by sp_ids (i.e. there are no rows with empty features), and that
604  all the indices of sp_ids are in canonical row-major order.
605
606  `sp_ids` and `sp_weights` (if not None) are `SparseTensor`s with rank of 2.
607  Embeddings are always aggregated along the last dimension.
608
609  It also assumes that all id values lie in the range [0, p0), where p0
610  is the sum of the size of params along dimension 0.
611
612  If `len(params) > 1`, each element of `sp_ids` is partitioned between the
613  elements of `params` according to the "div" partition strategy, which means we
614  assign ids to partitions in a contiguous manner. For instance, 13 ids are
615  split across 5 partitions as:
616  `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
617
618  If the id space does not evenly divide the number of partitions, each of the
619  first `(max_id + 1) % len(params)` partitions will be assigned one more id.
620
621  Args:
622    params: A single tensor representing the complete embedding tensor, or a
623      list of tensors all of same shape except for the first dimension,
624      representing sharded embedding tensors following "div" partition strategy.
625    sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size
626      and M is arbitrary.
627    sp_weights: either a `SparseTensor` of float / double weights, or `None` to
628      indicate all weights should be taken to be 1. If specified, `sp_weights`
629      must have exactly the same shape and indices as `sp_ids`.
630    combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
631      and "sum" are supported. "sum" computes the weighted sum of the embedding
632      results for each row. "mean" is the weighted sum divided by the total
633      weight. "sqrtn" is the weighted sum divided by the square root of the sum
634      of the squares of the weights. Defaults to `mean`.
635    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
636      than this value, before combining.
637    name: Optional name for the op.
638
639  Returns:
640    A dense tensor representing the combined embeddings for the
641    sparse ids. For each row in the dense tensor represented by `sp_ids`, the op
642    looks up the embeddings for all ids in that row, multiplies them by the
643    corresponding weight, and combines these embeddings as specified.
644
645    In other words, if
646
647      `shape(combined params) = [p0, p1, ..., pm]`
648
649    and
650
651      `shape(sp_ids) = shape(sp_weights) = [d0, d1]`
652
653    then
654
655      `shape(output) = [d0, p1, ..., pm]`.
656
657    For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
658
659      ```python
660      [0, 0]: id 1, weight 2.0
661      [0, 1]: id 3, weight 0.5
662      [1, 0]: id 0, weight 1.0
663      [2, 3]: id 1, weight 3.0
664      ```
665
666    with `combiner`="mean", then the output will be a 3x20 matrix where
667
668      ```python
669      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
670      output[1, :] = (params[0, :] * 1.0) / 1.0
671      output[2, :] = (params[1, :] * 3.0) / 3.0
672      ```
673
674  Raises:
675    TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is
676      neither `None` nor `SparseTensor`.
677    ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}.
678  """
679  return embedding_lookup_sparse(params, sp_ids, sp_weights, "div", name,
680                                 combiner, max_norm)
681
682
683@tf_export("nn.safe_embedding_lookup_sparse", v1=[])
684@dispatch.add_dispatch_support
685def safe_embedding_lookup_sparse_v2(embedding_weights,
686                                    sparse_ids,
687                                    sparse_weights=None,
688                                    combiner="mean",
689                                    default_id=None,
690                                    max_norm=None,
691                                    name=None):
692  """Lookup embedding results, accounting for invalid IDs and empty features.
693
694  The partitioned embedding in `embedding_weights` must all be the same shape
695  except for the first dimension. The first dimension is allowed to vary as the
696  vocabulary size is not necessarily a multiple of num of shards.
697
698  Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
699  with non-positive weight. For an entry with no features, the embedding vector
700  for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
701
702  The ids and weights may be multi-dimensional. Embeddings are always aggregated
703  along the last dimension.
704
705  If `len(embedding_weights) > 1`, each element `id` of `ids` is partitioned
706  between the elements of `embedding_weights` according to the "div" partition
707  strategy, which means we assign ids to partitions in a contiguous manner. For
708  instance, 13 ids are split across 5 partitions as:
709  `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
710
711  If the id space does not evenly divide the number of partitions, each of the
712  first `(max_id + 1) % len(embedding_weights)` partitions will be assigned one
713  more id.
714
715  Args:
716    embedding_weights: A single tensor representing the complete embedding
717      tensor, or a list of tensors all of same shape except for the first
718      dimension, representing sharded embedding tensors following "div"
719      partition strategy.
720    sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
721      ids. `d_0` is typically batch size.
722    sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
723      float weights corresponding to `sparse_ids`, or `None` if all weights are
724      be assumed to be 1.0.
725    combiner: A string specifying how to combine embedding results for each
726      entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the
727      default.
728    default_id: The id to use for an entry with no features. Defaults to
729      0-vector.
730    max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
731      combining.
732    name: A name for this operation (optional).
733
734  Returns:
735    A dense tensor representing the combined embeddings for the
736    sparse ids. For each row in the dense tensor represented by `sparse_ids`,
737    the op looks up the embeddings for all ids in that row, multiplies them by
738    the corresponding weight, and combines these embeddings as specified.
739
740    In other words, if
741
742      `shape(combined embedding_weights) = [p0, p1, ..., pm]`
743
744    and
745
746      `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]`
747
748    then
749
750      `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`.
751
752    For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
753
754      ```python
755      [0, 0]: id 1, weight 2.0
756      [0, 1]: id 3, weight 0.5
757      [1, 0]: id -1, weight 1.0
758      [2, 3]: id 1, weight 3.0
759      ```
760
761    `default_id` is 0.
762
763    with `combiner`="mean", then the output will be a 3x20 matrix where
764
765      ```python
766      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
767      output[1, :] = (params[0, :] * 1.0) / 1.0
768      output[2, :] = (params[1, :] * 3.0) / 3.0
769      ```
770
771  Raises:
772    ValueError: if `embedding_weights` is empty.
773  """
774  return safe_embedding_lookup_sparse(
775      embedding_weights,
776      sparse_ids,
777      sparse_weights=sparse_weights,
778      combiner=combiner,
779      default_id=default_id,
780      name=name,
781      partition_strategy="div",
782      max_norm=max_norm)
783
784
785@tf_export(v1=["nn.safe_embedding_lookup_sparse"])
786@dispatch.add_dispatch_support
787def safe_embedding_lookup_sparse(embedding_weights,
788                                 sparse_ids,
789                                 sparse_weights=None,
790                                 combiner="mean",
791                                 default_id=None,
792                                 name=None,
793                                 partition_strategy="div",
794                                 max_norm=None):
795  """Lookup embedding results, accounting for invalid IDs and empty features.
796
797  The partitioned embedding in `embedding_weights` must all be the same shape
798  except for the first dimension. The first dimension is allowed to vary as the
799  vocabulary size is not necessarily a multiple of `P`.  `embedding_weights`
800  may be a `PartitionedVariable` as returned by using
801  `tf.compat.v1.get_variable()` with a
802  partitioner.
803
804  Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
805  with non-positive weight. For an entry with no features, the embedding vector
806  for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
807
808  The ids and weights may be multi-dimensional. Embeddings are always aggregated
809  along the last dimension.
810
811  Args:
812    embedding_weights: A single tensor representing the complete embedding
813      tensor, or a list tensors all of same shape except for the first
814      dimension, representing sharded embedding tensors. Alternatively, a
815      `PartitionedVariable`, created by partitioning along dimension 0. Each
816      element must be appropriately sized for the given `partition_strategy`.
817    sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
818      ids. `d_0` is typically batch size.
819    sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
820      float weights corresponding to `sparse_ids`, or `None` if all weights are
821      be assumed to be 1.0.
822    combiner: A string specifying how to combine embedding results for each
823      entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the
824      default.
825    default_id: The id to use for an entry with no features.
826    name: A name for this operation (optional).
827    partition_strategy: A string specifying the partitioning strategy. Currently
828      `"div"` and `"mod"` are supported. Default is `"div"`.
829    max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
830      combining.
831
832  Returns:
833    A dense tensor representing the combined embeddings for the
834    sparse ids. For each row in the dense tensor represented by `sp_ids`, the op
835    looks up the embeddings for all ids in that row, multiplies them by the
836    corresponding weight, and combines these embeddings as specified.
837
838    In other words, if
839
840      `shape(combined embedding_weights) = [p0, p1, ..., pm]`
841
842    and
843
844      `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]`
845
846    then
847
848      `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`.
849
850    For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
851
852      ```python
853      [0, 0]: id 1, weight 2.0
854      [0, 1]: id 3, weight 0.5
855      [1, 0]: id -1, weight 1.0
856      [2, 3]: id 1, weight 3.0
857      ```
858
859    `default_id` is 0.
860
861    with `combiner`="mean", then the output will be a 3x20 matrix where
862
863      ```python
864      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
865      output[1, :] = (params[0, :] * 1.0) / 1.0
866      output[2, :] = (params[1, :] * 3.0) / 3.0
867      ```
868
869  Raises:
870    ValueError: if `embedding_weights` is empty.
871  """
872  if embedding_weights is None:
873    raise ValueError("Missing embedding_weights %s." % embedding_weights)
874  if isinstance(embedding_weights, variables.PartitionedVariable):
875    embedding_weights = list(embedding_weights)  # get underlying Variables.
876  if not isinstance(embedding_weights, list):
877    embedding_weights = [embedding_weights]
878  if len(embedding_weights) < 1:
879    raise ValueError("Missing embedding_weights %s." % embedding_weights)
880
881  dtype = sparse_weights.dtype if sparse_weights is not None else None
882  embedding_weights = [
883      w if (isinstance(w, resource_variable_ops.ResourceVariable)
884            and dtype in (None, w.dtype))
885      else ops.convert_to_tensor(w, dtype=dtype)
886      for w in embedding_weights
887  ]
888
889  with ops.name_scope(name, "embedding_lookup", embedding_weights +
890                      [sparse_ids, sparse_weights]) as scope:
891    # Reshape higher-rank sparse ids and weights to linear segment ids.
892    original_shape = sparse_ids.dense_shape
893    original_rank_dim = tensor_shape.dimension_value(
894        sparse_ids.dense_shape.get_shape()[0])
895    original_rank = (
896        array_ops.size(original_shape)
897        if original_rank_dim is None else original_rank_dim)
898    sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
899        math_ops.reduce_prod(
900            array_ops.slice(original_shape, [0], [original_rank - 1])),
901        array_ops.gather(original_shape, original_rank - 1)
902    ])
903    if sparse_weights is not None:
904      sparse_weights = sparse_tensor.SparseTensor(sparse_ids.indices,
905                                                  sparse_weights.values,
906                                                  sparse_ids.dense_shape)
907
908    # Prune invalid ids and weights.
909    sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
910    if combiner != "sum":
911      sparse_ids, sparse_weights = _prune_invalid_weights(
912          sparse_ids, sparse_weights)
913
914    # Fill in dummy values for empty features, if necessary.
915    sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(
916        sparse_ids, default_id or 0)
917    if sparse_weights is not None:
918      sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)
919
920    result = embedding_lookup_sparse(
921        embedding_weights,
922        sparse_ids,
923        sparse_weights,
924        combiner=combiner,
925        partition_strategy=partition_strategy,
926        name=None if default_id is None else scope,
927        max_norm=max_norm)
928
929    if default_id is None:
930      # Broadcast is_row_empty to the same shape as embedding_lookup_result,
931      # for use in Select.
932      is_row_empty = array_ops.tile(
933          array_ops.reshape(is_row_empty, [-1, 1]),
934          array_ops.stack([1, array_ops.shape(result)[1]]))
935
936      result = array_ops.where(
937          is_row_empty, array_ops.zeros_like(result), result, name=scope)
938
939    # Reshape back from linear ids back into higher-dimensional dense result.
940    final_result = array_ops.reshape(
941        result,
942        array_ops.concat([
943            array_ops.slice(
944                math_ops.cast(original_shape, dtypes.int32), [0],
945                [original_rank - 1]),
946            array_ops.slice(array_ops.shape(result), [1], [-1])
947        ], 0))
948    final_result.set_shape(
949        tensor_shape.unknown_shape(
950            (tensor_shape.Dimension(original_rank_dim) - 1).value).concatenate(
951                result.get_shape()[1:]))
952    return final_result
953
954
955def embedding_lookup_ragged(embedding_weights,
956                            ragged_ids,
957                            partition_strategy="mod",
958                            max_norm=None,
959                            name=None):
960  """Look up the ragged ids in a list of embedding tensors.
961
962  Args:
963    embedding_weights: A tensor representing the complete embedding tensor
964      having the shape [e1, ...eM]
965    ragged_ids: A 'RaggedTensor' with type 'int32' or 'int64' containing the ids
966      to be looked up in 'embedding_weights' of shape [r0, ..rN]. Values must be
967      in the range '[0, embedding_weights.shape[0]]'.
968    partition_strategy: A string specifying the partitioning strategy.
969    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
970      than this value.
971    name: A name for the operation (optional)
972
973  Returns:
974    A ragged tensor of shape [r0, r1, ...rN, e1, ...eM].
975
976  Raises:
977    ValueError: whether the embedding_weights is empty or the ragged_ids is
978    not a RaggedTensor.
979  """
980  if embedding_weights is None:
981    raise ValueError("The embedding weights must be specified.")
982  if isinstance(embedding_weights, (list, tuple)) and not embedding_weights:
983    raise ValueError("The embedding weights should not be empty.")
984  if ragged_ids.dtype != dtypes.int32 and ragged_ids.dtype != dtypes.int64:
985    raise ValueError("The values contained by the inputs have type " +
986                     str(ragged_ids.dtype) +
987                     " and cannot be processed. All values"
988                     " should be indices, either of type `in32` or `int64`.")
989
990  with ops.name_scope(name, "embedding_lookup_ragged") as name:
991    looked_up_ragged = ragged_functional_ops.map_flat_values(
992        embedding_lookup,
993        params=embedding_weights,
994        ids=ragged_ids,
995        partition_strategy=partition_strategy,
996        max_norm=max_norm)
997
998    return looked_up_ragged
999
1000
1001def _prune_invalid_ids(sparse_ids, sparse_weights):
1002  """Prune invalid IDs (< 0) from the input ids and weights."""
1003  is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
1004  if sparse_weights is not None:
1005    is_id_valid = math_ops.logical_and(
1006        is_id_valid,
1007        array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
1008  sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
1009  if sparse_weights is not None:
1010    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
1011  return sparse_ids, sparse_weights
1012
1013
1014def _prune_invalid_weights(sparse_ids, sparse_weights):
1015  """Prune invalid weights (< 0) from the input ids and weights."""
1016  if sparse_weights is not None:
1017    is_weights_valid = math_ops.greater(sparse_weights.values, 0)
1018    sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
1019    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
1020  return sparse_ids, sparse_weights
1021