• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Operations for TPUs."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25# pylint: disable=wildcard-import,unused-import
26from tensorflow.python.ops import gen_tpu_ops
27from tensorflow.python.ops.gen_tpu_ops import *
28# pylint: enable=wildcard-import,unused-import
29from tensorflow.python.platform import tf_logging as logging
30from tensorflow.python.tpu import tpu_function
31from tensorflow.python.util.tf_export import tf_export
32
33
34def _create_default_group_assignment():
35  num_shards = tpu_function.get_tpu_context().number_of_shards
36  if num_shards is None:
37    logging.warning(
38        "cross_replica_sum should be used within a tpu_shard_context, but "
39        "got unset number_of_shards. Assuming 1.")
40    num_shards = 1
41  group_assignment = [list(range(num_shards))]
42  return group_assignment
43
44
45def all_to_all(x,
46               concat_dimension,
47               split_dimension,
48               split_count,
49               group_assignment=None,
50               name=None):
51  """Exchange data across TPU replicas.
52
53  Args:
54    x: The local tensor.
55    concat_dimension: The dimension number to concatenate.
56    split_dimension: The dimension number to split.
57    split_count: The number of splits, this number must equal to the sub-group
58      size(group_assignment.get_shape()[1])
59    group_assignment: Optional 2d int32 lists with shape [num_groups,
60      num_replicas_per_group]. `group_assignment[i]` represents the replica
61      ids in the ith subgroup.
62    name: Optional op name.
63
64  Returns:
65    A `Tensor` which is concatenated by data from different replicas.
66  """
67  if group_assignment is None:
68    group_assignment = _create_default_group_assignment()
69  return gen_tpu_ops.all_to_all(
70      x,
71      group_assignment,
72      concat_dimension=concat_dimension,
73      split_dimension=split_dimension,
74      split_count=split_count,
75      name=name)
76
77
78@ops.RegisterGradient("AllToAll")
79def _all_to_all_grad(op, grad):
80  # The gradient of a all-to-all is also a all-to-all but the
81  # split_dimension and concat_dimension is swapped.
82  # The graident with respect to group_assignment is None.
83  return [
84      gen_tpu_ops.all_to_all(
85          grad,
86          op.inputs[1],
87          concat_dimension=op.get_attr("split_dimension"),
88          split_dimension=op.get_attr("concat_dimension"),
89          split_count=op.get_attr("split_count")), None
90  ]
91
92
93@tf_export(v1=["tpu.cross_replica_sum"])
94def cross_replica_sum(x, group_assignment=None, name=None):
95  """Sum the input tensor across replicas according to group_assignment.
96
97  Args:
98    x: The local tensor to the sum.
99    group_assignment: Optional 2d int32 lists with shape [num_groups,
100      num_replicas_per_group]. `group_assignment[i]` represents the replica
101      ids in the ith subgroup.
102    name: Optional op name.
103
104  Returns:
105    A `Tensor` which is summed across replicas.
106  """
107  if group_assignment is None:
108    group_assignment = _create_default_group_assignment()
109
110  return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
111
112
113def collective_permute(x, source_target_pairs, name=None):
114  """Permute the input tensor across replicas given source_target_pairs.
115
116  For each source_target_pair <a, b>, we send replica a's input to replica b.
117  Each replica id must only appear once in the source column. Also it must
118  only appear once in the target column.
119  For the replica id not in the target column, this op returns a zero tensor
120  with the same shape and dtype of the input x.
121
122  For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
123  source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs:
124  `[0, A, B, C]`.
125
126  Args:
127    x: The local tensor to be permuted.
128    source_target_pairs: 2d int lists with shape [num_pairs, 2].
129      source_target_pairs[i][0] represents the source replica id and
130      source_target_pairs[i][1] represents the target replica id.
131    name: Optional op name.
132
133  Returns:
134    A `Tensor` which is permuted.
135  """
136  return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name)
137
138
139@ops.RegisterGradient("CollectivePermute")
140def _collective_permute_grad(op, grad):
141  # The gradient of a collective permute operation is also a collective
142  # permute, but with source/target pairs reversed. The gradient with respect
143  # to input argument `source_target_pairs` is `None`.
144  source_target_pairs = op.inputs[1][:, ::-1]
145  return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None]
146
147
148@ops.RegisterGradient("CrossReplicaSum")
149def _cross_replica_sum_grad(op, grad):
150  # The gradient of a cross replica sum is also a cross-replica sum.
151  # The gradient with respect to group_assignment is None.
152  return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
153
154
155# This extra type checking exists to give a more helpful error message in
156# the common case that uint8 and int64 values are infed. Remove when both
157# types are supported.
158
159_SUPPORTED_INFEED_DTYPES = set([
160    dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32,
161    dtypes.complex64, dtypes.uint32
162])
163
164
165@ops.RegisterGradient("TPUEmbeddingActivations")
166def _embedding_activations_grad(activations_op, grad_wrt_activations):
167  """Saves the gradient of embedding activations ops in a graph collection."""
168  g = ops.get_default_graph()
169  table_id = activations_op.get_attr("table_id")
170  lookup_id = activations_op.get_attr("lookup_id")
171  table_gradients = g.get_collection_ref(
172      "tpu_embedding_gradients_table_%d" % table_id)
173
174  if not table_gradients:
175    raise RuntimeError(
176        "Gradients for TPUEmbedding have been generated in non-training mode."
177        "This is not expected. Consider putting your Optimizer.minimize code "
178        "behind the training mode condition check. For Estimator, you can "
179        "do \n\n"
180        "    if mode == tf.estimator.ModeKeys.TRAIN:\n"
181        "        train_op = opt.minimize(loss)\n"
182        "\n")
183
184  table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations)
185  return [
186      # RegisterGradient requires that value be returned for all inputs. Since
187      # the first argument (tpu_gradient_variable_{table_name}) has shape [1],
188      # we will return zeros(shape=[1]). The actual gradient w.r.t. the
189      # embedding activations (grad_wrt_activations) has the same shape as the
190      # activations returned by  embedding_activations.
191      array_ops.zeros(arg.shape, dtype=dtypes.float32)
192      for arg in activations_op.inputs
193  ]
194
195
196def infeed_dequeue(dtype, shape, name=None):
197  """A placeholder op for a value that will be fed into the computation.
198
199  Args:
200    dtype: A `tf.DType`. The type of elements in the tensor.
201    shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor.
202    name: A name for the operation (optional).
203
204  Returns:
205    A `Tensor` of type `dtype`.
206    A tensor that will be provided using the infeed mechanism.
207
208  Raises:
209    TypeError: If 'dtype` is not a supported infeed type.
210  """
211  if dtype not in _SUPPORTED_INFEED_DTYPES:
212    raise TypeError(
213        "Operation '{}' has type {} which is not a supported TPU infeed type. "
214        "Supported types are: {}".format(name, dtype,
215                                         list(_SUPPORTED_INFEED_DTYPES)))
216
217  return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name)
218
219
220# pylint: disable=redefined-outer-name
221def infeed_dequeue_tuple(dtypes, shapes, name=None):
222  """A placeholder op for values fed into the TPU simultaneously as a tuple.
223
224  Args:
225    dtypes: A list of `tf.DType`s that has length `>= 1`.
226      The element types of each element in `outputs`.
227    shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`).
228      The shapes of each tensor in `outputs`.
229    name: A name for the operation (optional).
230
231  Returns:
232    A list of `Tensor` objects of type `dtypes`.
233    A list of tensors that will be provided using the infeed mechanism.
234
235  Raises:
236    TypeError: If a type in 'dtypes` is not a supported infeed type.
237  """
238  for dtype in dtypes:
239    if dtype not in _SUPPORTED_INFEED_DTYPES:
240      raise TypeError(
241          "{} is not a supported TPU infeed type. Supported types are: "
242          "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
243  return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
244# pylint: enable=redefined-outer-name
245
246
247# pylint: disable=protected-access
248def send_tpu_embedding_gradients(inputs,
249                                 config,
250                                 learning_rates=None,
251                                 name=None):
252  """A placeholder op for feeding per-sample gradients to the embedding layer.
253
254  Args:
255    inputs: A TensorList of gradients with which to update embedding tables.
256        This argument has the same length and shapes as the return value of
257        RecvTPUEmbeddingActivations, but contains gradients of the model's
258        loss with respect to the embedding activations. The embedding tables
259        are updated from these gradients via the optimizers specified in the
260        TPU embedding configuration given to tpu.initialize_system.
261    config: Serialized TPUEmbeddingConfiguration proto.
262    learning_rates: A TensorList of float32 scalars, one for each dynamic
263        learning rate tag: see the comments in
264        //third_party/tensorflow/core/protobuf/tpu/
265                                             optimization_parameters.proto.
266        Multiple tables can share the same dynamic learning rate tag as
267        specified in the configuration. If the learning rates for all tables
268        are constant, this list should be empty.
269    name: A name for the operation (optional).
270
271  Returns:
272    A SendTPUEmbeddingGradients operation.
273  """
274  if learning_rates is None:
275    learning_rates = []
276  return gen_tpu_ops.send_tpu_embedding_gradients(
277      inputs=inputs, learning_rates=learning_rates, config=config, name=name)
278
279
280send_tpu_embedding_gradients.__doc__ = (
281    gen_tpu_ops.send_tpu_embedding_gradients.__doc__)
282
283
284# pylint: disable=protected-access
285def enqueue_tpu_embedding_integer_batch(batch,
286                                        device_ordinal,
287                                        mode_override=None,
288                                        name=None):
289  """A placeholder op for enqueueing embedding IDs to the TPU.
290
291  Args:
292    batch: A list of 1D tensors, one for each embedding table, containing the
293      indices into the tables.
294    device_ordinal: The TPU device to use. Should be >= 0 and less than the
295      number of TPU cores in the task on which the node is placed.
296    mode_override: A string input that overrides the mode specified in the
297      TPUEmbeddingConfiguration. Supported values are {'unspecified',
298      'inference', 'train', 'backward_pass_only'}. When set to 'unspecified',
299      the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
300      is used (optional).
301    name: A name for the operation (optional).
302
303  Returns:
304    An EnqueueTPUEmbeddingIntegerBatch operation.
305  """
306  if mode_override is None:
307    mode_override = "unspecified"
308  return gen_tpu_ops.enqueue_tpu_embedding_integer_batch(
309      batch=batch,
310      device_ordinal=device_ordinal,
311      mode_override=mode_override,
312      name=name)
313
314
315enqueue_tpu_embedding_integer_batch.__doc__ = (
316    gen_tpu_ops.enqueue_tpu_embedding_integer_batch.__doc__)
317
318
319# pylint: disable=protected-access
320def enqueue_tpu_embedding_sparse_batch(sample_indices,
321                                       embedding_indices,
322                                       aggregation_weights,
323                                       device_ordinal,
324                                       combiners=None,
325                                       mode_override=None,
326                                       name=None):
327  """A placeholder op for enqueueing embedding IDs to the TPU.
328
329  Args:
330    sample_indices: A list of rank 1 Tensors specifying the training example
331      and feature to which the corresponding embedding_indices and
332      aggregation_weights values belong. sample_indices[i] must equal b * nf +
333      f, where nf is the number of features from the corresponding table, f is
334      in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed,
335      and will be converted to int32 internally.
336    embedding_indices: A list of rank 1 Tensors, indices into the embedding
337      tables. Both int32 and int64 are allowed and will be converted to int32
338      internally.
339    aggregation_weights: A list of rank 1 Tensors containing per sample --
340      i.e. per (training example, feature) -- aggregation weights. Both float32
341      and float64 are allowed and will be converted to float32 internally.
342    device_ordinal: The TPU device to use. Should be >= 0 and less than the
343      number of TPU cores in the task on which the node is placed.
344    combiners: A list of string scalars, one for each embedding table that
345      specify how to normalize the embedding activations after weighted
346      summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
347      invalid to have the sum of the weights be 0 for 'mean' or the sum of the
348      squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
349      is to use 'sum' for all tables (optional).
350    mode_override: A string input that overrides the mode specified in the
351      TPUEmbeddingConfiguration. Supported values are {'unspecified',
352      'inference', 'train', 'backward_pass_only'}. When set to 'unspecified',
353      the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
354      is used (optional).
355    name: A name for the operation (optional).
356
357  Returns:
358    An EnqueueTPUEmbeddingSparseBatch operation.
359  """
360  if mode_override is None:
361    mode_override = "unspecified"
362  return gen_tpu_ops.enqueue_tpu_embedding_sparse_batch(
363      sample_indices=sample_indices,
364      embedding_indices=embedding_indices,
365      aggregation_weights=aggregation_weights,
366      device_ordinal=device_ordinal,
367      combiners=combiners,
368      mode_override=mode_override,
369      name=name)
370
371
372enqueue_tpu_embedding_sparse_batch.__doc__ = (
373    gen_tpu_ops.enqueue_tpu_embedding_sparse_batch.__doc__)
374
375
376# pylint: disable=protected-access
377def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
378                                              embedding_indices,
379                                              aggregation_weights,
380                                              table_ids,
381                                              device_ordinal,
382                                              max_sequence_lengths=None,
383                                              combiners=None,
384                                              mode_override=None,
385                                              name=None):
386  """A placeholder op for enqueueing embedding IDs to the TPU.
387
388  Args:
389    sample_indices: A list of rank 2 Tensors specifying the training example
390      to which the corresponding embedding_indices and aggregation_weights
391      values belong. It corresponds to sp_ids.indices in
392      embedding_lookup_sparse(). If the size of its first dimension is 0, we
393      assume each embedding_indices belongs to a different sample. Both int32
394      and int64 are allowed and will be converted to int32 internally.
395    embedding_indices: A list of rank 1 Tensors, indices into the embedding
396      tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both
397      int32 and int64 are allowed and will be converted to int32 internally.
398    aggregation_weights: A list of rank 1 Tensors containing per training
399      example aggregation weights. It corresponds to sp_weights.values in
400      embedding_lookup_sparse(). If the size of its first dimension is 0, we
401      assume all weights are 1. Both float32 and float64 are allowed and will
402      be converted to float32 internally.
403    table_ids: A list of integers specifying the identifier of the embedding
404      table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
405      lookup the corresponding input. The ith input is looked up using
406      table_ids[i]. The size of the table_ids list must be equal to that of
407      sample_indices, embedding_indices and aggregation_weights.
408    device_ordinal: The TPU device to use. Should be >= 0 and less than the
409      number of TPU cores in the task on which the node is placed.
410    max_sequence_lengths: A list of integers, the size of which is equal to
411      sample_indices. If equal to 0, the corresponding feature is considered to
412      be a non-sequence feature, If greater than 0, the corresponding feature is
413      a sequence feature with the given maximal length. If None, then we assume
414      a list of all zeroes.
415    combiners: A list of string scalars, one for each embedding table that
416      specify how to normalize the embedding activations after weighted
417      summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
418      invalid to have the sum of the weights be 0 for 'mean' or the sum of the
419      squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
420      is to use 'sum' for all tables (optional).
421    mode_override: A string input that overrides the mode specified in the
422      TPUEmbeddingConfiguration. Supported values are {'unspecified',
423      'inference', 'train', 'backward_pass_only'}. When set to 'unspecified',
424      the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
425      is used (optional).
426    name: A name for the operation (optional).
427
428  Returns:
429    An EnqueueTPUEmbeddingSparseTensorBatch operation.
430  """
431  if mode_override is None:
432    mode_override = "unspecified"
433  return gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
434      sample_indices=sample_indices,
435      embedding_indices=embedding_indices,
436      aggregation_weights=aggregation_weights,
437      table_ids=table_ids,
438      device_ordinal=device_ordinal,
439      max_sequence_lengths=max_sequence_lengths,
440      combiners=combiners,
441      mode_override=mode_override,
442      name=name)
443
444
445enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = (
446    gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__)
447