1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Operations for TPUs.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import array_ops 25# pylint: disable=wildcard-import,unused-import 26from tensorflow.python.ops import gen_tpu_ops 27from tensorflow.python.ops.gen_tpu_ops import * 28# pylint: enable=wildcard-import,unused-import 29from tensorflow.python.platform import tf_logging as logging 30from tensorflow.python.tpu import tpu_function 31from tensorflow.python.util.tf_export import tf_export 32 33 34def _create_default_group_assignment(): 35 num_shards = tpu_function.get_tpu_context().number_of_shards 36 if num_shards is None: 37 logging.warning( 38 "cross_replica_sum should be used within a tpu_shard_context, but " 39 "got unset number_of_shards. Assuming 1.") 40 num_shards = 1 41 group_assignment = [list(range(num_shards))] 42 return group_assignment 43 44 45def all_to_all(x, 46 concat_dimension, 47 split_dimension, 48 split_count, 49 group_assignment=None, 50 name=None): 51 """Exchange data across TPU replicas. 52 53 Args: 54 x: The local tensor. 55 concat_dimension: The dimension number to concatenate. 56 split_dimension: The dimension number to split. 57 split_count: The number of splits, this number must equal to the sub-group 58 size(group_assignment.get_shape()[1]) 59 group_assignment: Optional 2d int32 lists with shape [num_groups, 60 num_replicas_per_group]. `group_assignment[i]` represents the replica 61 ids in the ith subgroup. 62 name: Optional op name. 63 64 Returns: 65 A `Tensor` which is concatenated by data from different replicas. 66 """ 67 if group_assignment is None: 68 group_assignment = _create_default_group_assignment() 69 return gen_tpu_ops.all_to_all( 70 x, 71 group_assignment, 72 concat_dimension=concat_dimension, 73 split_dimension=split_dimension, 74 split_count=split_count, 75 name=name) 76 77 78@ops.RegisterGradient("AllToAll") 79def _all_to_all_grad(op, grad): 80 # The gradient of a all-to-all is also a all-to-all but the 81 # split_dimension and concat_dimension is swapped. 82 # The gradient with respect to group_assignment is None. 83 return [ 84 gen_tpu_ops.all_to_all( 85 grad, 86 op.inputs[1], 87 concat_dimension=op.get_attr("split_dimension"), 88 split_dimension=op.get_attr("concat_dimension"), 89 split_count=op.get_attr("split_count")), None 90 ] 91 92 93@tf_export(v1=["tpu.cross_replica_sum"]) 94def cross_replica_sum(x, group_assignment=None, name=None): 95 """Sum the input tensor across replicas according to group_assignment. 96 97 Args: 98 x: The local tensor to the sum. 99 group_assignment: Optional 2d int32 lists with shape [num_groups, 100 num_replicas_per_group]. `group_assignment[i]` represents the replica 101 ids in the ith subgroup. 102 name: Optional op name. 103 104 Returns: 105 A `Tensor` which is summed across replicas. 106 """ 107 if group_assignment is None: 108 group_assignment = _create_default_group_assignment() 109 110 return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name) 111 112 113def collective_permute(x, source_target_pairs, name=None): 114 """Permute the input tensor across replicas given source_target_pairs. 115 116 For each source_target_pair <a, b>, we send replica a's input to replica b. 117 Each replica id must only appear once in the source column. Also it must 118 only appear once in the target column. 119 For the replica id not in the target column, this op returns a zero tensor 120 with the same shape and dtype of the input x. 121 122 For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing 123 source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs: 124 `[0, A, B, C]`. 125 126 Args: 127 x: The local tensor to be permuted. 128 source_target_pairs: 2d int lists with shape [num_pairs, 2]. 129 source_target_pairs[i][0] represents the source replica id and 130 source_target_pairs[i][1] represents the target replica id. 131 name: Optional op name. 132 133 Returns: 134 A `Tensor` which is permuted. 135 """ 136 return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name) 137 138 139@ops.RegisterGradient("CollectivePermute") 140def _collective_permute_grad(op, grad): 141 # The gradient of a collective permute operation is also a collective 142 # permute, but with source/target pairs reversed. The gradient with respect 143 # to input argument `source_target_pairs` is `None`. 144 source_target_pairs = op.inputs[1][:, ::-1] 145 return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None] 146 147 148@ops.RegisterGradient("CrossReplicaSum") 149def _cross_replica_sum_grad(op, grad): 150 # The gradient of a cross replica sum is also a cross-replica sum. 151 # The gradient with respect to group_assignment is None. 152 return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None] 153 154 155# This extra type checking exists to give a more helpful error message in 156# the common case that uint8 and int64 values are infed. Remove when both 157# types are supported. 158 159_SUPPORTED_INFEED_DTYPES = set([ 160 dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32, 161 dtypes.complex64, dtypes.uint32 162]) 163 164 165@ops.RegisterGradient("TPUEmbeddingActivations") 166def _embedding_activations_grad(activations_op, grad_wrt_activations): 167 """Saves the gradient of embedding activations ops in a graph collection.""" 168 g = ops.get_default_graph() 169 table_id = activations_op.get_attr("table_id") 170 lookup_id = activations_op.get_attr("lookup_id") 171 table_gradients = g.get_collection_ref( 172 "tpu_embedding_gradients_table_%d" % table_id) 173 174 if not table_gradients: 175 raise RuntimeError( 176 "Gradients for TPUEmbedding have been generated in non-training mode." 177 "This is not expected. Consider putting your Optimizer.minimize code " 178 "behind the training mode condition check. For Estimator, you can " 179 "do \n\n" 180 " if mode == tf.estimator.ModeKeys.TRAIN:\n" 181 " train_op = opt.minimize(loss)\n" 182 "\n") 183 184 table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations) 185 return [ 186 # RegisterGradient requires that value be returned for all inputs. Since 187 # the first argument (tpu_gradient_variable_{table_name}) has shape [1], 188 # we will return zeros(shape=[1]). The actual gradient w.r.t. the 189 # embedding activations (grad_wrt_activations) has the same shape as the 190 # activations returned by embedding_activations. 191 array_ops.zeros(arg.shape, dtype=dtypes.float32) 192 for arg in activations_op.inputs 193 ] 194 195 196def infeed_dequeue(dtype, shape, name=None): 197 """A placeholder op for a value that will be fed into the computation. 198 199 Args: 200 dtype: A `tf.DType`. The type of elements in the tensor. 201 shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor. 202 name: A name for the operation (optional). 203 204 Returns: 205 A `Tensor` of type `dtype`. 206 A tensor that will be provided using the infeed mechanism. 207 208 Raises: 209 TypeError: If 'dtype` is not a supported infeed type. 210 """ 211 if dtype not in _SUPPORTED_INFEED_DTYPES: 212 raise TypeError( 213 "Operation '{}' has type {} which is not a supported TPU infeed type. " 214 "Supported types are: {}".format(name, dtype, 215 list(_SUPPORTED_INFEED_DTYPES))) 216 217 return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name) 218 219 220# pylint: disable=redefined-outer-name 221def infeed_dequeue_tuple(dtypes, shapes, name=None): 222 """A placeholder op for values fed into the TPU simultaneously as a tuple. 223 224 Args: 225 dtypes: A list of `tf.DType`s that has length `>= 1`. 226 The element types of each element in `outputs`. 227 shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). 228 The shapes of each tensor in `outputs`. 229 name: A name for the operation (optional). 230 231 Returns: 232 A list of `Tensor` objects of type `dtypes`. 233 A list of tensors that will be provided using the infeed mechanism. 234 235 Raises: 236 TypeError: If a type in 'dtypes` is not a supported infeed type. 237 """ 238 for dtype in dtypes: 239 if dtype not in _SUPPORTED_INFEED_DTYPES: 240 raise TypeError( 241 "{} is not a supported TPU infeed type. Supported types are: " 242 "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES))) 243 return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name) 244# pylint: enable=redefined-outer-name 245 246 247# pylint: disable=protected-access 248def send_tpu_embedding_gradients(inputs, 249 config, 250 learning_rates=None, 251 name=None): 252 """A placeholder op for feeding per-sample gradients to the embedding layer. 253 254 Args: 255 inputs: A TensorList of gradients with which to update embedding tables. 256 This argument has the same length and shapes as the return value of 257 RecvTPUEmbeddingActivations, but contains gradients of the model's 258 loss with respect to the embedding activations. The embedding tables 259 are updated from these gradients via the optimizers specified in the 260 TPU embedding configuration given to tpu.initialize_system. 261 config: Serialized TPUEmbeddingConfiguration proto. 262 learning_rates: A TensorList of float32 scalars, one for each dynamic 263 learning rate tag: see the comments in 264 //third_party/tensorflow/core/protobuf/tpu/ 265 optimization_parameters.proto. 266 Multiple tables can share the same dynamic learning rate tag as 267 specified in the configuration. If the learning rates for all tables 268 are constant, this list should be empty. 269 name: A name for the operation (optional). 270 271 Returns: 272 A SendTPUEmbeddingGradients operation. 273 """ 274 if learning_rates is None: 275 learning_rates = [] 276 return gen_tpu_ops.send_tpu_embedding_gradients( 277 inputs=inputs, learning_rates=learning_rates, config=config, name=name) 278 279 280send_tpu_embedding_gradients.__doc__ = ( 281 gen_tpu_ops.send_tpu_embedding_gradients.__doc__) 282 283 284# pylint: disable=protected-access 285def enqueue_tpu_embedding_integer_batch(batch, 286 device_ordinal, 287 mode_override=None, 288 name=None): 289 """A placeholder op for enqueueing embedding IDs to the TPU. 290 291 Args: 292 batch: A list of 1D tensors, one for each embedding table, containing the 293 indices into the tables. 294 device_ordinal: The TPU device to use. Should be >= 0 and less than the 295 number of TPU cores in the task on which the node is placed. 296 mode_override: A string input that overrides the mode specified in the 297 TPUEmbeddingConfiguration. Supported values are {'unspecified', 298 'inference', 'train', 'backward_pass_only'}. When set to 'unspecified', 299 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 300 is used (optional). 301 name: A name for the operation (optional). 302 303 Returns: 304 An EnqueueTPUEmbeddingIntegerBatch operation. 305 """ 306 if mode_override is None: 307 mode_override = "unspecified" 308 return gen_tpu_ops.enqueue_tpu_embedding_integer_batch( 309 batch=batch, 310 device_ordinal=device_ordinal, 311 mode_override=mode_override, 312 name=name) 313 314 315enqueue_tpu_embedding_integer_batch.__doc__ = ( 316 gen_tpu_ops.enqueue_tpu_embedding_integer_batch.__doc__) 317 318 319# pylint: disable=protected-access 320def enqueue_tpu_embedding_sparse_batch(sample_indices, 321 embedding_indices, 322 aggregation_weights, 323 device_ordinal, 324 combiners=None, 325 mode_override=None, 326 name=None): 327 """A placeholder op for enqueueing embedding IDs to the TPU. 328 329 Args: 330 sample_indices: A list of rank 1 Tensors specifying the training example 331 and feature to which the corresponding embedding_indices and 332 aggregation_weights values belong. sample_indices[i] must equal b * nf + 333 f, where nf is the number of features from the corresponding table, f is 334 in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed, 335 and will be converted to int32 internally. 336 embedding_indices: A list of rank 1 Tensors, indices into the embedding 337 tables. Both int32 and int64 are allowed and will be converted to int32 338 internally. 339 aggregation_weights: A list of rank 1 Tensors containing per sample -- 340 i.e. per (training example, feature) -- aggregation weights. Both float32 341 and float64 are allowed and will be converted to float32 internally. 342 device_ordinal: The TPU device to use. Should be >= 0 and less than the 343 number of TPU cores in the task on which the node is placed. 344 combiners: A list of string scalars, one for each embedding table that 345 specify how to normalize the embedding activations after weighted 346 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is 347 invalid to have the sum of the weights be 0 for 'mean' or the sum of the 348 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default 349 is to use 'sum' for all tables (optional). 350 mode_override: A string input that overrides the mode specified in the 351 TPUEmbeddingConfiguration. Supported values are {'unspecified', 352 'inference', 'train', 'backward_pass_only'}. When set to 'unspecified', 353 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 354 is used (optional). 355 name: A name for the operation (optional). 356 357 Returns: 358 An EnqueueTPUEmbeddingSparseBatch operation. 359 """ 360 if mode_override is None: 361 mode_override = "unspecified" 362 return gen_tpu_ops.enqueue_tpu_embedding_sparse_batch( 363 sample_indices=sample_indices, 364 embedding_indices=embedding_indices, 365 aggregation_weights=aggregation_weights, 366 device_ordinal=device_ordinal, 367 combiners=combiners, 368 mode_override=mode_override, 369 name=name) 370 371 372enqueue_tpu_embedding_sparse_batch.__doc__ = ( 373 gen_tpu_ops.enqueue_tpu_embedding_sparse_batch.__doc__) 374 375 376# pylint: disable=protected-access 377def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices, 378 embedding_indices, 379 aggregation_weights, 380 table_ids, 381 device_ordinal, 382 max_sequence_lengths=None, 383 num_features=None, 384 combiners=None, 385 mode_override=None, 386 name=None): 387 """A placeholder op for enqueueing embedding IDs to the TPU. 388 389 Args: 390 sample_indices: A list of rank 2 Tensors specifying the training example 391 to which the corresponding embedding_indices and aggregation_weights 392 values belong. It corresponds to sp_ids.indices in 393 embedding_lookup_sparse(). If the size of its first dimension is 0, we 394 assume each embedding_indices belongs to a different sample. Both int32 395 and int64 are allowed and will be converted to int32 internally. 396 embedding_indices: A list of rank 1 Tensors, indices into the embedding 397 tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both 398 int32 and int64 are allowed and will be converted to int32 internally. 399 aggregation_weights: A list of rank 1 Tensors containing per training 400 example aggregation weights. It corresponds to sp_weights.values in 401 embedding_lookup_sparse(). If the size of its first dimension is 0, we 402 assume all weights are 1. Both float32 and float64 are allowed and will 403 be converted to float32 internally. 404 table_ids: A list of integers specifying the identifier of the embedding 405 table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to 406 lookup the corresponding input. The ith input is looked up using 407 table_ids[i]. The size of the table_ids list must be equal to that of 408 sample_indices, embedding_indices and aggregation_weights. 409 device_ordinal: The TPU device to use. Should be >= 0 and less than the 410 number of TPU cores in the task on which the node is placed. 411 max_sequence_lengths: A list of integers, the size of which is equal to 412 sample_indices. If equal to 0, the corresponding feature is considered to 413 be a non-sequence feature, If greater than 0, the corresponding feature is 414 a sequence feature with the given maximal length. If None, then we assume 415 a list of all zeroes. 416 num_features: A list of integers, the size of which is equal to 417 sample_indices. If non-empty, entries in this list must be at least 1. 418 For each batch element, we will take num_features rows of the input 419 tensor for embedding lookup. E.g., when sample_indices is empty, 420 the embedding indices must be of shape (batch_size*num_features). 421 combiners: A list of string scalars, one for each embedding table that 422 specify how to normalize the embedding activations after weighted 423 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is 424 invalid to have the sum of the weights be 0 for 'mean' or the sum of the 425 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default 426 is to use 'sum' for all tables (optional). 427 mode_override: A string input that overrides the mode specified in the 428 TPUEmbeddingConfiguration. Supported values are {'unspecified', 429 'inference', 'train', 'backward_pass_only'}. When set to 'unspecified', 430 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 431 is used (optional). 432 name: A name for the operation (optional). 433 434 Returns: 435 An EnqueueTPUEmbeddingSparseTensorBatch operation. 436 """ 437 if mode_override is None: 438 mode_override = "unspecified" 439 return gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( 440 sample_indices=sample_indices, 441 embedding_indices=embedding_indices, 442 aggregation_weights=aggregation_weights, 443 table_ids=table_ids, 444 device_ordinal=device_ordinal, 445 max_sequence_lengths=max_sequence_lengths, 446 combiners=combiners, 447 mode_override=mode_override, 448 num_features=num_features, 449 name=name) 450 451 452enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = ( 453 gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__) 454 455 456# pylint: disable=protected-access 457def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits, 458 embedding_indices, 459 aggregation_weights, 460 table_ids, 461 device_ordinal, 462 max_sequence_lengths=None, 463 num_features=None, 464 combiners=None, 465 mode_override=None, 466 name=None): 467 """A placeholder op for enqueueing embedding IDs to the TPU. 468 469 Args: 470 sample_splits: A list of rank 1 Tensors specifying the break points for 471 splitting embedding_indices and aggregation_weights into rows. It 472 corresponds to ids.row_splits in embedding_lookup(), when ids is a 473 RaggedTensor. Both int32 and int64 are allowed and will be converted to 474 int32 internally. 475 embedding_indices: A list of rank 1 Tensors, indices into the embedding 476 tables. It corresponds to ids.values in embedding_lookup(), when ids is a 477 RaggedTensor. Both int32 and int64 are allowed and will be converted to 478 int32 internally. 479 aggregation_weights: A list of rank 1 Tensors containing per training 480 example aggregation weights. It corresponds to the values field of a 481 RaggedTensor with the same row_splits as ids in embedding_lookup(), when 482 ids is a RaggedTensor. Both float32 and float64 are allowed and will be 483 converted to float32 internally. 484 table_ids: A list of integers specifying the identifier of the embedding 485 table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to 486 lookup the corresponding input. The ith input is looked up using 487 table_ids[i]. The size of the table_ids list must be equal to that of 488 sample_indices, embedding_indices and aggregation_weights. 489 device_ordinal: The TPU device to use. Should be >= 0 and less than the 490 number of TPU cores in the task on which the node is placed. 491 max_sequence_lengths: A list of integers, the size of which is equal to 492 sample_indices. If equal to 0, the corresponding feature is considered to 493 be a non-sequence feature, If greater than 0, the corresponding feature is 494 a sequence feature with the given maximal length. If None, then we assume 495 a list of all zeroes. 496 num_features: A list of integers, the size of which must be equal to 497 sample_indices. If non-empty, entries in this list must be at least 1. 498 For each batch element, we will take num_features rows of the input 499 tensor for embedding lookup. E.g., when sample_indices is empty, 500 the embedding indices must be of shape (batch_size*num_features). 501 combiners: A list of string scalars, one for each embedding table that 502 specify how to normalize the embedding activations after weighted 503 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is 504 invalid to have the sum of the weights be 0 for 'mean' or the sum of the 505 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default 506 is to use 'sum' for all tables (optional). 507 mode_override: A string input that overrides the mode specified in the 508 TPUEmbeddingConfiguration. Supported values are {'unspecified', 509 'inference', 'training', 'backward_pass_only'}. When set to 'unspecified', 510 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 511 is used (optional). 512 name: A name for the operation (optional). 513 514 Returns: 515 An EnqueueTPUEmbeddingRaggedTensorBatch operation. 516 """ 517 if mode_override is None: 518 mode_override = "unspecified" 519 return gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch( 520 sample_splits=sample_splits, 521 embedding_indices=embedding_indices, 522 aggregation_weights=aggregation_weights, 523 table_ids=table_ids, 524 device_ordinal=device_ordinal, 525 max_sequence_lengths=max_sequence_lengths, 526 combiners=combiners, 527 mode_override=mode_override, 528 num_features=num_features, 529 name=name) 530 531 532enqueue_tpu_embedding_ragged_tensor_batch.__doc__ = ( 533 gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch.__doc__) 534