• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Mid level API for TPU Embeddings."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20from __future__ import unicode_literals
21
22import functools
23from typing import Any, Dict, Callable, Iterable, List, Optional, Text, Tuple, Union
24
25from absl import logging
26
27from tensorflow.core.framework import attr_value_pb2
28from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
29from tensorflow.python.distribute import device_util
30from tensorflow.python.distribute import distribute_utils
31from tensorflow.python.distribute import distribution_strategy_context
32from tensorflow.python.distribute import sharded_variable
33from tensorflow.python.distribute import tpu_strategy
34from tensorflow.python.eager import context
35from tensorflow.python.eager import def_function
36from tensorflow.python.framework import constant_op
37from tensorflow.python.framework import device as tf_device
38from tensorflow.python.framework import dtypes
39from tensorflow.python.framework import ops
40from tensorflow.python.framework import sparse_tensor
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import embedding_ops
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import sparse_ops
45from tensorflow.python.ops import variable_scope
46from tensorflow.python.ops import variables as tf_variables
47from tensorflow.python.ops.ragged import ragged_tensor
48from tensorflow.python.saved_model import save_context
49from tensorflow.python.tpu import tpu
50from tensorflow.python.tpu import tpu_embedding_v2_utils
51from tensorflow.python.tpu.ops import tpu_ops
52from tensorflow.python.training.saving import saveable_hook
53from tensorflow.python.training.tracking import base
54from tensorflow.python.training.tracking import tracking
55from tensorflow.python.types import core
56from tensorflow.python.types import internal as internal_types
57from tensorflow.python.util import compat
58from tensorflow.python.util import nest
59from tensorflow.python.util import tf_inspect
60from tensorflow.python.util.tf_export import tf_export
61
62
63_HOOK_KEY = "TPUEmbedding_saveable"
64_NAME_KEY = "_tpu_embedding_layer"
65
66
67# TODO(bfontain): Cleanup and remove this once there is an implementation of
68# sharded variables that can be used in the PSStrategy with optimizers.
69# We implement just enough of the of a tf.Variable so that this could be passed
70# to an optimizer.
71class TPUShardedVariable(sharded_variable.ShardedVariableMixin):
72  """A ShardedVariable class for TPU."""
73
74  @property
75  def _in_graph_mode(self):
76    return self.variables[0]._in_graph_mode  # pylint: disable=protected-access
77
78  @property
79  def _unique_id(self):
80    return self.variables[0]._unique_id  # pylint: disable=protected-access
81
82  @property
83  def _distribute_strategy(self):
84    return self.variables[0]._distribute_strategy  # pylint: disable=protected-access
85
86  @property
87  def _shared_name(self):
88    return self._name
89
90
91def _add_key_attr(op, name):
92  op._set_attr(_NAME_KEY, attr_value_pb2.AttrValue(s=compat.as_bytes(name)))  # pylint: disable=protected-access
93
94
95@tf_export("tpu.experimental.embedding.TPUEmbedding")
96class TPUEmbedding(tracking.AutoTrackable):
97  """The TPUEmbedding mid level API.
98
99  NOTE: When instantiated under a TPUStrategy, this class can only be created
100  once per call to `tf.tpu.experimental.initialize_tpu_system`. If you wish to
101  re-initialize the embedding engine you must re-initialize the tpu as well.
102  Doing this will clear any variables from TPU, so ensure you have checkpointed
103  before you do this. If a further instances of the class are needed,
104  set the `initialize_tpu_embedding` argument to `False`.
105
106  This class can be used to support training large embeddings on TPU. When
107  creating an instance of this class, you must specify the complete set of
108  tables and features you expect to lookup in those tables. See the
109  documentation of `tf.tpu.experimental.embedding.TableConfig` and
110  `tf.tpu.experimental.embedding.FeatureConfig` for more details on the complete
111  set of options. We will cover the basic usage here.
112
113  NOTE: multiple `FeatureConfig` objects can use the same `TableConfig` object,
114  allowing different features to share the same table:
115
116  ```python
117  table_config_one = tf.tpu.experimental.embedding.TableConfig(
118      vocabulary_size=...,
119      dim=...)
120  table_config_two = tf.tpu.experimental.embedding.TableConfig(
121      vocabulary_size=...,
122      dim=...)
123  feature_config = {
124      'feature_one': tf.tpu.experimental.embedding.FeatureConfig(
125          table=table_config_one),
126      'feature_two': tf.tpu.experimental.embedding.FeatureConfig(
127          table=table_config_one),
128      'feature_three': tf.tpu.experimental.embedding.FeatureConfig(
129          table=table_config_two)}
130  ```
131
132  There are two modes under which the `TPUEmbedding` class can used. This
133  depends on if the class was created under a `TPUStrategy` scope or not.
134
135  Under `TPUStrategy`, we allow access to the method `enqueue`, `dequeue` and
136  `apply_gradients`. We will show examples below of how to use these to train
137  and evaluate your model. Under CPU, we only access to the `embedding_tables`
138  property which allow access to the embedding tables so that you can use them
139  to run model evaluation/prediction on CPU.
140
141  First lets look at the `TPUStrategy` mode. Initial setup looks like:
142
143  ```python
144  strategy = tf.distribute.TPUStrategy(...)
145  with strategy.scope():
146    embedding = tf.tpu.experimental.embedding.TPUEmbedding(
147        feature_config=feature_config,
148        optimizer=tf.tpu.experimental.embedding.SGD(0.1))
149  ```
150
151  When creating a distributed dataset that is to be passed to the enqueue
152  operation a special input option must be specified:
153
154  ```python
155  distributed_dataset = (
156      strategy.distribute_datasets_from_function(
157          dataset_fn=...,
158          options=tf.distribute.InputOptions(
159              experimental_fetch_to_device=False))
160  dataset_iterator = iter(distributed_dataset)
161  ```
162
163  NOTE: All batches passed to the layer must have the same batch size for each
164  input, more over once you have called the layer with one batch size all
165  subsequent calls must use the same batch_size. In the event that the batch
166  size cannot be automatically determined by the enqueue method, you must call
167  the build method with the batch size to initialize the layer.
168
169  To use this API on TPU you should use a custom training loop. Below is an
170  example of a training and evaluation step:
171
172  ```python
173  @tf.function
174  def training_step(dataset_iterator, num_steps):
175    def tpu_step(tpu_features):
176      with tf.GradientTape() as tape:
177        activations = embedding.dequeue()
178        tape.watch(activations)
179        model_output = model(activations)
180        loss = ...  # some function of labels and model_output
181
182      embedding_gradients = tape.gradient(loss, activations)
183      embedding.apply_gradients(embedding_gradients)
184      # Insert your model gradient and optimizer application here
185
186    for _ in tf.range(num_steps):
187      embedding_features, tpu_features = next(dataset_iterator)
188      embedding.enqueue(embedding_features, training=True)
189      strategy.run(tpu_step, args=(embedding_features, ))
190
191  @tf.function
192  def evalution_step(dataset_iterator, num_steps):
193    def tpu_step(tpu_features):
194      activations = embedding.dequeue()
195      model_output = model(activations)
196      # Insert your evaluation code here.
197
198    for _ in tf.range(num_steps):
199      embedding_features, tpu_features = next(dataset_iterator)
200      embedding.enqueue(embedding_features, training=False)
201      strategy.run(tpu_step, args=(embedding_features, ))
202  ```
203
204  NOTE: The calls to `enqueue` have `training` set to `True` when
205  `embedding.apply_gradients` is used and set to `False` when
206  `embedding.apply_gradients` is not present in the function. If you don't
207  follow this pattern you may cause an error to be raised or the tpu may
208  deadlock.
209
210  In the above examples, we assume that the user has a dataset which returns
211  a tuple where the first element of the tuple matches the structure of what
212  was passed as the `feature_config` argument to the object initializer. Also we
213  utilize `tf.range` to get a `tf.while_loop` in order to increase performance.
214
215  When checkpointing your model, you should include your
216  `tf.tpu.experimental.embedding.TPUEmbedding` object in the checkpoint. It is a
217  trackable object and saving it will save the embedding tables and their
218  optimizer slot variables:
219
220  ```python
221  checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
222  checkpoint.save(...)
223  ```
224
225  On CPU, only the `embedding_table` property is usable. This will allow you to
226  restore a checkpoint to the object and have access to the table variables:
227
228  ```python
229  model = model_fn(...)
230  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
231      feature_config=feature_config,
232      optimizer=tf.tpu.experimental.embedding.SGD(0.1))
233  checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
234  checkpoint.restore(...)
235
236  tables = embedding.embedding_tables
237  ```
238
239  You can now use table in functions like `tf.nn.embedding_lookup` to perform
240  your embedding lookup and pass to your model.
241
242  """
243
244  def __init__(
245      self,
246      feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable],  # pylint:disable=g-bare-generic
247      optimizer: Optional[tpu_embedding_v2_utils._Optimizer],  # pylint:disable=protected-access
248      pipeline_execution_with_tensor_core: bool = False):
249    """Creates the TPUEmbedding mid level API object.
250
251    ```python
252    strategy = tf.distribute.TPUStrategy(...)
253    with strategy.scope():
254      embedding = tf.tpu.experimental.embedding.TPUEmbedding(
255          feature_config=tf.tpu.experimental.embedding.FeatureConfig(
256              table=tf.tpu.experimental.embedding.TableConfig(
257                  dim=...,
258                  vocabulary_size=...)))
259    ```
260
261    Args:
262      feature_config: A nested structure of
263        `tf.tpu.experimental.embedding.FeatureConfig` configs.
264      optimizer: An instance of one of `tf.tpu.experimental.embedding.SGD`,
265        `tf.tpu.experimental.embedding.Adagrad` or
266        `tf.tpu.experimental.embedding.Adam`. When not created under
267        TPUStrategy may be set to None to avoid the creation of the optimizer
268        slot variables, useful for optimizing memory consumption when exporting
269        the model for serving where slot variables aren't needed.
270      pipeline_execution_with_tensor_core: If True, the TPU embedding
271        computations will overlap with the TensorCore computations (and hence
272        will be one step old). Set to True for improved performance.
273
274    Raises:
275      ValueError: If optimizer is not one of tf.tpu.experimental.embedding.(SGD,
276      Adam or Adagrad) or None when created under a TPUStrategy.
277    """
278    self._strategy = distribution_strategy_context.get_strategy()
279    self._using_tpu = isinstance(self._strategy, (tpu_strategy.TPUStrategy,
280                                                  tpu_strategy.TPUStrategyV2))
281    self._pipeline_execution_with_tensor_core = (
282        pipeline_execution_with_tensor_core)
283
284    self._feature_config = feature_config
285
286    # The TPU embedding ops are slightly inconsistent with how they refer to
287    # tables:
288    # * The enqueue op takes a parallel list of tensors for input, one of those
289    #   is the table id for the feature which matches the integer index of the
290    #   table in the proto created by _create_config_proto().
291    # * The recv_tpu_embedding_activations op emits lookups per table in the
292    #   order from the config proto.
293    # * The send_tpu_embedding_gradients expects input tensors to be per table
294    #   in the same order as the config proto.
295    # * Per optimizer load and retrieve ops are specified per table and take the
296    #   table name rather than the table id.
297    # Thus we must fix a common order to tables and ensure they have unique
298    # names.
299
300    # Set table order here to the order of the first occurence of the table in a
301    # feature provided by the user. The order of this struct must be fixed
302    # to provide the user with deterministic behavior over multiple
303    # instantiations.
304    self._table_config = []
305    for feature in nest.flatten(feature_config):
306      if feature.table not in self._table_config:
307        self._table_config.append(feature.table)
308
309    # Ensure tables have unique names. Also error check the optimizer as we
310    # specifically don't do that in the TableConfig class to allow high level
311    # APIs that are built on this to use strings/other classes to represent
312    # optimizers (before they are passed to this class).
313    table_names = []
314    for i, table in enumerate(self._table_config):
315      if table.optimizer is None:
316        # TODO(bfontain) Should we allow some sort of optimizer merging here?
317        table.optimizer = optimizer
318      if ((table.optimizer is not None or self._using_tpu) and
319          not isinstance(table.optimizer, tpu_embedding_v2_utils._Optimizer)):  # pylint: disable=protected-access
320        raise ValueError("{} is an unsupported optimizer class. Please pass an "
321                         "instance of one of the optimizer classes under "
322                         "tf.tpu.experimental.embedding.".format(
323                             type(table.optimizer)))
324      if table.name is None:
325        table.name = "table_{}".format(i)
326      if table.name in table_names:
327        raise ValueError("Tables must have a unique name. "
328                         f"Multiple tables with name {table.name} found.")
329      table_names.append(table.name)
330
331    if self._using_tpu:
332      # Extract a list of callable learning rates also in fixed order. Each
333      # table in the confix proto will get a index into this list and we will
334      # pass this list in the same order after evaluation to the
335      # send_tpu_embedding_gradients op.
336      self._dynamic_learning_rates = list({
337          table.optimizer.learning_rate for table in self._table_config if
338          callable(table.optimizer.learning_rate)})
339
340      # We need to list of host devices for the load/retrieve operations.
341      self._hosts = get_list_of_hosts(self._strategy)
342
343    self._built = False
344    self._verify_batch_size_on_enqueue = True
345
346  def build(self, per_replica_batch_size: Optional[int] = None):
347    """Create the underlying variables and initializes the TPU for embeddings.
348
349    This method creates the underlying variables (including slot variables). If
350    created under a TPUStrategy, this will also initialize the TPU for
351    embeddings.
352
353    This function will automatically get called by enqueue, which will try to
354    determine your batch size automatically. If this fails, you must manually
355    call this method before you call enqueue.
356
357    Args:
358      per_replica_batch_size: The per replica batch size that you intend to use.
359        Note that is fixed and the same batch size must be used for both
360        training and evaluation. If you want to calculate this from the global
361        batch size, you can use `num_replicas_in_sync` property of your strategy
362        object. May be set to None if not created under a TPUStrategy.
363
364    Raises:
365      ValueError: If per_replica_batch_size is None and object was created in a
366        TPUStrategy scope.
367    """
368    if self._built:
369      return
370
371    if self._using_tpu:
372      if per_replica_batch_size is None:
373        raise ValueError(
374            "When calling TpuShardedVariable.build under TpuStrategy you must "
375            "specify a per_replica_batch_size argument.")
376
377      self._batch_size = per_replica_batch_size
378
379      self._config_proto = self._create_config_proto()
380
381      logging.info("Initializing TPU Embedding engine.")
382      tpu_embedding_v2_utils.log_tpu_embedding_configuration(self._config_proto)
383
384      @def_function.function
385      def load_config():
386        tpu.initialize_system_for_tpu_embedding(self._config_proto)
387
388      load_config()
389      logging.info("Done initializing TPU Embedding engine.")
390
391    # Create and load variables and slot variables into the TPU.
392    # Note that this is a dict of dicts. Keys to the first dict are table names.
393    # We would prefer to use TableConfigs, but then these variables won't be
394    # properly tracked by the tracking API.
395    self._variables = self._create_variables_and_slots()
396
397    self._built = True
398
399    # This is internally conditioned self._built and self._using_tpu
400    self._load_variables()
401
402  def _maybe_build(self, batch_size: Optional[int]):
403    if not self._built:
404      # This can be called while tracing a function, so we wrap the
405      # initialization code with init_scope so it runs eagerly, this means that
406      # it will not be included the function graph generated by tracing so that
407      # we can be sure that we only initialize the TPU for embeddings exactly
408      # once.
409      with ops.init_scope():
410        self.build(batch_size)
411
412  @property
413  def embedding_tables(
414      self
415  ) -> Dict[tpu_embedding_v2_utils.TableConfig, tf_variables.Variable]:
416    """Returns a dict of embedding tables, keyed by `TableConfig`.
417
418    This property only works when the `TPUEmbedding` object is created under a
419    non-TPU strategy. This is intended to be used to for CPU based lookup when
420    creating a serving checkpoint.
421
422    Returns:
423      A dict of embedding tables, keyed by `TableConfig`.
424
425    Raises:
426      RuntimeError: If object was created under a `TPUStrategy`.
427    """
428    # We don't support returning tables on TPU due to their sharded nature and
429    # the fact that when using a TPUStrategy:
430    # 1. Variables are stale and are only updated when a checkpoint is made.
431    # 2. Updating the variables won't affect the actual tables on the TPU.
432    if self._using_tpu:
433      if save_context.in_save_context():
434        return {table: self._variables[table.name]["parameters"].variables[0]
435                for table in self._table_config}
436      raise RuntimeError("Unable to retrieve embedding tables when using a TPU "
437                         "strategy. If you need access, save your model, "
438                         "create this object under a CPU strategy and restore.")
439
440    self._maybe_build(None)
441
442    # Only return the tables and not the slot variables. On CPU this are honest
443    # tf.Variables.
444    return {table: self._variables[table.name]["parameters"]
445            for table in self._table_config}
446
447  def _create_config_proto(
448      self
449  ) -> tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration:
450    """Creates the TPUEmbeddingConfiguration proto.
451
452    This proto is used to initialize the TPU embedding engine.
453
454    Returns:
455      A TPUEmbeddingConfiguration proto.
456    """
457
458    config_proto = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration()
459
460    # There are several things that need to be computed here:
461    # 1. Each table has a num_features, which corresponds to the number of
462    #    output rows per example for this table. Sequence features count for
463    #    their maximum sequence length.
464    # 2. Learning rate index: the index of the dynamic learning rate for this
465    #    table (if it exists) in the list we created at initialization.
466    #    We don't simply create one learning rate index per table as this has
467    #    extremely bad performance characteristics. The more separate
468    #    optimization configurations we have, the worse the performance will be.
469    num_features = {table: 0 for table in self._table_config}
470    for feature in nest.flatten(self._feature_config):
471      num_features[feature.table] += (1 if feature.max_sequence_length == 0
472                                      else feature.max_sequence_length)
473
474    # Map each callable dynamic learning rate to its in index in the list.
475    learning_rate_index = {r: i for i, r in enumerate(
476        self._dynamic_learning_rates)}
477
478    for table in self._table_config:
479      table_descriptor = config_proto.table_descriptor.add()
480      table_descriptor.name = table.name
481
482      # For small tables, we pad to the number of hosts so that at least one
483      # id will be assigned to each host.
484      table_descriptor.vocabulary_size = max(table.vocabulary_size,
485                                             self._strategy.extended.num_hosts)
486      table_descriptor.dimension = table.dim
487
488      table_descriptor.num_features = num_features[table]
489
490      parameters = table_descriptor.optimization_parameters
491
492      # We handle the learning rate separately here and don't allow the
493      # optimization class to handle this, as it doesn't know about dynamic
494      # rates.
495      if callable(table.optimizer.learning_rate):
496        parameters.learning_rate.dynamic.tag = (
497            learning_rate_index[table.optimizer.learning_rate])
498      else:
499        parameters.learning_rate.constant = table.optimizer.learning_rate
500
501      # Use optimizer to handle the rest of the parameters.
502      table.optimizer._set_optimization_parameters(parameters)  # pylint: disable=protected-access
503
504    # Always set mode to training, we override the mode during enqueue.
505    config_proto.mode = (
506        tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.TRAINING)
507
508    config_proto.batch_size_per_tensor_core = self._batch_size
509    config_proto.num_hosts = self._strategy.extended.num_hosts
510    config_proto.num_tensor_cores = self._strategy.num_replicas_in_sync
511
512    # TODO(bfontain): Allow users to pick MOD for the host sharding.
513    config_proto.sharding_strategy = (
514        tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.DIV_DEFAULT)
515    config_proto.pipeline_execution_with_tensor_core = (
516        self._pipeline_execution_with_tensor_core)
517
518    return config_proto
519
520  def _compute_per_table_gradients(
521      self,
522      gradients
523  ) -> Dict[Text, List[core.Tensor]]:
524    """Computes a dict of lists of gradients, keyed by table name.
525
526    Args:
527      gradients: A nested structure of Tensors (and Nones) with the same
528        structure as the feature config.
529
530    Returns:
531      A dict of lists of tensors, keyed by the table names, containing the
532    gradients in the correct order with None gradients replaced by zeros.
533    """
534
535    nest.assert_same_structure(self._feature_config, gradients)
536
537    per_table_gradients = {table: [] for table in self._table_config}
538    for (path, gradient), feature in zip(
539        nest.flatten_with_joined_string_paths(gradients),
540        nest.flatten(self._feature_config)):
541      if gradient is not None and not isinstance(gradient, ops.Tensor):
542        raise ValueError(
543            f"When computing per-table gradients, found non-tensor type: "
544            f"{type(gradient)} at path {path}.")
545
546      # Expected tensor shape differs for sequence and non-sequence features.
547      if feature.max_sequence_length > 0:
548        shape = [self._batch_size, feature.max_sequence_length,
549                 feature.table.dim]
550      else:
551        shape = [self._batch_size, feature.table.dim]
552
553      if gradient is not None:
554        if gradient.shape != shape:
555          raise ValueError("Found gradient of shape {} at path {}. Expected "
556                           "shape {}.".format(gradient.shape, path, shape))
557
558        # We expand dims on non-sequence features so that all features are
559        # of rank 3 and we can concat on axis=1.
560        if len(shape) == 2:
561          gradient = array_ops.expand_dims(gradient, axis=1)
562      else:
563        # No gradient for this feature, since we must give a gradient for all
564        # features, pass in a zero tensor here. Note that this is not correct
565        # for all optimizers.
566        logging.warn("No gradient passed for feature %s, sending zero "
567                     "gradient. This may not be correct behavior for certain "
568                     "optimizers like Adam.", path)
569        # Create a shape to mimic the expand_dims above for non-sequence
570        # features.
571        if len(shape) == 2:
572          shape = [shape[0], 1, shape[1]]
573        gradient = array_ops.zeros(shape, dtype=dtypes.float32)
574      per_table_gradients[feature.table].append(gradient)
575
576    return per_table_gradients
577
578  def apply_gradients(self, gradients, name: Optional[Text] = None):
579    """Applies the gradient update to the embedding tables.
580
581    If a gradient of `None` is passed in any position of the nested structure,
582    then an gradient update with a zero gradient is applied for that feature.
583    For optimizers like SGD or Adagrad, this is the same as applying no update
584    at all. For lazy Adam and other sparsely applied optimizers with decay,
585    ensure you understand the effect of applying a zero gradient.
586
587    ```python
588    strategy = tf.distribute.TPUStrategy(...)
589    with strategy.scope():
590      embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
591
592    distributed_dataset = (
593        strategy.distribute_datasets_from_function(
594            dataset_fn=...,
595            options=tf.distribute.InputOptions(
596                experimental_fetch_to_device=False))
597    dataset_iterator = iter(distributed_dataset)
598
599    @tf.function
600    def training_step():
601      def tpu_step(tpu_features):
602        with tf.GradientTape() as tape:
603          activations = embedding.dequeue()
604          tape.watch(activations)
605
606          loss = ... #  some computation involving activations
607
608        embedding_gradients = tape.gradient(loss, activations)
609        embedding.apply_gradients(embedding_gradients)
610
611      embedding_features, tpu_features = next(dataset_iterator)
612      embedding.enqueue(embedding_features, training=True)
613      strategy.run(tpu_step, args=(embedding_features, ))
614
615    training_step()
616    ```
617
618    Args:
619      gradients: A nested structure of gradients, with structure matching the
620        `feature_config` passed to this object.
621      name: A name for the underlying op.
622
623    Raises:
624      RuntimeError: If called when object wasn't created under a `TPUStrategy`
625        or if not built (either by manually calling build or calling enqueue).
626      ValueError: If a non-`tf.Tensor` non-`None` gradient is passed in, or a
627        `tf.Tensor` of the incorrect shape is passed in. Also if
628        the size of any sequence in `gradients` does not match corresponding
629        sequence in `feature_config`.
630      TypeError: If the type of any sequence in `gradients` does not match
631        corresponding sequence in `feature_config`.
632    """
633    if not self._using_tpu:
634      raise RuntimeError("apply_gradients is not valid when TPUEmbedding "
635                         "object is not created under a TPUStrategy.")
636
637    if not self._built:
638      raise RuntimeError("apply_gradients called on unbuilt TPUEmbedding "
639                         "object. Please either call enqueue first or manually "
640                         "call the build method.")
641
642    # send_tpu_embedding_gradients requires per table gradient, if we only have
643    # one feature per table this isn't an issue. When multiple features share
644    # the same table, the order of the features in per table tensor returned by
645    # recv_tpu_embedding_activations matches the order in which they were passed
646    # to enqueue.
647    # In all three places, we use the fixed order given by nest.flatten to have
648    # a consistent feature order.
649
650    # First construct a dict of tensors one for each table.
651    per_table_gradients = self._compute_per_table_gradients(gradients)
652
653    # Now that we have a list of gradients we can compute a list of gradients
654    # in the fixed order of self._table_config which interleave the gradients of
655    # the individual features. We concat on axis 1 and then reshape into a 2d
656    # tensor. The send gradients op expects a tensor of shape
657    # [num_features*batch_size, dim] for each table.
658    interleaved_gradients = []
659    for table in self._table_config:
660      interleaved_gradients.append(array_ops.reshape(
661          array_ops.concat(per_table_gradients[table], axis=1),
662          [-1, table.dim]))
663    op = tpu_ops.send_tpu_embedding_gradients(
664        inputs=interleaved_gradients,
665        learning_rates=[math_ops.cast(fn(), dtype=dtypes.float32)
666                        for fn in self._dynamic_learning_rates],
667        config=self._config_proto.SerializeToString())
668
669    # Apply the name tag to the op.
670    if name is not None:
671      _add_key_attr(op, name)
672
673  def dequeue(self, name: Optional[Text] = None):
674    """Get the embedding results.
675
676    Returns a nested structure of `tf.Tensor` objects, matching the structure of
677    the `feature_config` argument to the `TPUEmbedding` class. The output shape
678    of the tensors is `(batch_size, dim)`, where `batch_size` is the per core
679    batch size, `dim` is the dimension of the corresponding `TableConfig`. If
680    the feature's corresponding `FeatureConfig` has `max_sequence_length`
681    greater than 0, the output will be a sequence of shape
682    `(batch_size, max_sequence_length, dim)` instead.
683
684    ```python
685    strategy = tf.distribute.TPUStrategy(...)
686    with strategy.scope():
687      embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
688
689    distributed_dataset = (
690        strategy.distribute_datasets_from_function(
691            dataset_fn=...,
692            options=tf.distribute.InputOptions(
693                experimental_fetch_to_device=False))
694    dataset_iterator = iter(distributed_dataset)
695
696    @tf.function
697    def training_step():
698      def tpu_step(tpu_features):
699        with tf.GradientTape() as tape:
700          activations = embedding.dequeue()
701          tape.watch(activations)
702
703          loss = ... #  some computation involving activations
704
705        embedding_gradients = tape.gradient(loss, activations)
706        embedding.apply_gradients(embedding_gradients)
707
708      embedding_features, tpu_features = next(dataset_iterator)
709      embedding.enqueue(embedding_features, training=True)
710      strategy.run(tpu_step, args=(embedding_features, ))
711
712    training_step()
713    ```
714
715    Args:
716      name: A name for the underlying op.
717
718    Returns:
719      A nested structure of tensors, with the same structure as `feature_config`
720    passed to this instance of the `TPUEmbedding` object.
721
722    Raises:
723      RuntimeError: If called when object wasn't created under a `TPUStrategy`
724        or if not built (either by manually calling build or calling enqueue).
725    """
726    if not self._using_tpu:
727      raise RuntimeError("dequeue is not valid when TPUEmbedding object is not "
728                         "created under a TPUStrategy.")
729
730    if not self._built:
731      raise RuntimeError("dequeue called on unbuilt TPUEmbedding object. "
732                         "Please either call enqueue first or manually call "
733                         "the build method.")
734
735    # The activations returned by this op are per table. So we must separate
736    # them out into per feature activations. The activations are interleaved:
737    # for each table, we expect a [num_features*batch_size, dim] tensor.
738    # E.g. we expect the slice [:num_features, :] to contain the lookups for the
739    # first example of all features using this table.
740    activations = tpu_ops.recv_tpu_embedding_activations(
741        num_outputs=len(self._table_config),
742        config=self._config_proto.SerializeToString())
743
744    # Apply the name tag to the op.
745    if name is not None:
746      _add_key_attr(activations[0].op, name)
747
748    # Compute the number of features for this  table.
749    num_features = {table: 0 for table in self._table_config}
750    for feature in nest.flatten(self._feature_config):
751      num_features[feature.table] += (1 if feature.max_sequence_length == 0
752                                      else feature.max_sequence_length)
753
754    # Activations are reshaped so that they are indexed by batch size and then
755    # by the 'feature' index within the batch. The final dimension should equal
756    # the dimension of the table.
757    table_to_activation = {
758        table: array_ops.reshape(activation,
759                                 [self._batch_size, num_features[table], -1])
760        for table, activation in zip(self._table_config, activations)}
761
762    # We process the features in the same order we enqueued them.
763    # For each feature we take the next slice of the activations, so need to
764    # track the activations and the current position we are in.
765    table_to_position = {table: 0 for table in self._table_config}
766
767    per_feature_activations = []
768    for feature in nest.flatten(self._feature_config):
769      activation = table_to_activation[feature.table]
770      feature_index = table_to_position[feature.table]
771      # We treat non-sequence and sequence features differently here as sequence
772      # features have rank 3 while non-sequence features have rank 2.
773      if feature.max_sequence_length == 0:
774        per_feature_activations.append(
775            activation[:, feature_index, :])
776        table_to_position[feature.table] += 1
777      else:
778        per_feature_activations.append(
779            activation[:, feature_index:(
780                feature_index+feature.max_sequence_length), :])
781        table_to_position[feature.table] += feature.max_sequence_length
782
783    # Pack the list back into the same nested structure as the features.
784    return nest.pack_sequence_as(self._feature_config, per_feature_activations)
785
786  def _create_variables_and_slots(
787      self
788  ) -> Dict[Text, Dict[Text, tf_variables.Variable]]:
789    """Create variables for TPU embeddings.
790
791    Note under TPUStrategy this will ensure that all creations happen within a
792    variable creation scope of the sharded variable creator.
793
794    Returns:
795      A dict of dicts. The outer dict is keyed by the table names and the inner
796      dicts are keyed by 'parameters' and the slot variable names.
797    """
798
799    def create_variables(table):
800      """Create all variables."""
801      variable_shape = (table.vocabulary_size, table.dim)
802
803      def getter(name, shape, dtype, initializer, trainable):
804        del shape
805        # _add_variable_with_custom_getter clears the shape sometimes, so we
806        # take the global shape from outside the getter.
807        initial_value = functools.partial(initializer, variable_shape,
808                                          dtype=dtype)
809        return tf_variables.Variable(
810            name=name,
811            initial_value=initial_value,
812            shape=variable_shape,
813            dtype=dtype,
814            trainable=trainable)
815
816      def variable_creator(name, initializer, trainable=True):
817        # use add_variable_with_custom_getter here so that we take advantage of
818        # the checkpoint loading to allow restore before the variables get
819        # created which avoids double initialization.
820        return self._add_variable_with_custom_getter(
821            name=name,
822            initializer=initializer,
823            shape=variable_shape,
824            dtype=dtypes.float32,
825            getter=getter,
826            trainable=trainable)
827
828      parameters = variable_creator(table.name, table.initializer,
829                                    trainable=not self._using_tpu)
830
831      def slot_creator(name, initializer):
832        return variable_creator(table.name + "/" + name,
833                                initializer,
834                                False)
835
836      if table.optimizer is not None:
837        slot_vars = table.optimizer._create_slots(parameters, slot_creator)  # pylint: disable=protected-access
838      else:
839        slot_vars = {}
840      slot_vars["parameters"] = parameters
841      return slot_vars
842
843    # Store tables based on name rather than TableConfig as we can't track
844    # through dicts with non-string keys, i.e. we won't be able to save.
845    variables = {}
846    for table in self._table_config:
847      if not self._using_tpu:
848        variables[table.name] = create_variables(table)
849      else:
850        with variable_scope.variable_creator_scope(
851            make_sharded_variable_creator(self._hosts)):
852          variables[table.name] = create_variables(table)
853
854    return variables
855
856  def _load_variables(self):
857    # Only load the variables if we are:
858    # 1) Using TPU
859    # 2) Variables are created
860    # 3) Not in save context (except if running eagerly)
861    if self._using_tpu and self._built and not (
862        not context.executing_eagerly() and save_context.in_save_context()):
863      _load_variables_impl(self._config_proto.SerializeToString(),
864                           self._hosts,
865                           self._variables,
866                           self._table_config)
867
868  def _retrieve_variables(self):
869    # Only retrieve the variables if we are:
870    # 1) Using TPU
871    # 2) Variables are created
872    # 3) Not in save context (except if running eagerly)
873    if self._using_tpu and self._built and not (
874        not context.executing_eagerly() and save_context.in_save_context()):
875      _retrieve_variables_impl(self._config_proto.SerializeToString(),
876                               self._hosts,
877                               self._variables,
878                               self._table_config)
879
880  def _gather_saveables_for_checkpoint(
881      self
882  ) -> Dict[Text, Callable[[Text], "TPUEmbeddingSaveable"]]:
883    """Overrides default Trackable implementation to add load/retrieve hook."""
884    # This saveable should be here in both TPU and CPU checkpoints, so when on
885    # CPU, we add the hook with no functions.
886    # TODO(bfontain): Update restore logic in saver so that these hooks are
887    # always executed. Once that is done, we can output an empty list when on
888    # CPU.
889
890    def factory(name=_HOOK_KEY):
891      return TPUEmbeddingSaveable(name, self._load_variables,
892                                  self._retrieve_variables)
893    return {_HOOK_KEY: factory}
894
895  # Some helper functions for the below enqueue function.
896  def _add_data_for_tensor(self, tensor, weight, indices, values, weights,
897                           int_zeros, float_zeros, path):
898    if weight is not None:
899      raise ValueError(
900          "Weight specified for dense input {}, which is not allowed. "
901          "Weight will always be 1 in this case.".format(path))
902    # For tensors, there are no indices and no weights.
903    indices.append(int_zeros)
904    values.append(math_ops.cast(tensor, dtypes.int32))
905    weights.append(float_zeros)
906
907  def _add_data_for_sparse_tensor(self, tensor, weight, indices, values,
908                                  weights, int_zeros, float_zeros, path):
909    indices.append(math_ops.cast(tensor.indices, dtypes.int32))
910    values.append(math_ops.cast(tensor.values, dtypes.int32))
911    # If we have weights they must be a SparseTensor.
912    if weight is not None:
913      if not isinstance(weight, sparse_tensor.SparseTensor):
914        raise ValueError("Weight for {} is type {} which does not match "
915                         "type input which is SparseTensor.".format(
916                             path, type(weight)))
917      weights.append(math_ops.cast(weight.values, dtypes.float32))
918    else:
919      weights.append(float_zeros)
920
921  def _add_data_for_ragged_tensor(self, tensor, weight, indices, values,
922                                  weights, int_zeros, float_zeros, path):
923    indices.append(math_ops.cast(tensor.row_splits, dtypes.int32))
924    values.append(math_ops.cast(tensor.values, dtypes.int32))
925    # If we have weights they must be a RaggedTensor.
926    if weight is not None:
927      if not isinstance(weight, ragged_tensor.RaggedTensor):
928        raise ValueError("Weight for {} is type {} which does not match "
929                         "type input which is RaggedTensor.".format(
930                             path, type(weight)))
931      weights.append(math_ops.cast(weight.values, dtypes.float32))
932    else:
933      weights.append(float_zeros)
934
935  def _generate_enqueue_op(
936      self,
937      flat_inputs: List[internal_types.NativeObject],
938      flat_weights: List[Optional[internal_types.NativeObject]],
939      flat_features: List[tpu_embedding_v2_utils.FeatureConfig],
940      device_ordinal: int,
941      mode_override: Text
942  ) -> ops.Operation:
943    """Outputs a the enqueue op given the inputs and weights.
944
945    Args:
946      flat_inputs: A list of input tensors.
947      flat_weights: A list of input weights (or None) of the same length as
948        flat_inputs.
949      flat_features: A list of FeatureConfigs of the same length as flat_inputs.
950      device_ordinal: The device to create the enqueue op for.
951      mode_override: A tensor containing the string "train" or "inference".
952
953    Returns:
954      The enqueue op.
955    """
956
957    # First we need to understand which op to use. This depends on if sparse
958    # or ragged tensors are in the flat_inputs.
959    sparse = False
960    ragged = False
961    for inp in flat_inputs:
962      if isinstance(inp, sparse_tensor.SparseTensor):
963        sparse = True
964      elif isinstance(inp, ragged_tensor.RaggedTensor):
965        ragged = True
966    if sparse and ragged:
967      raise ValueError(
968          "Found both SparseTensors and RaggedTensors in the input to the "
969          "enqueue operation. Please ensure that your data does not include "
970          "both SparseTensors and RaggedTensors. It is ok to have Tensors in "
971          "combination with one of the previous types.")
972
973    # Combiners are per table, list in the same order as the table order.
974    combiners = [table.combiner for table in self._table_config]
975
976    # Reverse mapping of self._table_config, so that we can lookup the table
977    # index.
978    table_to_id = {table: i for i, table in enumerate(self._table_config)}
979
980    # These parallel arrays will be the inputs to the enqueue op.
981    indices = []  # sample_indices for sparse, sample_splits for ragged.
982    values = []
983    weights = []
984    table_ids = []
985    max_sequence_lengths = []
986
987    # We have to supply a empty/zero tensor in a list position where we don't
988    # have data (e.g. indices for standard Tensor input, weight when no weight
989    # is specified). We create one op here per call, so that we reduce the
990    # graph size.
991    int_zeros = array_ops.zeros((0,), dtype=dtypes.int32)
992    float_zeros = array_ops.zeros((0,), dtype=dtypes.float32)
993
994    # In the following loop we insert casts so that everything is either int32
995    # or float32. This is because op inputs which are lists of tensors must be
996    # of the same type within the list. Moreover the CPU implementations of
997    # these ops cast to these types anyway, so we don't lose any data by casting
998    # early.
999    for inp, weight, (path, feature) in zip(
1000        flat_inputs, flat_weights, flat_features):
1001      table_ids.append(table_to_id[feature.table])
1002      max_sequence_lengths.append(feature.max_sequence_length)
1003      if isinstance(inp, ops.Tensor):
1004        self._add_data_for_tensor(inp, weight, indices, values, weights,
1005                                  int_zeros, float_zeros, path)
1006      elif isinstance(inp, sparse_tensor.SparseTensor):
1007        self._add_data_for_sparse_tensor(inp, weight, indices, values, weights,
1008                                         int_zeros, float_zeros, path)
1009      elif isinstance(inp, ragged_tensor.RaggedTensor):
1010        self._add_data_for_ragged_tensor(inp, weight, indices, values, weights,
1011                                         int_zeros, float_zeros, path)
1012      else:
1013        raise ValueError("Input {} is of unknown type {}. Please only pass "
1014                         "Tensor, SparseTensor or RaggedTensor as input to "
1015                         "enqueue.".format(path, type(inp)))
1016
1017    if ragged:
1018      return tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch(
1019          sample_splits=indices,
1020          embedding_indices=values,
1021          aggregation_weights=weights,
1022          mode_override=mode_override,
1023          device_ordinal=device_ordinal,
1024          combiners=combiners,
1025          table_ids=table_ids,
1026          max_sequence_lengths=max_sequence_lengths)
1027    return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
1028        sample_indices=indices,
1029        embedding_indices=values,
1030        aggregation_weights=weights,
1031        mode_override=mode_override,
1032        device_ordinal=device_ordinal,
1033        combiners=combiners,
1034        table_ids=table_ids,
1035        max_sequence_lengths=max_sequence_lengths)
1036
1037  def _raise_error_for_incorrect_control_flow_context(self):
1038    """Raises an error if we are not in the TPUReplicateContext."""
1039    # Do not allow any XLA control flow (i.e. control flow in between a
1040    # TPUStrategy's run call and the call to this function), as we can't
1041    # extract the enqueue from the head when in XLA control flow.
1042    graph = ops.get_default_graph()
1043    in_tpu_ctx = False
1044    while graph is not None:
1045      ctx = graph._get_control_flow_context()  # pylint: disable=protected-access
1046      while ctx is not None:
1047        if isinstance(ctx, tpu.TPUReplicateContext):
1048          in_tpu_ctx = True
1049          break
1050        ctx = ctx.outer_context
1051      if in_tpu_ctx:
1052        break
1053      graph = getattr(graph, "outer_graph", None)
1054    if graph != ops.get_default_graph() and in_tpu_ctx:
1055      raise RuntimeError(
1056          "Current graph {} does not match graph which contains "
1057          "TPUReplicateContext {}. This is most likely due to the fact that "
1058          "enqueueing embedding data is called inside control flow or a "
1059          "nested function inside `strategy.run`. This is not supported "
1060          "because outside compilation fails to extract the enqueue ops as "
1061          "head of computation.".format(ops.get_default_graph(), graph))
1062    return in_tpu_ctx
1063
1064  def _raise_error_for_non_direct_inputs(self, features):
1065    """Checks all tensors in features to see if they are a direct input."""
1066
1067    # expand_composites here is important: as composite tensors pass through
1068    # tpu.replicate, they get 'flattened' into their component tensors and then
1069    # repacked before being passed to the tpu function. In means that it is the
1070    # component tensors which are produced by an op with the
1071    # "_tpu_input_identity" attribute.
1072    for path, input_tensor in nest.flatten_with_joined_string_paths(
1073        features, expand_composites=True):
1074      if input_tensor.op.type == "Placeholder":
1075        continue
1076      try:
1077        is_input = input_tensor.op.get_attr("_tpu_input_identity")
1078      except ValueError:
1079        is_input = False
1080      if not is_input:
1081        raise ValueError(
1082            "Received input tensor {} which is the output of op {} (type {}) "
1083            "which does not have the `_tpu_input_identity` attr. Please "
1084            "ensure that the inputs to this layer are taken directly from "
1085            "the arguments of the function called by "
1086            "strategy.run. Two possible causes are: dynamic batch size "
1087            "support or you are using a keras layer and are not passing "
1088            "tensors which match the dtype of the `tf.keras.Input`s."
1089            "If you are triggering dynamic batch size support, you can "
1090            "disable it by passing tf.distribute.RunOptions("
1091            "experimental_enable_dynamic_batch_size=False) to the options "
1092            "argument of strategy.run().".format(path,
1093                                                 input_tensor.op.name,
1094                                                 input_tensor.op.type))
1095
1096  def _raise_error_for_inputs_not_on_cpu(self, flat_inputs, flat_paths):
1097    """Checks all tensors in features to see are placed on the CPU."""
1098
1099    def check_device(path, device_string):
1100      spec = tf_device.DeviceSpec.from_string(device_string)
1101      if spec.device_type == "TPU":
1102        raise ValueError(
1103            "Received input tensor {} which is on a TPU input device {}. Input "
1104            "tensors for TPU embeddings must be placed on the CPU. Please "
1105            "ensure that your dataset is prefetching tensors to the host by "
1106            "setting the 'experimental_fetch_to_device' option of the "
1107            "dataset distribution function. See the documentation of the "
1108            "enqueue method for an example.".format(path, device_string))
1109
1110    # expand_composites here is important, we need to check the device of each
1111    # underlying tensor.
1112    for input_tensor, input_path in zip(flat_inputs, flat_paths):
1113      if nest.is_sequence_or_composite(input_tensor):
1114        input_tensors = nest.flatten(input_tensor, expand_composites=True)
1115      else:
1116        input_tensors = [input_tensor]
1117      for t in input_tensors:
1118        if (t.op.type == "Identity" and
1119            t.op.inputs[0].op.type == "TPUReplicatedInput"):
1120          for tensor in t.op.inputs[0].op.inputs:
1121            check_device(input_path, tensor.device)
1122        else:
1123          check_device(input_path, t.device)
1124
1125  def enqueue(
1126      self,
1127      features,
1128      weights=None,
1129      training: bool = True,
1130      name: Optional[Text] = None,
1131      device: Optional[Text] = None):
1132    """Enqueues id tensors for embedding lookup.
1133
1134    This function enqueues a structure of features to be looked up in the
1135    embedding tables. We expect that the batch size of each of the tensors in
1136    features matches the per core batch size. This will automatically happen if
1137    your input dataset is batched to the global batch size and you use
1138    `tf.distribute.TPUStrategy`'s `experimental_distribute_dataset`
1139    or if you use `distribute_datasets_from_function` and batch
1140    to the per core batch size computed by the context passed to your input
1141    function.
1142
1143    ```python
1144    strategy = tf.distribute.TPUStrategy(...)
1145    with strategy.scope():
1146      embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
1147
1148    distributed_dataset = (
1149        strategy.distribute_datasets_from_function(
1150            dataset_fn=...,
1151            options=tf.distribute.InputOptions(
1152                experimental_fetch_to_device=False))
1153    dataset_iterator = iter(distributed_dataset)
1154
1155    @tf.function
1156    def training_step():
1157      def tpu_step(tpu_features):
1158        with tf.GradientTape() as tape:
1159          activations = embedding.dequeue()
1160          tape.watch(activations)
1161
1162          loss = ... #  some computation involving activations
1163
1164        embedding_gradients = tape.gradient(loss, activations)
1165        embedding.apply_gradients(embedding_gradients)
1166
1167      embedding_features, tpu_features = next(dataset_iterator)
1168      embedding.enqueue(embedding_features, training=True)
1169      strategy.run(tpu_step, args=(embedding_features,))
1170
1171    training_step()
1172    ```
1173
1174    NOTE: You should specify `training=True` when using
1175    `embedding.apply_gradients` as above and `training=False` when not using
1176    `embedding.apply_gradients` (e.g. for frozen embeddings or when doing
1177    evaluation).
1178
1179    For finer grained control, in the above example the line
1180
1181    ```
1182      embedding.enqueue(embedding_features, training=True)
1183    ```
1184
1185    may be replaced with
1186
1187    ```
1188      per_core_embedding_features = self.strategy.experimental_local_results(
1189          embedding_features)
1190
1191      def per_core_enqueue(ctx):
1192        core_id = ctx.replica_id_in_sync_group
1193        device = strategy.extended.worker_devices[core_id]
1194        embedding.enqueue(per_core_embedding_features[core_id],
1195                          device=device)
1196
1197      strategy.experimental_distribute_values_from_function(
1198          per_core_queue_inputs)
1199    ```
1200
1201    Args:
1202      features: A nested structure of `tf.Tensor`s, `tf.SparseTensor`s or
1203        `tf.RaggedTensor`s, with the same structure as `feature_config`. Inputs
1204        will be downcast to `tf.int32`. Only one type out of `tf.SparseTensor`
1205        or `tf.RaggedTensor` is supported per call.
1206      weights: If not `None`, a nested structure of `tf.Tensor`s,
1207        `tf.SparseTensor`s or `tf.RaggedTensor`s, matching the above, except
1208        that the tensors should be of float type (and they will be downcast to
1209        `tf.float32`). For `tf.SparseTensor`s we assume the `indices` are the
1210        same for the parallel entries from `features` and similarly for
1211        `tf.RaggedTensor`s we assume the row_splits are the same.
1212      training: Defaults to `True`. If `False`, enqueue the batch as inference
1213        batch (forward pass only). Do not call `apply_gradients` when this is
1214        `False` as this may lead to a deadlock.
1215       name: A name for the underlying op.
1216       device: The device name (e.g. '/task:0/device:TPU:2') where this batch
1217         should be enqueued. This should be set if and only if features is not a
1218         `tf.distribute.DistributedValues` and enqueue is not being called
1219         inside a TPU context (e.g. inside `TPUStrategy.run`).
1220
1221    Raises:
1222      ValueError: When called inside a strategy.run call and input is not
1223        directly taken from the args of the `strategy.run` call. Also if
1224        the size of any sequence in `features` does not match corresponding
1225        sequence in `feature_config`. Similarly for `weights`, if not `None`.
1226        If batch size of features is unequal or different from a previous call.
1227      RuntimeError: When called inside a strategy.run call and inside XLA
1228        control flow. If batch_size is not able to be determined and build was
1229        not called.
1230      TypeError: If the type of any sequence in `features` does not match
1231        corresponding sequence in `feature_config`. Similarly for `weights`, if
1232        not `None`.
1233    """
1234    if not self._using_tpu:
1235      raise RuntimeError("enqueue is not valid when TPUEmbedding object is not "
1236                         "created under a TPUStrategy.")
1237
1238    in_tpu_context = self._raise_error_for_incorrect_control_flow_context()
1239
1240    if not self._verify_batch_size_on_enqueue:
1241      if not self._batch_size or not self._built:
1242        raise ValueError(
1243            "Configured not to check batch size on each enqueue() call; please "
1244            "ensure build() was called with global batch size to initialize "
1245            "the TPU for embeddings.")
1246    else:
1247      # Should we also get batch_size from weights if they exist?
1248      # Since features is assumed to be batched at the per replica batch size
1249      # the returned batch size here is per replica an not global.
1250      batch_size = self._get_batch_size(features, in_tpu_context)
1251      if batch_size is None and not self._built:
1252        raise RuntimeError("Unable to determine batch size from input features."
1253                           "Please call build() with global batch size to "
1254                           "initialize the TPU for embeddings.")
1255      if batch_size is not None:
1256        self._maybe_build(batch_size)
1257        if self._batch_size != batch_size:
1258          raise ValueError("Multiple calls to enqueue with different batch "
1259                           "sizes {} and {}.".format(self._batch_size,
1260                                                     batch_size))
1261
1262    nest.assert_same_structure(self._feature_config, features)
1263
1264    flat_inputs = nest.flatten(features)
1265    flat_weights = [None] * len(flat_inputs)
1266    if weights is not None:
1267      nest.assert_same_structure(self._feature_config, weights)
1268      flat_weights = nest.flatten(weights)
1269    flat_features = nest.flatten_with_joined_string_paths(self._feature_config)
1270    flat_paths, _ = zip(*flat_features)
1271
1272    self._raise_error_for_inputs_not_on_cpu(flat_inputs, flat_paths)
1273    # If we are in a tpu_context, automatically apply outside compilation.
1274    if in_tpu_context:
1275      self._raise_error_for_non_direct_inputs(features)
1276
1277      def generate_enqueue_ops():
1278        """Generate enqueue ops for outside compilation."""
1279        # Note that we put array_ops.where_v2 rather than a python if so that
1280        # the op is explicitly create and the constant ops are both in the graph
1281        # even though we don't expect training to be a tensor (and thus generate
1282        # control flow automatically). This need to make it easier to re-write
1283        # the graph later if we need to fix which mode needs to be used.
1284        mode_override = array_ops.where_v2(training,
1285                                           constant_op.constant("train"),
1286                                           constant_op.constant("inference"))
1287
1288        # Device ordinal is -1 here, a later rewrite will fix this once the op
1289        # is expanded by outside compilation.
1290        enqueue_op = self._generate_enqueue_op(
1291            flat_inputs, flat_weights, flat_features, device_ordinal=-1,
1292            mode_override=mode_override)
1293
1294        # Apply the name tag to the op.
1295        if name is not None:
1296          _add_key_attr(enqueue_op, name)
1297
1298        # Ensure that this op has outbound control flow, otherwise it won't be
1299        # executed.
1300        ops.get_default_graph().control_outputs.append(enqueue_op)
1301
1302      tpu.outside_compilation(generate_enqueue_ops)
1303
1304    elif device is None:
1305      mode_override = "train" if training else "inference"
1306      # We generate enqueue ops per device, so we need to gather the all
1307      # features for a single device in to a dict.
1308      # We rely here on the fact that the devices in the PerReplica value occur
1309      # in the same (standard) order as self._strategy.extended.worker_devices.
1310      enqueue_ops = []
1311      for replica_id in range(self._strategy.num_replicas_in_sync):
1312        replica_inputs = distribute_utils.select_replica(replica_id,
1313                                                         flat_inputs)
1314        replica_weights = distribute_utils.select_replica(replica_id,
1315                                                          flat_weights)
1316        tpu_device = self._strategy.extended.worker_devices[replica_id]
1317        # TPU devices string are like /job:worker/replica:0/task:0/device:TPU:0
1318        # the device ordinal is the last number
1319        device_ordinal = (
1320            tf_device.DeviceSpec.from_string(tpu_device).device_index)
1321        with ops.device(device_util.get_host_for_device(tpu_device)):
1322          enqueue_op = self._generate_enqueue_op(
1323              replica_inputs, replica_weights, flat_features,
1324              device_ordinal=device_ordinal, mode_override=mode_override)
1325
1326          # Apply the name tag to the op.
1327          if name is not None:
1328            _add_key_attr(enqueue_op, name)
1329          enqueue_ops.append(enqueue_op)
1330      ops.get_default_graph().control_outputs.extend(enqueue_ops)
1331    else:
1332      mode_override = "train" if training else "inference"
1333      device_spec = tf_device.DeviceSpec.from_string(device)
1334      if device_spec.device_type != "TPU":
1335        raise ValueError(
1336            "Non-TPU device {} passed to enqueue.".format(device))
1337      with ops.device(device_util.get_host_for_device(device)):
1338        enqueue_op = self._generate_enqueue_op(
1339            flat_inputs, flat_weights, flat_features,
1340            device_ordinal=device_spec.device_index,
1341            mode_override=mode_override)
1342
1343        # Apply the name tag to the op.
1344        if name is not None:
1345          _add_key_attr(enqueue_op, name)
1346        ops.get_default_graph().control_outputs.append(enqueue_op)
1347
1348  def _get_batch_size(self, tensors, in_tpu_context: bool):
1349    """Gets the batch size from a nested structure of features."""
1350    batch_size = None
1351    for path, maybe_tensor in nest.flatten_with_joined_string_paths(tensors):
1352      tensor_list = []
1353      if not in_tpu_context:
1354        # if we are not in a context, then this is PerReplica and we need to
1355        # check each replica's batch size.
1356        for replica_id in range(self._strategy.num_replicas_in_sync):
1357          tensor_list.append(distribute_utils.select_replica(replica_id,
1358                                                             maybe_tensor))
1359      else:
1360        tensor_list = [maybe_tensor]
1361
1362      for tensor in tensor_list:
1363        if tensor.shape.rank < 1:
1364          raise ValueError(
1365              "Input {} has rank 0, rank must be at least 1.".format(path))
1366        shape = tensor.shape.as_list()
1367        if shape[0] is not None:
1368          if batch_size is None:
1369            batch_size = shape[0]
1370          elif batch_size != shape[0]:
1371            raise ValueError("Found multiple batch sizes {} and {}. All inputs "
1372                             "must have the same batch dimensions size.".format(
1373                                 batch_size, shape[0]))
1374    return batch_size
1375
1376
1377@def_function.function
1378def _load_variables_impl(
1379    config: Text,
1380    hosts: List[Tuple[int, Text]],
1381    variables: Dict[Text, Dict[Text, tf_variables.Variable]],
1382    table_config: tpu_embedding_v2_utils.TableConfig):
1383  """Load embedding tables to onto TPU for each table and host.
1384
1385  Args:
1386    config: A serialized TPUEmbeddingConfiguration proto.
1387    hosts: A list of CPU devices, on per host.
1388    variables: A dictionary of dictionaries of TPUShardedVariables. First key is
1389      the table name, second key is 'parameters' or the optimizer slot name.
1390    table_config: A list of tf.tpu.experimental.embedding.TableConfig objects.
1391  """
1392  def select_fn(host_id):
1393
1394    def select_or_zeros(x):
1395      if host_id >= len(x.variables):
1396        # In the edge case where we have more hosts than variables, due to using
1397        # a small number of rows, we load zeros for the later hosts. We copy
1398        # the shape of the first host's variables, which we assume is defined
1399        # because TableConfig guarantees at least one row.
1400        return array_ops.zeros_like(x.variables[0])
1401      return x.variables[host_id]
1402
1403    return select_or_zeros
1404
1405  for host_id, host in enumerate(hosts):
1406    with ops.device(host):
1407      host_variables = nest.map_structure(select_fn(host_id), variables)
1408      for table in table_config:
1409        table.optimizer._load()(  # pylint: disable=protected-access
1410            table_name=table.name,
1411            num_shards=len(hosts),
1412            shard_id=host_id,
1413            config=config,
1414            **host_variables[table.name])
1415        # Ensure that only the first table/first host gets a config so that we
1416        # don't bloat graph by attaching this large string to each op.
1417        # We have num tables * num hosts of these so for models with a large
1418        # number of tables training on a large slice, this can be an issue.
1419        config = None
1420
1421
1422@def_function.function
1423def _retrieve_variables_impl(
1424    config: Text,
1425    hosts: List[Tuple[int, Text]],
1426    variables: Dict[Text, Dict[Text, tf_variables.Variable]],
1427    table_config: tpu_embedding_v2_utils.TableConfig):
1428  """Retrieve embedding tables from TPU to host memory.
1429
1430  Args:
1431    config: A serialized TPUEmbeddingConfiguration proto.
1432    hosts: A list of all the host CPU devices.
1433    variables: A dictionary of dictionaries of TPUShardedVariables. First key is
1434      the table name, second key is 'parameters' or the optimizer slot name.
1435    table_config: A list of tf.tpu.experimental.embedding.TableConfig objects.
1436  """
1437  for host_id, host in enumerate(hosts):
1438    with ops.device(host):
1439      for table in table_config:
1440        retrieved = table.optimizer._retrieve()(  # pylint: disable=protected-access
1441            table_name=table.name,
1442            num_shards=len(hosts),
1443            shard_id=host_id,
1444            config=config)
1445        # When there are no slot variables (e.g with SGD) this returns a
1446        # single tensor rather than a tuple. In this case we put the tensor in
1447        # a list to make the following code easier to write.
1448        if not isinstance(retrieved, tuple):
1449          retrieved = (retrieved,)
1450
1451        for i, slot in enumerate(["parameters"] +
1452                                 table.optimizer._slot_names()):  # pylint: disable=protected-access
1453          # We must assign the CPU variables the values of tensors that were
1454          # returned from the TPU.
1455          sharded_var = variables[table.name][slot]
1456          if host_id < len(sharded_var.variables):
1457            # In the edge case where we have more hosts than variables, due to
1458            # using a small number of rows, we skip the later hosts.
1459            sharded_var.variables[host_id].assign(retrieved[i])
1460        # Ensure that only the first table/first host gets a config so that we
1461        # don't bloat graph by attaching this large string to each op.
1462        # We have num tables * num hosts of these so for models with a large
1463        # number of tables training on a large slice, this can be an issue.
1464        config = None
1465
1466
1467class TPUEmbeddingSaveable(saveable_hook.SaveableHook):
1468  """Save/Restore hook to Retrieve/Load TPUEmbedding variables."""
1469
1470  def __init__(
1471      self,
1472      name: Text,
1473      load: Callable[[], Any],
1474      retrieve: Callable[[], Any]):
1475    self._load = load
1476    self._retrieve = retrieve
1477    super(TPUEmbeddingSaveable, self).__init__(name=name)
1478
1479  def before_save(self):
1480    if self._retrieve is not None:
1481      self._retrieve()
1482
1483  def after_restore(self):
1484    if self._load is not None:
1485      self._load()
1486
1487
1488def _ragged_embedding_lookup_with_reduce(
1489    table: tf_variables.Variable,
1490    ragged: ragged_tensor.RaggedTensor,
1491    weights: ragged_tensor.RaggedTensor,
1492    combiner: Text) -> core.Tensor:
1493  """Compute a ragged lookup followed by a reduce on axis 1.
1494
1495  Args:
1496    table: The embedding table.
1497    ragged: A RaggedTensor of ids to look up.
1498    weights: A RaggedTensor of weights (or None).
1499    combiner: One of "mean", "sum", "sqrtn".
1500
1501  Returns:
1502    A Tensor.
1503  """
1504  if weights is None:
1505    weights = array_ops.ones_like(ragged, dtype=table.dtype)
1506  weights = array_ops.expand_dims(weights, axis=2)
1507  ragged_result = embedding_ops.embedding_lookup_ragged(table, ragged)
1508  ragged_result = math_ops.reduce_sum(ragged_result * weights, axis=1)
1509  if combiner == "mean":
1510    ragged_result = ragged_result / math_ops.reduce_sum(weights, axis=1)
1511  elif combiner == "sqrtn":
1512    ragged_result = ragged_result, math_ops.sqrt(math_ops.reduce_sum(
1513        weights*weights, axis=1))
1514  return ragged_result
1515
1516
1517@tf_export("tpu.experimental.embedding.serving_embedding_lookup")
1518def cpu_embedding_lookup(inputs, weights, tables, feature_config):
1519  """Apply standard lookup ops with `tf.tpu.experimental.embedding` configs.
1520
1521  This function is a utility which allows using the
1522  `tf.tpu.experimental.embedding` config objects with standard lookup functions.
1523  This can be used when exporting a model which uses
1524  `tf.tpu.experimental.embedding.TPUEmbedding` for serving on CPU. In particular
1525  `tf.tpu.experimental.embedding.TPUEmbedding` only supports lookups on TPUs and
1526  should not be part of your serving graph.
1527
1528  Note that TPU specific options (such as `max_sequence_length`) in the
1529  configuration objects will be ignored.
1530
1531  In the following example we take a trained model (see the documentation for
1532  `tf.tpu.experimental.embedding.TPUEmbedding` for the context) and create a
1533  saved model with a serving function that will perform the embedding lookup and
1534  pass the results to your model:
1535
1536  ```python
1537  model = model_fn(...)
1538  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
1539      feature_config=feature_config,
1540      batch_size=1024,
1541      optimizer=tf.tpu.experimental.embedding.SGD(0.1))
1542  checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
1543  checkpoint.restore(...)
1544
1545  @tf.function(input_signature=[{'feature_one': tf.TensorSpec(...),
1546                                 'feature_two': tf.TensorSpec(...),
1547                                 'feature_three': tf.TensorSpec(...)}])
1548  def serve_tensors(embedding_featurese):
1549    embedded_features = tf.tpu.experimental.embedding.serving_embedding_lookup(
1550        embedding_features, None, embedding.embedding_tables,
1551        feature_config)
1552    return model(embedded_features)
1553
1554  model.embedding_api = embedding
1555  tf.saved_model.save(model,
1556                      export_dir=...,
1557                      signatures={'serving_default': serve_tensors})
1558
1559  ```
1560
1561  NOTE: Its important to assign the embedding api object to a member of your
1562  model as `tf.saved_model.save` only supports saving variables one `Trackable`
1563  object. Since the model's weights are in `model` and the embedding table are
1564  managed by `embedding`, we assign `embedding` to and attribute of `model` so
1565  that tf.saved_model.save can find the embedding variables.
1566
1567  NOTE: The same `serve_tensors` function and `tf.saved_model.save` call will
1568  work directly from training.
1569
1570  Args:
1571    inputs: a nested structure of Tensors, SparseTensors or RaggedTensors.
1572    weights: a nested structure of Tensors, SparseTensors or RaggedTensors or
1573      None for no weights. If not None, structure must match that of inputs, but
1574      entries are allowed to be None.
1575    tables: a dict of mapping TableConfig objects to Variables.
1576    feature_config: a nested structure of FeatureConfig objects with the same
1577      structure as inputs.
1578
1579  Returns:
1580    A nested structure of Tensors with the same structure as inputs.
1581  """
1582
1583  nest.assert_same_structure(inputs, feature_config)
1584
1585  flat_inputs = nest.flatten(inputs)
1586  flat_weights = [None] * len(flat_inputs)
1587  if weights is not None:
1588    nest.assert_same_structure(inputs, weights)
1589    flat_weights = nest.flatten(weights)
1590  flat_features = nest.flatten_with_joined_string_paths(feature_config)
1591
1592  outputs = []
1593  for inp, weight, (path, feature) in zip(
1594      flat_inputs, flat_weights, flat_features):
1595    table = tables[feature.table]
1596
1597    if weight is not None:
1598      if isinstance(inp, ops.Tensor):
1599        raise ValueError(
1600            "Weight specified for {}, but input is dense.".format(path))
1601      elif type(weight) is not type(inp):
1602        raise ValueError(
1603            "Weight for {} is of type {} but it does not match type of the "
1604            "input which is {}.".format(path, type(weight), type(inp)))
1605      elif feature.max_sequence_length > 0:
1606        raise ValueError("Weight specified for {}, but this is a sequence "
1607                         "feature.".format(path))
1608
1609    if isinstance(inp, ops.Tensor):
1610      if feature.max_sequence_length > 0:
1611        raise ValueError("Feature {} is a sequence feature but a dense tensor "
1612                         "was passed.".format(path))
1613      outputs.append(embedding_ops.embedding_lookup_v2(table, inp))
1614
1615    elif isinstance(inp, sparse_tensor.SparseTensor):
1616      if feature.max_sequence_length > 0:
1617        batch_size = math_ops.cast(array_ops.shape(inp)[0], dtype=dtypes.int64)
1618        sparse_shape = array_ops.stack(
1619            [batch_size, feature.max_sequence_length], axis=0)
1620        # TPU Embedding truncates sequences to max_sequence_length, and if we
1621        # don't truncate, scatter_nd will error out if the index was out of
1622        # bounds.
1623        truncated_inp = sparse_ops.sparse_slice(inp, start=[0, 0],
1624                                                size=sparse_shape)
1625
1626        dense_output_shape = array_ops.stack(
1627            [batch_size, feature.max_sequence_length, feature.table.dim],
1628            axis=0)
1629        outputs.append(
1630            array_ops.scatter_nd(
1631                truncated_inp.indices,
1632                array_ops.gather(table.read_value(), truncated_inp.values),
1633                dense_output_shape))
1634      else:
1635        inp_rank = inp.dense_shape.get_shape()[0]
1636        if (not feature.validate_weights_and_indices and
1637            inp_rank is not None and inp_rank <= 2):
1638          outputs.append(
1639              embedding_ops.embedding_lookup_sparse_v2(
1640                  table,
1641                  inp,
1642                  sp_weights=weight,
1643                  combiner=feature.table.combiner))
1644        else:
1645          outputs.append(
1646              embedding_ops.safe_embedding_lookup_sparse_v2(
1647                  table,
1648                  inp,
1649                  sparse_weights=weight,
1650                  combiner=feature.table.combiner))
1651
1652    elif isinstance(inp, ragged_tensor.RaggedTensor):
1653      if feature.max_sequence_length > 0:
1654        batch_size = inp.shape[0]
1655        dense_output_shape = [
1656            batch_size, feature.max_sequence_length, feature.table.dim]
1657        ragged_lookup = embedding_ops.embedding_lookup_v2(table, inp)
1658        # Unlike scatter_nd, RaggedTensor.to_tensor truncates to the given
1659        # shape.
1660        outputs.append(ragged_lookup.to_tensor(shape=dense_output_shape))
1661      else:
1662        outputs.append(_ragged_embedding_lookup_with_reduce(
1663            table, inp, weight, feature.table.combiner))
1664
1665    else:
1666      raise ValueError("Input {} is type {}. Tensor, SparseTensor or "
1667                       "RaggedTensor expected.".format(path, type(inp)))
1668  return nest.pack_sequence_as(feature_config, outputs)
1669
1670
1671def get_list_of_hosts(strategy: tpu_strategy.TPUStrategy) -> List[Text]:
1672  """Returns a sorted list of CPU devices for the remote jobs.
1673
1674  Args:
1675    strategy: A TPUStrategy object.
1676
1677  Returns:
1678    A sort list of device strings.
1679  """
1680  list_of_hosts = []
1681  # Assume this is sorted by task
1682  for tpu_device in strategy.extended.worker_devices:
1683    host = device_util.get_host_for_device(tpu_device)
1684    if host not in list_of_hosts:
1685      list_of_hosts.append(host)
1686  assert len(list_of_hosts) == strategy.extended.num_hosts
1687  return list_of_hosts
1688
1689
1690def extract_variable_info(
1691    kwargs) -> Tuple[Text, Tuple[int, ...], dtypes.DType, Callable[[], Any]]:
1692  """Extracts the variable creation attributes from the kwargs.
1693
1694  Args:
1695    kwargs: a dict of keyword arguments that were passed to a variable creator
1696      scope.
1697
1698  Returns:
1699    A tuple of variable name, shape, dtype, initialization function.
1700  """
1701  if (isinstance(kwargs["initial_value"], functools.partial) and (
1702      "shape" in kwargs["initial_value"].keywords or
1703      kwargs["initial_value"].args)):
1704    # Sometimes shape is passed positionally, sometimes it's passed as a kwarg.
1705    if "shape" in kwargs["initial_value"].keywords:
1706      shape = kwargs["initial_value"].keywords["shape"]
1707    else:
1708      shape = kwargs["initial_value"].args[0]
1709    return (kwargs["name"], shape,
1710            kwargs["initial_value"].keywords.get("dtype", kwargs["dtype"]),
1711            kwargs["initial_value"].func)
1712  elif "shape" not in kwargs or kwargs["shape"] is None or not callable(
1713      kwargs["initial_value"]):
1714    raise ValueError(
1715        "Unable to extract initializer function and shape from {}. Please "
1716        "either pass a function that expects a shape and dtype as the "
1717        "initial value for your variable or functools.partial object with "
1718        "the shape and dtype kwargs set. This is needed so that we can "
1719        "initialize the shards of the ShardedVariable locally.".format(
1720            kwargs["initial_value"]))
1721  else:
1722    return (kwargs["name"], kwargs["shape"], kwargs["dtype"],
1723            kwargs["initial_value"])
1724
1725
1726def make_sharded_variable_creator(
1727    hosts: List[Text]) -> Callable[..., TPUShardedVariable]:
1728  """Makes a sharded variable creator given a list of hosts.
1729
1730  Args:
1731    hosts: a list of tensorflow devices on which to shard the tensors.
1732
1733  Returns:
1734    A variable creator function.
1735  """
1736
1737  def sharded_variable_creator(
1738      next_creator: Callable[..., tf_variables.Variable], *args, **kwargs):
1739    """The sharded variable creator."""
1740    kwargs["skip_mirrored_creator"] = True
1741
1742    num_hosts = len(hosts)
1743    name, shape, dtype, unwrapped_initial_value = extract_variable_info(kwargs)
1744    initial_value = kwargs["initial_value"]
1745    rows = shape[0]
1746    cols = shape[1]
1747    partial_partition = rows % num_hosts
1748    full_rows_per_host = rows // num_hosts
1749    # We partition as if we were using MOD sharding: at least
1750    # `full_rows_per_host` rows to `num_hosts` hosts, where the first
1751    # `partial_partition` hosts get an additional row when the number of rows
1752    # is not cleanly divisible. Note that `full_rows_per_host` may be zero.
1753    partitions = (
1754        [full_rows_per_host + 1] * partial_partition
1755        + [full_rows_per_host] * (num_hosts - partial_partition))
1756    variables = []
1757    sharding_aware = "shard_info" in tf_inspect.getargspec(initial_value).args
1758
1759    # Keep track of offset for sharding aware initializers.
1760    offset = 0
1761    kwargs["dtype"] = dtype
1762    for i, p in enumerate(partitions):
1763      if p == 0:
1764        # Skip variable creation for empty partitions, resulting from the edge
1765        # case of 'rows < num_hosts'. This is safe because both load/restore
1766        # can handle the missing values.
1767        continue
1768      with ops.device(hosts[i]):
1769        kwargs["name"] = "{}_{}".format(name, i)
1770        kwargs["shape"] = (p, cols)
1771        if sharding_aware:
1772          shard_info = base.ShardInfo(kwargs["shape"], (offset, 0))
1773          kwargs["initial_value"] = functools.partial(
1774              initial_value, shard_info=shard_info)
1775          offset += p
1776        else:
1777          kwargs["initial_value"] = functools.partial(
1778              unwrapped_initial_value, kwargs["shape"], dtype=dtype)
1779        variables.append(next_creator(*args, **kwargs))
1780    return TPUShardedVariable(variables, name=name)
1781  return sharded_variable_creator
1782