1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Operations for TPUs.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.ops import array_ops 24# pylint: disable=wildcard-import,unused-import 25from tensorflow.python.ops import gen_tpu_ops 26from tensorflow.python.ops.gen_tpu_ops import * 27# pylint: enable=wildcard-import,unused-import 28from tensorflow.python.platform import tf_logging as logging 29from tensorflow.python.tpu import tpu_function 30from tensorflow.python.util.tf_export import tf_export 31 32 33def _create_default_group_assignment(): 34 num_shards = tpu_function.get_tpu_context().number_of_shards 35 if num_shards is None: 36 logging.warning( 37 "cross_replica_sum should be used within a tpu_shard_context, but " 38 "got unset number_of_shards. Assuming 1.") 39 num_shards = 1 40 group_assignment = [list(range(num_shards))] 41 return group_assignment 42 43 44def all_to_all(x, 45 concat_dimension, 46 split_dimension, 47 split_count, 48 group_assignment=None, 49 name=None): 50 """Exchange data across TPU replicas. 51 52 Args: 53 x: The local tensor. 54 concat_dimension: The dimension number to concatenate. 55 split_dimension: The dimension number to split. 56 split_count: The number of splits, this number must equal to the sub-group 57 size(group_assignment.get_shape()[1]) 58 group_assignment: Optional 2d int32 lists with shape [num_groups, 59 num_replicas_per_group]. `group_assignment[i]` represents the replica ids 60 in the ith subgroup. 61 name: Optional op name. 62 63 Returns: 64 A `Tensor` which is concatenated by data from different replicas. 65 """ 66 if group_assignment is None: 67 group_assignment = _create_default_group_assignment() 68 return gen_tpu_ops.all_to_all( 69 x, 70 group_assignment, 71 concat_dimension=concat_dimension, 72 split_dimension=split_dimension, 73 split_count=split_count, 74 name=name) 75 76 77@ops.RegisterGradient("AllToAll") 78def _all_to_all_grad(op, grad): 79 # The gradient of a all-to-all is also a all-to-all but the 80 # split_dimension and concat_dimension is swapped. 81 # The gradient with respect to group_assignment is None. 82 return [ 83 gen_tpu_ops.all_to_all( 84 grad, 85 op.inputs[1], 86 concat_dimension=op.get_attr("split_dimension"), 87 split_dimension=op.get_attr("concat_dimension"), 88 split_count=op.get_attr("split_count")), None 89 ] 90 91 92@tf_export(v1=["tpu.cross_replica_sum"]) 93def cross_replica_sum(x, group_assignment=None, name=None): 94 """Sum the input tensor across replicas according to group_assignment. 95 96 Args: 97 x: The local tensor to the sum. 98 group_assignment: Optional 2d int32 lists with shape [num_groups, 99 num_replicas_per_group]. `group_assignment[i]` represents the replica ids 100 in the ith subgroup. 101 name: Optional op name. 102 103 Returns: 104 A `Tensor` which is summed across replicas. 105 """ 106 if group_assignment is None: 107 group_assignment = _create_default_group_assignment() 108 109 return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name) 110 111 112def collective_permute(x, source_target_pairs, name=None): 113 """Permute the input tensor across replicas given source_target_pairs. 114 115 For each source_target_pair <a, b>, we send replica a's input to replica b. 116 Each replica id must only appear once in the source column. Also it must 117 only appear once in the target column. 118 For the replica id not in the target column, this op returns a zero tensor 119 with the same shape and dtype of the input x. 120 121 For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing 122 source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs: 123 `[0, A, B, C]`. 124 125 Args: 126 x: The local tensor to be permuted. 127 source_target_pairs: 2d int lists with shape [num_pairs, 2]. 128 source_target_pairs[i][0] represents the source replica id and 129 source_target_pairs[i][1] represents the target replica id. 130 name: Optional op name. 131 132 Returns: 133 A `Tensor` which is permuted. 134 """ 135 return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name) 136 137 138@ops.RegisterGradient("CollectivePermute") 139def _collective_permute_grad(op, grad): 140 # The gradient of a collective permute operation is also a collective 141 # permute, but with source/target pairs reversed. The gradient with respect 142 # to input argument `source_target_pairs` is `None`. 143 source_target_pairs = op.inputs[1][:, ::-1] 144 return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None] 145 146 147@ops.RegisterGradient("CrossReplicaSum") 148def _cross_replica_sum_grad(op, grad): 149 # The gradient of a cross replica sum is also a cross-replica sum. 150 # The gradient with respect to group_assignment is None. 151 return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None] 152 153 154# This extra type checking exists to give a more helpful error message in 155# the common case that uint8 and int64 values are infed. Remove when both 156# types are supported. 157 158_SUPPORTED_INFEED_DTYPES = set([ 159 dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32, 160 dtypes.complex64, dtypes.uint32 161]) 162 163 164@ops.RegisterGradient("TPUEmbeddingActivations") 165def _embedding_activations_grad(activations_op, grad_wrt_activations): 166 """Saves the gradient of embedding activations ops in a graph collection.""" 167 g = ops.get_default_graph() 168 table_id = activations_op.get_attr("table_id") 169 lookup_id = activations_op.get_attr("lookup_id") 170 table_gradients = g.get_collection_ref("tpu_embedding_gradients_table_%d" % 171 table_id) 172 173 if not table_gradients: 174 raise RuntimeError( 175 "Gradients for TPUEmbedding have been generated in non-training mode." 176 "This is not expected. Consider putting your Optimizer.minimize code " 177 "behind the training mode condition check. For Estimator, you can " 178 "do \n\n" 179 " if mode == tf.estimator.ModeKeys.TRAIN:\n" 180 " train_op = opt.minimize(loss)\n" 181 "\n") 182 183 if lookup_id < 0 or lookup_id >= len(table_gradients): 184 raise RuntimeError( 185 "Gradients (w.r.t. TPUEmbedding activations) generated for table_id {} " 186 "and lookup_id {}. The lookup_id attribute is outside the expected " 187 "range [0, {}).".format(table_id, lookup_id, len(table_gradients))) 188 189 if table_gradients[lookup_id] is not None: 190 raise RuntimeError( 191 "Duplicate gradients (w.r.t. TPUEmbedding activations) generated for " 192 "table_id {} and lookup_id {}. This happens when there are multiple " 193 "calls to tf.gradients in a graph containing TPU embeddings. " 194 "TF cannot identify which gradient to use for updating the embedding " 195 "variables. Consider placing tf.StopGradient around tensors where " 196 "variable update is not required. Previous gradients were generated by " 197 "the following callstack: {}.".format( 198 table_id, lookup_id, table_gradients[lookup_id].op.traceback)) 199 200 table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations) 201 return [ 202 # RegisterGradient requires that value be returned for all inputs. Since 203 # the first argument (tpu_gradient_variable_{table_name}) has shape [1], 204 # we will return zeros(shape=[1]). The actual gradient w.r.t. the 205 # embedding activations (grad_wrt_activations) has the same shape as the 206 # activations returned by embedding_activations. 207 array_ops.zeros(arg.shape, dtype=dtypes.float32) 208 for arg in activations_op.inputs 209 ] 210 211 212def infeed_dequeue(dtype, shape, name=None): 213 """A placeholder op for a value that will be fed into the computation. 214 215 Args: 216 dtype: A `tf.DType`. The type of elements in the tensor. 217 shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor. 218 name: A name for the operation (optional). 219 220 Returns: 221 A `Tensor` of type `dtype`. 222 A tensor that will be provided using the infeed mechanism. 223 224 Raises: 225 TypeError: If 'dtype` is not a supported infeed type. 226 """ 227 if dtype not in _SUPPORTED_INFEED_DTYPES: 228 raise TypeError( 229 "Operation '{}' has type {} which is not a supported TPU infeed type. " 230 "Supported types are: {}".format(name, dtype, 231 list(_SUPPORTED_INFEED_DTYPES))) 232 233 return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name) 234 235 236# pylint: disable=redefined-outer-name 237def infeed_dequeue_tuple(dtypes, shapes, name=None): 238 """A placeholder op for values fed into the TPU simultaneously as a tuple. 239 240 Args: 241 dtypes: A list of `tf.DType`s that has length `>= 1`. The element types of 242 each element in `outputs`. 243 shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). The 244 shapes of each tensor in `outputs`. 245 name: A name for the operation (optional). 246 247 Returns: 248 A list of `Tensor` objects of type `dtypes`. 249 A list of tensors that will be provided using the infeed mechanism. 250 251 Raises: 252 TypeError: If a type in 'dtypes` is not a supported infeed type. 253 """ 254 for dtype in dtypes: 255 if dtype not in _SUPPORTED_INFEED_DTYPES: 256 raise TypeError( 257 "{} is not a supported TPU infeed type. Supported types are: " 258 "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES))) 259 return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name) 260 261 262# pylint: enable=redefined-outer-name 263 264 265# pylint: disable=protected-access 266def send_tpu_embedding_gradients(inputs, 267 config, 268 learning_rates=None, 269 name=None): 270 """A placeholder op for feeding per-sample gradients to the embedding layer. 271 272 Args: 273 inputs: A TensorList of gradients with which to update embedding tables. 274 This argument has the same length and shapes as the return value of 275 RecvTPUEmbeddingActivations, but contains gradients of the model's loss 276 with respect to the embedding activations. The embedding tables are 277 updated from these gradients via the optimizers specified in the TPU 278 embedding configuration given to tpu.initialize_system. 279 config: Serialized TPUEmbeddingConfiguration proto. 280 learning_rates: A TensorList of float32 scalars, one for each dynamic 281 learning rate tag: see the comments in 282 //third_party/tensorflow/core/protobuf/tpu/ 283 optimization_parameters.proto. Multiple tables can share the same 284 dynamic learning rate tag as specified in the configuration. If the 285 learning rates for all tables are constant, this list should be empty. 286 name: A name for the operation (optional). 287 288 Returns: 289 A SendTPUEmbeddingGradients operation. 290 """ 291 if learning_rates is None: 292 learning_rates = [] 293 return gen_tpu_ops.send_tpu_embedding_gradients( 294 inputs=inputs, learning_rates=learning_rates, config=config, name=name) 295 296 297send_tpu_embedding_gradients.__doc__ = ( 298 gen_tpu_ops.send_tpu_embedding_gradients.__doc__) 299 300 301# pylint: disable=protected-access 302def enqueue_tpu_embedding_integer_batch(batch, 303 device_ordinal, 304 mode_override=None, 305 name=None): 306 """A placeholder op for enqueueing embedding IDs to the TPU. 307 308 Args: 309 batch: A list of 1D tensors, one for each embedding table, containing the 310 indices into the tables. 311 device_ordinal: The TPU device to use. Should be >= 0 and less than the 312 number of TPU cores in the task on which the node is placed. 313 mode_override: A string input that overrides the mode specified in the 314 TPUEmbeddingConfiguration. Supported values are {'unspecified', 315 'inference', 'train', 'backward_pass_only'}. When set to 'unspecified', 316 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 317 is used (optional). 318 name: A name for the operation (optional). 319 320 Returns: 321 An EnqueueTPUEmbeddingIntegerBatch operation. 322 """ 323 if mode_override is None: 324 mode_override = "unspecified" 325 return gen_tpu_ops.enqueue_tpu_embedding_integer_batch( 326 batch=batch, 327 device_ordinal=device_ordinal, 328 mode_override=mode_override, 329 name=name) 330 331 332enqueue_tpu_embedding_integer_batch.__doc__ = ( 333 gen_tpu_ops.enqueue_tpu_embedding_integer_batch.__doc__) 334 335 336# pylint: disable=protected-access 337def enqueue_tpu_embedding_sparse_batch(sample_indices, 338 embedding_indices, 339 aggregation_weights, 340 device_ordinal, 341 combiners=None, 342 mode_override=None, 343 name=None): 344 """A placeholder op for enqueueing embedding IDs to the TPU. 345 346 Args: 347 sample_indices: A list of rank 1 Tensors specifying the training example and 348 feature to which the corresponding embedding_indices and 349 aggregation_weights values belong. sample_indices[i] must equal b * nf + 350 f, where nf is the number of features from the corresponding table, f is 351 in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed, 352 and will be converted to int32 internally. 353 embedding_indices: A list of rank 1 Tensors, indices into the embedding 354 tables. Both int32 and int64 are allowed and will be converted to int32 355 internally. 356 aggregation_weights: A list of rank 1 Tensors containing per sample -- i.e., 357 per (training example, feature) -- aggregation weights. Both float32 and 358 float64 are allowed and will be converted to float32 internally. 359 device_ordinal: The TPU device to use. Should be >= 0 and less than the 360 number of TPU cores in the task on which the node is placed. 361 combiners: A list of string scalars, one for each embedding table that 362 specify how to normalize the embedding activations after weighted 363 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is 364 invalid to have the sum of the weights be 0 for 'mean' or the sum of the 365 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default 366 is to use 'sum' for all tables (optional). 367 mode_override: A string input that overrides the mode specified in the 368 TPUEmbeddingConfiguration. Supported values are {'unspecified', 369 'inference', 'train', 'backward_pass_only'}. When set to 'unspecified', 370 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 371 is used (optional). 372 name: A name for the operation (optional). 373 374 Returns: 375 An EnqueueTPUEmbeddingSparseBatch operation. 376 """ 377 if mode_override is None: 378 mode_override = "unspecified" 379 return gen_tpu_ops.enqueue_tpu_embedding_sparse_batch( 380 sample_indices=sample_indices, 381 embedding_indices=embedding_indices, 382 aggregation_weights=aggregation_weights, 383 device_ordinal=device_ordinal, 384 combiners=combiners, 385 mode_override=mode_override, 386 name=name) 387 388 389enqueue_tpu_embedding_sparse_batch.__doc__ = ( 390 gen_tpu_ops.enqueue_tpu_embedding_sparse_batch.__doc__) 391 392 393# pylint: disable=protected-access 394def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices, 395 embedding_indices, 396 aggregation_weights, 397 table_ids, 398 device_ordinal, 399 max_sequence_lengths=None, 400 num_features=None, 401 combiners=None, 402 mode_override=None, 403 name=None): 404 """A placeholder op for enqueueing embedding IDs to the TPU. 405 406 Args: 407 sample_indices: A list of rank 2 Tensors specifying the training example to 408 which the corresponding embedding_indices and aggregation_weights values 409 belong. It corresponds to sp_ids.indices in embedding_lookup_sparse(). If 410 the size of its first dimension is 0, we assume each embedding_indices 411 belongs to a different sample. Both int32 and int64 are allowed and will 412 be converted to int32 internally. 413 embedding_indices: A list of rank 1 Tensors, indices into the embedding 414 tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both 415 int32 and int64 are allowed and will be converted to int32 internally. 416 aggregation_weights: A list of rank 1 Tensors containing per training 417 example aggregation weights. It corresponds to sp_weights.values in 418 embedding_lookup_sparse(). If the size of its first dimension is 0, we 419 assume all weights are 1. Both float32 and float64 are allowed and will be 420 converted to float32 internally. 421 table_ids: A list of integers specifying the identifier of the embedding 422 table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to 423 lookup the corresponding input. The ith input is looked up using 424 table_ids[i]. The size of the table_ids list must be equal to that of 425 sample_indices, embedding_indices and aggregation_weights. 426 device_ordinal: The TPU device to use. Should be >= 0 and less than the 427 number of TPU cores in the task on which the node is placed. 428 max_sequence_lengths: A list of integers, the size of which is equal to 429 sample_indices. If equal to 0, the corresponding feature is considered to 430 be a non-sequence feature, If greater than 0, the corresponding feature is 431 a sequence feature with the given maximal length. If None, then we assume 432 a list of all zeroes. 433 num_features: A list of integers, the size of which is equal to 434 sample_indices. If non-empty, entries in this list must be at least 1. For 435 each batch element, we will take num_features rows of the input tensor for 436 embedding lookup. E.g., when sample_indices is empty, the embedding 437 indices must be of shape (batch_size*num_features). 438 combiners: A list of string scalars, one for each embedding table that 439 specify how to normalize the embedding activations after weighted 440 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is 441 invalid to have the sum of the weights be 0 for 'mean' or the sum of the 442 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default 443 is to use 'sum' for all tables (optional). 444 mode_override: A string input that overrides the mode specified in the 445 TPUEmbeddingConfiguration. Supported values are {'unspecified', 446 'inference', 'train', 'backward_pass_only'}. When set to 'unspecified', 447 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 448 is used (optional). 449 name: A name for the operation (optional). 450 451 Returns: 452 An EnqueueTPUEmbeddingSparseTensorBatch operation. 453 """ 454 if mode_override is None: 455 mode_override = "unspecified" 456 return gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( 457 sample_indices=sample_indices, 458 embedding_indices=embedding_indices, 459 aggregation_weights=aggregation_weights, 460 table_ids=table_ids, 461 device_ordinal=device_ordinal, 462 max_sequence_lengths=max_sequence_lengths, 463 combiners=combiners, 464 mode_override=mode_override, 465 num_features=num_features, 466 name=name) 467 468 469enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = ( 470 gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__) 471 472 473# pylint: disable=protected-access 474def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits, 475 embedding_indices, 476 aggregation_weights, 477 table_ids, 478 device_ordinal, 479 max_sequence_lengths=None, 480 num_features=None, 481 combiners=None, 482 mode_override=None, 483 name=None): 484 """A placeholder op for enqueueing embedding IDs to the TPU. 485 486 Args: 487 sample_splits: A list of rank 1 Tensors specifying the break points for 488 splitting embedding_indices and aggregation_weights into rows. It 489 corresponds to ids.row_splits in embedding_lookup(), when ids is a 490 RaggedTensor. Both int32 and int64 are allowed and will be converted to 491 int32 internally. 492 embedding_indices: A list of rank 1 Tensors, indices into the embedding 493 tables. It corresponds to ids.values in embedding_lookup(), when ids is a 494 RaggedTensor. Both int32 and int64 are allowed and will be converted to 495 int32 internally. 496 aggregation_weights: A list of rank 1 Tensors containing per training 497 example aggregation weights. It corresponds to the values field of a 498 RaggedTensor with the same row_splits as ids in embedding_lookup(), when 499 ids is a RaggedTensor. Both float32 and float64 are allowed and will be 500 converted to float32 internally. 501 table_ids: A list of integers specifying the identifier of the embedding 502 table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to 503 lookup the corresponding input. The ith input is looked up using 504 table_ids[i]. The size of the table_ids list must be equal to that of 505 sample_indices, embedding_indices and aggregation_weights. 506 device_ordinal: The TPU device to use. Should be >= 0 and less than the 507 number of TPU cores in the task on which the node is placed. 508 max_sequence_lengths: A list of integers, the size of which is equal to 509 sample_indices. If equal to 0, the corresponding feature is considered to 510 be a non-sequence feature, If greater than 0, the corresponding feature is 511 a sequence feature with the given maximal length. If None, then we assume 512 a list of all zeroes. 513 num_features: A list of integers, the size of which must be equal to 514 sample_indices. If non-empty, entries in this list must be at least 1. For 515 each batch element, we will take num_features rows of the input tensor for 516 embedding lookup. E.g., when sample_indices is empty, the embedding 517 indices must be of shape (batch_size*num_features). 518 combiners: A list of string scalars, one for each embedding table that 519 specify how to normalize the embedding activations after weighted 520 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is 521 invalid to have the sum of the weights be 0 for 'mean' or the sum of the 522 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default 523 is to use 'sum' for all tables (optional). 524 mode_override: A string input that overrides the mode specified in the 525 TPUEmbeddingConfiguration. Supported values are {'unspecified', 526 'inference', 'training', 'backward_pass_only'}. When set to 'unspecified', 527 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 528 is used (optional). 529 name: A name for the operation (optional). 530 531 Returns: 532 An EnqueueTPUEmbeddingRaggedTensorBatch operation. 533 """ 534 if mode_override is None: 535 mode_override = "unspecified" 536 return gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch( 537 sample_splits=sample_splits, 538 embedding_indices=embedding_indices, 539 aggregation_weights=aggregation_weights, 540 table_ids=table_ids, 541 device_ordinal=device_ordinal, 542 max_sequence_lengths=max_sequence_lengths, 543 combiners=combiners, 544 mode_override=mode_override, 545 num_features=num_features, 546 name=name) 547 548 549enqueue_tpu_embedding_ragged_tensor_batch.__doc__ = ( 550 gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch.__doc__) 551