1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Mid level API for TPU Embeddings.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20from __future__ import unicode_literals 21 22import functools 23from typing import Any, Dict, Callable, Iterable, List, Optional, Text, Tuple, Union 24 25from absl import logging 26 27from tensorflow.core.framework import attr_value_pb2 28from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 29from tensorflow.python.distribute import device_util 30from tensorflow.python.distribute import distribute_utils 31from tensorflow.python.distribute import distribution_strategy_context 32from tensorflow.python.distribute import sharded_variable 33from tensorflow.python.distribute import tpu_strategy 34from tensorflow.python.eager import context 35from tensorflow.python.eager import def_function 36from tensorflow.python.framework import constant_op 37from tensorflow.python.framework import device as tf_device 38from tensorflow.python.framework import dtypes 39from tensorflow.python.framework import ops 40from tensorflow.python.framework import sparse_tensor 41from tensorflow.python.ops import array_ops 42from tensorflow.python.ops import embedding_ops 43from tensorflow.python.ops import math_ops 44from tensorflow.python.ops import sparse_ops 45from tensorflow.python.ops import variable_scope 46from tensorflow.python.ops import variables as tf_variables 47from tensorflow.python.ops.ragged import ragged_tensor 48from tensorflow.python.saved_model import save_context 49from tensorflow.python.tpu import tpu 50from tensorflow.python.tpu import tpu_embedding_v2_utils 51from tensorflow.python.tpu.ops import tpu_ops 52from tensorflow.python.training.saving import saveable_hook 53from tensorflow.python.training.tracking import base 54from tensorflow.python.training.tracking import tracking 55from tensorflow.python.types import core 56from tensorflow.python.types import internal as internal_types 57from tensorflow.python.util import compat 58from tensorflow.python.util import nest 59from tensorflow.python.util import tf_inspect 60from tensorflow.python.util.tf_export import tf_export 61 62 63_HOOK_KEY = "TPUEmbedding_saveable" 64_NAME_KEY = "_tpu_embedding_layer" 65 66 67# TODO(bfontain): Cleanup and remove this once there is an implementation of 68# sharded variables that can be used in the PSStrategy with optimizers. 69# We implement just enough of the of a tf.Variable so that this could be passed 70# to an optimizer. 71class TPUShardedVariable(sharded_variable.ShardedVariableMixin): 72 """A ShardedVariable class for TPU.""" 73 74 @property 75 def _in_graph_mode(self): 76 return self.variables[0]._in_graph_mode # pylint: disable=protected-access 77 78 @property 79 def _unique_id(self): 80 return self.variables[0]._unique_id # pylint: disable=protected-access 81 82 @property 83 def _distribute_strategy(self): 84 return self.variables[0]._distribute_strategy # pylint: disable=protected-access 85 86 @property 87 def _shared_name(self): 88 return self._name 89 90 91def _add_key_attr(op, name): 92 op._set_attr(_NAME_KEY, attr_value_pb2.AttrValue(s=compat.as_bytes(name))) # pylint: disable=protected-access 93 94 95@tf_export("tpu.experimental.embedding.TPUEmbedding") 96class TPUEmbedding(tracking.AutoTrackable): 97 """The TPUEmbedding mid level API. 98 99 NOTE: When instantiated under a TPUStrategy, this class can only be created 100 once per call to `tf.tpu.experimental.initialize_tpu_system`. If you wish to 101 re-initialize the embedding engine you must re-initialize the tpu as well. 102 Doing this will clear any variables from TPU, so ensure you have checkpointed 103 before you do this. If a further instances of the class are needed, 104 set the `initialize_tpu_embedding` argument to `False`. 105 106 This class can be used to support training large embeddings on TPU. When 107 creating an instance of this class, you must specify the complete set of 108 tables and features you expect to lookup in those tables. See the 109 documentation of `tf.tpu.experimental.embedding.TableConfig` and 110 `tf.tpu.experimental.embedding.FeatureConfig` for more details on the complete 111 set of options. We will cover the basic usage here. 112 113 NOTE: multiple `FeatureConfig` objects can use the same `TableConfig` object, 114 allowing different features to share the same table: 115 116 ```python 117 table_config_one = tf.tpu.experimental.embedding.TableConfig( 118 vocabulary_size=..., 119 dim=...) 120 table_config_two = tf.tpu.experimental.embedding.TableConfig( 121 vocabulary_size=..., 122 dim=...) 123 feature_config = { 124 'feature_one': tf.tpu.experimental.embedding.FeatureConfig( 125 table=table_config_one), 126 'feature_two': tf.tpu.experimental.embedding.FeatureConfig( 127 table=table_config_one), 128 'feature_three': tf.tpu.experimental.embedding.FeatureConfig( 129 table=table_config_two)} 130 ``` 131 132 There are two modes under which the `TPUEmbedding` class can used. This 133 depends on if the class was created under a `TPUStrategy` scope or not. 134 135 Under `TPUStrategy`, we allow access to the method `enqueue`, `dequeue` and 136 `apply_gradients`. We will show examples below of how to use these to train 137 and evaluate your model. Under CPU, we only access to the `embedding_tables` 138 property which allow access to the embedding tables so that you can use them 139 to run model evaluation/prediction on CPU. 140 141 First lets look at the `TPUStrategy` mode. Initial setup looks like: 142 143 ```python 144 strategy = tf.distribute.TPUStrategy(...) 145 with strategy.scope(): 146 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 147 feature_config=feature_config, 148 optimizer=tf.tpu.experimental.embedding.SGD(0.1)) 149 ``` 150 151 When creating a distributed dataset that is to be passed to the enqueue 152 operation a special input option must be specified: 153 154 ```python 155 distributed_dataset = ( 156 strategy.distribute_datasets_from_function( 157 dataset_fn=..., 158 options=tf.distribute.InputOptions( 159 experimental_fetch_to_device=False)) 160 dataset_iterator = iter(distributed_dataset) 161 ``` 162 163 NOTE: All batches passed to the layer must have the same batch size for each 164 input, more over once you have called the layer with one batch size all 165 subsequent calls must use the same batch_size. In the event that the batch 166 size cannot be automatically determined by the enqueue method, you must call 167 the build method with the batch size to initialize the layer. 168 169 To use this API on TPU you should use a custom training loop. Below is an 170 example of a training and evaluation step: 171 172 ```python 173 @tf.function 174 def training_step(dataset_iterator, num_steps): 175 def tpu_step(tpu_features): 176 with tf.GradientTape() as tape: 177 activations = embedding.dequeue() 178 tape.watch(activations) 179 model_output = model(activations) 180 loss = ... # some function of labels and model_output 181 182 embedding_gradients = tape.gradient(loss, activations) 183 embedding.apply_gradients(embedding_gradients) 184 # Insert your model gradient and optimizer application here 185 186 for _ in tf.range(num_steps): 187 embedding_features, tpu_features = next(dataset_iterator) 188 embedding.enqueue(embedding_features, training=True) 189 strategy.run(tpu_step, args=(embedding_features, )) 190 191 @tf.function 192 def evalution_step(dataset_iterator, num_steps): 193 def tpu_step(tpu_features): 194 activations = embedding.dequeue() 195 model_output = model(activations) 196 # Insert your evaluation code here. 197 198 for _ in tf.range(num_steps): 199 embedding_features, tpu_features = next(dataset_iterator) 200 embedding.enqueue(embedding_features, training=False) 201 strategy.run(tpu_step, args=(embedding_features, )) 202 ``` 203 204 NOTE: The calls to `enqueue` have `training` set to `True` when 205 `embedding.apply_gradients` is used and set to `False` when 206 `embedding.apply_gradients` is not present in the function. If you don't 207 follow this pattern you may cause an error to be raised or the tpu may 208 deadlock. 209 210 In the above examples, we assume that the user has a dataset which returns 211 a tuple where the first element of the tuple matches the structure of what 212 was passed as the `feature_config` argument to the object initializer. Also we 213 utilize `tf.range` to get a `tf.while_loop` in order to increase performance. 214 215 When checkpointing your model, you should include your 216 `tf.tpu.experimental.embedding.TPUEmbedding` object in the checkpoint. It is a 217 trackable object and saving it will save the embedding tables and their 218 optimizer slot variables: 219 220 ```python 221 checkpoint = tf.train.Checkpoint(model=model, embedding=embedding) 222 checkpoint.save(...) 223 ``` 224 225 On CPU, only the `embedding_table` property is usable. This will allow you to 226 restore a checkpoint to the object and have access to the table variables: 227 228 ```python 229 model = model_fn(...) 230 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 231 feature_config=feature_config, 232 optimizer=tf.tpu.experimental.embedding.SGD(0.1)) 233 checkpoint = tf.train.Checkpoint(model=model, embedding=embedding) 234 checkpoint.restore(...) 235 236 tables = embedding.embedding_tables 237 ``` 238 239 You can now use table in functions like `tf.nn.embedding_lookup` to perform 240 your embedding lookup and pass to your model. 241 242 """ 243 244 def __init__( 245 self, 246 feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic 247 optimizer: Optional[tpu_embedding_v2_utils._Optimizer], # pylint:disable=protected-access 248 pipeline_execution_with_tensor_core: bool = False): 249 """Creates the TPUEmbedding mid level API object. 250 251 ```python 252 strategy = tf.distribute.TPUStrategy(...) 253 with strategy.scope(): 254 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 255 feature_config=tf.tpu.experimental.embedding.FeatureConfig( 256 table=tf.tpu.experimental.embedding.TableConfig( 257 dim=..., 258 vocabulary_size=...))) 259 ``` 260 261 Args: 262 feature_config: A nested structure of 263 `tf.tpu.experimental.embedding.FeatureConfig` configs. 264 optimizer: An instance of one of `tf.tpu.experimental.embedding.SGD`, 265 `tf.tpu.experimental.embedding.Adagrad` or 266 `tf.tpu.experimental.embedding.Adam`. When not created under 267 TPUStrategy may be set to None to avoid the creation of the optimizer 268 slot variables, useful for optimizing memory consumption when exporting 269 the model for serving where slot variables aren't needed. 270 pipeline_execution_with_tensor_core: If True, the TPU embedding 271 computations will overlap with the TensorCore computations (and hence 272 will be one step old). Set to True for improved performance. 273 274 Raises: 275 ValueError: If optimizer is not one of tf.tpu.experimental.embedding.(SGD, 276 Adam or Adagrad) or None when created under a TPUStrategy. 277 """ 278 self._strategy = distribution_strategy_context.get_strategy() 279 self._using_tpu = isinstance(self._strategy, (tpu_strategy.TPUStrategy, 280 tpu_strategy.TPUStrategyV2)) 281 self._pipeline_execution_with_tensor_core = ( 282 pipeline_execution_with_tensor_core) 283 284 self._feature_config = feature_config 285 286 # The TPU embedding ops are slightly inconsistent with how they refer to 287 # tables: 288 # * The enqueue op takes a parallel list of tensors for input, one of those 289 # is the table id for the feature which matches the integer index of the 290 # table in the proto created by _create_config_proto(). 291 # * The recv_tpu_embedding_activations op emits lookups per table in the 292 # order from the config proto. 293 # * The send_tpu_embedding_gradients expects input tensors to be per table 294 # in the same order as the config proto. 295 # * Per optimizer load and retrieve ops are specified per table and take the 296 # table name rather than the table id. 297 # Thus we must fix a common order to tables and ensure they have unique 298 # names. 299 300 # Set table order here to the order of the first occurence of the table in a 301 # feature provided by the user. The order of this struct must be fixed 302 # to provide the user with deterministic behavior over multiple 303 # instantiations. 304 self._table_config = [] 305 for feature in nest.flatten(feature_config): 306 if feature.table not in self._table_config: 307 self._table_config.append(feature.table) 308 309 # Ensure tables have unique names. Also error check the optimizer as we 310 # specifically don't do that in the TableConfig class to allow high level 311 # APIs that are built on this to use strings/other classes to represent 312 # optimizers (before they are passed to this class). 313 table_names = [] 314 for i, table in enumerate(self._table_config): 315 if table.optimizer is None: 316 # TODO(bfontain) Should we allow some sort of optimizer merging here? 317 table.optimizer = optimizer 318 if ((table.optimizer is not None or self._using_tpu) and 319 not isinstance(table.optimizer, tpu_embedding_v2_utils._Optimizer)): # pylint: disable=protected-access 320 raise ValueError("{} is an unsupported optimizer class. Please pass an " 321 "instance of one of the optimizer classes under " 322 "tf.tpu.experimental.embedding.".format( 323 type(table.optimizer))) 324 if table.name is None: 325 table.name = "table_{}".format(i) 326 if table.name in table_names: 327 raise ValueError("Tables must have a unique name. " 328 f"Multiple tables with name {table.name} found.") 329 table_names.append(table.name) 330 331 if self._using_tpu: 332 # Extract a list of callable learning rates also in fixed order. Each 333 # table in the confix proto will get a index into this list and we will 334 # pass this list in the same order after evaluation to the 335 # send_tpu_embedding_gradients op. 336 self._dynamic_learning_rates = list({ 337 table.optimizer.learning_rate for table in self._table_config if 338 callable(table.optimizer.learning_rate)}) 339 340 # We need to list of host devices for the load/retrieve operations. 341 self._hosts = get_list_of_hosts(self._strategy) 342 343 self._built = False 344 self._verify_batch_size_on_enqueue = True 345 346 def build(self, per_replica_batch_size: Optional[int] = None): 347 """Create the underlying variables and initializes the TPU for embeddings. 348 349 This method creates the underlying variables (including slot variables). If 350 created under a TPUStrategy, this will also initialize the TPU for 351 embeddings. 352 353 This function will automatically get called by enqueue, which will try to 354 determine your batch size automatically. If this fails, you must manually 355 call this method before you call enqueue. 356 357 Args: 358 per_replica_batch_size: The per replica batch size that you intend to use. 359 Note that is fixed and the same batch size must be used for both 360 training and evaluation. If you want to calculate this from the global 361 batch size, you can use `num_replicas_in_sync` property of your strategy 362 object. May be set to None if not created under a TPUStrategy. 363 364 Raises: 365 ValueError: If per_replica_batch_size is None and object was created in a 366 TPUStrategy scope. 367 """ 368 if self._built: 369 return 370 371 if self._using_tpu: 372 if per_replica_batch_size is None: 373 raise ValueError( 374 "When calling TpuShardedVariable.build under TpuStrategy you must " 375 "specify a per_replica_batch_size argument.") 376 377 self._batch_size = per_replica_batch_size 378 379 self._config_proto = self._create_config_proto() 380 381 logging.info("Initializing TPU Embedding engine.") 382 tpu_embedding_v2_utils.log_tpu_embedding_configuration(self._config_proto) 383 384 @def_function.function 385 def load_config(): 386 tpu.initialize_system_for_tpu_embedding(self._config_proto) 387 388 load_config() 389 logging.info("Done initializing TPU Embedding engine.") 390 391 # Create and load variables and slot variables into the TPU. 392 # Note that this is a dict of dicts. Keys to the first dict are table names. 393 # We would prefer to use TableConfigs, but then these variables won't be 394 # properly tracked by the tracking API. 395 self._variables = self._create_variables_and_slots() 396 397 self._built = True 398 399 # This is internally conditioned self._built and self._using_tpu 400 self._load_variables() 401 402 def _maybe_build(self, batch_size: Optional[int]): 403 if not self._built: 404 # This can be called while tracing a function, so we wrap the 405 # initialization code with init_scope so it runs eagerly, this means that 406 # it will not be included the function graph generated by tracing so that 407 # we can be sure that we only initialize the TPU for embeddings exactly 408 # once. 409 with ops.init_scope(): 410 self.build(batch_size) 411 412 @property 413 def embedding_tables( 414 self 415 ) -> Dict[tpu_embedding_v2_utils.TableConfig, tf_variables.Variable]: 416 """Returns a dict of embedding tables, keyed by `TableConfig`. 417 418 This property only works when the `TPUEmbedding` object is created under a 419 non-TPU strategy. This is intended to be used to for CPU based lookup when 420 creating a serving checkpoint. 421 422 Returns: 423 A dict of embedding tables, keyed by `TableConfig`. 424 425 Raises: 426 RuntimeError: If object was created under a `TPUStrategy`. 427 """ 428 # We don't support returning tables on TPU due to their sharded nature and 429 # the fact that when using a TPUStrategy: 430 # 1. Variables are stale and are only updated when a checkpoint is made. 431 # 2. Updating the variables won't affect the actual tables on the TPU. 432 if self._using_tpu: 433 if save_context.in_save_context(): 434 return {table: self._variables[table.name]["parameters"].variables[0] 435 for table in self._table_config} 436 raise RuntimeError("Unable to retrieve embedding tables when using a TPU " 437 "strategy. If you need access, save your model, " 438 "create this object under a CPU strategy and restore.") 439 440 self._maybe_build(None) 441 442 # Only return the tables and not the slot variables. On CPU this are honest 443 # tf.Variables. 444 return {table: self._variables[table.name]["parameters"] 445 for table in self._table_config} 446 447 def _create_config_proto( 448 self 449 ) -> tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration: 450 """Creates the TPUEmbeddingConfiguration proto. 451 452 This proto is used to initialize the TPU embedding engine. 453 454 Returns: 455 A TPUEmbeddingConfiguration proto. 456 """ 457 458 config_proto = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration() 459 460 # There are several things that need to be computed here: 461 # 1. Each table has a num_features, which corresponds to the number of 462 # output rows per example for this table. Sequence features count for 463 # their maximum sequence length. 464 # 2. Learning rate index: the index of the dynamic learning rate for this 465 # table (if it exists) in the list we created at initialization. 466 # We don't simply create one learning rate index per table as this has 467 # extremely bad performance characteristics. The more separate 468 # optimization configurations we have, the worse the performance will be. 469 num_features = {table: 0 for table in self._table_config} 470 for feature in nest.flatten(self._feature_config): 471 num_features[feature.table] += (1 if feature.max_sequence_length == 0 472 else feature.max_sequence_length) 473 474 # Map each callable dynamic learning rate to its in index in the list. 475 learning_rate_index = {r: i for i, r in enumerate( 476 self._dynamic_learning_rates)} 477 478 for table in self._table_config: 479 table_descriptor = config_proto.table_descriptor.add() 480 table_descriptor.name = table.name 481 482 # For small tables, we pad to the number of hosts so that at least one 483 # id will be assigned to each host. 484 table_descriptor.vocabulary_size = max(table.vocabulary_size, 485 self._strategy.extended.num_hosts) 486 table_descriptor.dimension = table.dim 487 488 table_descriptor.num_features = num_features[table] 489 490 parameters = table_descriptor.optimization_parameters 491 492 # We handle the learning rate separately here and don't allow the 493 # optimization class to handle this, as it doesn't know about dynamic 494 # rates. 495 if callable(table.optimizer.learning_rate): 496 parameters.learning_rate.dynamic.tag = ( 497 learning_rate_index[table.optimizer.learning_rate]) 498 else: 499 parameters.learning_rate.constant = table.optimizer.learning_rate 500 501 # Use optimizer to handle the rest of the parameters. 502 table.optimizer._set_optimization_parameters(parameters) # pylint: disable=protected-access 503 504 # Always set mode to training, we override the mode during enqueue. 505 config_proto.mode = ( 506 tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.TRAINING) 507 508 config_proto.batch_size_per_tensor_core = self._batch_size 509 config_proto.num_hosts = self._strategy.extended.num_hosts 510 config_proto.num_tensor_cores = self._strategy.num_replicas_in_sync 511 512 # TODO(bfontain): Allow users to pick MOD for the host sharding. 513 config_proto.sharding_strategy = ( 514 tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.DIV_DEFAULT) 515 config_proto.pipeline_execution_with_tensor_core = ( 516 self._pipeline_execution_with_tensor_core) 517 518 return config_proto 519 520 def _compute_per_table_gradients( 521 self, 522 gradients 523 ) -> Dict[Text, List[core.Tensor]]: 524 """Computes a dict of lists of gradients, keyed by table name. 525 526 Args: 527 gradients: A nested structure of Tensors (and Nones) with the same 528 structure as the feature config. 529 530 Returns: 531 A dict of lists of tensors, keyed by the table names, containing the 532 gradients in the correct order with None gradients replaced by zeros. 533 """ 534 535 nest.assert_same_structure(self._feature_config, gradients) 536 537 per_table_gradients = {table: [] for table in self._table_config} 538 for (path, gradient), feature in zip( 539 nest.flatten_with_joined_string_paths(gradients), 540 nest.flatten(self._feature_config)): 541 if gradient is not None and not isinstance(gradient, ops.Tensor): 542 raise ValueError( 543 f"When computing per-table gradients, found non-tensor type: " 544 f"{type(gradient)} at path {path}.") 545 546 # Expected tensor shape differs for sequence and non-sequence features. 547 if feature.max_sequence_length > 0: 548 shape = [self._batch_size, feature.max_sequence_length, 549 feature.table.dim] 550 else: 551 shape = [self._batch_size, feature.table.dim] 552 553 if gradient is not None: 554 if gradient.shape != shape: 555 raise ValueError("Found gradient of shape {} at path {}. Expected " 556 "shape {}.".format(gradient.shape, path, shape)) 557 558 # We expand dims on non-sequence features so that all features are 559 # of rank 3 and we can concat on axis=1. 560 if len(shape) == 2: 561 gradient = array_ops.expand_dims(gradient, axis=1) 562 else: 563 # No gradient for this feature, since we must give a gradient for all 564 # features, pass in a zero tensor here. Note that this is not correct 565 # for all optimizers. 566 logging.warn("No gradient passed for feature %s, sending zero " 567 "gradient. This may not be correct behavior for certain " 568 "optimizers like Adam.", path) 569 # Create a shape to mimic the expand_dims above for non-sequence 570 # features. 571 if len(shape) == 2: 572 shape = [shape[0], 1, shape[1]] 573 gradient = array_ops.zeros(shape, dtype=dtypes.float32) 574 per_table_gradients[feature.table].append(gradient) 575 576 return per_table_gradients 577 578 def apply_gradients(self, gradients, name: Optional[Text] = None): 579 """Applies the gradient update to the embedding tables. 580 581 If a gradient of `None` is passed in any position of the nested structure, 582 then an gradient update with a zero gradient is applied for that feature. 583 For optimizers like SGD or Adagrad, this is the same as applying no update 584 at all. For lazy Adam and other sparsely applied optimizers with decay, 585 ensure you understand the effect of applying a zero gradient. 586 587 ```python 588 strategy = tf.distribute.TPUStrategy(...) 589 with strategy.scope(): 590 embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) 591 592 distributed_dataset = ( 593 strategy.distribute_datasets_from_function( 594 dataset_fn=..., 595 options=tf.distribute.InputOptions( 596 experimental_fetch_to_device=False)) 597 dataset_iterator = iter(distributed_dataset) 598 599 @tf.function 600 def training_step(): 601 def tpu_step(tpu_features): 602 with tf.GradientTape() as tape: 603 activations = embedding.dequeue() 604 tape.watch(activations) 605 606 loss = ... # some computation involving activations 607 608 embedding_gradients = tape.gradient(loss, activations) 609 embedding.apply_gradients(embedding_gradients) 610 611 embedding_features, tpu_features = next(dataset_iterator) 612 embedding.enqueue(embedding_features, training=True) 613 strategy.run(tpu_step, args=(embedding_features, )) 614 615 training_step() 616 ``` 617 618 Args: 619 gradients: A nested structure of gradients, with structure matching the 620 `feature_config` passed to this object. 621 name: A name for the underlying op. 622 623 Raises: 624 RuntimeError: If called when object wasn't created under a `TPUStrategy` 625 or if not built (either by manually calling build or calling enqueue). 626 ValueError: If a non-`tf.Tensor` non-`None` gradient is passed in, or a 627 `tf.Tensor` of the incorrect shape is passed in. Also if 628 the size of any sequence in `gradients` does not match corresponding 629 sequence in `feature_config`. 630 TypeError: If the type of any sequence in `gradients` does not match 631 corresponding sequence in `feature_config`. 632 """ 633 if not self._using_tpu: 634 raise RuntimeError("apply_gradients is not valid when TPUEmbedding " 635 "object is not created under a TPUStrategy.") 636 637 if not self._built: 638 raise RuntimeError("apply_gradients called on unbuilt TPUEmbedding " 639 "object. Please either call enqueue first or manually " 640 "call the build method.") 641 642 # send_tpu_embedding_gradients requires per table gradient, if we only have 643 # one feature per table this isn't an issue. When multiple features share 644 # the same table, the order of the features in per table tensor returned by 645 # recv_tpu_embedding_activations matches the order in which they were passed 646 # to enqueue. 647 # In all three places, we use the fixed order given by nest.flatten to have 648 # a consistent feature order. 649 650 # First construct a dict of tensors one for each table. 651 per_table_gradients = self._compute_per_table_gradients(gradients) 652 653 # Now that we have a list of gradients we can compute a list of gradients 654 # in the fixed order of self._table_config which interleave the gradients of 655 # the individual features. We concat on axis 1 and then reshape into a 2d 656 # tensor. The send gradients op expects a tensor of shape 657 # [num_features*batch_size, dim] for each table. 658 interleaved_gradients = [] 659 for table in self._table_config: 660 interleaved_gradients.append(array_ops.reshape( 661 array_ops.concat(per_table_gradients[table], axis=1), 662 [-1, table.dim])) 663 op = tpu_ops.send_tpu_embedding_gradients( 664 inputs=interleaved_gradients, 665 learning_rates=[math_ops.cast(fn(), dtype=dtypes.float32) 666 for fn in self._dynamic_learning_rates], 667 config=self._config_proto.SerializeToString()) 668 669 # Apply the name tag to the op. 670 if name is not None: 671 _add_key_attr(op, name) 672 673 def dequeue(self, name: Optional[Text] = None): 674 """Get the embedding results. 675 676 Returns a nested structure of `tf.Tensor` objects, matching the structure of 677 the `feature_config` argument to the `TPUEmbedding` class. The output shape 678 of the tensors is `(batch_size, dim)`, where `batch_size` is the per core 679 batch size, `dim` is the dimension of the corresponding `TableConfig`. If 680 the feature's corresponding `FeatureConfig` has `max_sequence_length` 681 greater than 0, the output will be a sequence of shape 682 `(batch_size, max_sequence_length, dim)` instead. 683 684 ```python 685 strategy = tf.distribute.TPUStrategy(...) 686 with strategy.scope(): 687 embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) 688 689 distributed_dataset = ( 690 strategy.distribute_datasets_from_function( 691 dataset_fn=..., 692 options=tf.distribute.InputOptions( 693 experimental_fetch_to_device=False)) 694 dataset_iterator = iter(distributed_dataset) 695 696 @tf.function 697 def training_step(): 698 def tpu_step(tpu_features): 699 with tf.GradientTape() as tape: 700 activations = embedding.dequeue() 701 tape.watch(activations) 702 703 loss = ... # some computation involving activations 704 705 embedding_gradients = tape.gradient(loss, activations) 706 embedding.apply_gradients(embedding_gradients) 707 708 embedding_features, tpu_features = next(dataset_iterator) 709 embedding.enqueue(embedding_features, training=True) 710 strategy.run(tpu_step, args=(embedding_features, )) 711 712 training_step() 713 ``` 714 715 Args: 716 name: A name for the underlying op. 717 718 Returns: 719 A nested structure of tensors, with the same structure as `feature_config` 720 passed to this instance of the `TPUEmbedding` object. 721 722 Raises: 723 RuntimeError: If called when object wasn't created under a `TPUStrategy` 724 or if not built (either by manually calling build or calling enqueue). 725 """ 726 if not self._using_tpu: 727 raise RuntimeError("dequeue is not valid when TPUEmbedding object is not " 728 "created under a TPUStrategy.") 729 730 if not self._built: 731 raise RuntimeError("dequeue called on unbuilt TPUEmbedding object. " 732 "Please either call enqueue first or manually call " 733 "the build method.") 734 735 # The activations returned by this op are per table. So we must separate 736 # them out into per feature activations. The activations are interleaved: 737 # for each table, we expect a [num_features*batch_size, dim] tensor. 738 # E.g. we expect the slice [:num_features, :] to contain the lookups for the 739 # first example of all features using this table. 740 activations = tpu_ops.recv_tpu_embedding_activations( 741 num_outputs=len(self._table_config), 742 config=self._config_proto.SerializeToString()) 743 744 # Apply the name tag to the op. 745 if name is not None: 746 _add_key_attr(activations[0].op, name) 747 748 # Compute the number of features for this table. 749 num_features = {table: 0 for table in self._table_config} 750 for feature in nest.flatten(self._feature_config): 751 num_features[feature.table] += (1 if feature.max_sequence_length == 0 752 else feature.max_sequence_length) 753 754 # Activations are reshaped so that they are indexed by batch size and then 755 # by the 'feature' index within the batch. The final dimension should equal 756 # the dimension of the table. 757 table_to_activation = { 758 table: array_ops.reshape(activation, 759 [self._batch_size, num_features[table], -1]) 760 for table, activation in zip(self._table_config, activations)} 761 762 # We process the features in the same order we enqueued them. 763 # For each feature we take the next slice of the activations, so need to 764 # track the activations and the current position we are in. 765 table_to_position = {table: 0 for table in self._table_config} 766 767 per_feature_activations = [] 768 for feature in nest.flatten(self._feature_config): 769 activation = table_to_activation[feature.table] 770 feature_index = table_to_position[feature.table] 771 # We treat non-sequence and sequence features differently here as sequence 772 # features have rank 3 while non-sequence features have rank 2. 773 if feature.max_sequence_length == 0: 774 per_feature_activations.append( 775 activation[:, feature_index, :]) 776 table_to_position[feature.table] += 1 777 else: 778 per_feature_activations.append( 779 activation[:, feature_index:( 780 feature_index+feature.max_sequence_length), :]) 781 table_to_position[feature.table] += feature.max_sequence_length 782 783 # Pack the list back into the same nested structure as the features. 784 return nest.pack_sequence_as(self._feature_config, per_feature_activations) 785 786 def _create_variables_and_slots( 787 self 788 ) -> Dict[Text, Dict[Text, tf_variables.Variable]]: 789 """Create variables for TPU embeddings. 790 791 Note under TPUStrategy this will ensure that all creations happen within a 792 variable creation scope of the sharded variable creator. 793 794 Returns: 795 A dict of dicts. The outer dict is keyed by the table names and the inner 796 dicts are keyed by 'parameters' and the slot variable names. 797 """ 798 799 def create_variables(table): 800 """Create all variables.""" 801 variable_shape = (table.vocabulary_size, table.dim) 802 803 def getter(name, shape, dtype, initializer, trainable): 804 del shape 805 # _add_variable_with_custom_getter clears the shape sometimes, so we 806 # take the global shape from outside the getter. 807 initial_value = functools.partial(initializer, variable_shape, 808 dtype=dtype) 809 return tf_variables.Variable( 810 name=name, 811 initial_value=initial_value, 812 shape=variable_shape, 813 dtype=dtype, 814 trainable=trainable) 815 816 def variable_creator(name, initializer, trainable=True): 817 # use add_variable_with_custom_getter here so that we take advantage of 818 # the checkpoint loading to allow restore before the variables get 819 # created which avoids double initialization. 820 return self._add_variable_with_custom_getter( 821 name=name, 822 initializer=initializer, 823 shape=variable_shape, 824 dtype=dtypes.float32, 825 getter=getter, 826 trainable=trainable) 827 828 parameters = variable_creator(table.name, table.initializer, 829 trainable=not self._using_tpu) 830 831 def slot_creator(name, initializer): 832 return variable_creator(table.name + "/" + name, 833 initializer, 834 False) 835 836 if table.optimizer is not None: 837 slot_vars = table.optimizer._create_slots(parameters, slot_creator) # pylint: disable=protected-access 838 else: 839 slot_vars = {} 840 slot_vars["parameters"] = parameters 841 return slot_vars 842 843 # Store tables based on name rather than TableConfig as we can't track 844 # through dicts with non-string keys, i.e. we won't be able to save. 845 variables = {} 846 for table in self._table_config: 847 if not self._using_tpu: 848 variables[table.name] = create_variables(table) 849 else: 850 with variable_scope.variable_creator_scope( 851 make_sharded_variable_creator(self._hosts)): 852 variables[table.name] = create_variables(table) 853 854 return variables 855 856 def _load_variables(self): 857 # Only load the variables if we are: 858 # 1) Using TPU 859 # 2) Variables are created 860 # 3) Not in save context (except if running eagerly) 861 if self._using_tpu and self._built and not ( 862 not context.executing_eagerly() and save_context.in_save_context()): 863 _load_variables_impl(self._config_proto.SerializeToString(), 864 self._hosts, 865 self._variables, 866 self._table_config) 867 868 def _retrieve_variables(self): 869 # Only retrieve the variables if we are: 870 # 1) Using TPU 871 # 2) Variables are created 872 # 3) Not in save context (except if running eagerly) 873 if self._using_tpu and self._built and not ( 874 not context.executing_eagerly() and save_context.in_save_context()): 875 _retrieve_variables_impl(self._config_proto.SerializeToString(), 876 self._hosts, 877 self._variables, 878 self._table_config) 879 880 def _gather_saveables_for_checkpoint( 881 self 882 ) -> Dict[Text, Callable[[Text], "TPUEmbeddingSaveable"]]: 883 """Overrides default Trackable implementation to add load/retrieve hook.""" 884 # This saveable should be here in both TPU and CPU checkpoints, so when on 885 # CPU, we add the hook with no functions. 886 # TODO(bfontain): Update restore logic in saver so that these hooks are 887 # always executed. Once that is done, we can output an empty list when on 888 # CPU. 889 890 def factory(name=_HOOK_KEY): 891 return TPUEmbeddingSaveable(name, self._load_variables, 892 self._retrieve_variables) 893 return {_HOOK_KEY: factory} 894 895 # Some helper functions for the below enqueue function. 896 def _add_data_for_tensor(self, tensor, weight, indices, values, weights, 897 int_zeros, float_zeros, path): 898 if weight is not None: 899 raise ValueError( 900 "Weight specified for dense input {}, which is not allowed. " 901 "Weight will always be 1 in this case.".format(path)) 902 # For tensors, there are no indices and no weights. 903 indices.append(int_zeros) 904 values.append(math_ops.cast(tensor, dtypes.int32)) 905 weights.append(float_zeros) 906 907 def _add_data_for_sparse_tensor(self, tensor, weight, indices, values, 908 weights, int_zeros, float_zeros, path): 909 indices.append(math_ops.cast(tensor.indices, dtypes.int32)) 910 values.append(math_ops.cast(tensor.values, dtypes.int32)) 911 # If we have weights they must be a SparseTensor. 912 if weight is not None: 913 if not isinstance(weight, sparse_tensor.SparseTensor): 914 raise ValueError("Weight for {} is type {} which does not match " 915 "type input which is SparseTensor.".format( 916 path, type(weight))) 917 weights.append(math_ops.cast(weight.values, dtypes.float32)) 918 else: 919 weights.append(float_zeros) 920 921 def _add_data_for_ragged_tensor(self, tensor, weight, indices, values, 922 weights, int_zeros, float_zeros, path): 923 indices.append(math_ops.cast(tensor.row_splits, dtypes.int32)) 924 values.append(math_ops.cast(tensor.values, dtypes.int32)) 925 # If we have weights they must be a RaggedTensor. 926 if weight is not None: 927 if not isinstance(weight, ragged_tensor.RaggedTensor): 928 raise ValueError("Weight for {} is type {} which does not match " 929 "type input which is RaggedTensor.".format( 930 path, type(weight))) 931 weights.append(math_ops.cast(weight.values, dtypes.float32)) 932 else: 933 weights.append(float_zeros) 934 935 def _generate_enqueue_op( 936 self, 937 flat_inputs: List[internal_types.NativeObject], 938 flat_weights: List[Optional[internal_types.NativeObject]], 939 flat_features: List[tpu_embedding_v2_utils.FeatureConfig], 940 device_ordinal: int, 941 mode_override: Text 942 ) -> ops.Operation: 943 """Outputs a the enqueue op given the inputs and weights. 944 945 Args: 946 flat_inputs: A list of input tensors. 947 flat_weights: A list of input weights (or None) of the same length as 948 flat_inputs. 949 flat_features: A list of FeatureConfigs of the same length as flat_inputs. 950 device_ordinal: The device to create the enqueue op for. 951 mode_override: A tensor containing the string "train" or "inference". 952 953 Returns: 954 The enqueue op. 955 """ 956 957 # First we need to understand which op to use. This depends on if sparse 958 # or ragged tensors are in the flat_inputs. 959 sparse = False 960 ragged = False 961 for inp in flat_inputs: 962 if isinstance(inp, sparse_tensor.SparseTensor): 963 sparse = True 964 elif isinstance(inp, ragged_tensor.RaggedTensor): 965 ragged = True 966 if sparse and ragged: 967 raise ValueError( 968 "Found both SparseTensors and RaggedTensors in the input to the " 969 "enqueue operation. Please ensure that your data does not include " 970 "both SparseTensors and RaggedTensors. It is ok to have Tensors in " 971 "combination with one of the previous types.") 972 973 # Combiners are per table, list in the same order as the table order. 974 combiners = [table.combiner for table in self._table_config] 975 976 # Reverse mapping of self._table_config, so that we can lookup the table 977 # index. 978 table_to_id = {table: i for i, table in enumerate(self._table_config)} 979 980 # These parallel arrays will be the inputs to the enqueue op. 981 indices = [] # sample_indices for sparse, sample_splits for ragged. 982 values = [] 983 weights = [] 984 table_ids = [] 985 max_sequence_lengths = [] 986 987 # We have to supply a empty/zero tensor in a list position where we don't 988 # have data (e.g. indices for standard Tensor input, weight when no weight 989 # is specified). We create one op here per call, so that we reduce the 990 # graph size. 991 int_zeros = array_ops.zeros((0,), dtype=dtypes.int32) 992 float_zeros = array_ops.zeros((0,), dtype=dtypes.float32) 993 994 # In the following loop we insert casts so that everything is either int32 995 # or float32. This is because op inputs which are lists of tensors must be 996 # of the same type within the list. Moreover the CPU implementations of 997 # these ops cast to these types anyway, so we don't lose any data by casting 998 # early. 999 for inp, weight, (path, feature) in zip( 1000 flat_inputs, flat_weights, flat_features): 1001 table_ids.append(table_to_id[feature.table]) 1002 max_sequence_lengths.append(feature.max_sequence_length) 1003 if isinstance(inp, ops.Tensor): 1004 self._add_data_for_tensor(inp, weight, indices, values, weights, 1005 int_zeros, float_zeros, path) 1006 elif isinstance(inp, sparse_tensor.SparseTensor): 1007 self._add_data_for_sparse_tensor(inp, weight, indices, values, weights, 1008 int_zeros, float_zeros, path) 1009 elif isinstance(inp, ragged_tensor.RaggedTensor): 1010 self._add_data_for_ragged_tensor(inp, weight, indices, values, weights, 1011 int_zeros, float_zeros, path) 1012 else: 1013 raise ValueError("Input {} is of unknown type {}. Please only pass " 1014 "Tensor, SparseTensor or RaggedTensor as input to " 1015 "enqueue.".format(path, type(inp))) 1016 1017 if ragged: 1018 return tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch( 1019 sample_splits=indices, 1020 embedding_indices=values, 1021 aggregation_weights=weights, 1022 mode_override=mode_override, 1023 device_ordinal=device_ordinal, 1024 combiners=combiners, 1025 table_ids=table_ids, 1026 max_sequence_lengths=max_sequence_lengths) 1027 return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( 1028 sample_indices=indices, 1029 embedding_indices=values, 1030 aggregation_weights=weights, 1031 mode_override=mode_override, 1032 device_ordinal=device_ordinal, 1033 combiners=combiners, 1034 table_ids=table_ids, 1035 max_sequence_lengths=max_sequence_lengths) 1036 1037 def _raise_error_for_incorrect_control_flow_context(self): 1038 """Raises an error if we are not in the TPUReplicateContext.""" 1039 # Do not allow any XLA control flow (i.e. control flow in between a 1040 # TPUStrategy's run call and the call to this function), as we can't 1041 # extract the enqueue from the head when in XLA control flow. 1042 graph = ops.get_default_graph() 1043 in_tpu_ctx = False 1044 while graph is not None: 1045 ctx = graph._get_control_flow_context() # pylint: disable=protected-access 1046 while ctx is not None: 1047 if isinstance(ctx, tpu.TPUReplicateContext): 1048 in_tpu_ctx = True 1049 break 1050 ctx = ctx.outer_context 1051 if in_tpu_ctx: 1052 break 1053 graph = getattr(graph, "outer_graph", None) 1054 if graph != ops.get_default_graph() and in_tpu_ctx: 1055 raise RuntimeError( 1056 "Current graph {} does not match graph which contains " 1057 "TPUReplicateContext {}. This is most likely due to the fact that " 1058 "enqueueing embedding data is called inside control flow or a " 1059 "nested function inside `strategy.run`. This is not supported " 1060 "because outside compilation fails to extract the enqueue ops as " 1061 "head of computation.".format(ops.get_default_graph(), graph)) 1062 return in_tpu_ctx 1063 1064 def _raise_error_for_non_direct_inputs(self, features): 1065 """Checks all tensors in features to see if they are a direct input.""" 1066 1067 # expand_composites here is important: as composite tensors pass through 1068 # tpu.replicate, they get 'flattened' into their component tensors and then 1069 # repacked before being passed to the tpu function. In means that it is the 1070 # component tensors which are produced by an op with the 1071 # "_tpu_input_identity" attribute. 1072 for path, input_tensor in nest.flatten_with_joined_string_paths( 1073 features, expand_composites=True): 1074 if input_tensor.op.type == "Placeholder": 1075 continue 1076 try: 1077 is_input = input_tensor.op.get_attr("_tpu_input_identity") 1078 except ValueError: 1079 is_input = False 1080 if not is_input: 1081 raise ValueError( 1082 "Received input tensor {} which is the output of op {} (type {}) " 1083 "which does not have the `_tpu_input_identity` attr. Please " 1084 "ensure that the inputs to this layer are taken directly from " 1085 "the arguments of the function called by " 1086 "strategy.run. Two possible causes are: dynamic batch size " 1087 "support or you are using a keras layer and are not passing " 1088 "tensors which match the dtype of the `tf.keras.Input`s." 1089 "If you are triggering dynamic batch size support, you can " 1090 "disable it by passing tf.distribute.RunOptions(" 1091 "experimental_enable_dynamic_batch_size=False) to the options " 1092 "argument of strategy.run().".format(path, 1093 input_tensor.op.name, 1094 input_tensor.op.type)) 1095 1096 def _raise_error_for_inputs_not_on_cpu(self, flat_inputs, flat_paths): 1097 """Checks all tensors in features to see are placed on the CPU.""" 1098 1099 def check_device(path, device_string): 1100 spec = tf_device.DeviceSpec.from_string(device_string) 1101 if spec.device_type == "TPU": 1102 raise ValueError( 1103 "Received input tensor {} which is on a TPU input device {}. Input " 1104 "tensors for TPU embeddings must be placed on the CPU. Please " 1105 "ensure that your dataset is prefetching tensors to the host by " 1106 "setting the 'experimental_fetch_to_device' option of the " 1107 "dataset distribution function. See the documentation of the " 1108 "enqueue method for an example.".format(path, device_string)) 1109 1110 # expand_composites here is important, we need to check the device of each 1111 # underlying tensor. 1112 for input_tensor, input_path in zip(flat_inputs, flat_paths): 1113 if nest.is_sequence_or_composite(input_tensor): 1114 input_tensors = nest.flatten(input_tensor, expand_composites=True) 1115 else: 1116 input_tensors = [input_tensor] 1117 for t in input_tensors: 1118 if (t.op.type == "Identity" and 1119 t.op.inputs[0].op.type == "TPUReplicatedInput"): 1120 for tensor in t.op.inputs[0].op.inputs: 1121 check_device(input_path, tensor.device) 1122 else: 1123 check_device(input_path, t.device) 1124 1125 def enqueue( 1126 self, 1127 features, 1128 weights=None, 1129 training: bool = True, 1130 name: Optional[Text] = None, 1131 device: Optional[Text] = None): 1132 """Enqueues id tensors for embedding lookup. 1133 1134 This function enqueues a structure of features to be looked up in the 1135 embedding tables. We expect that the batch size of each of the tensors in 1136 features matches the per core batch size. This will automatically happen if 1137 your input dataset is batched to the global batch size and you use 1138 `tf.distribute.TPUStrategy`'s `experimental_distribute_dataset` 1139 or if you use `distribute_datasets_from_function` and batch 1140 to the per core batch size computed by the context passed to your input 1141 function. 1142 1143 ```python 1144 strategy = tf.distribute.TPUStrategy(...) 1145 with strategy.scope(): 1146 embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) 1147 1148 distributed_dataset = ( 1149 strategy.distribute_datasets_from_function( 1150 dataset_fn=..., 1151 options=tf.distribute.InputOptions( 1152 experimental_fetch_to_device=False)) 1153 dataset_iterator = iter(distributed_dataset) 1154 1155 @tf.function 1156 def training_step(): 1157 def tpu_step(tpu_features): 1158 with tf.GradientTape() as tape: 1159 activations = embedding.dequeue() 1160 tape.watch(activations) 1161 1162 loss = ... # some computation involving activations 1163 1164 embedding_gradients = tape.gradient(loss, activations) 1165 embedding.apply_gradients(embedding_gradients) 1166 1167 embedding_features, tpu_features = next(dataset_iterator) 1168 embedding.enqueue(embedding_features, training=True) 1169 strategy.run(tpu_step, args=(embedding_features,)) 1170 1171 training_step() 1172 ``` 1173 1174 NOTE: You should specify `training=True` when using 1175 `embedding.apply_gradients` as above and `training=False` when not using 1176 `embedding.apply_gradients` (e.g. for frozen embeddings or when doing 1177 evaluation). 1178 1179 For finer grained control, in the above example the line 1180 1181 ``` 1182 embedding.enqueue(embedding_features, training=True) 1183 ``` 1184 1185 may be replaced with 1186 1187 ``` 1188 per_core_embedding_features = self.strategy.experimental_local_results( 1189 embedding_features) 1190 1191 def per_core_enqueue(ctx): 1192 core_id = ctx.replica_id_in_sync_group 1193 device = strategy.extended.worker_devices[core_id] 1194 embedding.enqueue(per_core_embedding_features[core_id], 1195 device=device) 1196 1197 strategy.experimental_distribute_values_from_function( 1198 per_core_queue_inputs) 1199 ``` 1200 1201 Args: 1202 features: A nested structure of `tf.Tensor`s, `tf.SparseTensor`s or 1203 `tf.RaggedTensor`s, with the same structure as `feature_config`. Inputs 1204 will be downcast to `tf.int32`. Only one type out of `tf.SparseTensor` 1205 or `tf.RaggedTensor` is supported per call. 1206 weights: If not `None`, a nested structure of `tf.Tensor`s, 1207 `tf.SparseTensor`s or `tf.RaggedTensor`s, matching the above, except 1208 that the tensors should be of float type (and they will be downcast to 1209 `tf.float32`). For `tf.SparseTensor`s we assume the `indices` are the 1210 same for the parallel entries from `features` and similarly for 1211 `tf.RaggedTensor`s we assume the row_splits are the same. 1212 training: Defaults to `True`. If `False`, enqueue the batch as inference 1213 batch (forward pass only). Do not call `apply_gradients` when this is 1214 `False` as this may lead to a deadlock. 1215 name: A name for the underlying op. 1216 device: The device name (e.g. '/task:0/device:TPU:2') where this batch 1217 should be enqueued. This should be set if and only if features is not a 1218 `tf.distribute.DistributedValues` and enqueue is not being called 1219 inside a TPU context (e.g. inside `TPUStrategy.run`). 1220 1221 Raises: 1222 ValueError: When called inside a strategy.run call and input is not 1223 directly taken from the args of the `strategy.run` call. Also if 1224 the size of any sequence in `features` does not match corresponding 1225 sequence in `feature_config`. Similarly for `weights`, if not `None`. 1226 If batch size of features is unequal or different from a previous call. 1227 RuntimeError: When called inside a strategy.run call and inside XLA 1228 control flow. If batch_size is not able to be determined and build was 1229 not called. 1230 TypeError: If the type of any sequence in `features` does not match 1231 corresponding sequence in `feature_config`. Similarly for `weights`, if 1232 not `None`. 1233 """ 1234 if not self._using_tpu: 1235 raise RuntimeError("enqueue is not valid when TPUEmbedding object is not " 1236 "created under a TPUStrategy.") 1237 1238 in_tpu_context = self._raise_error_for_incorrect_control_flow_context() 1239 1240 if not self._verify_batch_size_on_enqueue: 1241 if not self._batch_size or not self._built: 1242 raise ValueError( 1243 "Configured not to check batch size on each enqueue() call; please " 1244 "ensure build() was called with global batch size to initialize " 1245 "the TPU for embeddings.") 1246 else: 1247 # Should we also get batch_size from weights if they exist? 1248 # Since features is assumed to be batched at the per replica batch size 1249 # the returned batch size here is per replica an not global. 1250 batch_size = self._get_batch_size(features, in_tpu_context) 1251 if batch_size is None and not self._built: 1252 raise RuntimeError("Unable to determine batch size from input features." 1253 "Please call build() with global batch size to " 1254 "initialize the TPU for embeddings.") 1255 if batch_size is not None: 1256 self._maybe_build(batch_size) 1257 if self._batch_size != batch_size: 1258 raise ValueError("Multiple calls to enqueue with different batch " 1259 "sizes {} and {}.".format(self._batch_size, 1260 batch_size)) 1261 1262 nest.assert_same_structure(self._feature_config, features) 1263 1264 flat_inputs = nest.flatten(features) 1265 flat_weights = [None] * len(flat_inputs) 1266 if weights is not None: 1267 nest.assert_same_structure(self._feature_config, weights) 1268 flat_weights = nest.flatten(weights) 1269 flat_features = nest.flatten_with_joined_string_paths(self._feature_config) 1270 flat_paths, _ = zip(*flat_features) 1271 1272 self._raise_error_for_inputs_not_on_cpu(flat_inputs, flat_paths) 1273 # If we are in a tpu_context, automatically apply outside compilation. 1274 if in_tpu_context: 1275 self._raise_error_for_non_direct_inputs(features) 1276 1277 def generate_enqueue_ops(): 1278 """Generate enqueue ops for outside compilation.""" 1279 # Note that we put array_ops.where_v2 rather than a python if so that 1280 # the op is explicitly create and the constant ops are both in the graph 1281 # even though we don't expect training to be a tensor (and thus generate 1282 # control flow automatically). This need to make it easier to re-write 1283 # the graph later if we need to fix which mode needs to be used. 1284 mode_override = array_ops.where_v2(training, 1285 constant_op.constant("train"), 1286 constant_op.constant("inference")) 1287 1288 # Device ordinal is -1 here, a later rewrite will fix this once the op 1289 # is expanded by outside compilation. 1290 enqueue_op = self._generate_enqueue_op( 1291 flat_inputs, flat_weights, flat_features, device_ordinal=-1, 1292 mode_override=mode_override) 1293 1294 # Apply the name tag to the op. 1295 if name is not None: 1296 _add_key_attr(enqueue_op, name) 1297 1298 # Ensure that this op has outbound control flow, otherwise it won't be 1299 # executed. 1300 ops.get_default_graph().control_outputs.append(enqueue_op) 1301 1302 tpu.outside_compilation(generate_enqueue_ops) 1303 1304 elif device is None: 1305 mode_override = "train" if training else "inference" 1306 # We generate enqueue ops per device, so we need to gather the all 1307 # features for a single device in to a dict. 1308 # We rely here on the fact that the devices in the PerReplica value occur 1309 # in the same (standard) order as self._strategy.extended.worker_devices. 1310 enqueue_ops = [] 1311 for replica_id in range(self._strategy.num_replicas_in_sync): 1312 replica_inputs = distribute_utils.select_replica(replica_id, 1313 flat_inputs) 1314 replica_weights = distribute_utils.select_replica(replica_id, 1315 flat_weights) 1316 tpu_device = self._strategy.extended.worker_devices[replica_id] 1317 # TPU devices string are like /job:worker/replica:0/task:0/device:TPU:0 1318 # the device ordinal is the last number 1319 device_ordinal = ( 1320 tf_device.DeviceSpec.from_string(tpu_device).device_index) 1321 with ops.device(device_util.get_host_for_device(tpu_device)): 1322 enqueue_op = self._generate_enqueue_op( 1323 replica_inputs, replica_weights, flat_features, 1324 device_ordinal=device_ordinal, mode_override=mode_override) 1325 1326 # Apply the name tag to the op. 1327 if name is not None: 1328 _add_key_attr(enqueue_op, name) 1329 enqueue_ops.append(enqueue_op) 1330 ops.get_default_graph().control_outputs.extend(enqueue_ops) 1331 else: 1332 mode_override = "train" if training else "inference" 1333 device_spec = tf_device.DeviceSpec.from_string(device) 1334 if device_spec.device_type != "TPU": 1335 raise ValueError( 1336 "Non-TPU device {} passed to enqueue.".format(device)) 1337 with ops.device(device_util.get_host_for_device(device)): 1338 enqueue_op = self._generate_enqueue_op( 1339 flat_inputs, flat_weights, flat_features, 1340 device_ordinal=device_spec.device_index, 1341 mode_override=mode_override) 1342 1343 # Apply the name tag to the op. 1344 if name is not None: 1345 _add_key_attr(enqueue_op, name) 1346 ops.get_default_graph().control_outputs.append(enqueue_op) 1347 1348 def _get_batch_size(self, tensors, in_tpu_context: bool): 1349 """Gets the batch size from a nested structure of features.""" 1350 batch_size = None 1351 for path, maybe_tensor in nest.flatten_with_joined_string_paths(tensors): 1352 tensor_list = [] 1353 if not in_tpu_context: 1354 # if we are not in a context, then this is PerReplica and we need to 1355 # check each replica's batch size. 1356 for replica_id in range(self._strategy.num_replicas_in_sync): 1357 tensor_list.append(distribute_utils.select_replica(replica_id, 1358 maybe_tensor)) 1359 else: 1360 tensor_list = [maybe_tensor] 1361 1362 for tensor in tensor_list: 1363 if tensor.shape.rank < 1: 1364 raise ValueError( 1365 "Input {} has rank 0, rank must be at least 1.".format(path)) 1366 shape = tensor.shape.as_list() 1367 if shape[0] is not None: 1368 if batch_size is None: 1369 batch_size = shape[0] 1370 elif batch_size != shape[0]: 1371 raise ValueError("Found multiple batch sizes {} and {}. All inputs " 1372 "must have the same batch dimensions size.".format( 1373 batch_size, shape[0])) 1374 return batch_size 1375 1376 1377@def_function.function 1378def _load_variables_impl( 1379 config: Text, 1380 hosts: List[Tuple[int, Text]], 1381 variables: Dict[Text, Dict[Text, tf_variables.Variable]], 1382 table_config: tpu_embedding_v2_utils.TableConfig): 1383 """Load embedding tables to onto TPU for each table and host. 1384 1385 Args: 1386 config: A serialized TPUEmbeddingConfiguration proto. 1387 hosts: A list of CPU devices, on per host. 1388 variables: A dictionary of dictionaries of TPUShardedVariables. First key is 1389 the table name, second key is 'parameters' or the optimizer slot name. 1390 table_config: A list of tf.tpu.experimental.embedding.TableConfig objects. 1391 """ 1392 def select_fn(host_id): 1393 1394 def select_or_zeros(x): 1395 if host_id >= len(x.variables): 1396 # In the edge case where we have more hosts than variables, due to using 1397 # a small number of rows, we load zeros for the later hosts. We copy 1398 # the shape of the first host's variables, which we assume is defined 1399 # because TableConfig guarantees at least one row. 1400 return array_ops.zeros_like(x.variables[0]) 1401 return x.variables[host_id] 1402 1403 return select_or_zeros 1404 1405 for host_id, host in enumerate(hosts): 1406 with ops.device(host): 1407 host_variables = nest.map_structure(select_fn(host_id), variables) 1408 for table in table_config: 1409 table.optimizer._load()( # pylint: disable=protected-access 1410 table_name=table.name, 1411 num_shards=len(hosts), 1412 shard_id=host_id, 1413 config=config, 1414 **host_variables[table.name]) 1415 # Ensure that only the first table/first host gets a config so that we 1416 # don't bloat graph by attaching this large string to each op. 1417 # We have num tables * num hosts of these so for models with a large 1418 # number of tables training on a large slice, this can be an issue. 1419 config = None 1420 1421 1422@def_function.function 1423def _retrieve_variables_impl( 1424 config: Text, 1425 hosts: List[Tuple[int, Text]], 1426 variables: Dict[Text, Dict[Text, tf_variables.Variable]], 1427 table_config: tpu_embedding_v2_utils.TableConfig): 1428 """Retrieve embedding tables from TPU to host memory. 1429 1430 Args: 1431 config: A serialized TPUEmbeddingConfiguration proto. 1432 hosts: A list of all the host CPU devices. 1433 variables: A dictionary of dictionaries of TPUShardedVariables. First key is 1434 the table name, second key is 'parameters' or the optimizer slot name. 1435 table_config: A list of tf.tpu.experimental.embedding.TableConfig objects. 1436 """ 1437 for host_id, host in enumerate(hosts): 1438 with ops.device(host): 1439 for table in table_config: 1440 retrieved = table.optimizer._retrieve()( # pylint: disable=protected-access 1441 table_name=table.name, 1442 num_shards=len(hosts), 1443 shard_id=host_id, 1444 config=config) 1445 # When there are no slot variables (e.g with SGD) this returns a 1446 # single tensor rather than a tuple. In this case we put the tensor in 1447 # a list to make the following code easier to write. 1448 if not isinstance(retrieved, tuple): 1449 retrieved = (retrieved,) 1450 1451 for i, slot in enumerate(["parameters"] + 1452 table.optimizer._slot_names()): # pylint: disable=protected-access 1453 # We must assign the CPU variables the values of tensors that were 1454 # returned from the TPU. 1455 sharded_var = variables[table.name][slot] 1456 if host_id < len(sharded_var.variables): 1457 # In the edge case where we have more hosts than variables, due to 1458 # using a small number of rows, we skip the later hosts. 1459 sharded_var.variables[host_id].assign(retrieved[i]) 1460 # Ensure that only the first table/first host gets a config so that we 1461 # don't bloat graph by attaching this large string to each op. 1462 # We have num tables * num hosts of these so for models with a large 1463 # number of tables training on a large slice, this can be an issue. 1464 config = None 1465 1466 1467class TPUEmbeddingSaveable(saveable_hook.SaveableHook): 1468 """Save/Restore hook to Retrieve/Load TPUEmbedding variables.""" 1469 1470 def __init__( 1471 self, 1472 name: Text, 1473 load: Callable[[], Any], 1474 retrieve: Callable[[], Any]): 1475 self._load = load 1476 self._retrieve = retrieve 1477 super(TPUEmbeddingSaveable, self).__init__(name=name) 1478 1479 def before_save(self): 1480 if self._retrieve is not None: 1481 self._retrieve() 1482 1483 def after_restore(self): 1484 if self._load is not None: 1485 self._load() 1486 1487 1488def _ragged_embedding_lookup_with_reduce( 1489 table: tf_variables.Variable, 1490 ragged: ragged_tensor.RaggedTensor, 1491 weights: ragged_tensor.RaggedTensor, 1492 combiner: Text) -> core.Tensor: 1493 """Compute a ragged lookup followed by a reduce on axis 1. 1494 1495 Args: 1496 table: The embedding table. 1497 ragged: A RaggedTensor of ids to look up. 1498 weights: A RaggedTensor of weights (or None). 1499 combiner: One of "mean", "sum", "sqrtn". 1500 1501 Returns: 1502 A Tensor. 1503 """ 1504 if weights is None: 1505 weights = array_ops.ones_like(ragged, dtype=table.dtype) 1506 weights = array_ops.expand_dims(weights, axis=2) 1507 ragged_result = embedding_ops.embedding_lookup_ragged(table, ragged) 1508 ragged_result = math_ops.reduce_sum(ragged_result * weights, axis=1) 1509 if combiner == "mean": 1510 ragged_result = ragged_result / math_ops.reduce_sum(weights, axis=1) 1511 elif combiner == "sqrtn": 1512 ragged_result = ragged_result, math_ops.sqrt(math_ops.reduce_sum( 1513 weights*weights, axis=1)) 1514 return ragged_result 1515 1516 1517@tf_export("tpu.experimental.embedding.serving_embedding_lookup") 1518def cpu_embedding_lookup(inputs, weights, tables, feature_config): 1519 """Apply standard lookup ops with `tf.tpu.experimental.embedding` configs. 1520 1521 This function is a utility which allows using the 1522 `tf.tpu.experimental.embedding` config objects with standard lookup functions. 1523 This can be used when exporting a model which uses 1524 `tf.tpu.experimental.embedding.TPUEmbedding` for serving on CPU. In particular 1525 `tf.tpu.experimental.embedding.TPUEmbedding` only supports lookups on TPUs and 1526 should not be part of your serving graph. 1527 1528 Note that TPU specific options (such as `max_sequence_length`) in the 1529 configuration objects will be ignored. 1530 1531 In the following example we take a trained model (see the documentation for 1532 `tf.tpu.experimental.embedding.TPUEmbedding` for the context) and create a 1533 saved model with a serving function that will perform the embedding lookup and 1534 pass the results to your model: 1535 1536 ```python 1537 model = model_fn(...) 1538 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 1539 feature_config=feature_config, 1540 batch_size=1024, 1541 optimizer=tf.tpu.experimental.embedding.SGD(0.1)) 1542 checkpoint = tf.train.Checkpoint(model=model, embedding=embedding) 1543 checkpoint.restore(...) 1544 1545 @tf.function(input_signature=[{'feature_one': tf.TensorSpec(...), 1546 'feature_two': tf.TensorSpec(...), 1547 'feature_three': tf.TensorSpec(...)}]) 1548 def serve_tensors(embedding_featurese): 1549 embedded_features = tf.tpu.experimental.embedding.serving_embedding_lookup( 1550 embedding_features, None, embedding.embedding_tables, 1551 feature_config) 1552 return model(embedded_features) 1553 1554 model.embedding_api = embedding 1555 tf.saved_model.save(model, 1556 export_dir=..., 1557 signatures={'serving_default': serve_tensors}) 1558 1559 ``` 1560 1561 NOTE: Its important to assign the embedding api object to a member of your 1562 model as `tf.saved_model.save` only supports saving variables one `Trackable` 1563 object. Since the model's weights are in `model` and the embedding table are 1564 managed by `embedding`, we assign `embedding` to and attribute of `model` so 1565 that tf.saved_model.save can find the embedding variables. 1566 1567 NOTE: The same `serve_tensors` function and `tf.saved_model.save` call will 1568 work directly from training. 1569 1570 Args: 1571 inputs: a nested structure of Tensors, SparseTensors or RaggedTensors. 1572 weights: a nested structure of Tensors, SparseTensors or RaggedTensors or 1573 None for no weights. If not None, structure must match that of inputs, but 1574 entries are allowed to be None. 1575 tables: a dict of mapping TableConfig objects to Variables. 1576 feature_config: a nested structure of FeatureConfig objects with the same 1577 structure as inputs. 1578 1579 Returns: 1580 A nested structure of Tensors with the same structure as inputs. 1581 """ 1582 1583 nest.assert_same_structure(inputs, feature_config) 1584 1585 flat_inputs = nest.flatten(inputs) 1586 flat_weights = [None] * len(flat_inputs) 1587 if weights is not None: 1588 nest.assert_same_structure(inputs, weights) 1589 flat_weights = nest.flatten(weights) 1590 flat_features = nest.flatten_with_joined_string_paths(feature_config) 1591 1592 outputs = [] 1593 for inp, weight, (path, feature) in zip( 1594 flat_inputs, flat_weights, flat_features): 1595 table = tables[feature.table] 1596 1597 if weight is not None: 1598 if isinstance(inp, ops.Tensor): 1599 raise ValueError( 1600 "Weight specified for {}, but input is dense.".format(path)) 1601 elif type(weight) is not type(inp): 1602 raise ValueError( 1603 "Weight for {} is of type {} but it does not match type of the " 1604 "input which is {}.".format(path, type(weight), type(inp))) 1605 elif feature.max_sequence_length > 0: 1606 raise ValueError("Weight specified for {}, but this is a sequence " 1607 "feature.".format(path)) 1608 1609 if isinstance(inp, ops.Tensor): 1610 if feature.max_sequence_length > 0: 1611 raise ValueError("Feature {} is a sequence feature but a dense tensor " 1612 "was passed.".format(path)) 1613 outputs.append(embedding_ops.embedding_lookup_v2(table, inp)) 1614 1615 elif isinstance(inp, sparse_tensor.SparseTensor): 1616 if feature.max_sequence_length > 0: 1617 batch_size = math_ops.cast(array_ops.shape(inp)[0], dtype=dtypes.int64) 1618 sparse_shape = array_ops.stack( 1619 [batch_size, feature.max_sequence_length], axis=0) 1620 # TPU Embedding truncates sequences to max_sequence_length, and if we 1621 # don't truncate, scatter_nd will error out if the index was out of 1622 # bounds. 1623 truncated_inp = sparse_ops.sparse_slice(inp, start=[0, 0], 1624 size=sparse_shape) 1625 1626 dense_output_shape = array_ops.stack( 1627 [batch_size, feature.max_sequence_length, feature.table.dim], 1628 axis=0) 1629 outputs.append( 1630 array_ops.scatter_nd( 1631 truncated_inp.indices, 1632 array_ops.gather(table.read_value(), truncated_inp.values), 1633 dense_output_shape)) 1634 else: 1635 inp_rank = inp.dense_shape.get_shape()[0] 1636 if (not feature.validate_weights_and_indices and 1637 inp_rank is not None and inp_rank <= 2): 1638 outputs.append( 1639 embedding_ops.embedding_lookup_sparse_v2( 1640 table, 1641 inp, 1642 sp_weights=weight, 1643 combiner=feature.table.combiner)) 1644 else: 1645 outputs.append( 1646 embedding_ops.safe_embedding_lookup_sparse_v2( 1647 table, 1648 inp, 1649 sparse_weights=weight, 1650 combiner=feature.table.combiner)) 1651 1652 elif isinstance(inp, ragged_tensor.RaggedTensor): 1653 if feature.max_sequence_length > 0: 1654 batch_size = inp.shape[0] 1655 dense_output_shape = [ 1656 batch_size, feature.max_sequence_length, feature.table.dim] 1657 ragged_lookup = embedding_ops.embedding_lookup_v2(table, inp) 1658 # Unlike scatter_nd, RaggedTensor.to_tensor truncates to the given 1659 # shape. 1660 outputs.append(ragged_lookup.to_tensor(shape=dense_output_shape)) 1661 else: 1662 outputs.append(_ragged_embedding_lookup_with_reduce( 1663 table, inp, weight, feature.table.combiner)) 1664 1665 else: 1666 raise ValueError("Input {} is type {}. Tensor, SparseTensor or " 1667 "RaggedTensor expected.".format(path, type(inp))) 1668 return nest.pack_sequence_as(feature_config, outputs) 1669 1670 1671def get_list_of_hosts(strategy: tpu_strategy.TPUStrategy) -> List[Text]: 1672 """Returns a sorted list of CPU devices for the remote jobs. 1673 1674 Args: 1675 strategy: A TPUStrategy object. 1676 1677 Returns: 1678 A sort list of device strings. 1679 """ 1680 list_of_hosts = [] 1681 # Assume this is sorted by task 1682 for tpu_device in strategy.extended.worker_devices: 1683 host = device_util.get_host_for_device(tpu_device) 1684 if host not in list_of_hosts: 1685 list_of_hosts.append(host) 1686 assert len(list_of_hosts) == strategy.extended.num_hosts 1687 return list_of_hosts 1688 1689 1690def extract_variable_info( 1691 kwargs) -> Tuple[Text, Tuple[int, ...], dtypes.DType, Callable[[], Any]]: 1692 """Extracts the variable creation attributes from the kwargs. 1693 1694 Args: 1695 kwargs: a dict of keyword arguments that were passed to a variable creator 1696 scope. 1697 1698 Returns: 1699 A tuple of variable name, shape, dtype, initialization function. 1700 """ 1701 if (isinstance(kwargs["initial_value"], functools.partial) and ( 1702 "shape" in kwargs["initial_value"].keywords or 1703 kwargs["initial_value"].args)): 1704 # Sometimes shape is passed positionally, sometimes it's passed as a kwarg. 1705 if "shape" in kwargs["initial_value"].keywords: 1706 shape = kwargs["initial_value"].keywords["shape"] 1707 else: 1708 shape = kwargs["initial_value"].args[0] 1709 return (kwargs["name"], shape, 1710 kwargs["initial_value"].keywords.get("dtype", kwargs["dtype"]), 1711 kwargs["initial_value"].func) 1712 elif "shape" not in kwargs or kwargs["shape"] is None or not callable( 1713 kwargs["initial_value"]): 1714 raise ValueError( 1715 "Unable to extract initializer function and shape from {}. Please " 1716 "either pass a function that expects a shape and dtype as the " 1717 "initial value for your variable or functools.partial object with " 1718 "the shape and dtype kwargs set. This is needed so that we can " 1719 "initialize the shards of the ShardedVariable locally.".format( 1720 kwargs["initial_value"])) 1721 else: 1722 return (kwargs["name"], kwargs["shape"], kwargs["dtype"], 1723 kwargs["initial_value"]) 1724 1725 1726def make_sharded_variable_creator( 1727 hosts: List[Text]) -> Callable[..., TPUShardedVariable]: 1728 """Makes a sharded variable creator given a list of hosts. 1729 1730 Args: 1731 hosts: a list of tensorflow devices on which to shard the tensors. 1732 1733 Returns: 1734 A variable creator function. 1735 """ 1736 1737 def sharded_variable_creator( 1738 next_creator: Callable[..., tf_variables.Variable], *args, **kwargs): 1739 """The sharded variable creator.""" 1740 kwargs["skip_mirrored_creator"] = True 1741 1742 num_hosts = len(hosts) 1743 name, shape, dtype, unwrapped_initial_value = extract_variable_info(kwargs) 1744 initial_value = kwargs["initial_value"] 1745 rows = shape[0] 1746 cols = shape[1] 1747 partial_partition = rows % num_hosts 1748 full_rows_per_host = rows // num_hosts 1749 # We partition as if we were using MOD sharding: at least 1750 # `full_rows_per_host` rows to `num_hosts` hosts, where the first 1751 # `partial_partition` hosts get an additional row when the number of rows 1752 # is not cleanly divisible. Note that `full_rows_per_host` may be zero. 1753 partitions = ( 1754 [full_rows_per_host + 1] * partial_partition 1755 + [full_rows_per_host] * (num_hosts - partial_partition)) 1756 variables = [] 1757 sharding_aware = "shard_info" in tf_inspect.getargspec(initial_value).args 1758 1759 # Keep track of offset for sharding aware initializers. 1760 offset = 0 1761 kwargs["dtype"] = dtype 1762 for i, p in enumerate(partitions): 1763 if p == 0: 1764 # Skip variable creation for empty partitions, resulting from the edge 1765 # case of 'rows < num_hosts'. This is safe because both load/restore 1766 # can handle the missing values. 1767 continue 1768 with ops.device(hosts[i]): 1769 kwargs["name"] = "{}_{}".format(name, i) 1770 kwargs["shape"] = (p, cols) 1771 if sharding_aware: 1772 shard_info = base.ShardInfo(kwargs["shape"], (offset, 0)) 1773 kwargs["initial_value"] = functools.partial( 1774 initial_value, shard_info=shard_info) 1775 offset += p 1776 else: 1777 kwargs["initial_value"] = functools.partial( 1778 unwrapped_initial_value, kwargs["shape"], dtype=dtype) 1779 variables.append(next_creator(*args, **kwargs)) 1780 return TPUShardedVariable(variables, name=name) 1781 return sharded_variable_creator 1782