• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Operations for TPUs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops
24# pylint: disable=wildcard-import,unused-import
25from tensorflow.python.ops import gen_tpu_ops
26from tensorflow.python.ops.gen_tpu_ops import *
27# pylint: enable=wildcard-import,unused-import
28from tensorflow.python.platform import tf_logging as logging
29from tensorflow.python.tpu import tpu_function
30from tensorflow.python.util.tf_export import tf_export
31
32
33def _create_default_group_assignment():
34  num_shards = tpu_function.get_tpu_context().number_of_shards
35  if num_shards is None:
36    logging.warning(
37        "cross_replica_sum should be used within a tpu_shard_context, but "
38        "got unset number_of_shards. Assuming 1.")
39    num_shards = 1
40  group_assignment = [list(range(num_shards))]
41  return group_assignment
42
43
44def all_to_all(x,
45               concat_dimension,
46               split_dimension,
47               split_count,
48               group_assignment=None,
49               name=None):
50  """Exchange data across TPU replicas.
51
52  Args:
53    x: The local tensor.
54    concat_dimension: The dimension number to concatenate.
55    split_dimension: The dimension number to split.
56    split_count: The number of splits, this number must equal to the sub-group
57      size(group_assignment.get_shape()[1])
58    group_assignment: Optional 2d int32 lists with shape [num_groups,
59      num_replicas_per_group]. `group_assignment[i]` represents the replica ids
60      in the ith subgroup.
61    name: Optional op name.
62
63  Returns:
64    A `Tensor` which is concatenated by data from different replicas.
65  """
66  if group_assignment is None:
67    group_assignment = _create_default_group_assignment()
68  return gen_tpu_ops.all_to_all(
69      x,
70      group_assignment,
71      concat_dimension=concat_dimension,
72      split_dimension=split_dimension,
73      split_count=split_count,
74      name=name)
75
76
77@ops.RegisterGradient("AllToAll")
78def _all_to_all_grad(op, grad):
79  # The gradient of a all-to-all is also a all-to-all but the
80  # split_dimension and concat_dimension is swapped.
81  # The gradient with respect to group_assignment is None.
82  return [
83      gen_tpu_ops.all_to_all(
84          grad,
85          op.inputs[1],
86          concat_dimension=op.get_attr("split_dimension"),
87          split_dimension=op.get_attr("concat_dimension"),
88          split_count=op.get_attr("split_count")), None
89  ]
90
91
92@tf_export(v1=["tpu.cross_replica_sum"])
93def cross_replica_sum(x, group_assignment=None, name=None):
94  """Sum the input tensor across replicas according to group_assignment.
95
96  Args:
97    x: The local tensor to the sum.
98    group_assignment: Optional 2d int32 lists with shape [num_groups,
99      num_replicas_per_group]. `group_assignment[i]` represents the replica ids
100      in the ith subgroup.
101    name: Optional op name.
102
103  Returns:
104    A `Tensor` which is summed across replicas.
105  """
106  if group_assignment is None:
107    group_assignment = _create_default_group_assignment()
108
109  return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
110
111
112def collective_permute(x, source_target_pairs, name=None):
113  """Permute the input tensor across replicas given source_target_pairs.
114
115  For each source_target_pair <a, b>, we send replica a's input to replica b.
116  Each replica id must only appear once in the source column. Also it must
117  only appear once in the target column.
118  For the replica id not in the target column, this op returns a zero tensor
119  with the same shape and dtype of the input x.
120
121  For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
122  source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs:
123  `[0, A, B, C]`.
124
125  Args:
126    x: The local tensor to be permuted.
127    source_target_pairs: 2d int lists with shape [num_pairs, 2].
128      source_target_pairs[i][0] represents the source replica id and
129      source_target_pairs[i][1] represents the target replica id.
130    name: Optional op name.
131
132  Returns:
133    A `Tensor` which is permuted.
134  """
135  return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name)
136
137
138@ops.RegisterGradient("CollectivePermute")
139def _collective_permute_grad(op, grad):
140  # The gradient of a collective permute operation is also a collective
141  # permute, but with source/target pairs reversed. The gradient with respect
142  # to input argument `source_target_pairs` is `None`.
143  source_target_pairs = op.inputs[1][:, ::-1]
144  return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None]
145
146
147@ops.RegisterGradient("CrossReplicaSum")
148def _cross_replica_sum_grad(op, grad):
149  # The gradient of a cross replica sum is also a cross-replica sum.
150  # The gradient with respect to group_assignment is None.
151  return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
152
153
154# This extra type checking exists to give a more helpful error message in
155# the common case that uint8 and int64 values are infed. Remove when both
156# types are supported.
157
158_SUPPORTED_INFEED_DTYPES = set([
159    dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32,
160    dtypes.complex64, dtypes.uint32
161])
162
163
164@ops.RegisterGradient("TPUEmbeddingActivations")
165def _embedding_activations_grad(activations_op, grad_wrt_activations):
166  """Saves the gradient of embedding activations ops in a graph collection."""
167  g = ops.get_default_graph()
168  table_id = activations_op.get_attr("table_id")
169  lookup_id = activations_op.get_attr("lookup_id")
170  table_gradients = g.get_collection_ref("tpu_embedding_gradients_table_%d" %
171                                         table_id)
172
173  if not table_gradients:
174    raise RuntimeError(
175        "Gradients for TPUEmbedding have been generated in non-training mode."
176        "This is not expected. Consider putting your Optimizer.minimize code "
177        "behind the training mode condition check. For Estimator, you can "
178        "do \n\n"
179        "    if mode == tf.estimator.ModeKeys.TRAIN:\n"
180        "        train_op = opt.minimize(loss)\n"
181        "\n")
182
183  if lookup_id < 0 or lookup_id >= len(table_gradients):
184    raise RuntimeError(
185        "Gradients (w.r.t. TPUEmbedding activations) generated for table_id {} "
186        "and lookup_id {}. The lookup_id attribute is outside the expected "
187        "range [0, {}).".format(table_id, lookup_id, len(table_gradients)))
188
189  if table_gradients[lookup_id] is not None:
190    raise RuntimeError(
191        "Duplicate gradients (w.r.t. TPUEmbedding activations) generated for "
192        "table_id {} and lookup_id {}. This happens when there are multiple "
193        "calls to tf.gradients in a graph containing TPU embeddings. "
194        "TF cannot identify which gradient to use for updating the embedding "
195        "variables. Consider placing tf.StopGradient around tensors where "
196        "variable update is not required. Previous gradients were generated by "
197        "the following callstack: {}.".format(
198            table_id, lookup_id, table_gradients[lookup_id].op.traceback))
199
200  table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations)
201  return [
202      # RegisterGradient requires that value be returned for all inputs. Since
203      # the first argument (tpu_gradient_variable_{table_name}) has shape [1],
204      # we will return zeros(shape=[1]). The actual gradient w.r.t. the
205      # embedding activations (grad_wrt_activations) has the same shape as the
206      # activations returned by  embedding_activations.
207      array_ops.zeros(arg.shape, dtype=dtypes.float32)
208      for arg in activations_op.inputs
209  ]
210
211
212def infeed_dequeue(dtype, shape, name=None):
213  """A placeholder op for a value that will be fed into the computation.
214
215  Args:
216    dtype: A `tf.DType`. The type of elements in the tensor.
217    shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor.
218    name: A name for the operation (optional).
219
220  Returns:
221    A `Tensor` of type `dtype`.
222    A tensor that will be provided using the infeed mechanism.
223
224  Raises:
225    TypeError: If 'dtype` is not a supported infeed type.
226  """
227  if dtype not in _SUPPORTED_INFEED_DTYPES:
228    raise TypeError(
229        "Operation '{}' has type {} which is not a supported TPU infeed type. "
230        "Supported types are: {}".format(name, dtype,
231                                         list(_SUPPORTED_INFEED_DTYPES)))
232
233  return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name)
234
235
236# pylint: disable=redefined-outer-name
237def infeed_dequeue_tuple(dtypes, shapes, name=None):
238  """A placeholder op for values fed into the TPU simultaneously as a tuple.
239
240  Args:
241    dtypes: A list of `tf.DType`s that has length `>= 1`. The element types of
242      each element in `outputs`.
243    shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). The
244      shapes of each tensor in `outputs`.
245    name: A name for the operation (optional).
246
247  Returns:
248    A list of `Tensor` objects of type `dtypes`.
249    A list of tensors that will be provided using the infeed mechanism.
250
251  Raises:
252    TypeError: If a type in 'dtypes` is not a supported infeed type.
253  """
254  for dtype in dtypes:
255    if dtype not in _SUPPORTED_INFEED_DTYPES:
256      raise TypeError(
257          "{} is not a supported TPU infeed type. Supported types are: "
258          "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
259  return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
260
261
262# pylint: enable=redefined-outer-name
263
264
265# pylint: disable=protected-access
266def send_tpu_embedding_gradients(inputs,
267                                 config,
268                                 learning_rates=None,
269                                 name=None):
270  """A placeholder op for feeding per-sample gradients to the embedding layer.
271
272  Args:
273    inputs: A TensorList of gradients with which to update embedding tables.
274      This argument has the same length and shapes as the return value of
275      RecvTPUEmbeddingActivations, but contains gradients of the model's loss
276      with respect to the embedding activations. The embedding tables are
277      updated from these gradients via the optimizers specified in the TPU
278      embedding configuration given to tpu.initialize_system.
279    config: Serialized TPUEmbeddingConfiguration proto.
280    learning_rates: A TensorList of float32 scalars, one for each dynamic
281        learning rate tag: see the comments in
282          //third_party/tensorflow/core/protobuf/tpu/
283          optimization_parameters.proto. Multiple tables can share the same
284          dynamic learning rate tag as specified in the configuration. If the
285          learning rates for all tables are constant, this list should be empty.
286    name: A name for the operation (optional).
287
288  Returns:
289    A SendTPUEmbeddingGradients operation.
290  """
291  if learning_rates is None:
292    learning_rates = []
293  return gen_tpu_ops.send_tpu_embedding_gradients(
294      inputs=inputs, learning_rates=learning_rates, config=config, name=name)
295
296
297send_tpu_embedding_gradients.__doc__ = (
298    gen_tpu_ops.send_tpu_embedding_gradients.__doc__)
299
300
301# pylint: disable=protected-access
302def enqueue_tpu_embedding_integer_batch(batch,
303                                        device_ordinal,
304                                        mode_override=None,
305                                        name=None):
306  """A placeholder op for enqueueing embedding IDs to the TPU.
307
308  Args:
309    batch: A list of 1D tensors, one for each embedding table, containing the
310      indices into the tables.
311    device_ordinal: The TPU device to use. Should be >= 0 and less than the
312      number of TPU cores in the task on which the node is placed.
313    mode_override: A string input that overrides the mode specified in the
314      TPUEmbeddingConfiguration. Supported values are {'unspecified',
315      'inference', 'train', 'backward_pass_only'}. When set to 'unspecified',
316      the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
317      is used (optional).
318    name: A name for the operation (optional).
319
320  Returns:
321    An EnqueueTPUEmbeddingIntegerBatch operation.
322  """
323  if mode_override is None:
324    mode_override = "unspecified"
325  return gen_tpu_ops.enqueue_tpu_embedding_integer_batch(
326      batch=batch,
327      device_ordinal=device_ordinal,
328      mode_override=mode_override,
329      name=name)
330
331
332enqueue_tpu_embedding_integer_batch.__doc__ = (
333    gen_tpu_ops.enqueue_tpu_embedding_integer_batch.__doc__)
334
335
336# pylint: disable=protected-access
337def enqueue_tpu_embedding_sparse_batch(sample_indices,
338                                       embedding_indices,
339                                       aggregation_weights,
340                                       device_ordinal,
341                                       combiners=None,
342                                       mode_override=None,
343                                       name=None):
344  """A placeholder op for enqueueing embedding IDs to the TPU.
345
346  Args:
347    sample_indices: A list of rank 1 Tensors specifying the training example and
348      feature to which the corresponding embedding_indices and
349      aggregation_weights values belong. sample_indices[i] must equal b * nf +
350      f, where nf is the number of features from the corresponding table, f is
351      in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed,
352      and will be converted to int32 internally.
353    embedding_indices: A list of rank 1 Tensors, indices into the embedding
354      tables. Both int32 and int64 are allowed and will be converted to int32
355      internally.
356    aggregation_weights: A list of rank 1 Tensors containing per sample -- i.e.,
357      per (training example, feature) -- aggregation weights. Both float32 and
358      float64 are allowed and will be converted to float32 internally.
359    device_ordinal: The TPU device to use. Should be >= 0 and less than the
360      number of TPU cores in the task on which the node is placed.
361    combiners: A list of string scalars, one for each embedding table that
362      specify how to normalize the embedding activations after weighted
363      summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
364      invalid to have the sum of the weights be 0 for 'mean' or the sum of the
365      squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
366      is to use 'sum' for all tables (optional).
367    mode_override: A string input that overrides the mode specified in the
368      TPUEmbeddingConfiguration. Supported values are {'unspecified',
369      'inference', 'train', 'backward_pass_only'}. When set to 'unspecified',
370      the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
371      is used (optional).
372    name: A name for the operation (optional).
373
374  Returns:
375    An EnqueueTPUEmbeddingSparseBatch operation.
376  """
377  if mode_override is None:
378    mode_override = "unspecified"
379  return gen_tpu_ops.enqueue_tpu_embedding_sparse_batch(
380      sample_indices=sample_indices,
381      embedding_indices=embedding_indices,
382      aggregation_weights=aggregation_weights,
383      device_ordinal=device_ordinal,
384      combiners=combiners,
385      mode_override=mode_override,
386      name=name)
387
388
389enqueue_tpu_embedding_sparse_batch.__doc__ = (
390    gen_tpu_ops.enqueue_tpu_embedding_sparse_batch.__doc__)
391
392
393# pylint: disable=protected-access
394def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
395                                              embedding_indices,
396                                              aggregation_weights,
397                                              table_ids,
398                                              device_ordinal,
399                                              max_sequence_lengths=None,
400                                              num_features=None,
401                                              combiners=None,
402                                              mode_override=None,
403                                              name=None):
404  """A placeholder op for enqueueing embedding IDs to the TPU.
405
406  Args:
407    sample_indices: A list of rank 2 Tensors specifying the training example to
408      which the corresponding embedding_indices and aggregation_weights values
409      belong. It corresponds to sp_ids.indices in embedding_lookup_sparse(). If
410      the size of its first dimension is 0, we assume each embedding_indices
411      belongs to a different sample. Both int32 and int64 are allowed and will
412      be converted to int32 internally.
413    embedding_indices: A list of rank 1 Tensors, indices into the embedding
414      tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both
415      int32 and int64 are allowed and will be converted to int32 internally.
416    aggregation_weights: A list of rank 1 Tensors containing per training
417      example aggregation weights. It corresponds to sp_weights.values in
418      embedding_lookup_sparse(). If the size of its first dimension is 0, we
419      assume all weights are 1. Both float32 and float64 are allowed and will be
420      converted to float32 internally.
421    table_ids: A list of integers specifying the identifier of the embedding
422      table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
423      lookup the corresponding input. The ith input is looked up using
424      table_ids[i]. The size of the table_ids list must be equal to that of
425      sample_indices, embedding_indices and aggregation_weights.
426    device_ordinal: The TPU device to use. Should be >= 0 and less than the
427      number of TPU cores in the task on which the node is placed.
428    max_sequence_lengths: A list of integers, the size of which is equal to
429      sample_indices. If equal to 0, the corresponding feature is considered to
430      be a non-sequence feature, If greater than 0, the corresponding feature is
431      a sequence feature with the given maximal length. If None, then we assume
432      a list of all zeroes.
433    num_features: A list of integers, the size of which is equal to
434      sample_indices. If non-empty, entries in this list must be at least 1. For
435      each batch element, we will take num_features rows of the input tensor for
436      embedding lookup. E.g., when sample_indices is empty, the embedding
437      indices must be of shape (batch_size*num_features).
438    combiners: A list of string scalars, one for each embedding table that
439      specify how to normalize the embedding activations after weighted
440      summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
441      invalid to have the sum of the weights be 0 for 'mean' or the sum of the
442      squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
443      is to use 'sum' for all tables (optional).
444    mode_override: A string input that overrides the mode specified in the
445      TPUEmbeddingConfiguration. Supported values are {'unspecified',
446      'inference', 'train', 'backward_pass_only'}. When set to 'unspecified',
447      the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
448      is used (optional).
449    name: A name for the operation (optional).
450
451  Returns:
452    An EnqueueTPUEmbeddingSparseTensorBatch operation.
453  """
454  if mode_override is None:
455    mode_override = "unspecified"
456  return gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
457      sample_indices=sample_indices,
458      embedding_indices=embedding_indices,
459      aggregation_weights=aggregation_weights,
460      table_ids=table_ids,
461      device_ordinal=device_ordinal,
462      max_sequence_lengths=max_sequence_lengths,
463      combiners=combiners,
464      mode_override=mode_override,
465      num_features=num_features,
466      name=name)
467
468
469enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = (
470    gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__)
471
472
473# pylint: disable=protected-access
474def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
475                                              embedding_indices,
476                                              aggregation_weights,
477                                              table_ids,
478                                              device_ordinal,
479                                              max_sequence_lengths=None,
480                                              num_features=None,
481                                              combiners=None,
482                                              mode_override=None,
483                                              name=None):
484  """A placeholder op for enqueueing embedding IDs to the TPU.
485
486  Args:
487    sample_splits: A list of rank 1 Tensors specifying the break points for
488      splitting embedding_indices and aggregation_weights into rows. It
489      corresponds to ids.row_splits in embedding_lookup(), when ids is a
490      RaggedTensor. Both int32 and int64 are allowed and will be converted to
491      int32 internally.
492    embedding_indices: A list of rank 1 Tensors, indices into the embedding
493      tables. It corresponds to ids.values in embedding_lookup(), when ids is a
494      RaggedTensor. Both int32 and int64 are allowed and will be converted to
495      int32 internally.
496    aggregation_weights: A list of rank 1 Tensors containing per training
497      example aggregation weights. It corresponds to the values field of a
498      RaggedTensor with the same row_splits as ids in embedding_lookup(), when
499      ids is a RaggedTensor. Both float32 and float64 are allowed and will be
500      converted to float32 internally.
501    table_ids: A list of integers specifying the identifier of the embedding
502      table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
503      lookup the corresponding input. The ith input is looked up using
504      table_ids[i]. The size of the table_ids list must be equal to that of
505      sample_indices, embedding_indices and aggregation_weights.
506    device_ordinal: The TPU device to use. Should be >= 0 and less than the
507      number of TPU cores in the task on which the node is placed.
508    max_sequence_lengths: A list of integers, the size of which is equal to
509      sample_indices. If equal to 0, the corresponding feature is considered to
510      be a non-sequence feature, If greater than 0, the corresponding feature is
511      a sequence feature with the given maximal length. If None, then we assume
512      a list of all zeroes.
513    num_features: A list of integers, the size of which must be equal to
514      sample_indices. If non-empty, entries in this list must be at least 1. For
515      each batch element, we will take num_features rows of the input tensor for
516      embedding lookup. E.g., when sample_indices is empty, the embedding
517      indices must be of shape (batch_size*num_features).
518    combiners: A list of string scalars, one for each embedding table that
519      specify how to normalize the embedding activations after weighted
520      summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
521      invalid to have the sum of the weights be 0 for 'mean' or the sum of the
522      squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
523      is to use 'sum' for all tables (optional).
524    mode_override: A string input that overrides the mode specified in the
525      TPUEmbeddingConfiguration. Supported values are {'unspecified',
526      'inference', 'training', 'backward_pass_only'}. When set to 'unspecified',
527      the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
528      is used (optional).
529    name: A name for the operation (optional).
530
531  Returns:
532    An EnqueueTPUEmbeddingRaggedTensorBatch operation.
533  """
534  if mode_override is None:
535    mode_override = "unspecified"
536  return gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch(
537      sample_splits=sample_splits,
538      embedding_indices=embedding_indices,
539      aggregation_weights=aggregation_weights,
540      table_ids=table_ids,
541      device_ordinal=device_ordinal,
542      max_sequence_lengths=max_sequence_lengths,
543      combiners=combiners,
544      mode_override=mode_override,
545      num_features=num_features,
546      name=name)
547
548
549enqueue_tpu_embedding_ragged_tensor_batch.__doc__ = (
550    gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch.__doc__)
551