1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Operations for TPUs.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import array_ops 25# pylint: disable=wildcard-import,unused-import 26from tensorflow.python.ops import gen_tpu_ops 27from tensorflow.python.ops.gen_tpu_ops import * 28# pylint: enable=wildcard-import,unused-import 29from tensorflow.python.platform import tf_logging as logging 30from tensorflow.python.tpu import tpu_function 31from tensorflow.python.util.tf_export import tf_export 32 33 34def _create_default_group_assignment(): 35 num_shards = tpu_function.get_tpu_context().number_of_shards 36 if num_shards is None: 37 logging.warning( 38 "cross_replica_sum should be used within a tpu_shard_context, but " 39 "got unset number_of_shards. Assuming 1.") 40 num_shards = 1 41 group_assignment = [list(range(num_shards))] 42 return group_assignment 43 44 45def all_to_all(x, 46 concat_dimension, 47 split_dimension, 48 split_count, 49 group_assignment=None, 50 name=None): 51 """Exchange data across TPU replicas. 52 53 Args: 54 x: The local tensor. 55 concat_dimension: The dimension number to concatenate. 56 split_dimension: The dimension number to split. 57 split_count: The number of splits, this number must equal to the sub-group 58 size(group_assignment.get_shape()[1]) 59 group_assignment: Optional 2d int32 lists with shape [num_groups, 60 num_replicas_per_group]. `group_assignment[i]` represents the replica 61 ids in the ith subgroup. 62 name: Optional op name. 63 64 Returns: 65 A `Tensor` which is concatenated by data from different replicas. 66 """ 67 if group_assignment is None: 68 group_assignment = _create_default_group_assignment() 69 return gen_tpu_ops.all_to_all( 70 x, 71 group_assignment, 72 concat_dimension=concat_dimension, 73 split_dimension=split_dimension, 74 split_count=split_count, 75 name=name) 76 77 78@ops.RegisterGradient("AllToAll") 79def _all_to_all_grad(op, grad): 80 # The gradient of a all-to-all is also a all-to-all but the 81 # split_dimension and concat_dimension is swapped. 82 # The graident with respect to group_assignment is None. 83 return [ 84 gen_tpu_ops.all_to_all( 85 grad, 86 op.inputs[1], 87 concat_dimension=op.get_attr("split_dimension"), 88 split_dimension=op.get_attr("concat_dimension"), 89 split_count=op.get_attr("split_count")), None 90 ] 91 92 93@tf_export(v1=["tpu.cross_replica_sum"]) 94def cross_replica_sum(x, group_assignment=None, name=None): 95 """Sum the input tensor across replicas according to group_assignment. 96 97 Args: 98 x: The local tensor to the sum. 99 group_assignment: Optional 2d int32 lists with shape [num_groups, 100 num_replicas_per_group]. `group_assignment[i]` represents the replica 101 ids in the ith subgroup. 102 name: Optional op name. 103 104 Returns: 105 A `Tensor` which is summed across replicas. 106 """ 107 if group_assignment is None: 108 group_assignment = _create_default_group_assignment() 109 110 return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name) 111 112 113def collective_permute(x, source_target_pairs, name=None): 114 """Permute the input tensor across replicas given source_target_pairs. 115 116 For each source_target_pair <a, b>, we send replica a's input to replica b. 117 Each replica id must only appear once in the source column. Also it must 118 only appear once in the target column. 119 For the replica id not in the target column, this op returns a zero tensor 120 with the same shape and dtype of the input x. 121 122 For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing 123 source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs: 124 `[0, A, B, C]`. 125 126 Args: 127 x: The local tensor to be permuted. 128 source_target_pairs: 2d int lists with shape [num_pairs, 2]. 129 source_target_pairs[i][0] represents the source replica id and 130 source_target_pairs[i][1] represents the target replica id. 131 name: Optional op name. 132 133 Returns: 134 A `Tensor` which is permuted. 135 """ 136 return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name) 137 138 139@ops.RegisterGradient("CollectivePermute") 140def _collective_permute_grad(op, grad): 141 # The gradient of a collective permute operation is also a collective 142 # permute, but with source/target pairs reversed. The gradient with respect 143 # to input argument `source_target_pairs` is `None`. 144 source_target_pairs = op.inputs[1][:, ::-1] 145 return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None] 146 147 148@ops.RegisterGradient("CrossReplicaSum") 149def _cross_replica_sum_grad(op, grad): 150 # The gradient of a cross replica sum is also a cross-replica sum. 151 # The gradient with respect to group_assignment is None. 152 return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None] 153 154 155# This extra type checking exists to give a more helpful error message in 156# the common case that uint8 and int64 values are infed. Remove when both 157# types are supported. 158 159_SUPPORTED_INFEED_DTYPES = set([ 160 dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32, 161 dtypes.complex64, dtypes.uint32 162]) 163 164 165@ops.RegisterGradient("TPUEmbeddingActivations") 166def _embedding_activations_grad(activations_op, grad_wrt_activations): 167 """Saves the gradient of embedding activations ops in a graph collection.""" 168 g = ops.get_default_graph() 169 table_id = activations_op.get_attr("table_id") 170 lookup_id = activations_op.get_attr("lookup_id") 171 table_gradients = g.get_collection_ref( 172 "tpu_embedding_gradients_table_%d" % table_id) 173 174 if not table_gradients: 175 raise RuntimeError( 176 "Gradients for TPUEmbedding have been generated in non-training mode." 177 "This is not expected. Consider putting your Optimizer.minimize code " 178 "behind the training mode condition check. For Estimator, you can " 179 "do \n\n" 180 " if mode == tf.estimator.ModeKeys.TRAIN:\n" 181 " train_op = opt.minimize(loss)\n" 182 "\n") 183 184 table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations) 185 return [ 186 # RegisterGradient requires that value be returned for all inputs. Since 187 # the first argument (tpu_gradient_variable_{table_name}) has shape [1], 188 # we will return zeros(shape=[1]). The actual gradient w.r.t. the 189 # embedding activations (grad_wrt_activations) has the same shape as the 190 # activations returned by embedding_activations. 191 array_ops.zeros(arg.shape, dtype=dtypes.float32) 192 for arg in activations_op.inputs 193 ] 194 195 196def infeed_dequeue(dtype, shape, name=None): 197 """A placeholder op for a value that will be fed into the computation. 198 199 Args: 200 dtype: A `tf.DType`. The type of elements in the tensor. 201 shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor. 202 name: A name for the operation (optional). 203 204 Returns: 205 A `Tensor` of type `dtype`. 206 A tensor that will be provided using the infeed mechanism. 207 208 Raises: 209 TypeError: If 'dtype` is not a supported infeed type. 210 """ 211 if dtype not in _SUPPORTED_INFEED_DTYPES: 212 raise TypeError( 213 "Operation '{}' has type {} which is not a supported TPU infeed type. " 214 "Supported types are: {}".format(name, dtype, 215 list(_SUPPORTED_INFEED_DTYPES))) 216 217 return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name) 218 219 220# pylint: disable=redefined-outer-name 221def infeed_dequeue_tuple(dtypes, shapes, name=None): 222 """A placeholder op for values fed into the TPU simultaneously as a tuple. 223 224 Args: 225 dtypes: A list of `tf.DType`s that has length `>= 1`. 226 The element types of each element in `outputs`. 227 shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). 228 The shapes of each tensor in `outputs`. 229 name: A name for the operation (optional). 230 231 Returns: 232 A list of `Tensor` objects of type `dtypes`. 233 A list of tensors that will be provided using the infeed mechanism. 234 235 Raises: 236 TypeError: If a type in 'dtypes` is not a supported infeed type. 237 """ 238 for dtype in dtypes: 239 if dtype not in _SUPPORTED_INFEED_DTYPES: 240 raise TypeError( 241 "{} is not a supported TPU infeed type. Supported types are: " 242 "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES))) 243 return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name) 244# pylint: enable=redefined-outer-name 245 246 247# pylint: disable=protected-access 248def send_tpu_embedding_gradients(inputs, 249 config, 250 learning_rates=None, 251 name=None): 252 """A placeholder op for feeding per-sample gradients to the embedding layer. 253 254 Args: 255 inputs: A TensorList of gradients with which to update embedding tables. 256 This argument has the same length and shapes as the return value of 257 RecvTPUEmbeddingActivations, but contains gradients of the model's 258 loss with respect to the embedding activations. The embedding tables 259 are updated from these gradients via the optimizers specified in the 260 TPU embedding configuration given to tpu.initialize_system. 261 config: Serialized TPUEmbeddingConfiguration proto. 262 learning_rates: A TensorList of float32 scalars, one for each dynamic 263 learning rate tag: see the comments in 264 //third_party/tensorflow/core/protobuf/tpu/ 265 optimization_parameters.proto. 266 Multiple tables can share the same dynamic learning rate tag as 267 specified in the configuration. If the learning rates for all tables 268 are constant, this list should be empty. 269 name: A name for the operation (optional). 270 271 Returns: 272 A SendTPUEmbeddingGradients operation. 273 """ 274 if learning_rates is None: 275 learning_rates = [] 276 return gen_tpu_ops.send_tpu_embedding_gradients( 277 inputs=inputs, learning_rates=learning_rates, config=config, name=name) 278 279 280send_tpu_embedding_gradients.__doc__ = ( 281 gen_tpu_ops.send_tpu_embedding_gradients.__doc__) 282 283 284# pylint: disable=protected-access 285def enqueue_tpu_embedding_integer_batch(batch, 286 device_ordinal, 287 mode_override=None, 288 name=None): 289 """A placeholder op for enqueueing embedding IDs to the TPU. 290 291 Args: 292 batch: A list of 1D tensors, one for each embedding table, containing the 293 indices into the tables. 294 device_ordinal: The TPU device to use. Should be >= 0 and less than the 295 number of TPU cores in the task on which the node is placed. 296 mode_override: A string input that overrides the mode specified in the 297 TPUEmbeddingConfiguration. Supported values are {'unspecified', 298 'inference', 'train', 'backward_pass_only'}. When set to 'unspecified', 299 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 300 is used (optional). 301 name: A name for the operation (optional). 302 303 Returns: 304 An EnqueueTPUEmbeddingIntegerBatch operation. 305 """ 306 if mode_override is None: 307 mode_override = "unspecified" 308 return gen_tpu_ops.enqueue_tpu_embedding_integer_batch( 309 batch=batch, 310 device_ordinal=device_ordinal, 311 mode_override=mode_override, 312 name=name) 313 314 315enqueue_tpu_embedding_integer_batch.__doc__ = ( 316 gen_tpu_ops.enqueue_tpu_embedding_integer_batch.__doc__) 317 318 319# pylint: disable=protected-access 320def enqueue_tpu_embedding_sparse_batch(sample_indices, 321 embedding_indices, 322 aggregation_weights, 323 device_ordinal, 324 combiners=None, 325 mode_override=None, 326 name=None): 327 """A placeholder op for enqueueing embedding IDs to the TPU. 328 329 Args: 330 sample_indices: A list of rank 1 Tensors specifying the training example 331 and feature to which the corresponding embedding_indices and 332 aggregation_weights values belong. sample_indices[i] must equal b * nf + 333 f, where nf is the number of features from the corresponding table, f is 334 in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed, 335 and will be converted to int32 internally. 336 embedding_indices: A list of rank 1 Tensors, indices into the embedding 337 tables. Both int32 and int64 are allowed and will be converted to int32 338 internally. 339 aggregation_weights: A list of rank 1 Tensors containing per sample -- 340 i.e. per (training example, feature) -- aggregation weights. Both float32 341 and float64 are allowed and will be converted to float32 internally. 342 device_ordinal: The TPU device to use. Should be >= 0 and less than the 343 number of TPU cores in the task on which the node is placed. 344 combiners: A list of string scalars, one for each embedding table that 345 specify how to normalize the embedding activations after weighted 346 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is 347 invalid to have the sum of the weights be 0 for 'mean' or the sum of the 348 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default 349 is to use 'sum' for all tables (optional). 350 mode_override: A string input that overrides the mode specified in the 351 TPUEmbeddingConfiguration. Supported values are {'unspecified', 352 'inference', 'train', 'backward_pass_only'}. When set to 'unspecified', 353 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 354 is used (optional). 355 name: A name for the operation (optional). 356 357 Returns: 358 An EnqueueTPUEmbeddingSparseBatch operation. 359 """ 360 if mode_override is None: 361 mode_override = "unspecified" 362 return gen_tpu_ops.enqueue_tpu_embedding_sparse_batch( 363 sample_indices=sample_indices, 364 embedding_indices=embedding_indices, 365 aggregation_weights=aggregation_weights, 366 device_ordinal=device_ordinal, 367 combiners=combiners, 368 mode_override=mode_override, 369 name=name) 370 371 372enqueue_tpu_embedding_sparse_batch.__doc__ = ( 373 gen_tpu_ops.enqueue_tpu_embedding_sparse_batch.__doc__) 374 375 376# pylint: disable=protected-access 377def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices, 378 embedding_indices, 379 aggregation_weights, 380 table_ids, 381 device_ordinal, 382 max_sequence_lengths=None, 383 combiners=None, 384 mode_override=None, 385 name=None): 386 """A placeholder op for enqueueing embedding IDs to the TPU. 387 388 Args: 389 sample_indices: A list of rank 2 Tensors specifying the training example 390 to which the corresponding embedding_indices and aggregation_weights 391 values belong. It corresponds to sp_ids.indices in 392 embedding_lookup_sparse(). If the size of its first dimension is 0, we 393 assume each embedding_indices belongs to a different sample. Both int32 394 and int64 are allowed and will be converted to int32 internally. 395 embedding_indices: A list of rank 1 Tensors, indices into the embedding 396 tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both 397 int32 and int64 are allowed and will be converted to int32 internally. 398 aggregation_weights: A list of rank 1 Tensors containing per training 399 example aggregation weights. It corresponds to sp_weights.values in 400 embedding_lookup_sparse(). If the size of its first dimension is 0, we 401 assume all weights are 1. Both float32 and float64 are allowed and will 402 be converted to float32 internally. 403 table_ids: A list of integers specifying the identifier of the embedding 404 table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to 405 lookup the corresponding input. The ith input is looked up using 406 table_ids[i]. The size of the table_ids list must be equal to that of 407 sample_indices, embedding_indices and aggregation_weights. 408 device_ordinal: The TPU device to use. Should be >= 0 and less than the 409 number of TPU cores in the task on which the node is placed. 410 max_sequence_lengths: A list of integers, the size of which is equal to 411 sample_indices. If equal to 0, the corresponding feature is considered to 412 be a non-sequence feature, If greater than 0, the corresponding feature is 413 a sequence feature with the given maximal length. If None, then we assume 414 a list of all zeroes. 415 combiners: A list of string scalars, one for each embedding table that 416 specify how to normalize the embedding activations after weighted 417 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is 418 invalid to have the sum of the weights be 0 for 'mean' or the sum of the 419 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default 420 is to use 'sum' for all tables (optional). 421 mode_override: A string input that overrides the mode specified in the 422 TPUEmbeddingConfiguration. Supported values are {'unspecified', 423 'inference', 'train', 'backward_pass_only'}. When set to 'unspecified', 424 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 425 is used (optional). 426 name: A name for the operation (optional). 427 428 Returns: 429 An EnqueueTPUEmbeddingSparseTensorBatch operation. 430 """ 431 if mode_override is None: 432 mode_override = "unspecified" 433 return gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( 434 sample_indices=sample_indices, 435 embedding_indices=embedding_indices, 436 aggregation_weights=aggregation_weights, 437 table_ids=table_ids, 438 device_ordinal=device_ordinal, 439 max_sequence_lengths=max_sequence_lengths, 440 combiners=combiners, 441 mode_override=mode_override, 442 name=name) 443 444 445enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = ( 446 gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__) 447