• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU embedding APIs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import math
24import re
25import six
26
27from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
28from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as elc
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import init_ops
35from tensorflow.python.ops import partitioned_variables
36from tensorflow.python.ops import state_ops
37from tensorflow.python.ops import variable_scope
38from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
39from tensorflow.python.tpu.ops import tpu_ops
40
41TRAINING = elc.TPUEmbeddingConfiguration.TRAINING
42INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE
43
44
45class TableConfig(
46    collections.namedtuple(
47        'TableConfig',
48        ['vocabulary_size', 'dimension', 'initializer', 'combiner'])):
49  """Embedding table configuration."""
50
51  def __new__(cls,
52              vocabulary_size,
53              dimension,
54              initializer=None,
55              combiner='mean'):
56    """Embedding table configuration.
57
58    Args:
59      vocabulary_size: Number of vocabulary (/rows) in the table.
60      dimension: The embedding dimension.
61      initializer: A variable initializer function to be used in embedding
62        variable initialization. If not specified, defaults to
63        `tf.truncated_normal_initializer` with mean `0.0` and standard deviation
64        `1/sqrt(dimension)`.
65      combiner: A string specifying how to reduce if there are multiple entries
66        in a single row. Currently 'mean', 'sqrtn', 'sum' and None are
67        supported, with 'mean' the default. 'sqrtn' often achieves good
68        accuracy, in particular with bag-of-words columns. For more information,
69        see `tf.nn.embedding_lookup_sparse`. None is only valid for dense rather
70        than sparse tensors.
71
72    Returns:
73      `TableConfig`.
74
75    Raises:
76      ValueError: if `vocabulary_size` is not positive integer.
77      ValueError: if `dimension` is not positive integer.
78      ValueError: if `initializer` is specified and is not callable.
79      ValueError: if `combiner` is not supported.
80    """
81    if not isinstance(vocabulary_size, int) or vocabulary_size < 1:
82      raise ValueError('Invalid vocabulary_size {}.'.format(vocabulary_size))
83
84    if not isinstance(dimension, int) or dimension < 1:
85      raise ValueError('Invalid dimension {}.'.format(dimension))
86
87    if (initializer is not None) and (not callable(initializer)):
88      raise ValueError('initializer must be callable if specified.')
89    if initializer is None:
90      initializer = init_ops.truncated_normal_initializer(
91          mean=0.0, stddev=1 / math.sqrt(dimension))
92
93    if combiner not in ('mean', 'sum', 'sqrtn', None):
94      raise ValueError('Invalid combiner {}'.format(combiner))
95
96    return super(TableConfig, cls).__new__(cls, vocabulary_size, dimension,
97                                           initializer, combiner)
98
99
100AdamSlotVariableNames = collections.namedtuple(
101    'AdamSlotVariableNames', ['m', 'v'])
102
103AdagradSlotVariableName = collections.namedtuple(
104    'AdagradSlotVariableName', ['accumulator'])
105
106AdamSlotVariables = collections.namedtuple(
107    'AdamSlotVariables', ['m', 'v'])
108
109AdagradSlotVariable = collections.namedtuple(
110    'AdagradSlotVariable', ['accumulator'])
111
112VariablesAndOps = collections.namedtuple(
113    'VariablesAndOps',
114    ['embedding_variables_by_table', 'slot_variables_by_table',
115     'load_ops', 'retrieve_ops']
116)
117
118
119class _OptimizationParameters(object):
120  """Parameters common to all optimizations."""
121
122  def __init__(self, learning_rate, use_gradient_accumulation):
123    self.learning_rate = learning_rate
124    self.use_gradient_accumulation = use_gradient_accumulation
125
126
127class AdagradParameters(_OptimizationParameters):
128  """Optimization parameters for Adagrad."""
129
130  def __init__(self, learning_rate, initial_accumulator=0.1,
131               use_gradient_accumulation=True):
132    """Optimization parameters for Adagrad.
133
134    Args:
135      learning_rate: used for updating embedding table.
136      initial_accumulator: initial accumulator for Adagrad.
137      use_gradient_accumulation: setting this to `False` makes embedding
138        gradients calculation less accurate but faster. Please see
139        `optimization_parameters.proto` for details.
140        for details.
141    """
142    super(AdagradParameters, self).__init__(learning_rate,
143                                            use_gradient_accumulation)
144    if initial_accumulator <= 0:
145      raise ValueError('Adagrad initial_accumulator must be positive')
146    self.initial_accumulator = initial_accumulator
147
148
149class AdamParameters(_OptimizationParameters):
150  """Optimization parameters for Adam."""
151
152  def __init__(self, learning_rate,
153               beta1=0.9,
154               beta2=0.999,
155               epsilon=1e-08,
156               lazy_adam=True,
157               sum_inside_sqrt=True,
158               use_gradient_accumulation=True):
159    """Optimization parameters for Adam.
160
161    Args:
162      learning_rate: a floating point value. The learning rate.
163      beta1: A float value.
164        The exponential decay rate for the 1st moment estimates.
165      beta2: A float value.
166        The exponential decay rate for the 2nd moment estimates.
167      epsilon: A small constant for numerical stability.
168      lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster.
169        Please see `optimization_parameters.proto` for details.
170      sum_inside_sqrt: This improves training speed. Please see
171        `optimization_parameters.proto` for details.
172      use_gradient_accumulation: setting this to `False` makes embedding
173        gradients calculation less accurate but faster. Please see
174        `optimization_parameters.proto` for details.
175        for details.
176    """
177    super(AdamParameters, self).__init__(learning_rate,
178                                         use_gradient_accumulation)
179    if beta1 < 0. or beta1 >= 1.:
180      raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
181    if beta2 < 0. or beta2 >= 1.:
182      raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2))
183    if epsilon <= 0.:
184      raise ValueError('epsilon must be positive; got {}.'.format(epsilon))
185    if not use_gradient_accumulation and not lazy_adam:
186      raise ValueError(
187          'When disabling Lazy Adam, gradient accumulation must be used.')
188
189    self.beta1 = beta1
190    self.beta2 = beta2
191    self.epsilon = epsilon
192    self.lazy_adam = lazy_adam
193    self.sum_inside_sqrt = sum_inside_sqrt
194
195
196class StochasticGradientDescentParameters(_OptimizationParameters):
197  """Optimization parameters for stochastic gradient descent.
198
199  Args:
200    learning_rate: a floating point value. The learning rate.
201  """
202
203  def __init__(self, learning_rate):
204    super(StochasticGradientDescentParameters, self).__init__(
205        learning_rate, False)
206
207
208class TPUEmbedding(object):
209  """API for using TPU for embedding.
210
211    Example:
212    ```
213    table_config_user = tpu_embedding.TableConfig(
214        vocabulary_size=4, dimension=2,
215        initializer=initializer, combiner='mean')
216    table_to_config_dict = {'video': table_config_video,
217                          'user': table_config_user}
218    feature_to_table_dict = {'watched': 'video',
219                             'favorited': 'video',
220                             'friends': 'user'}
221    batch_size = 4
222    num_hosts = 1
223    optimization_parameters = tpu_embedding.AdagradParameters(1., 1.)
224    mode = tpu_embedding.TRAINING
225    embedding = tpu_embedding.TPUEmbedding(
226        table_to_config_dict, feature_to_table_dict,
227        batch_size, num_hosts, mode, optimization_parameters)
228
229    batch_size_per_core = embedding.batch_size_per_core
230    sparse_features_list = []
231    for host in hosts:
232      with ops.device(host):
233        for _ in range(embedding.num_cores_per_host):
234          sparse_features = {}
235          sparse_features['watched'] = sparse_tensor.SparseTensor(...)
236          sparse_features['favorited'] = sparse_tensor.SparseTensor(...)
237          sparse_features['friends'] = sparse_tensor.SparseTensor(...)
238          sparse_features_list.append(sparse_features)
239
240    enqueue_ops = embedding.generate_enqueue_ops(sparse_features_list)
241    embedding_variables_and_ops = embedding.create_variables_and_ops()
242
243    def computation():
244      activations = embedding.get_activations()
245      loss = compute_loss(activations)
246
247      base_optimizer = gradient_descent.GradientDescentOptimizer(
248          learning_rate=1)
249      cross_shard_optimizer = tpu_optimizer.CrossShardOptimizer(
250          base_optimizer)
251
252      train_op = cross_shard_optimizer.minimize(loss)
253      gradients = (
254          tpu_embedding_gradient.get_gradients_through_compute_gradients(
255              cross_shard_optimizer, loss, activations)
256      send_gradients_op = embedding.generate_send_gradients_op(gradients)
257      with ops.control_dependencies([train_op, send_gradients_op]):
258        loss = array_ops.identity(loss)
259
260    loss = tpu.shard(computation,
261                     num_shards=embedding.num_cores)
262
263    with self.test_session() as sess:
264      sess.run(tpu.initialize_system(embedding_config=
265                                     embedding.config_proto))
266      sess.run(variables.global_variables_initializer())
267      sess.run(embedding_variables_and_ops.load_ops())
268      sess.run(enqueue_ops)
269      loss_val = sess.run(loss)
270    ```
271  """
272
273  # TODO(shizhiw): Instead of `feature_to_table_dict` which maps to table
274  # name, consider `feature_to_config_dict` which maps to `FeatureConfig`.
275  # `FeatureConfig` could have fields other than table name. For example, it
276  # could have a field to indicate that the feature should not be used to
277  # update embedding table (cr/204852758, cr/204940540). Also, this can support
278  # different combiners for different features within the same table.
279  # TODO(shizhiw, b/118512626): Remove `batch_size` from `__init__` and move it
280  # to `FeatureConfig`?
281
282  # TODO(shizhiw): will it be cleaner to make `table_to_config_dict` and
283  # `feature_to_table_dict` lists of `TableSpec` and `FeatureSpec` respectively?
284
285  # TODO(shizhiw): Consider adding `input_fn` as an option to remove boilerplate
286  # for-loops around construction of inputs.
287
288  # `optimization_parameter` applies to all tables. If the need arises,
289  # we can add `optimization_parameters` to `TableConfig` to override this
290  # global setting.
291  def __init__(self,
292               table_to_config_dict,
293               feature_to_table_dict,
294               batch_size,
295               mode,
296               master,
297               optimization_parameters=None,
298               cluster_def=None,
299               pipeline_execution_with_tensor_core=True):
300    """API for using TPU for embedding lookups.
301
302    Args:
303      table_to_config_dict: A dictionary mapping from string of table name to
304        `TableConfig`. Table refers to an embedding table, e.g. `params`
305        argument to `tf.nn.embedding_lookup_sparse()`.
306      feature_to_table_dict: A dictionary mapping from string of feature name
307        to string of table name. Feature refers to ids to lookup in embedding
308        table, e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`.
309      batch_size: An `int` representing the global batch size.
310      mode: `TRAINING` or `INFERENCE`.
311      master: A `string` representing the TensorFlow master to use.
312      optimization_parameters: `AdagradParameters`, `AdamParameters`,
313        `Stochasticgradientdescentparameters`. Must be set in training and must
314        be `None` in inference.
315      cluster_def: A ClusterDef object describing the TPU cluster.
316      pipeline_execution_with_tensor_core: setting this to `True` makes training
317        faster, but trained model will be different if step N and step N+1
318        involve the same set of embedding ID. Please see
319        `tpu_embedding_configuration.proto` for details.
320
321    Raises:
322      ValueError: if any input is invalid.
323    """
324    _validate_table_to_config_dict(table_to_config_dict)
325    # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`.
326    self._table_to_config_dict = _create_ordered_dict(table_to_config_dict)
327
328    _validate_feature_to_table_dict(table_to_config_dict, feature_to_table_dict)
329    self._feature_to_table_dict = _create_ordered_dict(feature_to_table_dict)
330    self._table_to_features_dict = _create_table_to_features_dict(
331        self._feature_to_table_dict)
332    self._combiners = _create_combiners(self._table_to_config_dict,
333                                        self._table_to_features_dict)
334
335    self._batch_size = batch_size
336
337    self._master = master
338    self._cluster_def = cluster_def
339    self._tpu_system_metadata = (
340        tpu_system_metadata_lib._query_tpu_system_metadata(  # pylint: disable=protected-access
341            self._master, cluster_def=self._cluster_def))
342    if self._tpu_system_metadata.num_cores == 0:
343      raise ValueError('TPUEmbedding needs TPUs, but master {} does not have '
344                       'TPUs.'.format(self._master))
345    self._num_hosts = self._tpu_system_metadata.num_hosts
346    master_job_name = tpu_system_metadata_lib.master_job(self._master,
347                                                         self._cluster_def)
348    self._hosts = sorted([
349        device.name for device in self._tpu_system_metadata.devices
350        if 'device:CPU:' in device.name and (master_job_name is None or
351                                             master_job_name in device.name)])
352    self._num_cores_per_host = self._tpu_system_metadata.num_of_cores_per_host
353    self._num_cores = self._tpu_system_metadata.num_cores
354
355    _validate_batch_size(self._batch_size, self._num_cores)
356    self._batch_size_per_core = self._batch_size // self._num_cores
357
358    # TODO(shizhiw): remove `mode`?
359    if mode == TRAINING:
360      _validate_optimization_parameters(optimization_parameters)
361      self._optimization_parameters = optimization_parameters
362    elif mode == INFERENCE:
363      if optimization_parameters is not None:
364        raise ValueError('`optimization_parameters` should be `None` '
365                         'for inference mode.')
366      self._optimization_parameters = (
367          StochasticGradientDescentParameters(1.))
368    else:
369      raise ValueError('`mode` only supports {} and {}; got {}.'
370                       .format(TRAINING, INFERENCE, mode))
371    self._mode = mode
372
373    # TODO(shizhiw): move `optimization_parameters` into `_optimizer_handler`
374    # and create special handler for inference that inherits from
375    # StochasticGradientDescentHandler with more user-friendly error message
376    # on get_slot().
377    self._optimizer_handler = _get_optimization_handler(
378        self._optimization_parameters)
379    self._pipeline_execution_with_tensor_core = (
380        pipeline_execution_with_tensor_core)
381
382    self._config_proto = self._create_config_proto()
383
384  @property
385  def hosts(self):
386    """A list of device names for CPU hosts.
387
388    Returns:
389      A list of device names for CPU hosts.
390    """
391    return copy.copy(self._hosts)
392
393  # TODO(shizhiw): change to num_tensor_cores_per_host to be more explicit and
394  # to be consistent with `tpu_embedding_configuration.proto`.
395  @property
396  def num_cores_per_host(self):
397    """Number of TPU cores on a CPU host.
398
399    Returns:
400      Number of TPU cores on a CPU host.
401    """
402    return self._num_cores_per_host
403
404  @property
405  def num_cores(self):
406    """Total number of TPU cores on all hosts.
407
408    Returns:
409      Total number of TPU cores on all hosts.
410    """
411    return self._num_cores
412
413  @property
414  def batch_size_per_core(self):
415    """Batch size for each TPU core.
416
417    The sparse tensors in `sparse_features_list` to `generate_enqueue_ops`
418       must have batch dimension equal to this.
419
420    Returns:
421      Batch size for each TPU core.
422    """
423    return self._batch_size_per_core
424
425  @property
426  def config_proto(self):
427    """Create embedding config proto for `tpu.initialize_system()`.
428
429    Returns:
430      an `TPUEmbeddingConfiguration` proto describing the desired
431         configuration of the hardware embedding lookup tables, which
432         is passed to `tpu.initialize_system()`.
433    """
434    return self._config_proto
435
436  @property
437  def table_to_config_dict(self):
438    return copy.copy(self._table_to_config_dict)
439
440  @property
441  def feature_to_table_dict(self):
442    return copy.copy(self._feature_to_table_dict)
443
444  @property
445  def table_to_features_dict(self):
446    return copy.copy(self._table_to_features_dict)
447
448  @property
449  def optimization_parameters(self):
450    return self._optimization_parameters
451
452  def _create_config_proto(self):
453    """Create `TPUEmbeddingConfiguration`."""
454    config_proto = elc.TPUEmbeddingConfiguration()
455    for table in self._table_to_config_dict:
456      table_descriptor = config_proto.table_descriptor.add()
457      table_descriptor.name = table
458
459      table_config = self._table_to_config_dict[table]
460      table_descriptor.vocabulary_size = table_config.vocabulary_size
461      table_descriptor.dimension = table_config.dimension
462
463      features_for_table = self._table_to_features_dict[table]
464      table_descriptor.num_features = len(features_for_table)
465
466      table_descriptor.optimization_parameters.learning_rate.constant = (
467          self._optimization_parameters.learning_rate)
468      table_descriptor.optimization_parameters.gradient_accumulation_status = (
469          optimization_parameters_pb2.GradientAccumulationStatus.ENABLED
470          if self._optimization_parameters.use_gradient_accumulation else
471          optimization_parameters_pb2.GradientAccumulationStatus.DISABLED)
472      self._optimizer_handler.set_optimization_parameters(table_descriptor)
473
474    config_proto.mode = self._mode
475    config_proto.batch_size_per_tensor_core = self._batch_size_per_core
476    config_proto.num_hosts = self._num_hosts
477    config_proto.num_tensor_cores = self._num_cores
478    config_proto.sharding_strategy = elc.TPUEmbeddingConfiguration.DIV_DEFAULT
479    config_proto.pipeline_execution_with_tensor_core = (
480        self._pipeline_execution_with_tensor_core)
481
482    return config_proto
483
484  def create_variables_and_ops(self, embedding_variable_name_by_table=None,
485                               slot_variable_names_by_table=None):
486    """Create embedding and slot variables, with ops to load and retrieve them.
487
488    Args:
489      embedding_variable_name_by_table: A dictionary mapping from string of
490        table name to string of embedding variable name. If `None`,
491        defaults from `get_default_slot_variable_names()` will be used.
492      slot_variable_names_by_table: A dictionary mapping from string of table
493        name to `AdamSlotVariableNames`, `AdagradSlotVariableNames` etc. If
494        `None`, defaults from `get_default_slot_variable_names()` will be used.
495
496    Returns:
497      `tpu_embedding.VariablesAndOps` with:
498        A dictionary mapping from string of table name to embedding variables,
499        A dictionary mapping from string of table name to AdagradSlotVariable,
500         AdamSlotVariables etc with slot variables,
501        A function which returns a list of ops to load embedding and slot
502         variables from TPU to CPU.
503        A function which returns a list of ops to retrieve embedding and slot
504         variables from TPU to CPU.
505    """
506    embedding_variables_by_table = {}
507    slot_variables_by_table = {}
508    load_op_fns = []
509    retrieve_op_fns = []
510    for table in self._table_to_config_dict:
511      if embedding_variable_name_by_table:
512        embedding_variable_name = embedding_variable_name_by_table[table]
513      else:
514        embedding_variable_name = table
515      if slot_variable_names_by_table:
516        slot_variable_names = slot_variable_names_by_table[table]
517      else:
518        slot_variable_names = (
519            self._optimizer_handler.get_default_slot_variable_names(table))
520
521      device_fn = _create_device_fn(self._hosts)
522      with ops.device(device_fn):
523        table_variables = _create_partitioned_variables(
524            name=embedding_variable_name,
525            num_hosts=self._num_hosts,
526            vocabulary_size=self._table_to_config_dict[table].vocabulary_size,
527            embedding_dimension=self._table_to_config_dict[table].dimension,
528            initializer=self._table_to_config_dict[table].initializer,
529            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
530        embedding_variables_by_table[table] = table_variables
531
532        slot_variables_for_table, load_ops_fn, retrieve_ops_fn = (
533            self._optimizer_handler.create_variables_and_ops(
534                table, slot_variable_names, self._num_hosts,
535                self._table_to_config_dict[table], table_variables)
536        )
537        slot_variables_by_table[table] = slot_variables_for_table
538        load_op_fns.append(load_ops_fn)
539        retrieve_op_fns.append(retrieve_ops_fn)
540
541    def load_ops():
542      """Calls and returns the load ops for each embedding table.
543
544      Returns:
545        A list of ops to load embedding and slot variables from CPU to TPU.
546      """
547      load_ops_list = []
548      for load_op_fn in load_op_fns:
549        load_ops_list.extend(load_op_fn())
550      return load_ops_list
551
552    def retrieve_ops():
553      """Calls and returns the retrieve ops for each embedding table.
554
555      Returns:
556        A list of ops to retrieve embedding and slot variables from TPU to CPU.
557      """
558      retrieve_ops_list = []
559      for retrieve_op_fn in retrieve_op_fns:
560        retrieve_ops_list.extend(retrieve_op_fn())
561      return retrieve_ops_list
562
563    return VariablesAndOps(embedding_variables_by_table,
564                           slot_variables_by_table,
565                           load_ops, retrieve_ops)
566
567  def generate_enqueue_ops(self, sparse_features_list):
568    """Generate enqueue ops.
569
570    Args:
571      sparse_features_list: a list of dictionary mapping from string
572        of feature names to sparse tensor. Each dictionary is for one
573        TPU core. Dictionaries for the same host should be contiguous
574        on the list.
575
576    Returns:
577      Ops to enqueue to TPU for embedding.
578    """
579    self._validate_generate_enqueue_ops_sparse_features_list(
580        sparse_features_list)
581    return [
582        self._generate_enqueue_op(
583            sparse_features, device_ordinal=i % self._num_cores_per_host)
584        for i, sparse_features in enumerate(sparse_features_list)
585    ]
586
587  def _validate_generate_enqueue_ops_sparse_features_list(
588      self, sparse_features_list):
589    """Validate `sparse_features_list`."""
590    if len(sparse_features_list) != self._num_cores:
591      raise ValueError('Length of `sparse_features_list` should match the '
592                       'number of cores; '
593                       '`len(sparse_features_list)` is {}, '
594                       'number of cores is {}.'.format(
595                           len(sparse_features_list), self._num_cores))
596
597    feature_set = set(self._feature_to_table_dict.keys())
598    contiguous_device = None
599    for i, sparse_features in enumerate(sparse_features_list):
600      used_feature_set = set(sparse_features.keys())
601
602      # Check features are valid.
603      missing_feature_set = feature_set - used_feature_set
604      if missing_feature_set:
605        raise ValueError('`sparse_features_list[{}]` misses a feature that is '
606                         'in `feature_to_config_dict`: {}.'.format(
607                             i, missing_feature_set))
608
609      extra_feature_set = used_feature_set - feature_set
610      if extra_feature_set:
611        raise ValueError('`sparse_features_list[{}]` has a feature that is not '
612                         'in `feature_to_config_dict`: {}.'.format(
613                             i, extra_feature_set))
614
615      device = None
616      device_feature = None
617      for feature, tensor in six.iteritems(sparse_features):
618        combiner = self._table_to_config_dict[
619            self._feature_to_table_dict[feature]].combiner
620        if not isinstance(tensor, sparse_tensor.SparseTensor) and combiner:
621          raise ValueError('`sparse_features_list[{}]` has a feature that is '
622                           'not mapped to `SparseTensor` and has a combiner. '
623                           '`feature`: {}, combiner: {}'.format(
624                               i, feature, combiner))
625
626        # Check all features are on the same device.
627        if device is None:
628          device = tensor.op.device
629          device_feature = feature
630        else:
631          if device != tensor.op.device:
632            raise ValueError('Devices are different between features in '
633                             '`sparse_features_list[{}]`; '
634                             'devices: {}, {}; features: {}, {}.'.format(
635                                 i, device, tensor.op.device, feature,
636                                 device_feature))
637
638      if i % self._num_cores_per_host:
639        if device != contiguous_device:
640          raise ValueError('We expect the `sparse_features` which are on the '
641                           'same host to be contiguous in '
642                           '`sparse_features_list`, '
643                           '`sparse_features_list[{}]` is on device {}, '
644                           'but is expected to be on device {}.'.format(
645                               i, device, contiguous_device))
646      else:
647        contiguous_device = device
648
649  def _generate_enqueue_op(self, sparse_features, device_ordinal):
650    with ops.colocate_with(list(sparse_features.values())[0]):
651      sample_idcs, embedding_idcs, aggregation_weights, table_ids = (
652          self._format_for_tpu_embedding_sparse_tensor_batch(sparse_features))
653      return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
654          sample_idcs,
655          embedding_idcs,
656          aggregation_weights,
657          table_ids,
658          device_ordinal=device_ordinal,
659          combiners=self._combiners)
660
661  def _format_for_tpu_embedding_sparse_tensor_batch(self, sparse_features):
662    """Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`.
663
664    Args:
665      sparse_features: a `Dict` of tensors for embedding. Can be sparse or
666      dense.
667
668    Returns:
669      Arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`.
670    """
671
672    sample_idcs, embedding_idcs, aggregation_weights, table_ids = (
673        list(), list(), list(), list())
674    for table_id, table in enumerate(self._table_to_features_dict):
675      features = self._table_to_features_dict[table]
676      for feature in features:
677        tensor = sparse_features[feature]
678        if not isinstance(tensor, sparse_tensor.SparseTensor):
679          sample_idcs.append(array_ops.zeros([0], dtype=dtypes.int32))
680          embedding_idcs.append(tensor)
681        else:
682          sample_idcs.append(tensor.indices)
683          embedding_idcs.append(tensor.values)
684        aggregation_weights.append(array_ops.zeros([0]))
685        table_ids.append(table_id)
686
687    return sample_idcs, embedding_idcs, aggregation_weights, table_ids
688
689  def get_activations(self):
690    """Get activations for features.
691
692    This should be called within `computation` that is passed to
693      `tpu.replicate` and friends.
694
695    Returns:
696      A dictionary mapping from `String` of feature name to `Tensor`
697        of activation.
698    """
699    recv_activations = tpu_ops.recv_tpu_embedding_activations(
700        num_outputs=len(self._table_to_config_dict),
701        config=self._config_proto.SerializeToString())
702
703    activations = collections.OrderedDict()
704    for table_id, table in enumerate(self._table_to_features_dict):
705      features = self._table_to_features_dict[table]
706      for lookup_id, feature in enumerate(features):
707        stride = len(self._table_to_features_dict[table])
708        activations[feature] = recv_activations[table_id][lookup_id::stride, :]
709    return activations
710
711  def generate_send_gradients_op(self, feature_to_gradient_dict):
712    """Send gradient to TPU embedding.
713
714    Args:
715      feature_to_gradient_dict: dict mapping feature names to gradient wrt
716        activations.
717
718    Returns:
719      SendTPUEmbeddingGradients Op.
720
721    Raises:
722      RuntimeError: If `mode` is not `TRAINING`.
723    """
724    if self._mode != TRAINING:
725      raise RuntimeError('Only in training mode gradients need to '
726                         'be sent to TPU embedding; got mode {}.'
727                         .format(self._mode))
728    gradients = []
729    for table in self._table_to_features_dict:
730      features = self._table_to_features_dict[table]
731      table_gradients = [
732          feature_to_gradient_dict[feature] for feature in features
733      ]
734      interleaved_table_grads = array_ops.reshape(
735          array_ops.stack(table_gradients, axis=1),
736          [-1, table_gradients[0].shape[1]])
737      gradients.append(interleaved_table_grads)
738    return tpu_ops.send_tpu_embedding_gradients(
739        inputs=gradients, config=self.config_proto.SerializeToString())
740
741
742def _validate_table_to_config_dict(table_to_config_dict):
743  """Validate `table_to_config_dict`."""
744  for k, v in six.iteritems(table_to_config_dict):
745    if not isinstance(v, TableConfig):
746      raise ValueError('Value of `table_to_config_dict` must be of type '
747                       '`TableConfig`, got {} for {}.'.format(type(v), k))
748
749
750def _validate_feature_to_table_dict(table_to_config_dict,
751                                    feature_to_table_dict):
752  """Validate `feature_to_table_dict`."""
753  used_table_set = set(feature_to_table_dict.values())
754  table_set = set(table_to_config_dict.keys())
755
756  unused_table_set = table_set - used_table_set
757  if unused_table_set:
758    raise ValueError('`table_to_config_dict` specifies table that is not '
759                     'used in `feature_to_table_dict`: {}.'
760                     .format(unused_table_set))
761
762  extra_table_set = used_table_set - table_set
763  if extra_table_set:
764    raise ValueError('`feature_to_table_dict` refers to a table that is not '
765                     'specified in `table_to_config_dict`: {}.'
766                     .format(extra_table_set))
767
768
769def _validate_batch_size(batch_size, num_cores):
770  if batch_size % num_cores:
771    raise ValueError('`batch_size` is not a multiple of number of '
772                     'cores. `batch_size`={}, `_num_cores`={}.'.format(
773                         batch_size, num_cores))
774
775
776def _validate_optimization_parameters(optimization_parameters):
777  if not isinstance(optimization_parameters, _OptimizationParameters):
778    raise ValueError('`optimization_parameters` must inherit from '
779                     '`_OptimizationPramaters`. '
780                     '`type(optimization_parameters)`={}'.format(
781                         type(optimization_parameters)))
782
783
784class _OptimizerHandler(object):
785  """Interface class for handling optimizer specific logic."""
786
787  def __init__(self, optimization_parameters):
788    self._optimization_parameters = optimization_parameters
789
790  def set_optimization_parameters(self, table_descriptor):
791    raise NotImplementedError()
792
793  def get_default_slot_variable_names(self, table):
794    raise NotImplementedError()
795
796  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
797                               table_config, table_variables):
798    raise NotImplementedError()
799
800
801class _AdagradHandler(_OptimizerHandler):
802  """Handles Adagrad specific logic."""
803
804  def __init__(self, optimization_parameters):
805    super(_AdagradHandler, self).__init__(optimization_parameters)
806    self._table_to_accumulator_variables_dict = {}
807
808  def set_optimization_parameters(self, table_descriptor):
809    table_descriptor.optimization_parameters.adagrad.SetInParent()
810
811  def get_default_slot_variable_names(self, table):
812    return AdagradSlotVariableName('{}/{}'.format(table, 'Adagrad'))
813
814  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
815                               table_config, table_variables):
816    accumulator_initializer = init_ops.constant_initializer(
817        self._optimization_parameters.initial_accumulator)
818    accumulator_variables = _create_partitioned_variables(
819        name=slot_variable_names.accumulator,
820        num_hosts=num_hosts,
821        vocabulary_size=table_config.vocabulary_size,
822        embedding_dimension=table_config.dimension,
823        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
824        initializer=accumulator_initializer)
825    slot_variables = AdagradSlotVariable(accumulator_variables)
826
827    def load_ops_fn():
828      """Returns the retrieve ops for AdaGrad embedding tables.
829
830      Returns:
831        A list of ops to load embedding and slot variables from CPU to TPU.
832      """
833      load_op_list = []
834      for host_id, table_variable, accumulator_variable in (zip(
835          range(num_hosts), table_variables, accumulator_variables)):
836        with ops.colocate_with(table_variable):
837          load_parameters_op = (
838              tpu_ops.load_tpu_embedding_adagrad_parameters(
839                  parameters=table_variable,
840                  accumulators=accumulator_variable,
841                  table_name=table,
842                  num_shards=num_hosts,
843                  shard_id=host_id))
844        load_op_list.append(load_parameters_op)
845      return load_op_list
846
847    def retrieve_ops_fn():
848      """Returns the retrieve ops for AdaGrad embedding tables.
849
850      Returns:
851        A list of ops to retrieve embedding and slot variables from TPU to CPU.
852      """
853      retrieve_op_list = []
854      for host_id, table_variable, accumulator_variable in (zip(
855          range(num_hosts), table_variables, accumulator_variables)):
856        with ops.colocate_with(table_variable):
857          retrieved_table, retrieved_accumulator = (
858              tpu_ops.retrieve_tpu_embedding_adagrad_parameters(
859                  table_name=table,
860                  num_shards=num_hosts,
861                  shard_id=host_id))
862          retrieve_parameters_op = control_flow_ops.group(
863              state_ops.assign(table_variable, retrieved_table),
864              state_ops.assign(accumulator_variable, retrieved_accumulator))
865        retrieve_op_list.append(retrieve_parameters_op)
866      return retrieve_op_list
867
868    return slot_variables, load_ops_fn, retrieve_ops_fn
869
870
871class _AdamHandler(_OptimizerHandler):
872  """Handles Adam specific logic."""
873
874  def __init__(self, optimization_parameters):
875    super(_AdamHandler, self).__init__(optimization_parameters)
876    self._table_to_m_variables_dict = {}
877    self._table_to_v_variables_dict = {}
878
879  def set_optimization_parameters(self, table_descriptor):
880    table_descriptor.optimization_parameters.adam.beta1 = (
881        self._optimization_parameters.beta1)
882    table_descriptor.optimization_parameters.adam.beta2 = (
883        self._optimization_parameters.beta2)
884    table_descriptor.optimization_parameters.adam.epsilon = (
885        self._optimization_parameters.epsilon)
886    table_descriptor.optimization_parameters.adam.use_non_lazy_adam = (
887        not self._optimization_parameters.lazy_adam)
888    table_descriptor.optimization_parameters.adam.use_sum_inside_sqrt = (
889        self._optimization_parameters.sum_inside_sqrt)
890
891  def get_default_slot_variable_names(self, table):
892    return AdamSlotVariableNames('{}/{}/m'.format(table, 'Adam'),
893                                 '{}/{}/v'.format(table, 'Adam'))
894
895  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
896                               table_config, table_variables):
897    m_initializer = init_ops.zeros_initializer()
898    m_variables = _create_partitioned_variables(
899        name=slot_variable_names.m,
900        num_hosts=num_hosts,
901        vocabulary_size=table_config.vocabulary_size,
902        embedding_dimension=table_config.dimension,
903        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
904        initializer=m_initializer)
905    v_initializer = init_ops.zeros_initializer()
906    v_variables = _create_partitioned_variables(
907        name=slot_variable_names.v,
908        num_hosts=num_hosts,
909        vocabulary_size=table_config.vocabulary_size,
910        embedding_dimension=table_config.dimension,
911        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
912        initializer=v_initializer)
913    slot_variables = AdamSlotVariables(m_variables, v_variables)
914
915    def load_ops_fn():
916      """Returns the retrieve ops for AdaGrad embedding tables.
917
918      Returns:
919        A list of ops to load embedding and slot variables from CPU to TPU.
920      """
921      load_op_list = []
922      for host_id, table_variable, m_variable, v_variable in (zip(
923          range(num_hosts), table_variables,
924          m_variables, v_variables)):
925        with ops.colocate_with(table_variable):
926          load_parameters_op = (
927              tpu_ops.load_tpu_embedding_adam_parameters(
928                  parameters=table_variable,
929                  momenta=m_variable,
930                  velocities=v_variable,
931                  table_name=table,
932                  num_shards=num_hosts,
933                  shard_id=host_id))
934        load_op_list.append(load_parameters_op)
935      return load_op_list
936
937    def retrieve_ops_fn():
938      """Returns the retrieve ops for Adam embedding tables.
939
940      Returns:
941        A list of ops to retrieve embedding and slot variables from TPU to CPU.
942      """
943
944      retrieve_op_list = []
945      for host_id, table_variable, m_variable, v_variable in (zip(
946          range(num_hosts), table_variables,
947          m_variables, v_variables)):
948        with ops.colocate_with(table_variable):
949          retrieved_table, retrieved_m, retrieved_v = (
950              tpu_ops.retrieve_tpu_embedding_adam_parameters(
951                  table_name=table,
952                  num_shards=num_hosts,
953                  shard_id=host_id))
954          retrieve_parameters_op = control_flow_ops.group(
955              state_ops.assign(table_variable, retrieved_table),
956              state_ops.assign(m_variable, retrieved_m),
957              state_ops.assign(v_variable, retrieved_v))
958
959        retrieve_op_list.append(retrieve_parameters_op)
960      return retrieve_op_list
961
962    return slot_variables, load_ops_fn, retrieve_ops_fn
963
964
965class _StochasticGradientDescentHandler(_OptimizerHandler):
966  """Handles stochastic gradient descent specific logic."""
967
968  def set_optimization_parameters(self, table_descriptor):
969    (table_descriptor.optimization_parameters.stochastic_gradient_descent
970     .SetInParent())
971
972  def get_default_slot_variable_names(self, table):
973    return None
974
975  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
976                               table_config, table_variables):
977    del table_config
978
979    def load_ops_fn():
980      """Returns the retrieve ops for AdaGrad embedding tables.
981
982      Returns:
983        A list of ops to load embedding and slot variables from CPU to TPU.
984      """
985      load_op_list = []
986      for host_id, table_variable in (zip(
987          range(num_hosts), table_variables)):
988        with ops.colocate_with(table_variable):
989          load_parameters_op = (
990              tpu_ops
991              .load_tpu_embedding_stochastic_gradient_descent_parameters(
992                  parameters=table_variable,
993                  table_name=table,
994                  num_shards=num_hosts,
995                  shard_id=host_id))
996
997        load_op_list.append(load_parameters_op)
998      return load_op_list
999
1000    def retrieve_ops_fn():
1001      """Returns the retrieve ops for SGD embedding tables.
1002
1003      Returns:
1004        A list of ops to retrieve embedding and slot variables from TPU to CPU.
1005      """
1006
1007      retrieve_op_list = []
1008      for host_id, table_variable in (zip(
1009          range(num_hosts), table_variables)):
1010        with ops.colocate_with(table_variable):
1011          retrieved_table = (
1012              tpu_ops
1013              .retrieve_tpu_embedding_stochastic_gradient_descent_parameters(
1014                  table_name=table,
1015                  num_shards=num_hosts,
1016                  shard_id=host_id))
1017          retrieve_parameters_op = control_flow_ops.group(
1018              state_ops.assign(table_variable, retrieved_table))
1019
1020        retrieve_op_list.append(retrieve_parameters_op)
1021      return retrieve_op_list
1022
1023    return None, load_ops_fn, retrieve_ops_fn
1024
1025
1026def _get_optimization_handler(optimization_parameters):
1027  if isinstance(optimization_parameters, AdagradParameters):
1028    return _AdagradHandler(optimization_parameters)
1029  elif isinstance(optimization_parameters, AdamParameters):
1030    return _AdamHandler(optimization_parameters)
1031  elif isinstance(optimization_parameters, StochasticGradientDescentParameters):
1032    return _StochasticGradientDescentHandler(optimization_parameters)
1033  else:
1034    return NotImplementedError()
1035
1036
1037def _create_ordered_dict(d):
1038  """Create an OrderedDict from Dict."""
1039  return collections.OrderedDict((k, d[k]) for k in sorted(d))
1040
1041
1042def _create_combiners(table_to_config_dict, table_to_features_dict):
1043  """Create a per feature list of combiners, ordered by table."""
1044  combiners = []
1045  for table in table_to_config_dict:
1046    combiner = table_to_config_dict[table].combiner or 'sum'
1047    combiners.extend([combiner] * len(table_to_features_dict[table]))
1048  return combiners
1049
1050
1051def _create_table_to_features_dict(feature_to_table_dict):
1052  """Create mapping from table to a list of its features."""
1053  table_to_features_dict_tmp = {}
1054  for feature, table in six.iteritems(feature_to_table_dict):
1055    if table in table_to_features_dict_tmp:
1056      table_to_features_dict_tmp[table].append(feature)
1057    else:
1058      table_to_features_dict_tmp[table] = [feature]
1059
1060  table_to_features_dict = collections.OrderedDict()
1061  for table in sorted(table_to_features_dict_tmp):
1062    table_to_features_dict[table] = sorted(table_to_features_dict_tmp[table])
1063  return table_to_features_dict
1064
1065
1066def _create_device_fn(hosts):
1067  """Create device_fn() to use with _create_partitioned_variables()."""
1068
1069  def device_fn(op):
1070    """Returns the `device` for `op`."""
1071    part_match = re.match(r'.*/part_(\d+)(/|$)', op.name)
1072
1073    if part_match:
1074      idx = int(part_match.group(1))
1075    else:
1076      raise RuntimeError('Internal Error: '
1077                         'Expected %s to contain /part_*.' % op.name)
1078
1079    device = hosts[idx]
1080    return device
1081
1082  return device_fn
1083
1084
1085def _create_partitioned_variables(name,
1086                                  num_hosts,
1087                                  vocabulary_size,
1088                                  embedding_dimension,
1089                                  initializer,
1090                                  collections=None):  # pylint: disable=redefined-outer-name
1091  """Creates ParitionedVariables based on `num_hosts` for `table`."""
1092  # TODO(shizhiw): automatically place embedding lookup elsewhere?
1093  if vocabulary_size < num_hosts:
1094    raise ValueError('`vocabulary_size`({}) is smaller than `num_hosts`({}). '
1095                     'As TPU embedding is not optimized for small tables, '
1096                     'please consider other ways for this embedding lookup.')
1097
1098  return list(variable_scope.get_variable(
1099      name,
1100      shape=(vocabulary_size, embedding_dimension),
1101      partitioner=partitioned_variables.fixed_size_partitioner(num_hosts),
1102      dtype=dtypes.float32,
1103      initializer=initializer,
1104      collections=collections,
1105      trainable=False))
1106