• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Mid level API for TPU Embeddings."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20from __future__ import unicode_literals
21
22import functools
23from typing import Any, Dict, Callable, Iterable, List, Optional, Text, Tuple, Union
24
25from absl import logging
26
27from tensorflow.core.framework import attr_value_pb2
28from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
29from tensorflow.python.distribute import device_util
30from tensorflow.python.distribute import distribute_utils
31from tensorflow.python.distribute import distribution_strategy_context
32from tensorflow.python.distribute import sharded_variable
33from tensorflow.python.distribute import tpu_strategy
34from tensorflow.python.eager import context
35from tensorflow.python.eager import def_function
36from tensorflow.python.framework import constant_op
37from tensorflow.python.framework import device as tf_device
38from tensorflow.python.framework import dtypes
39from tensorflow.python.framework import ops
40from tensorflow.python.framework import sparse_tensor
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import embedding_ops
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import sparse_ops
45from tensorflow.python.ops import variable_scope
46from tensorflow.python.ops import variables as tf_variables
47from tensorflow.python.ops.ragged import ragged_tensor
48from tensorflow.python.saved_model import save_context
49from tensorflow.python.tpu import tpu
50from tensorflow.python.tpu import tpu_embedding_v2_utils
51from tensorflow.python.tpu.ops import tpu_ops
52from tensorflow.python.training.saving import saveable_hook
53from tensorflow.python.training.tracking import base
54from tensorflow.python.training.tracking import tracking
55from tensorflow.python.types import core
56from tensorflow.python.types import internal as internal_types
57from tensorflow.python.util import compat
58from tensorflow.python.util import nest
59from tensorflow.python.util import tf_inspect
60from tensorflow.python.util.tf_export import tf_export
61
62
63_HOOK_KEY = "TPUEmbedding_saveable"
64_NAME_KEY = "_tpu_embedding_layer"
65
66
67# TODO(bfontain): Cleanup and remove this once there is an implementation of
68# sharded variables that can be used in the PSStrategy with optimizers.
69# We implement just enough of the of a tf.Variable so that this could be passed
70# to an optimizer.
71class TPUShardedVariable(sharded_variable.ShardedVariableMixin):
72  """A ShardedVariable class for TPU."""
73
74  @property
75  def _in_graph_mode(self):
76    return self.variables[0]._in_graph_mode  # pylint: disable=protected-access
77
78  @property
79  def _unique_id(self):
80    return self.variables[0]._unique_id  # pylint: disable=protected-access
81
82  @property
83  def _distribute_strategy(self):
84    return self.variables[0]._distribute_strategy  # pylint: disable=protected-access
85
86  @property
87  def _shared_name(self):
88    return self._name
89
90
91def _add_key_attr(op, name):
92  op._set_attr(_NAME_KEY, attr_value_pb2.AttrValue(s=compat.as_bytes(name)))  # pylint: disable=protected-access
93
94
95@tf_export("tpu.experimental.embedding.TPUEmbedding")
96class TPUEmbedding(tracking.AutoTrackable):
97  """The TPUEmbedding mid level API.
98
99  NOTE: When instantiated under a TPUStrategy, this class can only be created
100  once per call to `tf.tpu.experimental.initialize_tpu_system`. If you wish to
101  re-initialize the embedding engine you must re-initialize the tpu as well.
102  Doing this will clear any variables from TPU, so ensure you have checkpointed
103  before you do this. If a further instances of the class are needed,
104  set the `initialize_tpu_embedding` argument to `False`.
105
106  This class can be used to support training large embeddings on TPU. When
107  creating an instance of this class, you must specify the complete set of
108  tables and features you expect to lookup in those tables. See the
109  documentation of `tf.tpu.experimental.embedding.TableConfig` and
110  `tf.tpu.experimental.embedding.FeatureConfig` for more details on the complete
111  set of options. We will cover the basic usage here.
112
113  NOTE: multiple `FeatureConfig` objects can use the same `TableConfig` object,
114  allowing different features to share the same table:
115
116  ```python
117  table_config_one = tf.tpu.experimental.embedding.TableConfig(
118      vocabulary_size=...,
119      dim=...)
120  table_config_two = tf.tpu.experimental.embedding.TableConfig(
121      vocabulary_size=...,
122      dim=...)
123  feature_config = {
124      'feature_one': tf.tpu.experimental.embedding.FeatureConfig(
125          table=table_config_one),
126      'feature_two': tf.tpu.experimental.embedding.FeatureConfig(
127          table=table_config_one),
128      'feature_three': tf.tpu.experimental.embedding.FeatureConfig(
129          table=table_config_two)}
130  ```
131
132  There are two modes under which the `TPUEmbedding` class can used. This
133  depends on if the class was created under a `TPUStrategy` scope or not.
134
135  Under `TPUStrategy`, we allow access to the method `enqueue`, `dequeue` and
136  `apply_gradients`. We will show examples below of how to use these to train
137  and evaluate your model. Under CPU, we only access to the `embedding_tables`
138  property which allow access to the embedding tables so that you can use them
139  to run model evaluation/prediction on CPU.
140
141  First lets look at the `TPUStrategy` mode. Initial setup looks like:
142
143  ```python
144  strategy = tf.distribute.TPUStrategy(...)
145  with strategy.scope():
146    embedding = tf.tpu.experimental.embedding.TPUEmbedding(
147        feature_config=feature_config,
148        optimizer=tf.tpu.experimental.embedding.SGD(0.1))
149  ```
150
151  When creating a distributed dataset that is to be passed to the enqueue
152  operation a special input option must be specified:
153
154  ```python
155  distributed_dataset = (
156      strategy.distribute_datasets_from_function(
157          dataset_fn=...,
158          options=tf.distribute.InputOptions(
159              experimental_prefetch_to_device=False))
160  dataset_iterator = iter(distributed_dataset)
161  ```
162
163  NOTE: All batches passed to the layer must have the same batch size for each
164  input, more over once you have called the layer with one batch size all
165  subsequent calls must use the same batch_size. In the event that the batch
166  size cannot be automatically determined by the enqueue method, you must call
167  the build method with the batch size to initialize the layer.
168
169  To use this API on TPU you should use a custom training loop. Below is an
170  example of a training and evaluation step:
171
172  ```python
173  @tf.function
174  def training_step(dataset_iterator, num_steps):
175    def tpu_step(tpu_features):
176      with tf.GradientTape() as tape:
177        activations = embedding.dequeue()
178        tape.watch(activations)
179        model_output = model(activations)
180        loss = ...  # some function of labels and model_output
181
182      embedding_gradients = tape.gradient(loss, activations)
183      embedding.apply_gradients(embedding_gradients)
184      # Insert your model gradient and optimizer application here
185
186    for _ in tf.range(num_steps):
187      embedding_features, tpu_features = next(dataset_iterator)
188      embedding.enqueue(embedding_features, training=True)
189      strategy.run(tpu_step, args=(embedding_features, ))
190
191  @tf.function
192  def evalution_step(dataset_iterator, num_steps):
193    def tpu_step(tpu_features):
194      activations = embedding.dequeue()
195      model_output = model(activations)
196      # Insert your evaluation code here.
197
198    for _ in tf.range(num_steps):
199      embedding_features, tpu_features = next(dataset_iterator)
200      embedding.enqueue(embedding_features, training=False)
201      strategy.run(tpu_step, args=(embedding_features, ))
202  ```
203
204  NOTE: The calls to `enqueue` have `training` set to `True` when
205  `embedding.apply_gradients` is used and set to `False` when
206  `embedding.apply_gradients` is not present in the function. If you don't
207  follow this pattern you may cause an error to be raised or the tpu may
208  deadlock.
209
210  In the above examples, we assume that the user has a dataset which returns
211  a tuple where the first element of the tuple matches the structure of what
212  was passed as the `feature_config` argument to the object initializer. Also we
213  utilize `tf.range` to get a `tf.while_loop` in order to increase performance.
214
215  When checkpointing your model, you should include your
216  `tf.tpu.experimental.embedding.TPUEmbedding` object in the checkpoint. It is a
217  trackable object and saving it will save the embedding tables and their
218  optimizer slot variables:
219
220  ```python
221  checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
222  checkpoint.save(...)
223  ```
224
225  On CPU, only the `embedding_table` property is usable. This will allow you to
226  restore a checkpoint to the object and have access to the table variables:
227
228  ```python
229  model = model_fn(...)
230  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
231      feature_config=feature_config,
232      optimizer=tf.tpu.experimental.embedding.SGD(0.1))
233  checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
234  checkpoint.restore(...)
235
236  tables = embedding.embedding_tables
237  ```
238
239  You can now use table in functions like `tf.nn.embedding_lookup` to perform
240  your embedding lookup and pass to your model.
241
242  """
243
244  def __init__(
245      self,
246      feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable],  # pylint:disable=g-bare-generic
247      optimizer: Optional[tpu_embedding_v2_utils._Optimizer],  # pylint:disable=protected-access
248      pipeline_execution_with_tensor_core: bool = False):
249    """Creates the TPUEmbedding mid level API object.
250
251    ```python
252    strategy = tf.distribute.TPUStrategy(...)
253    with strategy.scope():
254      embedding = tf.tpu.experimental.embedding.TPUEmbedding(
255          feature_config=tf.tpu.experimental.embedding.FeatureConfig(
256              table=tf.tpu.experimental.embedding.TableConfig(
257                  dim=...,
258                  vocabulary_size=...)))
259    ```
260
261    Args:
262      feature_config: A nested structure of
263        `tf.tpu.experimental.embedding.FeatureConfig` configs.
264      optimizer: An instance of one of `tf.tpu.experimental.embedding.SGD`,
265        `tf.tpu.experimental.embedding.Adagrad` or
266        `tf.tpu.experimental.embedding.Adam`. When not created under
267        TPUStrategy may be set to None to avoid the creation of the optimizer
268        slot variables, useful for optimizing memory consumption when exporting
269        the model for serving where slot variables aren't needed.
270      pipeline_execution_with_tensor_core: If True, the TPU embedding
271        computations will overlap with the TensorCore computations (and hence
272        will be one step old). Set to True for improved performance.
273
274    Raises:
275      ValueError: If optimizer is not one of tf.tpu.experimental.embedding.(SGD,
276      Adam or Adagrad) or None when created under a TPUStrategy.
277    """
278    self._strategy = distribution_strategy_context.get_strategy()
279    self._using_tpu = isinstance(self._strategy, (tpu_strategy.TPUStrategy,
280                                                  tpu_strategy.TPUStrategyV2))
281    self._pipeline_execution_with_tensor_core = (
282        pipeline_execution_with_tensor_core)
283
284    self._feature_config = feature_config
285
286    # The TPU embedding ops are slightly inconsistent with how they refer to
287    # tables:
288    # * The enqueue op takes a parallel list of tensors for input, one of those
289    #   is the table id for the feature which matches the integer index of the
290    #   table in the proto created by _create_config_proto().
291    # * The recv_tpu_embedding_activations op emits lookups per table in the
292    #   order from the config proto.
293    # * The send_tpu_embedding_gradients expects input tensors to be per table
294    #   in the same order as the config proto.
295    # * Per optimizer load and retrieve ops are specified per table and take the
296    #   table name rather than the table id.
297    # Thus we must fix a common order to tables and ensure they have unique
298    # names.
299
300    # Set table order here to the order of the first occurence of the table in a
301    # feature provided by the user. The order of this struct must be fixed
302    # to provide the user with deterministic behavior over multiple
303    # instantiations.
304    self._table_config = []
305    for feature in nest.flatten(feature_config):
306      if feature.table not in self._table_config:
307        self._table_config.append(feature.table)
308
309    # Ensure tables have unique names. Also error check the optimizer as we
310    # specifically don't do that in the TableConfig class to allow high level
311    # APIs that are built on this to use strings/other classes to represent
312    # optimizers (before they are passed to this class).
313    table_names = []
314    for i, table in enumerate(self._table_config):
315      if table.optimizer is None:
316        # TODO(bfontain) Should we allow some sort of optimizer merging here?
317        table.optimizer = optimizer
318      if ((table.optimizer is not None or self._using_tpu) and
319          not isinstance(table.optimizer, tpu_embedding_v2_utils._Optimizer)):  # pylint: disable=protected-access
320        raise ValueError("{} is an unsupported optimizer class. Please pass an "
321                         "instance of one of the optimizer classes under "
322                         "tf.tpu.experimental.embedding.".format(
323                             type(table.optimizer)))
324      if table.name is None:
325        table.name = "table_{}".format(i)
326      if table.name in table_names:
327        raise ValueError("Multiple tables with name {} found.".format(
328            table.name))
329      table_names.append(table.name)
330
331    if self._using_tpu:
332      # Extract a list of callable learning rates also in fixed order. Each
333      # table in the confix proto will get a index into this list and we will
334      # pass this list in the same order after evaluation to the
335      # send_tpu_embedding_gradients op.
336      self._dynamic_learning_rates = list({
337          table.optimizer.learning_rate for table in self._table_config if
338          callable(table.optimizer.learning_rate)})
339
340      # We need to list of host devices for the load/retrieve operations.
341      self._hosts = get_list_of_hosts(self._strategy)
342
343    self._built = False
344
345  def build(self, per_replica_batch_size: Optional[int] = None):
346    """Create the underlying variables and initializes the TPU for embeddings.
347
348    This method creates the underlying variables (including slot variables). If
349    created under a TPUStrategy, this will also initialize the TPU for
350    embeddings.
351
352    This function will automatically get called by enqueue, which will try to
353    determine your batch size automatically. If this fails, you must manually
354    call this method before you call enqueue.
355
356    Args:
357      per_replica_batch_size: The per replica batch size that you intend to use.
358        Note that is fixed and the same batch size must be used for both
359        training and evaluation. If you want to calculate this from the global
360        batch size, you can use `num_replicas_in_sync` property of your strategy
361        object. May be set to None if not created under a TPUStrategy.
362
363    Raises:
364      ValueError: If per_replica_batch_size is None and object was created in a
365        TPUStrategy scope.
366    """
367    if self._built:
368      return
369
370    if self._using_tpu:
371      if per_replica_batch_size is None:
372        raise ValueError("You must specify a per_replica_batch_size when "
373                         "calling build if object is created under a "
374                         "TPUStrategy.")
375
376      self._batch_size = per_replica_batch_size
377
378      self._config_proto = self._create_config_proto()
379
380      logging.info("Initializing TPU Embedding engine.")
381      tpu_embedding_v2_utils.log_tpu_embedding_configuration(self._config_proto)
382
383      @def_function.function
384      def load_config():
385        tpu.initialize_system_for_tpu_embedding(self._config_proto)
386
387      load_config()
388      logging.info("Done initializing TPU Embedding engine.")
389
390    # Create and load variables and slot variables into the TPU.
391    # Note that this is a dict of dicts. Keys to the first dict are table names.
392    # We would prefer to use TableConfigs, but then these variables won't be
393    # properly tracked by the tracking API.
394    self._variables = self._create_variables_and_slots()
395
396    self._built = True
397
398    # This is internally conditioned self._built and self._using_tpu
399    self._load_variables()
400
401  def _maybe_build(self, batch_size: Optional[int]):
402    if not self._built:
403      # This can be called while tracing a function, so we wrap the
404      # initialization code with init_scope so it runs eagerly, this means that
405      # it will not be included the function graph generated by tracing so that
406      # we can be sure that we only initialize the TPU for embeddings exactly
407      # once.
408      with ops.init_scope():
409        self.build(batch_size)
410
411  @property
412  def embedding_tables(
413      self
414  ) -> Dict[tpu_embedding_v2_utils.TableConfig, tf_variables.Variable]:
415    """Returns a dict of embedding tables, keyed by `TableConfig`.
416
417    This property only works when the `TPUEmbedding` object is created under a
418    non-TPU strategy. This is intended to be used to for CPU based lookup when
419    creating a serving checkpoint.
420
421    Returns:
422      A dict of embedding tables, keyed by `TableConfig`.
423
424    Raises:
425      RuntimeError: If object was created under a `TPUStrategy`.
426    """
427    # We don't support returning tables on TPU due to their sharded nature and
428    # the fact that when using a TPUStrategy:
429    # 1. Variables are stale and are only updated when a checkpoint is made.
430    # 2. Updating the variables won't affect the actual tables on the TPU.
431    if self._using_tpu:
432      if save_context.in_save_context():
433        return {table: self._variables[table.name]["parameters"].variables[0]
434                for table in self._table_config}
435      raise RuntimeError("Unable to retrieve embedding tables when using a TPU "
436                         "strategy. If you need access, save your model, "
437                         "create this object under a CPU strategy and restore.")
438
439    self._maybe_build(None)
440
441    # Only return the tables and not the slot variables. On CPU this are honest
442    # tf.Variables.
443    return {table: self._variables[table.name]["parameters"]
444            for table in self._table_config}
445
446  def _create_config_proto(
447      self
448  ) -> tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration:
449    """Creates the TPUEmbeddingConfiguration proto.
450
451    This proto is used to initialize the TPU embedding engine.
452
453    Returns:
454      A TPUEmbeddingConfiguration proto.
455    """
456
457    config_proto = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration()
458
459    # There are several things that need to be computed here:
460    # 1. Each table has a num_features, which corresponds to the number of
461    #    output rows per example for this table. Sequence features count for
462    #    their maximum sequence length.
463    # 2. Learning rate index: the index of the dynamic learning rate for this
464    #    table (if it exists) in the list we created at initialization.
465    #    We don't simply create one learning rate index per table as this has
466    #    extremely bad performance characteristics. The more separate
467    #    optimization configurations we have, the worse the performance will be.
468    num_features = {table: 0 for table in self._table_config}
469    for feature in nest.flatten(self._feature_config):
470      num_features[feature.table] += (1 if feature.max_sequence_length == 0
471                                      else feature.max_sequence_length)
472
473    # Map each callable dynamic learning rate to its in index in the list.
474    learning_rate_index = {r: i for i, r in enumerate(
475        self._dynamic_learning_rates)}
476
477    for table in self._table_config:
478      table_descriptor = config_proto.table_descriptor.add()
479      table_descriptor.name = table.name
480
481      # For small tables, we pad to the number of hosts so that at least one
482      # id will be assigned to each host.
483      table_descriptor.vocabulary_size = max(table.vocabulary_size,
484                                             self._strategy.extended.num_hosts)
485      table_descriptor.dimension = table.dim
486
487      table_descriptor.num_features = num_features[table]
488
489      parameters = table_descriptor.optimization_parameters
490
491      # We handle the learning rate separately here and don't allow the
492      # optimization class to handle this, as it doesn't know about dynamic
493      # rates.
494      if callable(table.optimizer.learning_rate):
495        parameters.learning_rate.dynamic.tag = (
496            learning_rate_index[table.optimizer.learning_rate])
497      else:
498        parameters.learning_rate.constant = table.optimizer.learning_rate
499
500      # Use optimizer to handle the rest of the parameters.
501      table.optimizer._set_optimization_parameters(parameters)  # pylint: disable=protected-access
502
503    # Always set mode to training, we override the mode during enqueue.
504    config_proto.mode = (
505        tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.TRAINING)
506
507    config_proto.batch_size_per_tensor_core = self._batch_size
508    config_proto.num_hosts = self._strategy.extended.num_hosts
509    config_proto.num_tensor_cores = self._strategy.num_replicas_in_sync
510
511    # TODO(bfontain): Allow users to pick MOD for the host sharding.
512    config_proto.sharding_strategy = (
513        tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.DIV_DEFAULT)
514    config_proto.pipeline_execution_with_tensor_core = (
515        self._pipeline_execution_with_tensor_core)
516
517    return config_proto
518
519  def _compute_per_table_gradients(
520      self,
521      gradients
522  ) -> Dict[Text, List[core.Tensor]]:
523    """Computes a dict of lists of gradients, keyed by table name.
524
525    Args:
526      gradients: A nested structure of Tensors (and Nones) with the same
527        structure as the feature config.
528
529    Returns:
530      A dict of lists of tensors, keyed by the table names, containing the
531    gradients in the correct order with None gradients replaced by zeros.
532    """
533
534    nest.assert_same_structure(self._feature_config, gradients)
535
536    per_table_gradients = {table: [] for table in self._table_config}
537    for (path, gradient), feature in zip(
538        nest.flatten_with_joined_string_paths(gradients),
539        nest.flatten(self._feature_config)):
540      if gradient is not None and not isinstance(gradient, ops.Tensor):
541        raise ValueError(
542            "Found {} at path {} in gradients. Expected Tensor.".format(
543                type(gradient), path))
544
545      # Expected tensor shape differs for sequence and non-sequence features.
546      if feature.max_sequence_length > 0:
547        shape = [self._batch_size, feature.max_sequence_length,
548                 feature.table.dim]
549      else:
550        shape = [self._batch_size, feature.table.dim]
551
552      if gradient is not None:
553        if gradient.shape != shape:
554          raise ValueError("Found gradient of shape {} at path {}. Expected "
555                           "shape {}.".format(gradient.shape, path, shape))
556
557        # We expand dims on non-sequence features so that all features are
558        # of rank 3 and we can concat on axis=1.
559        if len(shape) == 2:
560          gradient = array_ops.expand_dims(gradient, axis=1)
561      else:
562        # No gradient for this feature, since we must give a gradient for all
563        # features, pass in a zero tensor here. Note that this is not correct
564        # for all optimizers.
565        logging.warn("No gradient passed for feature %s, sending zero "
566                     "gradient. This may not be correct behavior for certain "
567                     "optimizers like Adam.", path)
568        # Create a shape to mimic the expand_dims above for non-sequence
569        # features.
570        if len(shape) == 2:
571          shape = [shape[0], 1, shape[1]]
572        gradient = array_ops.zeros(shape, dtype=dtypes.float32)
573      per_table_gradients[feature.table].append(gradient)
574
575    return per_table_gradients
576
577  def apply_gradients(self, gradients, name: Text = None):
578    """Applies the gradient update to the embedding tables.
579
580    If a gradient of `None` is passed in any position of the nested structure,
581    then an gradient update with a zero gradient is applied for that feature.
582    For optimizers like SGD or Adagrad, this is the same as applying no update
583    at all. For lazy Adam and other sparsely applied optimizers with decay,
584    ensure you understand the effect of applying a zero gradient.
585
586    ```python
587    strategy = tf.distribute.TPUStrategy(...)
588    with strategy.scope():
589      embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
590
591    distributed_dataset = (
592        strategy.distribute_datasets_from_function(
593            dataset_fn=...,
594            options=tf.distribute.InputOptions(
595                experimental_prefetch_to_device=False))
596    dataset_iterator = iter(distributed_dataset)
597
598    @tf.function
599    def training_step():
600      def tpu_step(tpu_features):
601        with tf.GradientTape() as tape:
602          activations = embedding.dequeue()
603          tape.watch(activations)
604
605          loss = ... #  some computation involving activations
606
607        embedding_gradients = tape.gradient(loss, activations)
608        embedding.apply_gradients(embedding_gradients)
609
610      embedding_features, tpu_features = next(dataset_iterator)
611      embedding.enqueue(embedding_features, training=True)
612      strategy.run(tpu_step, args=(embedding_features, ))
613
614    training_step()
615    ```
616
617    Args:
618      gradients: A nested structure of gradients, with structure matching the
619        `feature_config` passed to this object.
620      name: A name for the underlying op.
621
622    Raises:
623      RuntimeError: If called when object wasn't created under a `TPUStrategy`
624        or if not built (either by manually calling build or calling enqueue).
625      ValueError: If a non-`tf.Tensor` non-`None` gradient is passed in, or a
626        `tf.Tensor` of the incorrect shape is passed in. Also if
627        the size of any sequence in `gradients` does not match corresponding
628        sequence in `feature_config`.
629      TypeError: If the type of any sequence in `gradients` does not match
630        corresponding sequence in `feature_config`.
631    """
632    if not self._using_tpu:
633      raise RuntimeError("apply_gradients is not valid when TPUEmbedding "
634                         "object is not created under a TPUStrategy.")
635
636    if not self._built:
637      raise RuntimeError("apply_gradients called on unbuilt TPUEmbedding "
638                         "object. Please either call enqueue first or manually "
639                         "call the build method.")
640
641    # send_tpu_embedding_gradients requires per table gradient, if we only have
642    # one feature per table this isn't an issue. When multiple features share
643    # the same table, the order of the features in per table tensor returned by
644    # recv_tpu_embedding_activations matches the order in which they were passed
645    # to enqueue.
646    # In all three places, we use the fixed order given by nest.flatten to have
647    # a consistent feature order.
648
649    # First construct a dict of tensors one for each table.
650    per_table_gradients = self._compute_per_table_gradients(gradients)
651
652    # Now that we have a list of gradients we can compute a list of gradients
653    # in the fixed order of self._table_config which interleave the gradients of
654    # the individual features. We concat on axis 1 and then reshape into a 2d
655    # tensor. The send gradients op expects a tensor of shape
656    # [num_features*batch_size, dim] for each table.
657    interleaved_gradients = []
658    for table in self._table_config:
659      interleaved_gradients.append(array_ops.reshape(
660          array_ops.concat(per_table_gradients[table], axis=1),
661          [-1, table.dim]))
662    op = tpu_ops.send_tpu_embedding_gradients(
663        inputs=interleaved_gradients,
664        learning_rates=[math_ops.cast(fn(), dtype=dtypes.float32)
665                        for fn in self._dynamic_learning_rates],
666        config=self._config_proto.SerializeToString())
667
668    # Apply the name tag to the op.
669    if name is not None:
670      _add_key_attr(op, name)
671
672  def dequeue(self, name: Text = None):
673    """Get the embedding results.
674
675    Returns a nested structure of `tf.Tensor` objects, matching the structure of
676    the `feature_config` argument to the `TPUEmbedding` class. The output shape
677    of the tensors is `(batch_size, dim)`, where `batch_size` is the per core
678    batch size, `dim` is the dimension of the corresponding `TableConfig`. If
679    the feature's corresponding `FeatureConfig` has `max_sequence_length`
680    greater than 0, the output will be a sequence of shape
681    `(batch_size, max_sequence_length, dim)` instead.
682
683    ```python
684    strategy = tf.distribute.TPUStrategy(...)
685    with strategy.scope():
686      embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
687
688    distributed_dataset = (
689        strategy.distribute_datasets_from_function(
690            dataset_fn=...,
691            options=tf.distribute.InputOptions(
692                experimental_prefetch_to_device=False))
693    dataset_iterator = iter(distributed_dataset)
694
695    @tf.function
696    def training_step():
697      def tpu_step(tpu_features):
698        with tf.GradientTape() as tape:
699          activations = embedding.dequeue()
700          tape.watch(activations)
701
702          loss = ... #  some computation involving activations
703
704        embedding_gradients = tape.gradient(loss, activations)
705        embedding.apply_gradients(embedding_gradients)
706
707      embedding_features, tpu_features = next(dataset_iterator)
708      embedding.enqueue(embedding_features, training=True)
709      strategy.run(tpu_step, args=(embedding_features, ))
710
711    training_step()
712    ```
713
714    Args:
715      name: A name for the underlying op.
716
717    Returns:
718      A nested structure of tensors, with the same structure as `feature_config`
719    passed to this instance of the `TPUEmbedding` object.
720
721    Raises:
722      RuntimeError: If called when object wasn't created under a `TPUStrategy`
723        or if not built (either by manually calling build or calling enqueue).
724    """
725    if not self._using_tpu:
726      raise RuntimeError("dequeue is not valid when TPUEmbedding object is not "
727                         "created under a TPUStrategy.")
728
729    if not self._built:
730      raise RuntimeError("dequeue called on unbuilt TPUEmbedding object. "
731                         "Please either call enqueue first or manually call "
732                         "the build method.")
733
734    # The activations returned by this op are per table. So we must separate
735    # them out into per feature activations. The activations are interleaved:
736    # for each table, we expect a [num_features*batch_size, dim] tensor.
737    # E.g. we expect the slice [:num_features, :] to contain the lookups for the
738    # first example of all features using this table.
739    activations = tpu_ops.recv_tpu_embedding_activations(
740        num_outputs=len(self._table_config),
741        config=self._config_proto.SerializeToString())
742
743    # Apply the name tag to the op.
744    if name is not None:
745      _add_key_attr(activations[0].op, name)
746
747    # Compute the number of features for this  table.
748    num_features = {table: 0 for table in self._table_config}
749    for feature in nest.flatten(self._feature_config):
750      num_features[feature.table] += (1 if feature.max_sequence_length == 0
751                                      else feature.max_sequence_length)
752
753    # Activations are reshaped so that they are indexed by batch size and then
754    # by the 'feature' index within the batch. The final dimension should equal
755    # the dimension of the table.
756    table_to_activation = {
757        table: array_ops.reshape(activation,
758                                 [self._batch_size, num_features[table], -1])
759        for table, activation in zip(self._table_config, activations)}
760
761    # We process the features in the same order we enqueued them.
762    # For each feature we take the next slice of the activations, so need to
763    # track the activations and the current position we are in.
764    table_to_position = {table: 0 for table in self._table_config}
765
766    per_feature_activations = []
767    for feature in nest.flatten(self._feature_config):
768      activation = table_to_activation[feature.table]
769      feature_index = table_to_position[feature.table]
770      # We treat non-sequence and sequence features differently here as sequence
771      # features have rank 3 while non-sequence features have rank 2.
772      if feature.max_sequence_length == 0:
773        per_feature_activations.append(
774            activation[:, feature_index, :])
775        table_to_position[feature.table] += 1
776      else:
777        per_feature_activations.append(
778            activation[:, feature_index:(
779                feature_index+feature.max_sequence_length), :])
780        table_to_position[feature.table] += feature.max_sequence_length
781
782    # Pack the list back into the same nested structure as the features.
783    return nest.pack_sequence_as(self._feature_config, per_feature_activations)
784
785  def _create_variables_and_slots(
786      self
787  ) -> Dict[Text, Dict[Text, tf_variables.Variable]]:
788    """Create variables for TPU embeddings.
789
790    Note under TPUStrategy this will ensure that all creations happen within a
791    variable creation scope of the sharded variable creator.
792
793    Returns:
794      A dict of dicts. The outer dict is keyed by the table names and the inner
795      dicts are keyed by 'parameters' and the slot variable names.
796    """
797
798    def create_variables(table):
799      """Create all variables."""
800      variable_shape = (table.vocabulary_size, table.dim)
801
802      def getter(name, shape, dtype, initializer, trainable):
803        del shape
804        # _add_variable_with_custom_getter clears the shape sometimes, so we
805        # take the global shape from outside the getter.
806        initial_value = functools.partial(initializer, variable_shape,
807                                          dtype=dtype)
808        return tf_variables.Variable(
809            name=name,
810            initial_value=initial_value,
811            shape=variable_shape,
812            dtype=dtype,
813            trainable=trainable)
814
815      def variable_creator(name, initializer, trainable=True):
816        # use add_variable_with_custom_getter here so that we take advantage of
817        # the checkpoint loading to allow restore before the variables get
818        # created which avoids double initialization.
819        return self._add_variable_with_custom_getter(
820            name=name,
821            initializer=initializer,
822            shape=variable_shape,
823            dtype=dtypes.float32,
824            getter=getter,
825            trainable=trainable)
826
827      parameters = variable_creator(table.name, table.initializer,
828                                    trainable=not self._using_tpu)
829
830      def slot_creator(name, initializer):
831        return variable_creator(table.name + "/" + name,
832                                initializer,
833                                False)
834
835      if table.optimizer is not None:
836        slot_vars = table.optimizer._create_slots(parameters, slot_creator)  # pylint: disable=protected-access
837      else:
838        slot_vars = {}
839      slot_vars["parameters"] = parameters
840      return slot_vars
841
842    # Store tables based on name rather than TableConfig as we can't track
843    # through dicts with non-string keys, i.e. we won't be able to save.
844    variables = {}
845    for table in self._table_config:
846      if not self._using_tpu:
847        variables[table.name] = create_variables(table)
848      else:
849        with variable_scope.variable_creator_scope(
850            make_sharded_variable_creator(self._hosts)):
851          variables[table.name] = create_variables(table)
852
853    return variables
854
855  def _load_variables(self):
856    # Only load the variables if we are:
857    # 1) Using TPU
858    # 2) Variables are created
859    # 3) Not in save context (except if running eagerly)
860    if self._using_tpu and self._built and not (
861        not context.executing_eagerly() and save_context.in_save_context()):
862      _load_variables_impl(self._config_proto.SerializeToString(),
863                           self._hosts,
864                           self._variables,
865                           self._table_config)
866
867  def _retrieve_variables(self):
868    # Only retrieve the variables if we are:
869    # 1) Using TPU
870    # 2) Variables are created
871    # 3) Not in save context (except if running eagerly)
872    if self._using_tpu and self._built and not (
873        not context.executing_eagerly() and save_context.in_save_context()):
874      _retrieve_variables_impl(self._config_proto.SerializeToString(),
875                               self._hosts,
876                               self._variables,
877                               self._table_config)
878
879  def _gather_saveables_for_checkpoint(
880      self
881  ) -> Dict[Text, Callable[[Text], "TPUEmbeddingSaveable"]]:
882    """Overrides default Trackable implementation to add load/retrieve hook."""
883    # This saveable should be here in both TPU and CPU checkpoints, so when on
884    # CPU, we add the hook with no functions.
885    # TODO(bfontain): Update restore logic in saver so that these hooks are
886    # always executed. Once that is done, we can output an empty list when on
887    # CPU.
888
889    def factory(name=_HOOK_KEY):
890      return TPUEmbeddingSaveable(name, self._load_variables,
891                                  self._retrieve_variables)
892    return {_HOOK_KEY: factory}
893
894  # Some helper functions for the below enqueue function.
895  def _add_data_for_tensor(self, tensor, weight, indices, values, weights,
896                           int_zeros, float_zeros, path):
897    if weight is not None:
898      raise ValueError(
899          "Weight specified for dense input {}, which is not allowed. "
900          "Weight will always be 1 in this case.".format(path))
901    # For tensors, there are no indices and no weights.
902    indices.append(int_zeros)
903    values.append(math_ops.cast(tensor, dtypes.int32))
904    weights.append(float_zeros)
905
906  def _add_data_for_sparse_tensor(self, tensor, weight, indices, values,
907                                  weights, int_zeros, float_zeros, path):
908    indices.append(math_ops.cast(tensor.indices, dtypes.int32))
909    values.append(math_ops.cast(tensor.values, dtypes.int32))
910    # If we have weights they must be a SparseTensor.
911    if weight is not None:
912      if not isinstance(weight, sparse_tensor.SparseTensor):
913        raise ValueError("Weight for {} is type {} which does not match "
914                         "type input which is SparseTensor.".format(
915                             path, type(weight)))
916      weights.append(math_ops.cast(weight.values, dtypes.float32))
917    else:
918      weights.append(float_zeros)
919
920  def _add_data_for_ragged_tensor(self, tensor, weight, indices, values,
921                                  weights, int_zeros, float_zeros, path):
922    indices.append(math_ops.cast(tensor.row_splits, dtypes.int32))
923    values.append(math_ops.cast(tensor.values, dtypes.int32))
924    # If we have weights they must be a RaggedTensor.
925    if weight is not None:
926      if not isinstance(weight, ragged_tensor.RaggedTensor):
927        raise ValueError("Weight for {} is type {} which does not match "
928                         "type input which is RaggedTensor.".format(
929                             path, type(weight)))
930      weights.append(math_ops.cast(weight.values, dtypes.float32))
931    else:
932      weights.append(float_zeros)
933
934  def _generate_enqueue_op(
935      self,
936      flat_inputs: List[internal_types.NativeObject],
937      flat_weights: List[Optional[internal_types.NativeObject]],
938      flat_features: List[tpu_embedding_v2_utils.FeatureConfig],
939      device_ordinal: int,
940      mode_override: Text
941  ) -> ops.Operation:
942    """Outputs a the enqueue op given the inputs and weights.
943
944    Args:
945      flat_inputs: A list of input tensors.
946      flat_weights: A list of input weights (or None) of the same length as
947        flat_inputs.
948      flat_features: A list of FeatureConfigs of the same length as flat_inputs.
949      device_ordinal: The device to create the enqueue op for.
950      mode_override: A tensor containing the string "train" or "inference".
951
952    Returns:
953      The enqueue op.
954    """
955
956    # First we need to understand which op to use. This depends on if sparse
957    # or ragged tensors are in the flat_inputs.
958    sparse = False
959    ragged = False
960    for inp in flat_inputs:
961      if isinstance(inp, sparse_tensor.SparseTensor):
962        sparse = True
963      elif isinstance(inp, ragged_tensor.RaggedTensor):
964        ragged = True
965    if sparse and ragged:
966      raise ValueError(
967          "Found both SparseTensors and RaggedTensors in the input to the "
968          "enqueue operation. Please ensure that your data does not include "
969          "both SparseTensors and RaggedTensors. It is ok to have Tensors in "
970          "combination with one of the previous types.")
971
972    # Combiners are per table, list in the same order as the table order.
973    combiners = [table.combiner for table in self._table_config]
974
975    # Reverse mapping of self._table_config, so that we can lookup the table
976    # index.
977    table_to_id = {table: i for i, table in enumerate(self._table_config)}
978
979    # These parallel arrays will be the inputs to the enqueue op.
980    indices = []  # sample_indices for sparse, sample_splits for ragged.
981    values = []
982    weights = []
983    table_ids = []
984    max_sequence_lengths = []
985
986    # We have to supply a empty/zero tensor in a list position where we don't
987    # have data (e.g. indices for standard Tensor input, weight when no weight
988    # is specified). We create one op here per call, so that we reduce the
989    # graph size.
990    int_zeros = array_ops.zeros((0,), dtype=dtypes.int32)
991    float_zeros = array_ops.zeros((0,), dtype=dtypes.float32)
992
993    # In the following loop we insert casts so that everything is either int32
994    # or float32. This is because op inputs which are lists of tensors must be
995    # of the same type within the list. Moreover the CPU implementations of
996    # these ops cast to these types anyway, so we don't lose any data by casting
997    # early.
998    for inp, weight, (path, feature) in zip(
999        flat_inputs, flat_weights, flat_features):
1000      table_ids.append(table_to_id[feature.table])
1001      max_sequence_lengths.append(feature.max_sequence_length)
1002      if isinstance(inp, ops.Tensor):
1003        self._add_data_for_tensor(inp, weight, indices, values, weights,
1004                                  int_zeros, float_zeros, path)
1005      elif isinstance(inp, sparse_tensor.SparseTensor):
1006        self._add_data_for_sparse_tensor(inp, weight, indices, values, weights,
1007                                         int_zeros, float_zeros, path)
1008      elif isinstance(inp, ragged_tensor.RaggedTensor):
1009        self._add_data_for_ragged_tensor(inp, weight, indices, values, weights,
1010                                         int_zeros, float_zeros, path)
1011      else:
1012        raise ValueError("Input {} is of unknown type {}. Please only pass "
1013                         "Tensor, SparseTensor or RaggedTensor as input to "
1014                         "enqueue.".format(path, type(inp)))
1015
1016    if ragged:
1017      return tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch(
1018          sample_splits=indices,
1019          embedding_indices=values,
1020          aggregation_weights=weights,
1021          mode_override=mode_override,
1022          device_ordinal=device_ordinal,
1023          combiners=combiners,
1024          table_ids=table_ids,
1025          max_sequence_lengths=max_sequence_lengths)
1026    return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
1027        sample_indices=indices,
1028        embedding_indices=values,
1029        aggregation_weights=weights,
1030        mode_override=mode_override,
1031        device_ordinal=device_ordinal,
1032        combiners=combiners,
1033        table_ids=table_ids,
1034        max_sequence_lengths=max_sequence_lengths)
1035
1036  def _raise_error_for_incorrect_control_flow_context(self):
1037    """Raises an error if we are not in the TPUReplicateContext."""
1038    # Do not allow any XLA control flow (i.e. control flow in between a
1039    # TPUStrategy's run call and the call to this function), as we can't
1040    # extract the enqueue from the head when in XLA control flow.
1041    graph = ops.get_default_graph()
1042    in_tpu_ctx = False
1043    while graph is not None:
1044      ctx = graph._get_control_flow_context()  # pylint: disable=protected-access
1045      while ctx is not None:
1046        if isinstance(ctx, tpu.TPUReplicateContext):
1047          in_tpu_ctx = True
1048          break
1049        ctx = ctx.outer_context
1050      if in_tpu_ctx:
1051        break
1052      graph = getattr(graph, "outer_graph", None)
1053    if graph != ops.get_default_graph() and in_tpu_ctx:
1054      raise RuntimeError(
1055          "Current graph {} does not match graph which contains "
1056          "TPUReplicateContext {}. This is most likely due to the fact that "
1057          "enqueueing embedding data is called inside control flow or a "
1058          "nested function inside `strategy.run`. This is not supported "
1059          "because outside compilation fails to extract the enqueue ops as "
1060          "head of computation.".format(ops.get_default_graph(), graph))
1061    return in_tpu_ctx
1062
1063  def _raise_error_for_non_direct_inputs(self, features):
1064    """Checks all tensors in features to see if they are a direct input."""
1065
1066    # expand_composites here is important: as composite tensors pass through
1067    # tpu.replicate, they get 'flattened' into their component tensors and then
1068    # repacked before being passed to the tpu function. In means that it is the
1069    # component tensors which are produced by an op with the
1070    # "_tpu_input_identity" attribute.
1071    for path, input_tensor in nest.flatten_with_joined_string_paths(
1072        features, expand_composites=True):
1073      if input_tensor.op.type == "Placeholder":
1074        continue
1075      try:
1076        is_input = input_tensor.op.get_attr("_tpu_input_identity")
1077      except ValueError:
1078        is_input = False
1079      if not is_input:
1080        raise ValueError(
1081            "Received input tensor {} which is the output of op {} (type {}) "
1082            "which does not have the `_tpu_input_identity` attr. Please "
1083            "ensure that the inputs to this layer are taken directly from "
1084            "the arguments of the function called by "
1085            "strategy.run. Two possible causes are: dynamic batch size "
1086            "support or you are using a keras layer and are not passing "
1087            "tensors which match the dtype of the `tf.keras.Input`s."
1088            "If you are triggering dynamic batch size support, you can "
1089            "disable it by passing tf.distribute.RunOptions("
1090            "experimental_enable_dynamic_batch_size=False) to the options "
1091            "argument of strategy.run().".format(path,
1092                                                 input_tensor.op.name,
1093                                                 input_tensor.op.type))
1094
1095  def _raise_error_for_inputs_not_on_cpu(self, features):
1096    """Checks all tensors in features to see are placed on the CPU."""
1097
1098    def check_device(path, device_string):
1099      spec = tf_device.DeviceSpec.from_string(device_string)
1100      if spec.device_type == "TPU":
1101        raise ValueError(
1102            "Received input tensor {} which is on a TPU input device {}. Input "
1103            "tensors for TPU embeddings must be placed on the CPU. Please "
1104            "ensure that your dataset is prefetching tensors to the host by "
1105            "setting the 'experimental_prefetch_to_device' option of the "
1106            "dataset distribution function. See the documentation of the "
1107            "enqueue method for an example.".format(
1108                path, device_string))
1109
1110    # expand_composites here is important, we need to check the device of each
1111    # underlying tensor.
1112    for path, input_tensor in nest.flatten_with_joined_string_paths(
1113        features, expand_composites=True):
1114      if (input_tensor.op.type == "Identity" and
1115          input_tensor.op.inputs[0].op.type == "TPUReplicatedInput"):
1116        for tensor in input_tensor.op.inputs[0].op.inputs:
1117          check_device(path, tensor.device)
1118      else:
1119        check_device(path, input_tensor.device)
1120
1121  def enqueue(
1122      self,
1123      features,
1124      weights=None,
1125      training: bool = True,
1126      name: Optional[Text] = None,
1127      device: Optional[Text] = None):
1128    """Enqueues id tensors for embedding lookup.
1129
1130    This function enqueues a structure of features to be looked up in the
1131    embedding tables. We expect that the batch size of each of the tensors in
1132    features matches the per core batch size. This will automatically happen if
1133    your input dataset is batched to the global batch size and you use
1134    `tf.distribute.TPUStrategy`'s `experimental_distribute_dataset`
1135    or if you use `distribute_datasets_from_function` and batch
1136    to the per core batch size computed by the context passed to your input
1137    function.
1138
1139    ```python
1140    strategy = tf.distribute.TPUStrategy(...)
1141    with strategy.scope():
1142      embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
1143
1144    distributed_dataset = (
1145        strategy.distribute_datasets_from_function(
1146            dataset_fn=...,
1147            options=tf.distribute.InputOptions(
1148                experimental_prefetch_to_device=False))
1149    dataset_iterator = iter(distributed_dataset)
1150
1151    @tf.function
1152    def training_step():
1153      def tpu_step(tpu_features):
1154        with tf.GradientTape() as tape:
1155          activations = embedding.dequeue()
1156          tape.watch(activations)
1157
1158          loss = ... #  some computation involving activations
1159
1160        embedding_gradients = tape.gradient(loss, activations)
1161        embedding.apply_gradients(embedding_gradients)
1162
1163      embedding_features, tpu_features = next(dataset_iterator)
1164      embedding.enqueue(embedding_features, training=True)
1165      strategy.run(tpu_step, args=(embedding_features,))
1166
1167    training_step()
1168    ```
1169
1170    NOTE: You should specify `training=True` when using
1171    `embedding.apply_gradients` as above and `training=False` when not using
1172    `embedding.apply_gradients` (e.g. for frozen embeddings or when doing
1173    evaluation).
1174
1175    For finer grained control, in the above example the line
1176
1177    ```
1178      embedding.enqueue(embedding_features, training=True)
1179    ```
1180
1181    may be replaced with
1182
1183    ```
1184      per_core_embedding_features = self.strategy.experimental_local_results(
1185          embedding_features)
1186
1187      def per_core_enqueue(ctx):
1188        core_id = ctx.replica_id_in_sync_group
1189        device = strategy.extended.worker_devices[core_id]
1190        embedding.enqueue(per_core_embedding_features[core_id],
1191                          device=device)
1192
1193      strategy.experimental_distribute_values_from_function(
1194          per_core_queue_inputs)
1195    ```
1196
1197    Args:
1198      features: A nested structure of `tf.Tensor`s, `tf.SparseTensor`s or
1199        `tf.RaggedTensor`s, with the same structure as `feature_config`. Inputs
1200        will be downcast to `tf.int32`. Only one type out of `tf.SparseTensor`
1201        or `tf.RaggedTensor` is supported per call.
1202      weights: If not `None`, a nested structure of `tf.Tensor`s,
1203        `tf.SparseTensor`s or `tf.RaggedTensor`s, matching the above, except
1204        that the tensors should be of float type (and they will be downcast to
1205        `tf.float32`). For `tf.SparseTensor`s we assume the `indices` are the
1206        same for the parallel entries from `features` and similarly for
1207        `tf.RaggedTensor`s we assume the row_splits are the same.
1208      training: Defaults to `True`. If `False`, enqueue the batch as inference
1209        batch (forward pass only). Do not call `apply_gradients` when this is
1210        `False` as this may lead to a deadlock.
1211       name: A name for the underlying op.
1212       device: The device name (e.g. '/task:0/device:TPU:2') where this batch
1213         should be enqueued. This should be set if and only if features is not a
1214         `tf.distribute.DistributedValues` and enqueue is not being called
1215         inside a TPU context (e.g. inside `TPUStrategy.run`).
1216
1217    Raises:
1218      ValueError: When called inside a strategy.run call and input is not
1219        directly taken from the args of the `strategy.run` call. Also if
1220        the size of any sequence in `features` does not match corresponding
1221        sequence in `feature_config`. Similarly for `weights`, if not `None`.
1222        If batch size of features is unequal or different from a previous call.
1223      RuntimeError: When called inside a strategy.run call and inside XLA
1224        control flow. If batch_size is not able to be determined and build was
1225        not called.
1226      TypeError: If the type of any sequence in `features` does not match
1227        corresponding sequence in `feature_config`. Similarly for `weights`, if
1228        not `None`.
1229    """
1230    if not self._using_tpu:
1231      raise RuntimeError("enqueue is not valid when TPUEmbedding object is not "
1232                         "created under a TPUStrategy.")
1233
1234    in_tpu_context = self._raise_error_for_incorrect_control_flow_context()
1235
1236    # Should we also get batch_size from weights if they exist?
1237    # Since features is assumed to be batched at the per replica batch size
1238    # the returned batch size here is per replica an not global.
1239    batch_size = self._get_batch_size(features, in_tpu_context)
1240    if batch_size is None and not self._built:
1241      raise RuntimeError("Unable to determine batch size from input features."
1242                         "Please call build() with global batch size to "
1243                         "initialize the TPU for embeddings.")
1244    if batch_size is not None:
1245      self._maybe_build(batch_size)
1246      if self._batch_size != batch_size:
1247        raise ValueError("Multiple calls to enqueue with different batch sizes "
1248                         "{} and {}.".format(self._batch_size,
1249                                             batch_size))
1250
1251    nest.assert_same_structure(self._feature_config, features)
1252
1253    flat_inputs = nest.flatten(features)
1254    flat_weights = [None] * len(flat_inputs)
1255    if weights is not None:
1256      nest.assert_same_structure(self._feature_config, weights)
1257      flat_weights = nest.flatten(weights)
1258    flat_features = nest.flatten_with_joined_string_paths(self._feature_config)
1259
1260    self._raise_error_for_inputs_not_on_cpu(features)
1261    # If we are in a tpu_context, automatically apply outside compilation.
1262    if in_tpu_context:
1263      self._raise_error_for_non_direct_inputs(features)
1264
1265      def generate_enqueue_ops():
1266        """Generate enqueue ops for outside compilation."""
1267        # Note that we put array_ops.where_v2 rather than a python if so that
1268        # the op is explicitly create and the constant ops are both in the graph
1269        # even though we don't expect training to be a tensor (and thus generate
1270        # control flow automatically). This need to make it easier to re-write
1271        # the graph later if we need to fix which mode needs to be used.
1272        mode_override = array_ops.where_v2(training,
1273                                           constant_op.constant("train"),
1274                                           constant_op.constant("inference"))
1275
1276        # Device ordinal is -1 here, a later rewrite will fix this once the op
1277        # is expanded by outside compilation.
1278        enqueue_op = self._generate_enqueue_op(
1279            flat_inputs, flat_weights, flat_features, device_ordinal=-1,
1280            mode_override=mode_override)
1281
1282        # Apply the name tag to the op.
1283        if name is not None:
1284          _add_key_attr(enqueue_op, name)
1285
1286        # Ensure that this op has outbound control flow, otherwise it won't be
1287        # executed.
1288        ops.get_default_graph().control_outputs.append(enqueue_op)
1289
1290      tpu.outside_compilation(generate_enqueue_ops)
1291
1292    elif device is None:
1293      mode_override = "train" if training else "inference"
1294      # We generate enqueue ops per device, so we need to gather the all
1295      # features for a single device in to a dict.
1296      # We rely here on the fact that the devices in the PerReplica value occur
1297      # in the same (standard) order as self._strategy.extended.worker_devices.
1298      enqueue_ops = []
1299      for replica_id in range(self._strategy.num_replicas_in_sync):
1300        replica_inputs = distribute_utils.select_replica(replica_id,
1301                                                         flat_inputs)
1302        replica_weights = distribute_utils.select_replica(replica_id,
1303                                                          flat_weights)
1304        tpu_device = self._strategy.extended.worker_devices[replica_id]
1305        # TPU devices string are like /job:worker/replica:0/task:0/device:TPU:0
1306        # the device ordinal is the last number
1307        device_ordinal = (
1308            tf_device.DeviceSpec.from_string(tpu_device).device_index)
1309        with ops.device(device_util.get_host_for_device(tpu_device)):
1310          enqueue_op = self._generate_enqueue_op(
1311              replica_inputs, replica_weights, flat_features,
1312              device_ordinal=device_ordinal, mode_override=mode_override)
1313
1314          # Apply the name tag to the op.
1315          if name is not None:
1316            _add_key_attr(enqueue_op, name)
1317          enqueue_ops.append(enqueue_op)
1318      ops.get_default_graph().control_outputs.extend(enqueue_ops)
1319    else:
1320      mode_override = "train" if training else "inference"
1321      device_spec = tf_device.DeviceSpec.from_string(device)
1322      if device_spec.device_type != "TPU":
1323        raise ValueError(
1324            "Non-TPU device {} passed to enqueue.".format(device))
1325      with ops.device(device_util.get_host_for_device(device)):
1326        enqueue_op = self._generate_enqueue_op(
1327            flat_inputs, flat_weights, flat_features,
1328            device_ordinal=device_spec.device_index,
1329            mode_override=mode_override)
1330
1331        # Apply the name tag to the op.
1332        if name is not None:
1333          _add_key_attr(enqueue_op, name)
1334        ops.get_default_graph().control_outputs.append(enqueue_op)
1335
1336  def _get_batch_size(self, tensors, in_tpu_context: bool):
1337    """Gets the batch size from a nested structure of features."""
1338    batch_size = None
1339    for path, maybe_tensor in nest.flatten_with_joined_string_paths(tensors):
1340      tensor_list = []
1341      if not in_tpu_context:
1342        # if we are not in a context, then this is PerReplica and we need to
1343        # check each replica's batch size.
1344        for replica_id in range(self._strategy.num_replicas_in_sync):
1345          tensor_list.append(distribute_utils.select_replica(replica_id,
1346                                                             maybe_tensor))
1347      else:
1348        tensor_list = [maybe_tensor]
1349
1350      for tensor in tensor_list:
1351        if tensor.shape.rank < 1:
1352          raise ValueError(
1353              "Input {} has rank 0, rank must be at least 1.".format(path))
1354        shape = tensor.shape.as_list()
1355        if shape[0] is not None:
1356          if batch_size is None:
1357            batch_size = shape[0]
1358          elif batch_size != shape[0]:
1359            raise ValueError("Found multiple batch sizes {} and {}. All inputs "
1360                             "must have the same batch dimensions size.".format(
1361                                 batch_size, shape[0]))
1362    return batch_size
1363
1364
1365@def_function.function
1366def _load_variables_impl(
1367    config: Text,
1368    hosts: List[Tuple[int, Text]],
1369    variables: Dict[Text, Dict[Text, tf_variables.Variable]],
1370    table_config: tpu_embedding_v2_utils.TableConfig):
1371  """Load embedding tables to onto TPU for each table and host.
1372
1373  Args:
1374    config: A serialized TPUEmbeddingConfiguration proto.
1375    hosts: A list of CPU devices, on per host.
1376    variables: A dictionary of dictionaries of TPUShardedVariables. First key is
1377      the table name, second key is 'parameters' or the optimizer slot name.
1378    table_config: A list of tf.tpu.experimental.embedding.TableConfig objects.
1379  """
1380  def select_fn(host_id):
1381
1382    def select_or_zeros(x):
1383      if host_id >= len(x.variables):
1384        # In the edge case where we have more hosts than variables, due to using
1385        # a small number of rows, we load zeros for the later hosts. We copy
1386        # the shape of the first host's variables, which we assume is defined
1387        # because TableConfig guarantees at least one row.
1388        return array_ops.zeros_like(x.variables[0])
1389      return x.variables[host_id]
1390
1391    return select_or_zeros
1392
1393  for host_id, host in enumerate(hosts):
1394    with ops.device(host):
1395      host_variables = nest.map_structure(select_fn(host_id), variables)
1396      for table in table_config:
1397        table.optimizer._load()(  # pylint: disable=protected-access
1398            table_name=table.name,
1399            num_shards=len(hosts),
1400            shard_id=host_id,
1401            config=config,
1402            **host_variables[table.name])
1403        # Ensure that only the first table/first host gets a config so that we
1404        # don't bloat graph by attaching this large string to each op.
1405        # We have num tables * num hosts of these so for models with a large
1406        # number of tables training on a large slice, this can be an issue.
1407        config = None
1408
1409
1410@def_function.function
1411def _retrieve_variables_impl(
1412    config: Text,
1413    hosts: List[Tuple[int, Text]],
1414    variables: Dict[Text, Dict[Text, tf_variables.Variable]],
1415    table_config: tpu_embedding_v2_utils.TableConfig):
1416  """Retrieve embedding tables from TPU to host memory.
1417
1418  Args:
1419    config: A serialized TPUEmbeddingConfiguration proto.
1420    hosts: A list of all the host CPU devices.
1421    variables: A dictionary of dictionaries of TPUShardedVariables. First key is
1422      the table name, second key is 'parameters' or the optimizer slot name.
1423    table_config: A list of tf.tpu.experimental.embedding.TableConfig objects.
1424  """
1425  for host_id, host in enumerate(hosts):
1426    with ops.device(host):
1427      for table in table_config:
1428        retrieved = table.optimizer._retrieve()(  # pylint: disable=protected-access
1429            table_name=table.name,
1430            num_shards=len(hosts),
1431            shard_id=host_id,
1432            config=config)
1433        # When there are no slot variables (e.g with SGD) this returns a
1434        # single tensor rather than a tuple. In this case we put the tensor in
1435        # a list to make the following code easier to write.
1436        if not isinstance(retrieved, tuple):
1437          retrieved = (retrieved,)
1438
1439        for i, slot in enumerate(["parameters"] +
1440                                 table.optimizer._slot_names()):  # pylint: disable=protected-access
1441          # We must assign the CPU variables the values of tensors that were
1442          # returned from the TPU.
1443          sharded_var = variables[table.name][slot]
1444          if host_id < len(sharded_var.variables):
1445            # In the edge case where we have more hosts than variables, due to
1446            # using a small number of rows, we skip the later hosts.
1447            sharded_var.variables[host_id].assign(retrieved[i])
1448        # Ensure that only the first table/first host gets a config so that we
1449        # don't bloat graph by attaching this large string to each op.
1450        # We have num tables * num hosts of these so for models with a large
1451        # number of tables training on a large slice, this can be an issue.
1452        config = None
1453
1454
1455class TPUEmbeddingSaveable(saveable_hook.SaveableHook):
1456  """Save/Restore hook to Retrieve/Load TPUEmbedding variables."""
1457
1458  def __init__(
1459      self,
1460      name: Text,
1461      load: Callable[[], Any],
1462      retrieve: Callable[[], Any]):
1463    self._load = load
1464    self._retrieve = retrieve
1465    super(TPUEmbeddingSaveable, self).__init__(name=name)
1466
1467  def before_save(self):
1468    if self._retrieve is not None:
1469      self._retrieve()
1470
1471  def after_restore(self):
1472    if self._load is not None:
1473      self._load()
1474
1475
1476def _ragged_embedding_lookup_with_reduce(
1477    table: tf_variables.Variable,
1478    ragged: ragged_tensor.RaggedTensor,
1479    weights: ragged_tensor.RaggedTensor,
1480    combiner: Text) -> core.Tensor:
1481  """Compute a ragged lookup followed by a reduce on axis 1.
1482
1483  Args:
1484    table: The embedding table.
1485    ragged: A RaggedTensor of ids to look up.
1486    weights: A RaggedTensor of weights (or None).
1487    combiner: One of "mean", "sum", "sqrtn".
1488
1489  Returns:
1490    A Tensor.
1491  """
1492  if weights is None:
1493    weights = array_ops.ones_like(ragged, dtype=table.dtype)
1494  weights = array_ops.expand_dims(weights, axis=2)
1495  ragged_result = embedding_ops.embedding_lookup_ragged(table, ragged)
1496  ragged_result = math_ops.reduce_sum(ragged_result * weights, axis=1)
1497  if combiner == "mean":
1498    ragged_result = ragged_result / math_ops.reduce_sum(weights, axis=1)
1499  elif combiner == "sqrtn":
1500    ragged_result = ragged_result, math_ops.sqrt(math_ops.reduce_sum(
1501        weights*weights, axis=1))
1502  return ragged_result
1503
1504
1505@tf_export("tpu.experimental.embedding.serving_embedding_lookup")
1506def cpu_embedding_lookup(inputs, weights, tables, feature_config):
1507  """Apply standard lookup ops with `tf.tpu.experimental.embedding` configs.
1508
1509  This function is a utility which allows using the
1510  `tf.tpu.experimental.embedding` config objects with standard lookup functions.
1511  This can be used when exporting a model which uses
1512  `tf.tpu.experimental.embedding.TPUEmbedding` for serving on CPU. In particular
1513  `tf.tpu.experimental.embedding.TPUEmbedding` only supports lookups on TPUs and
1514  should not be part of your serving graph.
1515
1516  Note that TPU specific options (such as `max_sequence_length`) in the
1517  configuration objects will be ignored.
1518
1519  In the following example we take a trained model (see the documentation for
1520  `tf.tpu.experimental.embedding.TPUEmbedding` for the context) and create a
1521  saved model with a serving function that will perform the embedding lookup and
1522  pass the results to your model:
1523
1524  ```python
1525  model = model_fn(...)
1526  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
1527      feature_config=feature_config,
1528      batch_size=1024,
1529      optimizer=tf.tpu.experimental.embedding.SGD(0.1))
1530  checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
1531  checkpoint.restore(...)
1532
1533  @tf.function(input_signature=[{'feature_one': tf.TensorSpec(...),
1534                                 'feature_two': tf.TensorSpec(...),
1535                                 'feature_three': tf.TensorSpec(...)}])
1536  def serve_tensors(embedding_featurese):
1537    embedded_features = tf.tpu.experimental.embedding.serving_embedding_lookup(
1538        embedding_features, None, embedding.embedding_tables,
1539        feature_config)
1540    return model(embedded_features)
1541
1542  model.embedding_api = embedding
1543  tf.saved_model.save(model,
1544                      export_dir=...,
1545                      signatures={'serving_default': serve_tensors})
1546
1547  ```
1548
1549  NOTE: Its important to assign the embedding api object to a member of your
1550  model as `tf.saved_model.save` only supports saving variables one `Trackable`
1551  object. Since the model's weights are in `model` and the embedding table are
1552  managed by `embedding`, we assign `embedding` to and attribute of `model` so
1553  that tf.saved_model.save can find the embedding variables.
1554
1555  NOTE: The same `serve_tensors` function and `tf.saved_model.save` call will
1556  work directly from training.
1557
1558  Args:
1559    inputs: a nested structure of Tensors, SparseTensors or RaggedTensors.
1560    weights: a nested structure of Tensors, SparseTensors or RaggedTensors or
1561      None for no weights. If not None, structure must match that of inputs, but
1562      entries are allowed to be None.
1563    tables: a dict of mapping TableConfig objects to Variables.
1564    feature_config: a nested structure of FeatureConfig objects with the same
1565      structure as inputs.
1566
1567  Returns:
1568    A nested structure of Tensors with the same structure as inputs.
1569  """
1570
1571  nest.assert_same_structure(inputs, feature_config)
1572
1573  flat_inputs = nest.flatten(inputs)
1574  flat_weights = [None] * len(flat_inputs)
1575  if weights is not None:
1576    nest.assert_same_structure(inputs, weights)
1577    flat_weights = nest.flatten(weights)
1578  flat_features = nest.flatten_with_joined_string_paths(feature_config)
1579
1580  outputs = []
1581  for inp, weight, (path, feature) in zip(
1582      flat_inputs, flat_weights, flat_features):
1583    table = tables[feature.table]
1584
1585    if weight is not None:
1586      if isinstance(inp, ops.Tensor):
1587        raise ValueError(
1588            "Weight specified for {}, but input is dense.".format(path))
1589      elif type(weight) is not type(inp):
1590        raise ValueError(
1591            "Weight for {} is of type {} but it does not match type of the "
1592            "input which is {}.".format(path, type(weight), type(inp)))
1593      elif feature.max_sequence_length > 0:
1594        raise ValueError("Weight specified for {}, but this is a sequence "
1595                         "feature.".format(path))
1596
1597    if isinstance(inp, ops.Tensor):
1598      if feature.max_sequence_length > 0:
1599        raise ValueError("Feature {} is a sequence feature but a dense tensor "
1600                         "was passed.".format(path))
1601      outputs.append(embedding_ops.embedding_lookup_v2(table, inp))
1602
1603    elif isinstance(inp, sparse_tensor.SparseTensor):
1604      if feature.max_sequence_length > 0:
1605        batch_size = math_ops.cast(array_ops.shape(inp)[0], dtype=dtypes.int64)
1606        sparse_shape = array_ops.concat(
1607            [batch_size, feature.max_sequence_length], axis=0)
1608        # TPU Embedding truncates sequences to max_sequence_length, and if we
1609        # don't truncate, scatter_nd will error out if the index was out of
1610        # bounds.
1611        truncated_inp = sparse_ops.sparse_slice(inp, start=[0, 0],
1612                                                size=sparse_shape)
1613
1614        dense_output_shape = array_ops.concat(
1615            [batch_size, feature.max_sequence_length, feature.table.dim],
1616            axis=0)
1617        outputs.append(
1618            array_ops.scatter_nd(
1619                inp.indices, array_ops.gather(table, truncated_inp.values),
1620                dense_output_shape))
1621      else:
1622        outputs.append(embedding_ops.safe_embedding_lookup_sparse_v2(
1623            table, inp, sparse_weights=weight, combiner=feature.table.combiner))
1624
1625    elif isinstance(inp, ragged_tensor.RaggedTensor):
1626      if feature.max_sequence_length > 0:
1627        batch_size = inp.shape[0]
1628        dense_output_shape = [
1629            batch_size, feature.max_sequence_length, feature.table.dim]
1630        ragged_lookup = embedding_ops.embedding_lookup_v2(table, inp)
1631        # Unlike scatter_nd, RaggedTensor.to_tensor truncates to the given
1632        # shape.
1633        outputs.append(ragged_lookup.to_tensor(shape=dense_output_shape))
1634      else:
1635        outputs.append(_ragged_embedding_lookup_with_reduce(
1636            table, inp, weight, feature.table.combiner))
1637
1638    else:
1639      raise ValueError("Input {} is type {}. Tensor, SparseTensor or "
1640                       "RaggedTensor expected.".format(path, type(inp)))
1641  return nest.pack_sequence_as(feature_config, outputs)
1642
1643
1644def get_list_of_hosts(strategy: tpu_strategy.TPUStrategy) -> List[Text]:
1645  """Returns a sorted list of CPU devices for the remote jobs.
1646
1647  Args:
1648    strategy: A TPUStrategy object.
1649
1650  Returns:
1651    A sort list of device strings.
1652  """
1653  list_of_hosts = []
1654  # Assume this is sorted by task
1655  for tpu_device in strategy.extended.worker_devices:
1656    host = device_util.get_host_for_device(tpu_device)
1657    if host not in list_of_hosts:
1658      list_of_hosts.append(host)
1659  assert len(list_of_hosts) == strategy.extended.num_hosts
1660  return list_of_hosts
1661
1662
1663def extract_variable_info(
1664    kwargs) -> Tuple[Text, Tuple[int, ...], dtypes.DType, Callable[[], Any]]:
1665  """Extracts the variable creation attributes from the kwargs.
1666
1667  Args:
1668    kwargs: a dict of keyword arguments that were passed to a variable creator
1669      scope.
1670
1671  Returns:
1672    A tuple of variable name, shape, dtype, initialization function.
1673  """
1674  if (isinstance(kwargs["initial_value"], functools.partial) and (
1675      "shape" in kwargs["initial_value"].keywords or
1676      kwargs["initial_value"].args)):
1677    # Sometimes shape is passed positionally, sometimes it's passed as a kwarg.
1678    if "shape" in kwargs["initial_value"].keywords:
1679      shape = kwargs["initial_value"].keywords["shape"]
1680    else:
1681      shape = kwargs["initial_value"].args[0]
1682    return (kwargs["name"], shape,
1683            kwargs["initial_value"].keywords.get("dtype", kwargs["dtype"]),
1684            kwargs["initial_value"].func)
1685  elif "shape" not in kwargs or kwargs["shape"] is None or not callable(
1686      kwargs["initial_value"]):
1687    raise ValueError(
1688        "Unable to extract initializer function and shape from {}. Please "
1689        "either pass a function that expects a shape and dtype as the "
1690        "initial value for your variable or functools.partial object with "
1691        "the shape and dtype kwargs set. This is needed so that we can "
1692        "initialize the shards of the ShardedVariable locally.".format(
1693            kwargs["initial_value"]))
1694  else:
1695    return (kwargs["name"], kwargs["shape"], kwargs["dtype"],
1696            kwargs["initial_value"])
1697
1698
1699def make_sharded_variable_creator(
1700    hosts: List[Text]) -> Callable[..., TPUShardedVariable]:
1701  """Makes a sharded variable creator given a list of hosts.
1702
1703  Args:
1704    hosts: a list of tensorflow devices on which to shard the tensors.
1705
1706  Returns:
1707    A variable creator function.
1708  """
1709
1710  def sharded_variable_creator(
1711      next_creator: Callable[..., tf_variables.Variable], *args, **kwargs):
1712    """The sharded variable creator."""
1713    kwargs["skip_mirrored_creator"] = True
1714
1715    num_hosts = len(hosts)
1716    name, shape, dtype, unwrapped_initial_value = extract_variable_info(kwargs)
1717    initial_value = kwargs["initial_value"]
1718    rows = shape[0]
1719    cols = shape[1]
1720    partial_partition = rows % num_hosts
1721    full_rows_per_host = rows // num_hosts
1722    # We partition as if we were using MOD sharding: at least
1723    # `full_rows_per_host` rows to `num_hosts` hosts, where the first
1724    # `partial_partition` hosts get an additional row when the number of rows
1725    # is not cleanly divisible. Note that `full_rows_per_host` may be zero.
1726    partitions = (
1727        [full_rows_per_host + 1] * partial_partition
1728        + [full_rows_per_host] * (num_hosts - partial_partition))
1729    variables = []
1730    sharding_aware = "shard_info" in tf_inspect.getargspec(initial_value).args
1731
1732    # Keep track of offset for sharding aware initializers.
1733    offset = 0
1734    kwargs["dtype"] = dtype
1735    for i, p in enumerate(partitions):
1736      if p == 0:
1737        # Skip variable creation for empty partitions, resulting from the edge
1738        # case of 'rows < num_hosts'. This is safe because both load/restore
1739        # can handle the missing values.
1740        continue
1741      with ops.device(hosts[i]):
1742        kwargs["name"] = "{}_{}".format(name, i)
1743        kwargs["shape"] = (p, cols)
1744        if sharding_aware:
1745          shard_info = base.ShardInfo(kwargs["shape"], (offset, 0))
1746          kwargs["initial_value"] = functools.partial(
1747              initial_value, shard_info=shard_info)
1748          offset += p
1749        else:
1750          kwargs["initial_value"] = functools.partial(
1751              unwrapped_initial_value, kwargs["shape"], dtype=dtype)
1752        variables.append(next_creator(*args, **kwargs))
1753    return TPUShardedVariable(variables, name=name)
1754  return sharded_variable_creator
1755