• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU embedding APIs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import math
24import re
25
26import six
27
28from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
29from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as elc
30from tensorflow.python.eager import context
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import init_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import partitioned_variables
38from tensorflow.python.ops import state_ops
39from tensorflow.python.ops import variable_scope
40from tensorflow.python.platform import tf_logging as logging
41from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
42from tensorflow.python.tpu.ops import tpu_ops
43from tensorflow.python.util.tf_export import tf_export
44
45TRAINING = elc.TPUEmbeddingConfiguration.TRAINING
46INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE
47
48
49# TODO(shizhiw): a more future-proof way is to have optimization_parameter such
50#  as AdagradParameters etc instead of learning_rate.
51class TableConfig(
52    collections.namedtuple('TableConfig', [
53        'vocabulary_size', 'dimension', 'initializer', 'combiner',
54        'hot_id_replication', 'learning_rate', 'learning_rate_fn'
55    ])):
56  """Embedding table configuration."""
57
58  def __new__(cls,
59              vocabulary_size,
60              dimension,
61              initializer=None,
62              combiner='mean',
63              hot_id_replication=False,
64              learning_rate=None,
65              learning_rate_fn=None):
66    """Embedding table configuration.
67
68    Args:
69      vocabulary_size: Number of vocabulary (/rows) in the table.
70      dimension: The embedding dimension.
71      initializer: A variable initializer function to be used in embedding
72        variable initialization. If not specified, defaults to
73        `tf.compat.v1.truncated_normal_initializer` with mean `0.0` and standard
74        deviation `1/sqrt(dimension)`.
75      combiner: A string specifying how to reduce if there are multiple entries
76        in a single row. Currently 'mean', 'sqrtn', 'sum' and None are
77        supported, with 'mean' the default. 'sqrtn' often achieves good
78        accuracy, in particular with bag-of-words columns. For more information,
79        see `tf.nn.embedding_lookup_sparse`. None is only valid for dense rather
80        than sparse tensors.
81      hot_id_replication: If true, enables hot id replication, which can make
82        embedding lookups faster if there are some hot rows in the table.
83      learning_rate: float, static learning rate for this table. If
84        learning_rate and learning_rate_fn are both `None`, global
85        static learning rate as specified in `optimization_parameters` in
86        `TPUEmbedding` constructor will be used. `learning_rate_fn` must be
87        `None` if `learning_rate` is not `None.
88      learning_rate_fn: string, use dynamic learning rate given by the function.
89        This function function will be passed the current global step. If
90        learning_rate and learning_rate_fn are both `None`, global static
91        learning rate as specified in `optimization_parameters` in
92        `TPUEmbedding` constructor will be used. `learning_rate` must be `None`
93        if `learning_rate_fn` is not `None.
94
95    Returns:
96      `TableConfig`.
97
98    Raises:
99      ValueError: if `vocabulary_size` is not positive integer.
100      ValueError: if `dimension` is not positive integer.
101      ValueError: if `initializer` is specified and is not callable.
102      ValueError: if `combiner` is not supported.
103      ValueError: if `learning_rate` and `learning_rate_fn` are both not
104        `None`.
105    """
106    if not isinstance(vocabulary_size, int) or vocabulary_size < 1:
107      raise ValueError('Invalid vocabulary_size {}.'.format(vocabulary_size))
108
109    if not isinstance(dimension, int) or dimension < 1:
110      raise ValueError('Invalid dimension {}.'.format(dimension))
111
112    if (initializer is not None) and (not callable(initializer)):
113      raise ValueError('initializer must be callable if specified.')
114    if initializer is None:
115      initializer = init_ops.truncated_normal_initializer(
116          mean=0.0, stddev=1 / math.sqrt(dimension))
117
118    if combiner not in ('mean', 'sum', 'sqrtn', None):
119      raise ValueError('Invalid combiner {}'.format(combiner))
120
121    if learning_rate is not None and learning_rate_fn is not None:
122      raise ValueError('At most one of learning_rate and learning_rate_fn '
123                       'can be None; got {} and {}'
124                       .format(learning_rate, learning_rate_fn))
125
126    return super(TableConfig, cls).__new__(
127        cls, vocabulary_size, dimension, initializer, combiner,
128        hot_id_replication, learning_rate, learning_rate_fn)
129
130
131class FeatureConfig(
132    collections.namedtuple(
133        'FeatureConfig',
134        ['table_id', 'max_sequence_length', 'weight_key'])):
135  """Feature configuration."""
136
137  def __new__(cls,
138              table_id,
139              max_sequence_length=0,
140              weight_key=None):
141    """Feature configuration.
142
143    Args:
144      table_id: Which table the feature is uses for embedding lookups.
145      max_sequence_length: If positive, the feature is a sequence feature with
146        the corresponding maximum sequence length. If the sequence is longer
147        than this, it will be truncated. If 0, the feature is not a sequence
148        feature.
149      weight_key: If using weights for the combiner, this key specifies which
150        input feature contains the weights.
151
152    Returns:
153      `FeatureConfig`.
154
155    Raises:
156      ValueError: if `max_sequence_length` non-negative.
157    """
158    if not isinstance(max_sequence_length, int) or max_sequence_length < 0:
159      raise ValueError('Invalid max_sequence_length {}.'.format(
160          max_sequence_length))
161
162    return super(FeatureConfig, cls).__new__(cls, table_id, max_sequence_length,
163                                             weight_key)
164
165
166class EnqueueData(
167    collections.namedtuple(
168        'EnqueueData',
169        ['embedding_indices', 'sample_indices', 'aggregation_weights'])):
170  """Data to be enqueued through generate_enqueue_ops()."""
171
172  def __new__(cls,
173              embedding_indices,
174              sample_indices=None,
175              aggregation_weights=None):
176    """Data to be enqueued through generate_enqueue_ops().
177
178    Args:
179      embedding_indices: A rank 1 Tensors, indices into the embedding tables. It
180        corresponds to sp_ids.values in embedding_lookup_sparse(). Both int32
181        and int64 are allowed and will be converted to int32 internally.
182      sample_indices: A rank 2 Tensors specifying the training example to which
183        the corresponding embedding_indices and aggregation_weights values
184        belong. It corresponds to sp_ids.indices in embedding_lookup_sparse().
185        If it is None, we assume each embedding_indices belongs to a different
186        sample. Both int32 and int64 are allowed and will be converted to int32
187        internally.
188      aggregation_weights: A rank 1 Tensors containing aggregation weights.
189        It corresponds to sp_weights.values in embedding_lookup_sparse(). If it
190        is None, we assume all weights are 1. Both float32 and float64 are
191        allowed and will be converted to float32 internally.
192
193    Returns:
194      An EnqueueData tuple.
195
196    """
197    return super(EnqueueData, cls).__new__(cls, embedding_indices,
198                                           sample_indices, aggregation_weights)
199
200  @staticmethod
201  def from_sparse_tensor(sp_tensor, weights=None):
202    return EnqueueData(
203        sp_tensor.values,
204        sp_tensor.indices,
205        aggregation_weights=weights.values if weights is not None else None)
206
207
208def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list):
209  """Convenient function for generate_enqueue_ops().
210
211  Args:
212    sp_tensors_list: a list of dictionary mapping from string of feature names
213      to SparseTensor. Each dictionary is for one TPU core. Dictionaries for the
214      same host should be contiguous on the list.
215
216  Returns:
217    enqueue_datas_list: a list of dictionary mapping from string
218      of feature names to EnqueueData. Each dictionary is for one
219      TPU core. Dictionaries for the same host should be contiguous
220      on the list.
221
222  """
223  enqueue_datas_list = []
224  for sp_tensors in sp_tensors_list:
225    enqueue_datas = collections.OrderedDict(
226        (k, EnqueueData.from_sparse_tensor(v))
227        for k, v in six.iteritems(sp_tensors))
228    enqueue_datas_list.append(enqueue_datas)
229  return enqueue_datas_list
230
231
232AdamSlotVariableNames = collections.namedtuple(
233    'AdamSlotVariableNames', ['m', 'v'])
234
235AdagradSlotVariableName = collections.namedtuple(
236    'AdagradSlotVariableName', ['accumulator'])
237
238FtrlSlotVariableName = collections.namedtuple(
239    'FtrlSlotVariableName', ['accumulator', 'linear'])
240
241AdamSlotVariables = collections.namedtuple(
242    'AdamSlotVariables', ['m', 'v'])
243
244AdagradSlotVariable = collections.namedtuple(
245    'AdagradSlotVariable', ['accumulator'])
246
247FtrlSlotVariable = collections.namedtuple(
248    'FtrlSlotVariable', ['accumulator', 'linear'])
249
250VariablesAndOps = collections.namedtuple(
251    'VariablesAndOps',
252    ['embedding_variables_by_table', 'slot_variables_by_table',
253     'load_ops', 'retrieve_ops']
254)
255
256
257class _OptimizationParameters(object):
258  """Parameters common to all optimizations."""
259
260  def __init__(self, learning_rate, use_gradient_accumulation,
261               clip_weight_min, clip_weight_max):
262    self.learning_rate = learning_rate
263    self.use_gradient_accumulation = use_gradient_accumulation
264    self.clip_weight_min = clip_weight_min
265    self.clip_weight_max = clip_weight_max
266
267
268@tf_export(v1=['tpu.experimental.AdagradParameters'])
269class AdagradParameters(_OptimizationParameters):
270  """Optimization parameters for Adagrad with TPU embeddings.
271
272  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
273  `optimization_parameters` argument to set the optimizer and its parameters.
274  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
275  for more details.
276
277  ```
278  estimator = tf.estimator.tpu.TPUEstimator(
279      ...
280      embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
281          ...
282          optimization_parameters=tf.tpu.experimental.AdagradParameters(0.1),
283          ...))
284  ```
285
286  """
287
288  def __init__(self,
289               learning_rate,
290               initial_accumulator=0.1,
291               use_gradient_accumulation=True,
292               clip_weight_min=None,
293               clip_weight_max=None):
294    """Optimization parameters for Adagrad.
295
296    Args:
297      learning_rate: used for updating embedding table.
298      initial_accumulator: initial accumulator for Adagrad.
299      use_gradient_accumulation: setting this to `False` makes embedding
300        gradients calculation less accurate but faster. Please see
301        `optimization_parameters.proto` for details.
302        for details.
303      clip_weight_min: the minimum value to clip by; None means -infinity.
304      clip_weight_max: the maximum value to clip by; None means +infinity.
305    """
306    super(AdagradParameters,
307          self).__init__(learning_rate, use_gradient_accumulation,
308                         clip_weight_min, clip_weight_max)
309    if initial_accumulator <= 0:
310      raise ValueError('Adagrad initial_accumulator must be positive')
311    self.initial_accumulator = initial_accumulator
312
313
314@tf_export(v1=['tpu.experimental.AdamParameters'])
315class AdamParameters(_OptimizationParameters):
316  """Optimization parameters for Adam with TPU embeddings.
317
318  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
319  `optimization_parameters` argument to set the optimizer and its parameters.
320  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
321  for more details.
322
323  ```
324  estimator = tf.estimator.tpu.TPUEstimator(
325      ...
326      embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
327          ...
328          optimization_parameters=tf.tpu.experimental.AdamParameters(0.1),
329          ...))
330  ```
331
332  """
333
334  def __init__(self,
335               learning_rate,
336               beta1=0.9,
337               beta2=0.999,
338               epsilon=1e-08,
339               lazy_adam=True,
340               sum_inside_sqrt=True,
341               use_gradient_accumulation=True,
342               clip_weight_min=None,
343               clip_weight_max=None):
344    """Optimization parameters for Adam.
345
346    Args:
347      learning_rate: a floating point value. The learning rate.
348      beta1: A float value.
349        The exponential decay rate for the 1st moment estimates.
350      beta2: A float value.
351        The exponential decay rate for the 2nd moment estimates.
352      epsilon: A small constant for numerical stability.
353      lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster.
354        Please see `optimization_parameters.proto` for details.
355      sum_inside_sqrt: This improves training speed. Please see
356        `optimization_parameters.proto` for details.
357      use_gradient_accumulation: setting this to `False` makes embedding
358        gradients calculation less accurate but faster. Please see
359        `optimization_parameters.proto` for details.
360        for details.
361      clip_weight_min: the minimum value to clip by; None means -infinity.
362      clip_weight_max: the maximum value to clip by; None means +infinity.
363    """
364    super(AdamParameters,
365          self).__init__(learning_rate, use_gradient_accumulation,
366                         clip_weight_min, clip_weight_max)
367    if beta1 < 0. or beta1 >= 1.:
368      raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
369    if beta2 < 0. or beta2 >= 1.:
370      raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2))
371    if epsilon <= 0.:
372      raise ValueError('epsilon must be positive; got {}.'.format(epsilon))
373    if not use_gradient_accumulation and not lazy_adam:
374      raise ValueError(
375          'When disabling Lazy Adam, gradient accumulation must be used.')
376
377    self.beta1 = beta1
378    self.beta2 = beta2
379    self.epsilon = epsilon
380    self.lazy_adam = lazy_adam
381    self.sum_inside_sqrt = sum_inside_sqrt
382
383
384@tf_export(v1=['tpu.experimental.FtrlParameters'])
385class FtrlParameters(_OptimizationParameters):
386  """Optimization parameters for Ftrl with TPU embeddings.
387
388  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
389  `optimization_parameters` argument to set the optimizer and its parameters.
390  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
391  for more details.
392
393  ```
394  estimator = tf.estimator.tpu.TPUEstimator(
395      ...
396      embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
397          ...
398          optimization_parameters=tf.tpu.experimental.FtrlParameters(0.1),
399          ...))
400  ```
401
402  """
403
404  def __init__(self,
405               learning_rate,
406               learning_rate_power=-0.5,
407               initial_accumulator_value=0.1,
408               l1_regularization_strength=0.0,
409               l2_regularization_strength=0.0,
410               use_gradient_accumulation=True,
411               clip_weight_min=None,
412               clip_weight_max=None):
413    """Optimization parameters for Ftrl.
414
415    Args:
416      learning_rate: a floating point value. The learning rate.
417      learning_rate_power: A float value, must be less or equal to zero.
418        Controls how the learning rate decreases during training. Use zero for
419        a fixed learning rate. See section 3.1 in the
420        [paper](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf).
421      initial_accumulator_value: The starting value for accumulators.
422        Only zero or positive values are allowed.
423      l1_regularization_strength: A float value, must be greater than or
424        equal to zero.
425      l2_regularization_strength: A float value, must be greater than or
426        equal to zero.
427      use_gradient_accumulation: setting this to `False` makes embedding
428        gradients calculation less accurate but faster. Please see
429        `optimization_parameters.proto` for details.
430        for details.
431      clip_weight_min: the minimum value to clip by; None means -infinity.
432      clip_weight_max: the maximum value to clip by; None means +infinity.
433    """
434    super(FtrlParameters,
435          self).__init__(learning_rate, use_gradient_accumulation,
436                         clip_weight_min, clip_weight_max)
437    if learning_rate_power > 0.:
438      raise ValueError('learning_rate_power must be less than or equal to 0. '
439                       'got {}.'.format(learning_rate_power))
440
441    if initial_accumulator_value < 0.:
442      raise ValueError('initial_accumulator_value must be greater than or equal'
443                       ' to 0. got {}.'.format(initial_accumulator_value))
444
445    if l1_regularization_strength < 0.:
446      raise ValueError('l1_regularization_strength must be greater than or '
447                       'equal to 0. got {}.'.format(l1_regularization_strength))
448
449    if l2_regularization_strength < 0.:
450      raise ValueError('l2_regularization_strength must be greater than or '
451                       'equal to 0. got {}.'.format(l2_regularization_strength))
452
453    self.learning_rate_power = learning_rate_power
454    self.initial_accumulator_value = initial_accumulator_value
455    self.initial_linear_value = 0.0
456    self.l1_regularization_strength = l1_regularization_strength
457    self.l2_regularization_strength = l2_regularization_strength
458
459
460@tf_export(v1=['tpu.experimental.StochasticGradientDescentParameters'])
461class StochasticGradientDescentParameters(_OptimizationParameters):
462  """Optimization parameters for stochastic gradient descent for TPU embeddings.
463
464  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
465  `optimization_parameters` argument to set the optimizer and its parameters.
466  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
467  for more details.
468
469  ```
470  estimator = tf.estimator.tpu.TPUEstimator(
471      ...
472      embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
473          ...
474          optimization_parameters=(
475              tf.tpu.experimental.StochasticGradientDescentParameters(0.1))))
476  ```
477
478  """
479
480  def __init__(self, learning_rate, clip_weight_min=None,
481               clip_weight_max=None):
482    """Optimization parameters for stochastic gradient descent.
483
484    Args:
485      learning_rate: a floating point value. The learning rate.
486      clip_weight_min: the minimum value to clip by; None means -infinity.
487      clip_weight_max: the maximum value to clip by; None means +infinity.
488    """
489    super(StochasticGradientDescentParameters,
490          self).__init__(learning_rate, False, clip_weight_min, clip_weight_max)
491
492
493DeviceConfig = collections.namedtuple('DeviceConfig',
494                                      ['num_hosts', 'num_cores', 'job_name'])
495
496
497class TPUEmbedding(object):
498  """API for using TPU for embedding.
499
500    Example:
501    ```
502    table_config_user = tpu_embedding.TableConfig(
503        vocabulary_size=4, dimension=2,
504        initializer=initializer, combiner='mean')
505    table_to_config_dict = {'video': table_config_video,
506                          'user': table_config_user}
507    feature_to_config_dict = {'watched': tpu_embedding.FeatureConfig('video'),
508                              'favorited': tpu_embedding.FeatureConfig('video'),
509                              'friends': tpu_embedding.FeatureConfig('user')}
510    batch_size = 4
511    num_hosts = 1
512    optimization_parameters = tpu_embedding.AdagradParameters(1., 1.)
513    mode = tpu_embedding.TRAINING
514    embedding = tpu_embedding.TPUEmbedding(
515        table_to_config_dict, feature_to_config_dict,
516        batch_size, num_hosts, mode, optimization_parameters)
517
518    batch_size_per_core = embedding.batch_size_per_core
519    sparse_features_list = []
520    for host in hosts:
521      with ops.device(host):
522        for _ in range(embedding.num_cores_per_host):
523          sparse_features = {}
524          sparse_features['watched'] = sparse_tensor.SparseTensor(...)
525          sparse_features['favorited'] = sparse_tensor.SparseTensor(...)
526          sparse_features['friends'] = sparse_tensor.SparseTensor(...)
527          sparse_features_list.append(sparse_features)
528
529    enqueue_ops = embedding.generate_enqueue_ops(sparse_features_list)
530    embedding_variables_and_ops = embedding.create_variables_and_ops()
531
532    def computation():
533      activations = embedding.get_activations()
534      loss = compute_loss(activations)
535
536      base_optimizer = gradient_descent.GradientDescentOptimizer(
537          learning_rate=1)
538      cross_shard_optimizer = tpu_optimizer.CrossShardOptimizer(
539          base_optimizer)
540
541      train_op = cross_shard_optimizer.minimize(loss)
542      gradients = (
543          tpu_embedding_gradient.get_gradients_through_compute_gradients(
544              cross_shard_optimizer, loss, activations)
545      send_gradients_op = embedding.generate_send_gradients_op(gradients)
546      with ops.control_dependencies([train_op, send_gradients_op]):
547        loss = array_ops.identity(loss)
548
549    loss = tpu.shard(computation,
550                     num_shards=embedding.num_cores)
551
552    with self.test_session() as sess:
553      sess.run(tpu.initialize_system(embedding_config=
554                                     embedding.config_proto))
555      sess.run(variables.global_variables_initializer())
556      sess.run(embedding_variables_and_ops.load_ops())
557      sess.run(enqueue_ops)
558      loss_val = sess.run(loss)
559    ```
560  """
561
562  # TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that
563  # the feature should not be used to update embedding table (cr/204852758,
564  # cr/204940540). Also, this can support different combiners for different
565  # features within the same table.
566  # TODO(shizhiw, b/118512626): Remove `batch_size` from `__init__` and move it
567  # to `FeatureConfig`?
568
569  # TODO(shizhiw): will it be cleaner to make `table_to_config_dict` and
570  # `feature_to_config_dict` lists of `TableSpec` and `FeatureSpec`
571  # respectively?
572
573  # TODO(shizhiw): Consider adding `input_fn` as an option to remove boilerplate
574  # for-loops around construction of inputs.
575
576  # `optimization_parameter` applies to all tables. If the need arises,
577  # we can add `optimization_parameters` to `TableConfig` to override this
578  # global setting.
579  def __init__(self,
580               table_to_config_dict,
581               feature_to_config_dict,
582               batch_size,
583               mode,
584               master=None,
585               optimization_parameters=None,
586               cluster_def=None,
587               pipeline_execution_with_tensor_core=False,
588               partition_strategy='div',
589               device_config=None,
590               master_job_name=None):
591    """API for using TPU for embedding lookups.
592
593    Args:
594      table_to_config_dict: A dictionary mapping from string of table name to
595        `TableConfig`. Table refers to an embedding table, e.g. `params`
596        argument to `tf.nn.embedding_lookup_sparse()`.
597      feature_to_config_dict: A dictionary mapping from string of feature name
598        to `FeatureConfig`. Feature refers to ids to lookup in embedding table,
599        e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`.
600      batch_size: An `int` representing the global batch size.
601      mode: `TRAINING` or `INFERENCE`.
602      master: A `string` representing the TensorFlow master to use.
603      optimization_parameters: `AdagradParameters`, `AdamParameters`,
604        `Stochasticgradientdescentparameters`. Must be set in training and must
605        be `None` in inference.
606      cluster_def: A ClusterDef object describing the TPU cluster.
607      pipeline_execution_with_tensor_core: setting this to `True` makes training
608        faster, but trained model will be different if step N and step N+1
609        involve the same set of embedding IDs. Please see
610        `tpu_embedding_configuration.proto` for details.
611      partition_strategy: A string, either 'mod' or 'div', specifying how to map
612        the lookup id to the embedding tensor. For more information see
613        `tf.nn.embedding_lookup_sparse`.
614      device_config: A DeviceConfig instance, used when `master` and
615        `cluster_def` are both `None`.
616      master_job_name: if set, overrides the master job name used to schedule
617        embedding ops.
618
619    Raises:
620      ValueError: if any input is invalid.
621    """
622    if partition_strategy not in ('div', 'mod'):
623      raise ValueError(
624          'Invalid partition_strategy {}'.format(partition_strategy))
625    self._partition_strategy = partition_strategy
626
627    _validate_table_to_config_dict(table_to_config_dict)
628    # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`.
629    self._table_to_config_dict = _create_ordered_dict(table_to_config_dict)
630
631    _validate_feature_to_config_dict(table_to_config_dict,
632                                     feature_to_config_dict)
633    self._feature_to_config_dict = _create_ordered_dict(feature_to_config_dict)
634    self._table_to_features_dict, self._table_to_num_features_dict = (
635        _create_table_to_features_and_num_features_dicts(
636            self._feature_to_config_dict))
637    self._combiners = _create_combiners(self._table_to_config_dict,
638                                        self._table_to_features_dict)
639
640    self._batch_size = batch_size
641
642    if master is None and cluster_def is None:
643      if device_config is None:
644        raise ValueError('When master and cluster_def are both None,'
645                         'device_config must be set but is not.')
646      if device_config.num_cores % device_config.num_hosts:
647        raise ValueError('num_hosts ({}) should divide num_cores ({}) '
648                         'but does not.'.format(device_config.num_cores,
649                                                device_config.num_hosts))
650      self._num_hosts = device_config.num_hosts
651      self._num_cores = device_config.num_cores
652      self._num_cores_per_host = self._num_cores // self._num_hosts
653      self._hosts = [
654          '{}/replica:0/task:{}/device:CPU:0'.format(device_config.job_name, i)
655          for i in range(self._num_hosts)
656      ]
657    else:
658      tpu_system_metadata = (
659          tpu_system_metadata_lib._query_tpu_system_metadata(  # pylint: disable=protected-access
660              master,
661              cluster_def=cluster_def))
662      if tpu_system_metadata.num_cores == 0:
663        raise ValueError('TPUEmbedding needs TPUs, but master {} does not have '
664                         'TPUs.'.format(master))
665      self._num_hosts = tpu_system_metadata.num_hosts
666      if master_job_name is None:
667        try:
668          master_job_name = tpu_system_metadata_lib.master_job(master,
669                                                               cluster_def)
670        except ValueError as e:
671          raise ValueError(str(e) + ' Please specify a master_job_name.')
672      self._hosts = []
673      for device in tpu_system_metadata.devices:
674        if 'device:CPU:' in device.name and (
675            master_job_name is None or master_job_name in device.name):
676          self._hosts.append(device.name)
677      self._num_cores_per_host = tpu_system_metadata.num_of_cores_per_host
678      self._num_cores = tpu_system_metadata.num_cores
679
680    _validate_batch_size(self._batch_size, self._num_cores)
681    self._batch_size_per_core = self._batch_size // self._num_cores
682
683    # TODO(shizhiw): remove `mode`?
684    if mode == TRAINING:
685      _validate_optimization_parameters(optimization_parameters)
686      self._optimization_parameters = optimization_parameters
687    elif mode == INFERENCE:
688      if optimization_parameters is not None:
689        raise ValueError('`optimization_parameters` should be `None` '
690                         'for inference mode.')
691      self._optimization_parameters = (
692          StochasticGradientDescentParameters(1.))
693    else:
694      raise ValueError('`mode` only supports {} and {}; got {}.'
695                       .format(TRAINING, INFERENCE, mode))
696    self._mode = mode
697
698    # TODO(shizhiw): move `optimization_parameters` into `_optimizer_handler`
699    # and create special handler for inference that inherits from
700    # StochasticGradientDescentHandler with more user-friendly error message
701    # on get_slot().
702    self._optimizer_handler = _get_optimization_handler(
703        self._optimization_parameters)
704    self._pipeline_execution_with_tensor_core = (
705        pipeline_execution_with_tensor_core)
706    self._learning_rate_fn = list(set(
707        c.learning_rate_fn for c in self._table_to_config_dict.values()
708        if c.learning_rate_fn is not None))
709    self._learning_rate_fn_to_tag = {
710        fn: id for id, fn in enumerate(self._learning_rate_fn)}
711
712    self._config_proto = self._create_config_proto()
713
714  @property
715  def hosts(self):
716    """A list of device names for CPU hosts.
717
718    Returns:
719      A list of device names for CPU hosts.
720    """
721    return copy.copy(self._hosts)
722
723  # TODO(shizhiw): change to num_tensor_cores_per_host to be more explicit and
724  # to be consistent with `tpu_embedding_configuration.proto`.
725  @property
726  def num_cores_per_host(self):
727    """Number of TPU cores on a CPU host.
728
729    Returns:
730      Number of TPU cores on a CPU host.
731    """
732    return self._num_cores_per_host
733
734  @property
735  def num_cores(self):
736    """Total number of TPU cores on all hosts.
737
738    Returns:
739      Total number of TPU cores on all hosts.
740    """
741    return self._num_cores
742
743  @property
744  def batch_size_per_core(self):
745    """Batch size for each TPU core.
746
747    The sparse tensors in `sparse_features_list` to `generate_enqueue_ops`
748       must have batch dimension equal to this.
749
750    Returns:
751      Batch size for each TPU core.
752    """
753    return self._batch_size_per_core
754
755  @property
756  def config_proto(self):
757    """Create embedding config proto for `tpu.initialize_system()`.
758
759    Returns:
760      an `TPUEmbeddingConfiguration` proto describing the desired
761         configuration of the hardware embedding lookup tables, which
762         is passed to `tpu.initialize_system()`.
763    """
764    return self._config_proto
765
766  @property
767  def table_to_config_dict(self):
768    return copy.copy(self._table_to_config_dict)
769
770  @property
771  def feature_to_config_dict(self):
772    return copy.copy(self._feature_to_config_dict)
773
774  @property
775  def table_to_features_dict(self):
776    return copy.copy(self._table_to_features_dict)
777
778  @property
779  def optimization_parameters(self):
780    return self._optimization_parameters
781
782  def _create_config_proto(self):
783    """Create `TPUEmbeddingConfiguration`."""
784    config_proto = elc.TPUEmbeddingConfiguration()
785    for table in self._table_to_config_dict:
786      table_descriptor = config_proto.table_descriptor.add()
787      table_descriptor.name = table
788
789      table_config = self._table_to_config_dict[table]
790      # For small tables, we pad to the number of hosts so that at least one
791      # id will be assigned to each host.
792      table_descriptor.vocabulary_size = max(table_config.vocabulary_size,
793                                             len(self.hosts))
794      table_descriptor.dimension = table_config.dimension
795
796      table_descriptor.num_features = self._table_to_num_features_dict[table]
797
798      parameters = table_descriptor.optimization_parameters
799      if table_config.learning_rate:
800        parameters.learning_rate.constant = (table_config.learning_rate)
801      elif table_config.learning_rate_fn:
802        parameters.learning_rate.dynamic.tag = (
803            self._learning_rate_fn_to_tag[table_config.learning_rate_fn])
804      else:
805        parameters.learning_rate.constant = (
806            self._optimization_parameters.learning_rate)
807      parameters.gradient_accumulation_status = (
808          optimization_parameters_pb2.GradientAccumulationStatus.ENABLED
809          if self._optimization_parameters.use_gradient_accumulation else
810          optimization_parameters_pb2.GradientAccumulationStatus.DISABLED)
811      if self._optimization_parameters.clip_weight_min is not None:
812        parameters.clipping_limits.lower.value = (
813            self._optimization_parameters.clip_weight_min)
814      if self._optimization_parameters.clip_weight_max is not None:
815        parameters.clipping_limits.upper.value = (
816            self._optimization_parameters.clip_weight_max)
817      if table_config.hot_id_replication:
818        parameters.hot_id_replication_configuration.status = (
819            optimization_parameters_pb2.HotIdReplicationConfiguration.ENABLED)
820      self._optimizer_handler.set_optimization_parameters(table_descriptor)
821
822    config_proto.mode = self._mode
823    config_proto.batch_size_per_tensor_core = self._batch_size_per_core
824    config_proto.num_hosts = self._num_hosts
825    config_proto.num_tensor_cores = self._num_cores
826    config_proto.sharding_strategy = (
827        elc.TPUEmbeddingConfiguration.DIV_DEFAULT
828        if self._partition_strategy == 'div' else
829        elc.TPUEmbeddingConfiguration.MOD)
830    config_proto.pipeline_execution_with_tensor_core = (
831        self._pipeline_execution_with_tensor_core)
832
833    return config_proto
834
835  def create_variables_and_ops(self, embedding_variable_name_by_table=None,
836                               slot_variable_names_by_table=None):
837    """Create embedding and slot variables, with ops to load and retrieve them.
838
839    N.B.: the retrieve embedding variables (including slot variables) ops are
840    returned as lambda fn, as the call side might want to impose control
841    dependencies between the TPU computation and retrieving actions. For
842    example, the following code snippet ensures the TPU computation finishes
843    first, and then we pull the variables back from TPU to CPU.
844
845    ```
846    updates_ops = []
847    with ops.control_dependencies([loss]):
848      for op_fn in retrieve_parameters_op_fns:
849        update_ops.append(op_fn())
850    ```
851
852    Args:
853      embedding_variable_name_by_table: A dictionary mapping from string of
854        table name to string of embedding variable name. If `None`,
855        defaults from `get_default_slot_variable_names()` will be used.
856      slot_variable_names_by_table: A dictionary mapping from string of table
857        name to `AdamSlotVariableNames`, `AdagradSlotVariableNames` etc. If
858        `None`, defaults from `get_default_slot_variable_names()` will be used.
859
860    Returns:
861      `tpu_embedding.VariablesAndOps` with:
862        A dictionary mapping from string of table name to embedding variables,
863        A dictionary mapping from string of table name to AdagradSlotVariable,
864         AdamSlotVariables etc with slot variables,
865        A function which returns a list of ops to load embedding and slot
866         variables from CPU to TPU.
867        A function which returns a list of ops to retrieve embedding and slot
868         variables from TPU to CPU.
869    """
870    embedding_variables_by_table = {}
871    slot_variables_by_table = {}
872    load_op_fns = []
873    retrieve_op_fns = []
874
875    for i, table in enumerate(self._table_to_config_dict):
876      if embedding_variable_name_by_table:
877        embedding_variable_name = embedding_variable_name_by_table[table]
878      else:
879        embedding_variable_name = table
880      if slot_variable_names_by_table:
881        slot_variable_names = slot_variable_names_by_table[table]
882      else:
883        slot_variable_names = (
884            self._optimizer_handler.get_default_slot_variable_names(table))
885
886      # TODO(b/139144091): Multi-host support for mid-level API in
887      #  eager context (TF 2.0)
888      # Workaround below allows single-host use case in TF 2.0
889      if context.executing_eagerly():
890        device = ''
891      else:
892        device = _create_device_fn(self._hosts)
893
894      with ops.device(device):
895        table_variables = _create_partitioned_variables(
896            name=embedding_variable_name,
897            num_hosts=self._num_hosts,
898            vocabulary_size=self._table_to_config_dict[table].vocabulary_size,
899            embedding_dimension=self._table_to_config_dict[table].dimension,
900            initializer=self._table_to_config_dict[table].initializer,
901            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
902        embedding_variables_by_table[table] = table_variables
903
904        # Only loads embedding config to load/retrieve nodes for the first table
905        # on the first host, other nodes would use config from the first node.
906        config = None if i else self.config_proto.SerializeToString()
907        slot_variables_for_table, load_ops_fn, retrieve_ops_fn = (
908            self._optimizer_handler.create_variables_and_ops(
909                table, slot_variable_names, self._num_hosts,
910                self._table_to_config_dict[table], table_variables, config))
911        slot_variables_by_table[table] = slot_variables_for_table
912        load_op_fns.append(load_ops_fn)
913        retrieve_op_fns.append(retrieve_ops_fn)
914
915    def load_ops():
916      """Calls and returns the load ops for each embedding table.
917
918      Returns:
919        A list of ops to load embedding and slot variables from CPU to TPU.
920      """
921      load_ops_list = []
922      for load_op_fn in load_op_fns:
923        load_ops_list.extend(load_op_fn())
924      return load_ops_list
925
926    def retrieve_ops():
927      """Calls and returns the retrieve ops for each embedding table.
928
929      Returns:
930        A list of ops to retrieve embedding and slot variables from TPU to CPU.
931      """
932      retrieve_ops_list = []
933      for retrieve_op_fn in retrieve_op_fns:
934        retrieve_ops_list.extend(retrieve_op_fn())
935      return retrieve_ops_list
936
937    return VariablesAndOps(embedding_variables_by_table,
938                           slot_variables_by_table,
939                           load_ops, retrieve_ops)
940
941  def generate_enqueue_ops(self, enqueue_datas_list, mode_override=None):
942    """Generate enqueue ops.
943
944    Args:
945      enqueue_datas_list: a list of dictionary mapping from string
946        of feature names to EnqueueData. Each dictionary is for one
947        TPU core. Dictionaries for the same host should be contiguous
948        on the list.
949      mode_override: A string input that overrides the mode specified in the
950        TPUEmbeddingConfiguration. Supported values are {'unspecified',
951        'inference', 'training', 'backward_pass_only'}. When set to
952        'unspecified', the mode set in TPUEmbeddingConfiguration is used,
953        otherwise mode_override is used (optional).
954
955    Returns:
956      Ops to enqueue to TPU for embedding.
957    """
958    self._validate_generate_enqueue_ops_enqueue_datas_list(enqueue_datas_list)
959    return [
960        self._generate_enqueue_op(  # pylint: disable=g-complex-comprehension
961            enqueue_datas,
962            device_ordinal=i % self._num_cores_per_host,
963            mode_override=mode_override,
964        ) for i, enqueue_datas in enumerate(enqueue_datas_list)
965    ]
966
967  def _validate_generate_enqueue_ops_enqueue_datas_list(self,
968                                                        enqueue_datas_list):
969    """Validate `enqueue_datas_list`."""
970    feature_set = set(self._feature_to_config_dict.keys())
971    contiguous_device = None
972    for i, enqueue_datas in enumerate(enqueue_datas_list):
973      used_feature_set = set(enqueue_datas.keys())
974
975      # Check features are valid.
976      missing_feature_set = feature_set - used_feature_set
977      if missing_feature_set:
978        raise ValueError('`enqueue_datas_list[{}]` misses a feature that is '
979                         'in `feature_to_config_dict`: {}.'.format(
980                             i, missing_feature_set))
981
982      extra_feature_set = used_feature_set - feature_set
983      if extra_feature_set:
984        raise ValueError('`enqueue_datas_list[{}]` has a feature that is not '
985                         'in `feature_to_config_dict`: {}.'.format(
986                             i, extra_feature_set))
987
988      device = None
989      device_feature = None
990      for feature, enqueue_data in six.iteritems(enqueue_datas):
991        combiner = self._table_to_config_dict[
992            self._feature_to_config_dict[feature].table_id].combiner
993        if not isinstance(enqueue_data, EnqueueData):
994          raise ValueError('`enqueue_datas_list[{}]` has a feature that is '
995                           'not mapped to `EnqueueData`. `feature`: {}'.format(
996                               i, feature))
997
998        if enqueue_data.sample_indices is None and combiner:
999          logging.warn('No sample indices set for features %f table %f but '
1000                       'combiner is set to %s.', feature,
1001                       self._feature_to_config_dict[feature].table_id, combiner)
1002
1003        if (enqueue_data.sample_indices is not None and
1004            enqueue_data.sample_indices.device !=
1005            enqueue_data.embedding_indices.device):
1006          raise ValueError(
1007              'Device of sample_indices does not agree with '
1008              'that of emebdding_indices for feature {}.'.format(feature))
1009        if (enqueue_data.aggregation_weights is not None and
1010            enqueue_data.aggregation_weights.device !=
1011            enqueue_data.embedding_indices.device):
1012          raise ValueError(
1013              'Device of aggregation_weights does not agree with '
1014              'that of emebdding_indices for feature {}.'.format(feature))
1015        # Check all features are on the same device.
1016        if device is None:
1017          device = enqueue_data.embedding_indices.device
1018          device_feature = feature
1019        else:
1020          if device != enqueue_data.embedding_indices.device:
1021            raise ValueError('Devices are different between features in '
1022                             '`enqueue_datas_list[{}]`; '
1023                             'devices: {}, {}; features: {}, {}.'.format(
1024                                 i, device,
1025                                 enqueue_data.embedding_indices.device, feature,
1026                                 device_feature))
1027
1028      if i % self._num_cores_per_host:
1029        if device != contiguous_device:
1030          raise ValueError('We expect the `enqueue_datas` which are on the '
1031                           'same host to be contiguous in '
1032                           '`enqueue_datas_list`, '
1033                           '`enqueue_datas_list[{}]` is on device {}, '
1034                           'but is expected to be on device {}.'.format(
1035                               i, device, contiguous_device))
1036      else:
1037        contiguous_device = device
1038
1039  def _generate_enqueue_op(
1040      self, enqueue_datas, device_ordinal, mode_override=None):
1041    enqueue_data0 = list(enqueue_datas.values())[0]
1042    with ops.colocate_with(enqueue_data0.embedding_indices):
1043      return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
1044          device_ordinal=device_ordinal,
1045          combiners=self._combiners,
1046          mode_override=mode_override,
1047          **self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas)
1048      )
1049
1050  def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas):
1051    """Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`.
1052
1053    Args:
1054      enqueue_datas: a `Dict` of tensors for embedding. Can be sparse or
1055      dense.
1056
1057    Returns:
1058      Dict of arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`.
1059    """
1060    kwargs = {
1061        'sample_indices': [],
1062        'embedding_indices': [],
1063        'aggregation_weights': [],
1064        'table_ids': [],
1065        'max_sequence_lengths': [],
1066    }
1067    for table_id, table in enumerate(self._table_to_features_dict):
1068      features = self._table_to_features_dict[table]
1069      for feature in features:
1070        enqueue_data = enqueue_datas[feature]
1071
1072        kwargs['sample_indices'].append(
1073            enqueue_data.sample_indices
1074            if enqueue_data.sample_indices is not None else array_ops.zeros(
1075                (0,), dtype=dtypes.int64))
1076
1077        kwargs['aggregation_weights'].append(
1078            enqueue_data.aggregation_weights if
1079            enqueue_data.aggregation_weights is not None else array_ops.zeros(
1080                (0,), dtype=dtypes.float32))
1081
1082        kwargs['embedding_indices'].append(enqueue_data.embedding_indices)
1083
1084        kwargs['table_ids'].append(table_id)
1085        kwargs['max_sequence_lengths'].append(
1086            self._feature_to_config_dict[feature].max_sequence_length)
1087
1088    return kwargs
1089
1090  def get_activations(self):
1091    """Get activations for features.
1092
1093    This should be called within `computation` that is passed to
1094      `tpu.replicate` and friends.
1095
1096    Returns:
1097      A dictionary mapping from `String` of feature name to `Tensor`
1098        of activation.
1099    """
1100    recv_activations = tpu_ops.recv_tpu_embedding_activations(
1101        num_outputs=len(self._table_to_config_dict),
1102        config=self._config_proto.SerializeToString())
1103
1104    activations = collections.OrderedDict()
1105    for table_id, table in enumerate(self._table_to_features_dict):
1106      features = self._table_to_features_dict[table]
1107      num_features = self._table_to_num_features_dict[table]
1108      feature_index = 0
1109      table_activations = array_ops.reshape(
1110          recv_activations[table_id],
1111          [self.batch_size_per_core, num_features, -1])
1112      for feature in features:
1113        seq_length = self._feature_to_config_dict[feature].max_sequence_length
1114        if not seq_length:
1115          activations[feature] = table_activations[:, feature_index, :]
1116          feature_index = feature_index + 1
1117        else:
1118          activations[feature] = (
1119              table_activations[:, feature_index:(feature_index+seq_length), :])
1120          feature_index = feature_index + seq_length
1121
1122    return activations
1123
1124  def generate_send_gradients_op(self,
1125                                 feature_to_gradient_dict,
1126                                 step=None):
1127    """Send gradient to TPU embedding.
1128
1129    Args:
1130      feature_to_gradient_dict: dict mapping feature names to gradient wrt
1131        activations.
1132      step: the current global step, used for dynamic learning rate.
1133
1134    Returns:
1135      SendTPUEmbeddingGradients Op.
1136
1137    Raises:
1138      RuntimeError: If `mode` is not `TRAINING`.
1139    """
1140    if self._mode != TRAINING:
1141      raise RuntimeError('Only in training mode gradients need to '
1142                         'be sent to TPU embedding; got mode {}.'
1143                         .format(self._mode))
1144    if step is None and self._learning_rate_fn:
1145      raise ValueError('There are dynamic learning rates but step is None.')
1146
1147    gradients = []
1148    for table in self._table_to_features_dict:
1149      features = self._table_to_features_dict[table]
1150      table_gradients = []
1151      for feature in features:
1152        gradient = feature_to_gradient_dict[feature]
1153        # Expand dims for non-sequence feature to match sequence features.
1154        if gradient.shape.ndims == 2:
1155          gradient = array_ops.expand_dims(gradient, 1)
1156        table_gradients.append(gradient)
1157      interleaved_table_grads = array_ops.reshape(
1158          array_ops.concat(table_gradients, axis=1),
1159          [-1, array_ops.shape(table_gradients[0])[-1]])
1160      gradients.append(interleaved_table_grads)
1161
1162    return tpu_ops.send_tpu_embedding_gradients(
1163        inputs=gradients,
1164        learning_rates=[math_ops.cast(fn(step), dtype=dtypes.float32)
1165                        for fn in self._learning_rate_fn],
1166        config=self.config_proto.SerializeToString())
1167
1168
1169def _validate_table_to_config_dict(table_to_config_dict):
1170  """Validate `table_to_config_dict`."""
1171  for k, v in six.iteritems(table_to_config_dict):
1172    if not isinstance(v, TableConfig):
1173      raise ValueError('Value of `table_to_config_dict` must be of type '
1174                       '`TableConfig`, got {} for {}.'.format(type(v), k))
1175
1176
1177def _validate_feature_to_config_dict(table_to_config_dict,
1178                                     feature_to_config_dict):
1179  """Validate `feature_to_config_dict`."""
1180  used_table_set = set([feature.table_id
1181                        for feature in feature_to_config_dict.values()])
1182  table_set = set(table_to_config_dict.keys())
1183
1184  unused_table_set = table_set - used_table_set
1185  if unused_table_set:
1186    raise ValueError('`table_to_config_dict` specifies table that is not '
1187                     'used in `feature_to_config_dict`: {}.'
1188                     .format(unused_table_set))
1189
1190  extra_table_set = used_table_set - table_set
1191  if extra_table_set:
1192    raise ValueError('`feature_to_config_dict` refers to a table that is not '
1193                     'specified in `table_to_config_dict`: {}.'
1194                     .format(extra_table_set))
1195
1196
1197def _validate_batch_size(batch_size, num_cores):
1198  if batch_size % num_cores:
1199    raise ValueError('`batch_size` is not a multiple of number of '
1200                     'cores. `batch_size`={}, `_num_cores`={}.'.format(
1201                         batch_size, num_cores))
1202
1203
1204def _validate_optimization_parameters(optimization_parameters):
1205  if not isinstance(optimization_parameters, _OptimizationParameters):
1206    raise ValueError('`optimization_parameters` must inherit from '
1207                     '`_OptimizationPramaters`. '
1208                     '`type(optimization_parameters)`={}'.format(
1209                         type(optimization_parameters)))
1210
1211
1212class _OptimizerHandler(object):
1213  """Interface class for handling optimizer specific logic."""
1214
1215  def __init__(self, optimization_parameters):
1216    self._optimization_parameters = optimization_parameters
1217
1218  def set_optimization_parameters(self, table_descriptor):
1219    raise NotImplementedError()
1220
1221  def get_default_slot_variable_names(self, table):
1222    raise NotImplementedError()
1223
1224  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
1225                               table_config, table_variables, config_proto):
1226    raise NotImplementedError()
1227
1228
1229class _AdagradHandler(_OptimizerHandler):
1230  """Handles Adagrad specific logic."""
1231
1232  def __init__(self, optimization_parameters):
1233    super(_AdagradHandler, self).__init__(optimization_parameters)
1234    self._table_to_accumulator_variables_dict = {}
1235
1236  def set_optimization_parameters(self, table_descriptor):
1237    table_descriptor.optimization_parameters.adagrad.SetInParent()
1238
1239  def get_default_slot_variable_names(self, table):
1240    return AdagradSlotVariableName('{}/{}'.format(table, 'Adagrad'))
1241
1242  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
1243                               table_config, table_variables, config_proto):
1244    accumulator_initializer = init_ops.constant_initializer(
1245        self._optimization_parameters.initial_accumulator)
1246    accumulator_variables = _create_partitioned_variables(
1247        name=slot_variable_names.accumulator,
1248        num_hosts=num_hosts,
1249        vocabulary_size=table_config.vocabulary_size,
1250        embedding_dimension=table_config.dimension,
1251        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
1252        initializer=accumulator_initializer)
1253    slot_variables = AdagradSlotVariable(accumulator_variables)
1254
1255    def load_ops_fn():
1256      """Returns the retrieve ops for AdaGrad embedding tables.
1257
1258      Returns:
1259        A list of ops to load embedding and slot variables from CPU to TPU.
1260      """
1261      config = config_proto
1262      load_op_list = []
1263      for host_id, table_variable, accumulator_variable in zip(
1264          range(num_hosts), table_variables, accumulator_variables):
1265        with ops.colocate_with(table_variable):
1266          load_parameters_op = (
1267              tpu_ops.load_tpu_embedding_adagrad_parameters(
1268                  parameters=table_variable,
1269                  accumulators=accumulator_variable,
1270                  table_name=table,
1271                  num_shards=num_hosts,
1272                  shard_id=host_id,
1273                  config=config))
1274        config = None
1275        load_op_list.append(load_parameters_op)
1276      return load_op_list
1277
1278    def retrieve_ops_fn():
1279      """Returns the retrieve ops for AdaGrad embedding tables.
1280
1281      Returns:
1282        A list of ops to retrieve embedding and slot variables from TPU to CPU.
1283      """
1284      config = config_proto
1285      retrieve_op_list = []
1286      for host_id, table_variable, accumulator_variable in (zip(
1287          range(num_hosts), table_variables, accumulator_variables)):
1288        with ops.colocate_with(table_variable):
1289          retrieved_table, retrieved_accumulator = (
1290              tpu_ops.retrieve_tpu_embedding_adagrad_parameters(
1291                  table_name=table,
1292                  num_shards=num_hosts,
1293                  shard_id=host_id,
1294                  config=config))
1295          retrieve_parameters_op = control_flow_ops.group(
1296              state_ops.assign(table_variable, retrieved_table),
1297              state_ops.assign(accumulator_variable, retrieved_accumulator))
1298        config = None
1299        retrieve_op_list.append(retrieve_parameters_op)
1300      return retrieve_op_list
1301
1302    return slot_variables, load_ops_fn, retrieve_ops_fn
1303
1304
1305class _AdamHandler(_OptimizerHandler):
1306  """Handles Adam specific logic."""
1307
1308  def __init__(self, optimization_parameters):
1309    super(_AdamHandler, self).__init__(optimization_parameters)
1310    self._table_to_m_variables_dict = {}
1311    self._table_to_v_variables_dict = {}
1312
1313  def set_optimization_parameters(self, table_descriptor):
1314    table_descriptor.optimization_parameters.adam.beta1 = (
1315        self._optimization_parameters.beta1)
1316    table_descriptor.optimization_parameters.adam.beta2 = (
1317        self._optimization_parameters.beta2)
1318    table_descriptor.optimization_parameters.adam.epsilon = (
1319        self._optimization_parameters.epsilon)
1320    table_descriptor.optimization_parameters.adam.use_non_lazy_adam = (
1321        not self._optimization_parameters.lazy_adam)
1322    table_descriptor.optimization_parameters.adam.use_sum_inside_sqrt = (
1323        self._optimization_parameters.sum_inside_sqrt)
1324
1325  def get_default_slot_variable_names(self, table):
1326    return AdamSlotVariableNames('{}/{}/m'.format(table, 'Adam'),
1327                                 '{}/{}/v'.format(table, 'Adam'))
1328
1329  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
1330                               table_config, table_variables, config_proto):
1331    m_initializer = init_ops.zeros_initializer()
1332    m_variables = _create_partitioned_variables(
1333        name=slot_variable_names.m,
1334        num_hosts=num_hosts,
1335        vocabulary_size=table_config.vocabulary_size,
1336        embedding_dimension=table_config.dimension,
1337        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
1338        initializer=m_initializer)
1339    v_initializer = init_ops.zeros_initializer()
1340    v_variables = _create_partitioned_variables(
1341        name=slot_variable_names.v,
1342        num_hosts=num_hosts,
1343        vocabulary_size=table_config.vocabulary_size,
1344        embedding_dimension=table_config.dimension,
1345        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
1346        initializer=v_initializer)
1347    slot_variables = AdamSlotVariables(m_variables, v_variables)
1348
1349    def load_ops_fn():
1350      """Returns the retrieve ops for AdaGrad embedding tables.
1351
1352      Returns:
1353        A list of ops to load embedding and slot variables from CPU to TPU.
1354      """
1355      load_op_list = []
1356      config = config_proto
1357      for host_id, table_variable, m_variable, v_variable in (zip(
1358          range(num_hosts), table_variables,
1359          m_variables, v_variables)):
1360        with ops.colocate_with(table_variable):
1361          load_parameters_op = (
1362              tpu_ops.load_tpu_embedding_adam_parameters(
1363                  parameters=table_variable,
1364                  momenta=m_variable,
1365                  velocities=v_variable,
1366                  table_name=table,
1367                  num_shards=num_hosts,
1368                  shard_id=host_id,
1369                  config=config))
1370        # Set config to None to enforce that config is only loaded to the first
1371        # table.
1372        config = None
1373        load_op_list.append(load_parameters_op)
1374      return load_op_list
1375
1376    def retrieve_ops_fn():
1377      """Returns the retrieve ops for Adam embedding tables.
1378
1379      Returns:
1380        A list of ops to retrieve embedding and slot variables from TPU to CPU.
1381      """
1382      retrieve_op_list = []
1383      config = config_proto
1384      for host_id, table_variable, m_variable, v_variable in (zip(
1385          range(num_hosts), table_variables,
1386          m_variables, v_variables)):
1387        with ops.colocate_with(table_variable):
1388          retrieved_table, retrieved_m, retrieved_v = (
1389              tpu_ops.retrieve_tpu_embedding_adam_parameters(
1390                  table_name=table,
1391                  num_shards=num_hosts,
1392                  shard_id=host_id,
1393                  config=config))
1394          retrieve_parameters_op = control_flow_ops.group(
1395              state_ops.assign(table_variable, retrieved_table),
1396              state_ops.assign(m_variable, retrieved_m),
1397              state_ops.assign(v_variable, retrieved_v))
1398        config = None
1399        retrieve_op_list.append(retrieve_parameters_op)
1400      return retrieve_op_list
1401
1402    return slot_variables, load_ops_fn, retrieve_ops_fn
1403
1404
1405class _FtrlHandler(_OptimizerHandler):
1406  """Handles Ftrl specific logic."""
1407
1408  def __init__(self, optimization_parameters):
1409    super(_FtrlHandler, self).__init__(optimization_parameters)
1410    self._table_to_accumulator_variables_dict = {}
1411    self._table_to_linear_variables_dict = {}
1412
1413  def set_optimization_parameters(self, table_descriptor):
1414    table_descriptor.optimization_parameters.ftrl.lr_power = (
1415        self._optimization_parameters.learning_rate_power)
1416    table_descriptor.optimization_parameters.ftrl.l1 = (
1417        self._optimization_parameters.l1_regularization_strength)
1418    table_descriptor.optimization_parameters.ftrl.l2 = (
1419        self._optimization_parameters.l2_regularization_strength)
1420    table_descriptor.optimization_parameters.ftrl.initial_accum = (
1421        self._optimization_parameters.initial_accumulator_value)
1422    table_descriptor.optimization_parameters.ftrl.initial_linear = (
1423        self._optimization_parameters.initial_linear_value)
1424
1425  def get_default_slot_variable_names(self, table):
1426    # These match the default slot variable names created by
1427    # tf.train.FtrlOptimizer.
1428    return FtrlSlotVariableName('{}/{}'.format(table, 'Ftrl'),  # accumulator
1429                                '{}/{}'.format(table, 'Ftrl_1'))  # linear
1430
1431  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
1432                               table_config, table_variables, config_proto):
1433    accumulator_initializer = init_ops.constant_initializer(
1434        self._optimization_parameters.initial_accumulator_value)
1435    accumulator_variables = _create_partitioned_variables(
1436        name=slot_variable_names.accumulator,
1437        num_hosts=num_hosts,
1438        vocabulary_size=table_config.vocabulary_size,
1439        embedding_dimension=table_config.dimension,
1440        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
1441        initializer=accumulator_initializer)
1442    linear_initializer = init_ops.constant_initializer(
1443        self._optimization_parameters.initial_linear_value)
1444    linear_variables = _create_partitioned_variables(
1445        name=slot_variable_names.linear,
1446        num_hosts=num_hosts,
1447        vocabulary_size=table_config.vocabulary_size,
1448        embedding_dimension=table_config.dimension,
1449        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
1450        initializer=linear_initializer)
1451    slot_variables = FtrlSlotVariable(accumulator_variables,
1452                                      linear_variables)
1453
1454    def load_ops_fn():
1455      """Returns the retrieve ops for Ftrl embedding tables.
1456
1457      Returns:
1458        A list of ops to load embedding and slot variables from CPU to TPU.
1459      """
1460      config = config_proto
1461      load_op_list = []
1462      for host_id, table_variable, accumulator_variable, linear_variable in zip(
1463          range(num_hosts), table_variables, accumulator_variables,
1464          linear_variables):
1465        with ops.colocate_with(table_variable):
1466          load_parameters_op = (
1467              tpu_ops.load_tpu_embedding_ftrl_parameters(
1468                  parameters=table_variable,
1469                  accumulators=accumulator_variable,
1470                  linears=linear_variable,
1471                  table_name=table,
1472                  num_shards=num_hosts,
1473                  shard_id=host_id,
1474                  config=config))
1475        config = None
1476        load_op_list.append(load_parameters_op)
1477      return load_op_list
1478
1479    def retrieve_ops_fn():
1480      """Returns the retrieve ops for Ftrl embedding tables.
1481
1482      Returns:
1483        A list of ops to retrieve embedding and slot variables from TPU to CPU.
1484      """
1485      config = config_proto
1486      retrieve_op_list = []
1487      for host_id, table_variable, accumulator_variable, linear_variable in zip(
1488          range(num_hosts), table_variables, accumulator_variables,
1489          linear_variables):
1490        with ops.colocate_with(table_variable):
1491          retrieved_table, retrieved_accumulator, retrieved_linear = (
1492              tpu_ops.retrieve_tpu_embedding_ftrl_parameters(
1493                  table_name=table,
1494                  num_shards=num_hosts,
1495                  shard_id=host_id,
1496                  config=config))
1497          retrieve_parameters_op = control_flow_ops.group(
1498              state_ops.assign(table_variable, retrieved_table),
1499              state_ops.assign(accumulator_variable, retrieved_accumulator),
1500              state_ops.assign(linear_variable, retrieved_linear))
1501        config = None
1502        retrieve_op_list.append(retrieve_parameters_op)
1503      return retrieve_op_list
1504
1505    return slot_variables, load_ops_fn, retrieve_ops_fn
1506
1507
1508class _StochasticGradientDescentHandler(_OptimizerHandler):
1509  """Handles stochastic gradient descent specific logic."""
1510
1511  def set_optimization_parameters(self, table_descriptor):
1512    (table_descriptor.optimization_parameters.stochastic_gradient_descent
1513     .SetInParent())
1514
1515  def get_default_slot_variable_names(self, table):
1516    return None
1517
1518  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
1519                               table_config, table_variables, config_proto):
1520    del table_config
1521
1522    def load_ops_fn():
1523      """Returns the retrieve ops for AdaGrad embedding tables.
1524
1525      Returns:
1526        A list of ops to load embedding and slot variables from CPU to TPU.
1527      """
1528      load_op_list = []
1529      config = config_proto
1530      for host_id, table_variable in (zip(
1531          range(num_hosts), table_variables)):
1532        with ops.colocate_with(table_variable):
1533          load_parameters_op = (
1534              tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters(
1535                  parameters=table_variable,
1536                  table_name=table,
1537                  num_shards=num_hosts,
1538                  shard_id=host_id,
1539                  config=config))
1540        config = None
1541        load_op_list.append(load_parameters_op)
1542      return load_op_list
1543
1544    def retrieve_ops_fn():
1545      """Returns the retrieve ops for SGD embedding tables.
1546
1547      Returns:
1548        A list of ops to retrieve embedding and slot variables from TPU to CPU.
1549      """
1550      retrieve_op_list = []
1551      config = config_proto
1552      for host_id, table_variable in (zip(
1553          range(num_hosts), table_variables)):
1554        with ops.colocate_with(table_variable):
1555          retrieved_table = (
1556              tpu_ops
1557              .retrieve_tpu_embedding_stochastic_gradient_descent_parameters(
1558                  table_name=table,
1559                  num_shards=num_hosts,
1560                  shard_id=host_id,
1561                  config=config))
1562          retrieve_parameters_op = control_flow_ops.group(
1563              state_ops.assign(table_variable, retrieved_table))
1564        config = None
1565        retrieve_op_list.append(retrieve_parameters_op)
1566      return retrieve_op_list
1567
1568    return None, load_ops_fn, retrieve_ops_fn
1569
1570
1571def _get_optimization_handler(optimization_parameters):
1572  """Gets the optimization handler given the parameter type."""
1573  if isinstance(optimization_parameters, AdagradParameters):
1574    return _AdagradHandler(optimization_parameters)
1575  elif isinstance(optimization_parameters, AdamParameters):
1576    return _AdamHandler(optimization_parameters)
1577  elif isinstance(optimization_parameters, FtrlParameters):
1578    return _FtrlHandler(optimization_parameters)
1579  elif isinstance(optimization_parameters, StochasticGradientDescentParameters):
1580    return _StochasticGradientDescentHandler(optimization_parameters)
1581  else:
1582    return NotImplementedError()
1583
1584
1585def _create_ordered_dict(d):
1586  """Create an OrderedDict from Dict."""
1587  return collections.OrderedDict((k, d[k]) for k in sorted(d))
1588
1589
1590def _create_combiners(table_to_config_dict, table_to_features_dict):
1591  """Create a per feature list of combiners, ordered by table."""
1592  combiners = []
1593  for table in table_to_config_dict:
1594    combiner = table_to_config_dict[table].combiner or 'sum'
1595    combiners.extend([combiner] * len(table_to_features_dict[table]))
1596  return combiners
1597
1598
1599def _create_table_to_features_and_num_features_dicts(feature_to_config_dict):
1600  """Create mapping from table to a list of its features."""
1601  table_to_features_dict_tmp = {}
1602  table_to_num_features_dict_tmp = {}
1603  for feature, feature_config in six.iteritems(feature_to_config_dict):
1604    if feature_config.table_id in table_to_features_dict_tmp:
1605      table_to_features_dict_tmp[feature_config.table_id].append(feature)
1606    else:
1607      table_to_features_dict_tmp[feature_config.table_id] = [feature]
1608      table_to_num_features_dict_tmp[feature_config.table_id] = 0
1609    if feature_config.max_sequence_length == 0:
1610      table_to_num_features_dict_tmp[feature_config.table_id] = (
1611          table_to_num_features_dict_tmp[feature_config.table_id] + 1)
1612    else:
1613      table_to_num_features_dict_tmp[feature_config.table_id] = (
1614          table_to_num_features_dict_tmp[feature_config.table_id] +
1615          feature_config.max_sequence_length)
1616
1617  table_to_features_dict = collections.OrderedDict()
1618  table_to_num_features_dict = collections.OrderedDict()
1619  for table in sorted(table_to_features_dict_tmp):
1620    table_to_features_dict[table] = sorted(table_to_features_dict_tmp[table])
1621    table_to_num_features_dict[table] = table_to_num_features_dict_tmp[table]
1622  return table_to_features_dict, table_to_num_features_dict
1623
1624
1625def _create_device_fn(hosts):
1626  """Create device_fn() to use with _create_partitioned_variables()."""
1627
1628  def device_fn(op):
1629    """Returns the `device` for `op`."""
1630    part_match = re.match(r'.*/part_(\d+)(/|$)', op.name)
1631    dummy_match = re.match(r'.*dummy_(\d+).*', op.name)
1632    if not part_match and not dummy_match:
1633      raise RuntimeError(
1634          'Internal Error: Expected {} to contain /part_* or dummy_*'.format(
1635              op.name))
1636
1637    if part_match:
1638      idx = int(part_match.group(1))
1639    else:
1640      idx = int(dummy_match.group(1))
1641
1642    device = hosts[idx]
1643    logging.debug('assigning {} to {}.', op, device)
1644    return device
1645
1646  return device_fn
1647
1648
1649def _create_partitioned_variables(name,
1650                                  num_hosts,
1651                                  vocabulary_size,
1652                                  embedding_dimension,
1653                                  initializer,
1654                                  collections=None):  # pylint: disable=redefined-outer-name
1655  """Creates ParitionedVariables based on `num_hosts` for `table`."""
1656
1657  num_slices = min(vocabulary_size, num_hosts)
1658
1659  var_list = list(
1660      variable_scope.get_variable(
1661          name,
1662          shape=(vocabulary_size, embedding_dimension),
1663          partitioner=partitioned_variables.fixed_size_partitioner(num_slices),
1664          dtype=dtypes.float32,
1665          initializer=initializer,
1666          collections=collections,
1667          trainable=False))
1668
1669  if vocabulary_size >= num_hosts:
1670    return var_list
1671
1672  # For padded part, define the dummy variable to be loaded into TPU system.
1673  for idx in range(num_hosts - vocabulary_size):
1674    var_list.append(
1675        variable_scope.get_variable(
1676            'dummy_{}_{}'.format(vocabulary_size + idx, name),
1677            shape=(1, embedding_dimension),
1678            dtype=dtypes.float32,
1679            initializer=initializer,
1680            collections=[ops.GraphKeys.LOCAL_VARIABLES],
1681            trainable=False))
1682
1683  return var_list
1684