1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Mid level API for TPU Embeddings.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20from __future__ import unicode_literals 21 22import functools 23from typing import Any, Dict, Callable, Iterable, List, Optional, Text, Tuple, Union 24 25from absl import logging 26 27from tensorflow.core.framework import attr_value_pb2 28from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 29from tensorflow.python.distribute import device_util 30from tensorflow.python.distribute import distribute_utils 31from tensorflow.python.distribute import distribution_strategy_context 32from tensorflow.python.distribute import sharded_variable 33from tensorflow.python.distribute import tpu_strategy 34from tensorflow.python.eager import context 35from tensorflow.python.eager import def_function 36from tensorflow.python.framework import constant_op 37from tensorflow.python.framework import device as tf_device 38from tensorflow.python.framework import dtypes 39from tensorflow.python.framework import ops 40from tensorflow.python.framework import sparse_tensor 41from tensorflow.python.ops import array_ops 42from tensorflow.python.ops import embedding_ops 43from tensorflow.python.ops import math_ops 44from tensorflow.python.ops import sparse_ops 45from tensorflow.python.ops import variable_scope 46from tensorflow.python.ops import variables as tf_variables 47from tensorflow.python.ops.ragged import ragged_tensor 48from tensorflow.python.saved_model import save_context 49from tensorflow.python.tpu import tpu 50from tensorflow.python.tpu import tpu_embedding_v2_utils 51from tensorflow.python.tpu.ops import tpu_ops 52from tensorflow.python.training.saving import saveable_hook 53from tensorflow.python.training.tracking import base 54from tensorflow.python.training.tracking import tracking 55from tensorflow.python.types import core 56from tensorflow.python.types import internal as internal_types 57from tensorflow.python.util import compat 58from tensorflow.python.util import nest 59from tensorflow.python.util import tf_inspect 60from tensorflow.python.util.tf_export import tf_export 61 62 63_HOOK_KEY = "TPUEmbedding_saveable" 64_NAME_KEY = "_tpu_embedding_layer" 65 66 67# TODO(bfontain): Cleanup and remove this once there is an implementation of 68# sharded variables that can be used in the PSStrategy with optimizers. 69# We implement just enough of the of a tf.Variable so that this could be passed 70# to an optimizer. 71class TPUShardedVariable(sharded_variable.ShardedVariableMixin): 72 """A ShardedVariable class for TPU.""" 73 74 @property 75 def _in_graph_mode(self): 76 return self.variables[0]._in_graph_mode # pylint: disable=protected-access 77 78 @property 79 def _unique_id(self): 80 return self.variables[0]._unique_id # pylint: disable=protected-access 81 82 @property 83 def _distribute_strategy(self): 84 return self.variables[0]._distribute_strategy # pylint: disable=protected-access 85 86 @property 87 def _shared_name(self): 88 return self._name 89 90 91def _add_key_attr(op, name): 92 op._set_attr(_NAME_KEY, attr_value_pb2.AttrValue(s=compat.as_bytes(name))) # pylint: disable=protected-access 93 94 95@tf_export("tpu.experimental.embedding.TPUEmbedding") 96class TPUEmbedding(tracking.AutoTrackable): 97 """The TPUEmbedding mid level API. 98 99 NOTE: When instantiated under a TPUStrategy, this class can only be created 100 once per call to `tf.tpu.experimental.initialize_tpu_system`. If you wish to 101 re-initialize the embedding engine you must re-initialize the tpu as well. 102 Doing this will clear any variables from TPU, so ensure you have checkpointed 103 before you do this. If a further instances of the class are needed, 104 set the `initialize_tpu_embedding` argument to `False`. 105 106 This class can be used to support training large embeddings on TPU. When 107 creating an instance of this class, you must specify the complete set of 108 tables and features you expect to lookup in those tables. See the 109 documentation of `tf.tpu.experimental.embedding.TableConfig` and 110 `tf.tpu.experimental.embedding.FeatureConfig` for more details on the complete 111 set of options. We will cover the basic usage here. 112 113 NOTE: multiple `FeatureConfig` objects can use the same `TableConfig` object, 114 allowing different features to share the same table: 115 116 ```python 117 table_config_one = tf.tpu.experimental.embedding.TableConfig( 118 vocabulary_size=..., 119 dim=...) 120 table_config_two = tf.tpu.experimental.embedding.TableConfig( 121 vocabulary_size=..., 122 dim=...) 123 feature_config = { 124 'feature_one': tf.tpu.experimental.embedding.FeatureConfig( 125 table=table_config_one), 126 'feature_two': tf.tpu.experimental.embedding.FeatureConfig( 127 table=table_config_one), 128 'feature_three': tf.tpu.experimental.embedding.FeatureConfig( 129 table=table_config_two)} 130 ``` 131 132 There are two modes under which the `TPUEmbedding` class can used. This 133 depends on if the class was created under a `TPUStrategy` scope or not. 134 135 Under `TPUStrategy`, we allow access to the method `enqueue`, `dequeue` and 136 `apply_gradients`. We will show examples below of how to use these to train 137 and evaluate your model. Under CPU, we only access to the `embedding_tables` 138 property which allow access to the embedding tables so that you can use them 139 to run model evaluation/prediction on CPU. 140 141 First lets look at the `TPUStrategy` mode. Initial setup looks like: 142 143 ```python 144 strategy = tf.distribute.TPUStrategy(...) 145 with strategy.scope(): 146 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 147 feature_config=feature_config, 148 optimizer=tf.tpu.experimental.embedding.SGD(0.1)) 149 ``` 150 151 When creating a distributed dataset that is to be passed to the enqueue 152 operation a special input option must be specified: 153 154 ```python 155 distributed_dataset = ( 156 strategy.distribute_datasets_from_function( 157 dataset_fn=..., 158 options=tf.distribute.InputOptions( 159 experimental_prefetch_to_device=False)) 160 dataset_iterator = iter(distributed_dataset) 161 ``` 162 163 NOTE: All batches passed to the layer must have the same batch size for each 164 input, more over once you have called the layer with one batch size all 165 subsequent calls must use the same batch_size. In the event that the batch 166 size cannot be automatically determined by the enqueue method, you must call 167 the build method with the batch size to initialize the layer. 168 169 To use this API on TPU you should use a custom training loop. Below is an 170 example of a training and evaluation step: 171 172 ```python 173 @tf.function 174 def training_step(dataset_iterator, num_steps): 175 def tpu_step(tpu_features): 176 with tf.GradientTape() as tape: 177 activations = embedding.dequeue() 178 tape.watch(activations) 179 model_output = model(activations) 180 loss = ... # some function of labels and model_output 181 182 embedding_gradients = tape.gradient(loss, activations) 183 embedding.apply_gradients(embedding_gradients) 184 # Insert your model gradient and optimizer application here 185 186 for _ in tf.range(num_steps): 187 embedding_features, tpu_features = next(dataset_iterator) 188 embedding.enqueue(embedding_features, training=True) 189 strategy.run(tpu_step, args=(embedding_features, )) 190 191 @tf.function 192 def evalution_step(dataset_iterator, num_steps): 193 def tpu_step(tpu_features): 194 activations = embedding.dequeue() 195 model_output = model(activations) 196 # Insert your evaluation code here. 197 198 for _ in tf.range(num_steps): 199 embedding_features, tpu_features = next(dataset_iterator) 200 embedding.enqueue(embedding_features, training=False) 201 strategy.run(tpu_step, args=(embedding_features, )) 202 ``` 203 204 NOTE: The calls to `enqueue` have `training` set to `True` when 205 `embedding.apply_gradients` is used and set to `False` when 206 `embedding.apply_gradients` is not present in the function. If you don't 207 follow this pattern you may cause an error to be raised or the tpu may 208 deadlock. 209 210 In the above examples, we assume that the user has a dataset which returns 211 a tuple where the first element of the tuple matches the structure of what 212 was passed as the `feature_config` argument to the object initializer. Also we 213 utilize `tf.range` to get a `tf.while_loop` in order to increase performance. 214 215 When checkpointing your model, you should include your 216 `tf.tpu.experimental.embedding.TPUEmbedding` object in the checkpoint. It is a 217 trackable object and saving it will save the embedding tables and their 218 optimizer slot variables: 219 220 ```python 221 checkpoint = tf.train.Checkpoint(model=model, embedding=embedding) 222 checkpoint.save(...) 223 ``` 224 225 On CPU, only the `embedding_table` property is usable. This will allow you to 226 restore a checkpoint to the object and have access to the table variables: 227 228 ```python 229 model = model_fn(...) 230 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 231 feature_config=feature_config, 232 optimizer=tf.tpu.experimental.embedding.SGD(0.1)) 233 checkpoint = tf.train.Checkpoint(model=model, embedding=embedding) 234 checkpoint.restore(...) 235 236 tables = embedding.embedding_tables 237 ``` 238 239 You can now use table in functions like `tf.nn.embedding_lookup` to perform 240 your embedding lookup and pass to your model. 241 242 """ 243 244 def __init__( 245 self, 246 feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic 247 optimizer: Optional[tpu_embedding_v2_utils._Optimizer], # pylint:disable=protected-access 248 pipeline_execution_with_tensor_core: bool = False): 249 """Creates the TPUEmbedding mid level API object. 250 251 ```python 252 strategy = tf.distribute.TPUStrategy(...) 253 with strategy.scope(): 254 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 255 feature_config=tf.tpu.experimental.embedding.FeatureConfig( 256 table=tf.tpu.experimental.embedding.TableConfig( 257 dim=..., 258 vocabulary_size=...))) 259 ``` 260 261 Args: 262 feature_config: A nested structure of 263 `tf.tpu.experimental.embedding.FeatureConfig` configs. 264 optimizer: An instance of one of `tf.tpu.experimental.embedding.SGD`, 265 `tf.tpu.experimental.embedding.Adagrad` or 266 `tf.tpu.experimental.embedding.Adam`. When not created under 267 TPUStrategy may be set to None to avoid the creation of the optimizer 268 slot variables, useful for optimizing memory consumption when exporting 269 the model for serving where slot variables aren't needed. 270 pipeline_execution_with_tensor_core: If True, the TPU embedding 271 computations will overlap with the TensorCore computations (and hence 272 will be one step old). Set to True for improved performance. 273 274 Raises: 275 ValueError: If optimizer is not one of tf.tpu.experimental.embedding.(SGD, 276 Adam or Adagrad) or None when created under a TPUStrategy. 277 """ 278 self._strategy = distribution_strategy_context.get_strategy() 279 self._using_tpu = isinstance(self._strategy, (tpu_strategy.TPUStrategy, 280 tpu_strategy.TPUStrategyV2)) 281 self._pipeline_execution_with_tensor_core = ( 282 pipeline_execution_with_tensor_core) 283 284 self._feature_config = feature_config 285 286 # The TPU embedding ops are slightly inconsistent with how they refer to 287 # tables: 288 # * The enqueue op takes a parallel list of tensors for input, one of those 289 # is the table id for the feature which matches the integer index of the 290 # table in the proto created by _create_config_proto(). 291 # * The recv_tpu_embedding_activations op emits lookups per table in the 292 # order from the config proto. 293 # * The send_tpu_embedding_gradients expects input tensors to be per table 294 # in the same order as the config proto. 295 # * Per optimizer load and retrieve ops are specified per table and take the 296 # table name rather than the table id. 297 # Thus we must fix a common order to tables and ensure they have unique 298 # names. 299 300 # Set table order here to the order of the first occurence of the table in a 301 # feature provided by the user. The order of this struct must be fixed 302 # to provide the user with deterministic behavior over multiple 303 # instantiations. 304 self._table_config = [] 305 for feature in nest.flatten(feature_config): 306 if feature.table not in self._table_config: 307 self._table_config.append(feature.table) 308 309 # Ensure tables have unique names. Also error check the optimizer as we 310 # specifically don't do that in the TableConfig class to allow high level 311 # APIs that are built on this to use strings/other classes to represent 312 # optimizers (before they are passed to this class). 313 table_names = [] 314 for i, table in enumerate(self._table_config): 315 if table.optimizer is None: 316 # TODO(bfontain) Should we allow some sort of optimizer merging here? 317 table.optimizer = optimizer 318 if ((table.optimizer is not None or self._using_tpu) and 319 not isinstance(table.optimizer, tpu_embedding_v2_utils._Optimizer)): # pylint: disable=protected-access 320 raise ValueError("{} is an unsupported optimizer class. Please pass an " 321 "instance of one of the optimizer classes under " 322 "tf.tpu.experimental.embedding.".format( 323 type(table.optimizer))) 324 if table.name is None: 325 table.name = "table_{}".format(i) 326 if table.name in table_names: 327 raise ValueError("Multiple tables with name {} found.".format( 328 table.name)) 329 table_names.append(table.name) 330 331 if self._using_tpu: 332 # Extract a list of callable learning rates also in fixed order. Each 333 # table in the confix proto will get a index into this list and we will 334 # pass this list in the same order after evaluation to the 335 # send_tpu_embedding_gradients op. 336 self._dynamic_learning_rates = list({ 337 table.optimizer.learning_rate for table in self._table_config if 338 callable(table.optimizer.learning_rate)}) 339 340 # We need to list of host devices for the load/retrieve operations. 341 self._hosts = get_list_of_hosts(self._strategy) 342 343 self._built = False 344 345 def build(self, per_replica_batch_size: Optional[int] = None): 346 """Create the underlying variables and initializes the TPU for embeddings. 347 348 This method creates the underlying variables (including slot variables). If 349 created under a TPUStrategy, this will also initialize the TPU for 350 embeddings. 351 352 This function will automatically get called by enqueue, which will try to 353 determine your batch size automatically. If this fails, you must manually 354 call this method before you call enqueue. 355 356 Args: 357 per_replica_batch_size: The per replica batch size that you intend to use. 358 Note that is fixed and the same batch size must be used for both 359 training and evaluation. If you want to calculate this from the global 360 batch size, you can use `num_replicas_in_sync` property of your strategy 361 object. May be set to None if not created under a TPUStrategy. 362 363 Raises: 364 ValueError: If per_replica_batch_size is None and object was created in a 365 TPUStrategy scope. 366 """ 367 if self._built: 368 return 369 370 if self._using_tpu: 371 if per_replica_batch_size is None: 372 raise ValueError("You must specify a per_replica_batch_size when " 373 "calling build if object is created under a " 374 "TPUStrategy.") 375 376 self._batch_size = per_replica_batch_size 377 378 self._config_proto = self._create_config_proto() 379 380 logging.info("Initializing TPU Embedding engine.") 381 tpu_embedding_v2_utils.log_tpu_embedding_configuration(self._config_proto) 382 383 @def_function.function 384 def load_config(): 385 tpu.initialize_system_for_tpu_embedding(self._config_proto) 386 387 load_config() 388 logging.info("Done initializing TPU Embedding engine.") 389 390 # Create and load variables and slot variables into the TPU. 391 # Note that this is a dict of dicts. Keys to the first dict are table names. 392 # We would prefer to use TableConfigs, but then these variables won't be 393 # properly tracked by the tracking API. 394 self._variables = self._create_variables_and_slots() 395 396 self._built = True 397 398 # This is internally conditioned self._built and self._using_tpu 399 self._load_variables() 400 401 def _maybe_build(self, batch_size: Optional[int]): 402 if not self._built: 403 # This can be called while tracing a function, so we wrap the 404 # initialization code with init_scope so it runs eagerly, this means that 405 # it will not be included the function graph generated by tracing so that 406 # we can be sure that we only initialize the TPU for embeddings exactly 407 # once. 408 with ops.init_scope(): 409 self.build(batch_size) 410 411 @property 412 def embedding_tables( 413 self 414 ) -> Dict[tpu_embedding_v2_utils.TableConfig, tf_variables.Variable]: 415 """Returns a dict of embedding tables, keyed by `TableConfig`. 416 417 This property only works when the `TPUEmbedding` object is created under a 418 non-TPU strategy. This is intended to be used to for CPU based lookup when 419 creating a serving checkpoint. 420 421 Returns: 422 A dict of embedding tables, keyed by `TableConfig`. 423 424 Raises: 425 RuntimeError: If object was created under a `TPUStrategy`. 426 """ 427 # We don't support returning tables on TPU due to their sharded nature and 428 # the fact that when using a TPUStrategy: 429 # 1. Variables are stale and are only updated when a checkpoint is made. 430 # 2. Updating the variables won't affect the actual tables on the TPU. 431 if self._using_tpu: 432 if save_context.in_save_context(): 433 return {table: self._variables[table.name]["parameters"].variables[0] 434 for table in self._table_config} 435 raise RuntimeError("Unable to retrieve embedding tables when using a TPU " 436 "strategy. If you need access, save your model, " 437 "create this object under a CPU strategy and restore.") 438 439 self._maybe_build(None) 440 441 # Only return the tables and not the slot variables. On CPU this are honest 442 # tf.Variables. 443 return {table: self._variables[table.name]["parameters"] 444 for table in self._table_config} 445 446 def _create_config_proto( 447 self 448 ) -> tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration: 449 """Creates the TPUEmbeddingConfiguration proto. 450 451 This proto is used to initialize the TPU embedding engine. 452 453 Returns: 454 A TPUEmbeddingConfiguration proto. 455 """ 456 457 config_proto = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration() 458 459 # There are several things that need to be computed here: 460 # 1. Each table has a num_features, which corresponds to the number of 461 # output rows per example for this table. Sequence features count for 462 # their maximum sequence length. 463 # 2. Learning rate index: the index of the dynamic learning rate for this 464 # table (if it exists) in the list we created at initialization. 465 # We don't simply create one learning rate index per table as this has 466 # extremely bad performance characteristics. The more separate 467 # optimization configurations we have, the worse the performance will be. 468 num_features = {table: 0 for table in self._table_config} 469 for feature in nest.flatten(self._feature_config): 470 num_features[feature.table] += (1 if feature.max_sequence_length == 0 471 else feature.max_sequence_length) 472 473 # Map each callable dynamic learning rate to its in index in the list. 474 learning_rate_index = {r: i for i, r in enumerate( 475 self._dynamic_learning_rates)} 476 477 for table in self._table_config: 478 table_descriptor = config_proto.table_descriptor.add() 479 table_descriptor.name = table.name 480 481 # For small tables, we pad to the number of hosts so that at least one 482 # id will be assigned to each host. 483 table_descriptor.vocabulary_size = max(table.vocabulary_size, 484 self._strategy.extended.num_hosts) 485 table_descriptor.dimension = table.dim 486 487 table_descriptor.num_features = num_features[table] 488 489 parameters = table_descriptor.optimization_parameters 490 491 # We handle the learning rate separately here and don't allow the 492 # optimization class to handle this, as it doesn't know about dynamic 493 # rates. 494 if callable(table.optimizer.learning_rate): 495 parameters.learning_rate.dynamic.tag = ( 496 learning_rate_index[table.optimizer.learning_rate]) 497 else: 498 parameters.learning_rate.constant = table.optimizer.learning_rate 499 500 # Use optimizer to handle the rest of the parameters. 501 table.optimizer._set_optimization_parameters(parameters) # pylint: disable=protected-access 502 503 # Always set mode to training, we override the mode during enqueue. 504 config_proto.mode = ( 505 tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.TRAINING) 506 507 config_proto.batch_size_per_tensor_core = self._batch_size 508 config_proto.num_hosts = self._strategy.extended.num_hosts 509 config_proto.num_tensor_cores = self._strategy.num_replicas_in_sync 510 511 # TODO(bfontain): Allow users to pick MOD for the host sharding. 512 config_proto.sharding_strategy = ( 513 tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.DIV_DEFAULT) 514 config_proto.pipeline_execution_with_tensor_core = ( 515 self._pipeline_execution_with_tensor_core) 516 517 return config_proto 518 519 def _compute_per_table_gradients( 520 self, 521 gradients 522 ) -> Dict[Text, List[core.Tensor]]: 523 """Computes a dict of lists of gradients, keyed by table name. 524 525 Args: 526 gradients: A nested structure of Tensors (and Nones) with the same 527 structure as the feature config. 528 529 Returns: 530 A dict of lists of tensors, keyed by the table names, containing the 531 gradients in the correct order with None gradients replaced by zeros. 532 """ 533 534 nest.assert_same_structure(self._feature_config, gradients) 535 536 per_table_gradients = {table: [] for table in self._table_config} 537 for (path, gradient), feature in zip( 538 nest.flatten_with_joined_string_paths(gradients), 539 nest.flatten(self._feature_config)): 540 if gradient is not None and not isinstance(gradient, ops.Tensor): 541 raise ValueError( 542 "Found {} at path {} in gradients. Expected Tensor.".format( 543 type(gradient), path)) 544 545 # Expected tensor shape differs for sequence and non-sequence features. 546 if feature.max_sequence_length > 0: 547 shape = [self._batch_size, feature.max_sequence_length, 548 feature.table.dim] 549 else: 550 shape = [self._batch_size, feature.table.dim] 551 552 if gradient is not None: 553 if gradient.shape != shape: 554 raise ValueError("Found gradient of shape {} at path {}. Expected " 555 "shape {}.".format(gradient.shape, path, shape)) 556 557 # We expand dims on non-sequence features so that all features are 558 # of rank 3 and we can concat on axis=1. 559 if len(shape) == 2: 560 gradient = array_ops.expand_dims(gradient, axis=1) 561 else: 562 # No gradient for this feature, since we must give a gradient for all 563 # features, pass in a zero tensor here. Note that this is not correct 564 # for all optimizers. 565 logging.warn("No gradient passed for feature %s, sending zero " 566 "gradient. This may not be correct behavior for certain " 567 "optimizers like Adam.", path) 568 # Create a shape to mimic the expand_dims above for non-sequence 569 # features. 570 if len(shape) == 2: 571 shape = [shape[0], 1, shape[1]] 572 gradient = array_ops.zeros(shape, dtype=dtypes.float32) 573 per_table_gradients[feature.table].append(gradient) 574 575 return per_table_gradients 576 577 def apply_gradients(self, gradients, name: Text = None): 578 """Applies the gradient update to the embedding tables. 579 580 If a gradient of `None` is passed in any position of the nested structure, 581 then an gradient update with a zero gradient is applied for that feature. 582 For optimizers like SGD or Adagrad, this is the same as applying no update 583 at all. For lazy Adam and other sparsely applied optimizers with decay, 584 ensure you understand the effect of applying a zero gradient. 585 586 ```python 587 strategy = tf.distribute.TPUStrategy(...) 588 with strategy.scope(): 589 embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) 590 591 distributed_dataset = ( 592 strategy.distribute_datasets_from_function( 593 dataset_fn=..., 594 options=tf.distribute.InputOptions( 595 experimental_prefetch_to_device=False)) 596 dataset_iterator = iter(distributed_dataset) 597 598 @tf.function 599 def training_step(): 600 def tpu_step(tpu_features): 601 with tf.GradientTape() as tape: 602 activations = embedding.dequeue() 603 tape.watch(activations) 604 605 loss = ... # some computation involving activations 606 607 embedding_gradients = tape.gradient(loss, activations) 608 embedding.apply_gradients(embedding_gradients) 609 610 embedding_features, tpu_features = next(dataset_iterator) 611 embedding.enqueue(embedding_features, training=True) 612 strategy.run(tpu_step, args=(embedding_features, )) 613 614 training_step() 615 ``` 616 617 Args: 618 gradients: A nested structure of gradients, with structure matching the 619 `feature_config` passed to this object. 620 name: A name for the underlying op. 621 622 Raises: 623 RuntimeError: If called when object wasn't created under a `TPUStrategy` 624 or if not built (either by manually calling build or calling enqueue). 625 ValueError: If a non-`tf.Tensor` non-`None` gradient is passed in, or a 626 `tf.Tensor` of the incorrect shape is passed in. Also if 627 the size of any sequence in `gradients` does not match corresponding 628 sequence in `feature_config`. 629 TypeError: If the type of any sequence in `gradients` does not match 630 corresponding sequence in `feature_config`. 631 """ 632 if not self._using_tpu: 633 raise RuntimeError("apply_gradients is not valid when TPUEmbedding " 634 "object is not created under a TPUStrategy.") 635 636 if not self._built: 637 raise RuntimeError("apply_gradients called on unbuilt TPUEmbedding " 638 "object. Please either call enqueue first or manually " 639 "call the build method.") 640 641 # send_tpu_embedding_gradients requires per table gradient, if we only have 642 # one feature per table this isn't an issue. When multiple features share 643 # the same table, the order of the features in per table tensor returned by 644 # recv_tpu_embedding_activations matches the order in which they were passed 645 # to enqueue. 646 # In all three places, we use the fixed order given by nest.flatten to have 647 # a consistent feature order. 648 649 # First construct a dict of tensors one for each table. 650 per_table_gradients = self._compute_per_table_gradients(gradients) 651 652 # Now that we have a list of gradients we can compute a list of gradients 653 # in the fixed order of self._table_config which interleave the gradients of 654 # the individual features. We concat on axis 1 and then reshape into a 2d 655 # tensor. The send gradients op expects a tensor of shape 656 # [num_features*batch_size, dim] for each table. 657 interleaved_gradients = [] 658 for table in self._table_config: 659 interleaved_gradients.append(array_ops.reshape( 660 array_ops.concat(per_table_gradients[table], axis=1), 661 [-1, table.dim])) 662 op = tpu_ops.send_tpu_embedding_gradients( 663 inputs=interleaved_gradients, 664 learning_rates=[math_ops.cast(fn(), dtype=dtypes.float32) 665 for fn in self._dynamic_learning_rates], 666 config=self._config_proto.SerializeToString()) 667 668 # Apply the name tag to the op. 669 if name is not None: 670 _add_key_attr(op, name) 671 672 def dequeue(self, name: Text = None): 673 """Get the embedding results. 674 675 Returns a nested structure of `tf.Tensor` objects, matching the structure of 676 the `feature_config` argument to the `TPUEmbedding` class. The output shape 677 of the tensors is `(batch_size, dim)`, where `batch_size` is the per core 678 batch size, `dim` is the dimension of the corresponding `TableConfig`. If 679 the feature's corresponding `FeatureConfig` has `max_sequence_length` 680 greater than 0, the output will be a sequence of shape 681 `(batch_size, max_sequence_length, dim)` instead. 682 683 ```python 684 strategy = tf.distribute.TPUStrategy(...) 685 with strategy.scope(): 686 embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) 687 688 distributed_dataset = ( 689 strategy.distribute_datasets_from_function( 690 dataset_fn=..., 691 options=tf.distribute.InputOptions( 692 experimental_prefetch_to_device=False)) 693 dataset_iterator = iter(distributed_dataset) 694 695 @tf.function 696 def training_step(): 697 def tpu_step(tpu_features): 698 with tf.GradientTape() as tape: 699 activations = embedding.dequeue() 700 tape.watch(activations) 701 702 loss = ... # some computation involving activations 703 704 embedding_gradients = tape.gradient(loss, activations) 705 embedding.apply_gradients(embedding_gradients) 706 707 embedding_features, tpu_features = next(dataset_iterator) 708 embedding.enqueue(embedding_features, training=True) 709 strategy.run(tpu_step, args=(embedding_features, )) 710 711 training_step() 712 ``` 713 714 Args: 715 name: A name for the underlying op. 716 717 Returns: 718 A nested structure of tensors, with the same structure as `feature_config` 719 passed to this instance of the `TPUEmbedding` object. 720 721 Raises: 722 RuntimeError: If called when object wasn't created under a `TPUStrategy` 723 or if not built (either by manually calling build or calling enqueue). 724 """ 725 if not self._using_tpu: 726 raise RuntimeError("dequeue is not valid when TPUEmbedding object is not " 727 "created under a TPUStrategy.") 728 729 if not self._built: 730 raise RuntimeError("dequeue called on unbuilt TPUEmbedding object. " 731 "Please either call enqueue first or manually call " 732 "the build method.") 733 734 # The activations returned by this op are per table. So we must separate 735 # them out into per feature activations. The activations are interleaved: 736 # for each table, we expect a [num_features*batch_size, dim] tensor. 737 # E.g. we expect the slice [:num_features, :] to contain the lookups for the 738 # first example of all features using this table. 739 activations = tpu_ops.recv_tpu_embedding_activations( 740 num_outputs=len(self._table_config), 741 config=self._config_proto.SerializeToString()) 742 743 # Apply the name tag to the op. 744 if name is not None: 745 _add_key_attr(activations[0].op, name) 746 747 # Compute the number of features for this table. 748 num_features = {table: 0 for table in self._table_config} 749 for feature in nest.flatten(self._feature_config): 750 num_features[feature.table] += (1 if feature.max_sequence_length == 0 751 else feature.max_sequence_length) 752 753 # Activations are reshaped so that they are indexed by batch size and then 754 # by the 'feature' index within the batch. The final dimension should equal 755 # the dimension of the table. 756 table_to_activation = { 757 table: array_ops.reshape(activation, 758 [self._batch_size, num_features[table], -1]) 759 for table, activation in zip(self._table_config, activations)} 760 761 # We process the features in the same order we enqueued them. 762 # For each feature we take the next slice of the activations, so need to 763 # track the activations and the current position we are in. 764 table_to_position = {table: 0 for table in self._table_config} 765 766 per_feature_activations = [] 767 for feature in nest.flatten(self._feature_config): 768 activation = table_to_activation[feature.table] 769 feature_index = table_to_position[feature.table] 770 # We treat non-sequence and sequence features differently here as sequence 771 # features have rank 3 while non-sequence features have rank 2. 772 if feature.max_sequence_length == 0: 773 per_feature_activations.append( 774 activation[:, feature_index, :]) 775 table_to_position[feature.table] += 1 776 else: 777 per_feature_activations.append( 778 activation[:, feature_index:( 779 feature_index+feature.max_sequence_length), :]) 780 table_to_position[feature.table] += feature.max_sequence_length 781 782 # Pack the list back into the same nested structure as the features. 783 return nest.pack_sequence_as(self._feature_config, per_feature_activations) 784 785 def _create_variables_and_slots( 786 self 787 ) -> Dict[Text, Dict[Text, tf_variables.Variable]]: 788 """Create variables for TPU embeddings. 789 790 Note under TPUStrategy this will ensure that all creations happen within a 791 variable creation scope of the sharded variable creator. 792 793 Returns: 794 A dict of dicts. The outer dict is keyed by the table names and the inner 795 dicts are keyed by 'parameters' and the slot variable names. 796 """ 797 798 def create_variables(table): 799 """Create all variables.""" 800 variable_shape = (table.vocabulary_size, table.dim) 801 802 def getter(name, shape, dtype, initializer, trainable): 803 del shape 804 # _add_variable_with_custom_getter clears the shape sometimes, so we 805 # take the global shape from outside the getter. 806 initial_value = functools.partial(initializer, variable_shape, 807 dtype=dtype) 808 return tf_variables.Variable( 809 name=name, 810 initial_value=initial_value, 811 shape=variable_shape, 812 dtype=dtype, 813 trainable=trainable) 814 815 def variable_creator(name, initializer, trainable=True): 816 # use add_variable_with_custom_getter here so that we take advantage of 817 # the checkpoint loading to allow restore before the variables get 818 # created which avoids double initialization. 819 return self._add_variable_with_custom_getter( 820 name=name, 821 initializer=initializer, 822 shape=variable_shape, 823 dtype=dtypes.float32, 824 getter=getter, 825 trainable=trainable) 826 827 parameters = variable_creator(table.name, table.initializer, 828 trainable=not self._using_tpu) 829 830 def slot_creator(name, initializer): 831 return variable_creator(table.name + "/" + name, 832 initializer, 833 False) 834 835 if table.optimizer is not None: 836 slot_vars = table.optimizer._create_slots(parameters, slot_creator) # pylint: disable=protected-access 837 else: 838 slot_vars = {} 839 slot_vars["parameters"] = parameters 840 return slot_vars 841 842 # Store tables based on name rather than TableConfig as we can't track 843 # through dicts with non-string keys, i.e. we won't be able to save. 844 variables = {} 845 for table in self._table_config: 846 if not self._using_tpu: 847 variables[table.name] = create_variables(table) 848 else: 849 with variable_scope.variable_creator_scope( 850 make_sharded_variable_creator(self._hosts)): 851 variables[table.name] = create_variables(table) 852 853 return variables 854 855 def _load_variables(self): 856 # Only load the variables if we are: 857 # 1) Using TPU 858 # 2) Variables are created 859 # 3) Not in save context (except if running eagerly) 860 if self._using_tpu and self._built and not ( 861 not context.executing_eagerly() and save_context.in_save_context()): 862 _load_variables_impl(self._config_proto.SerializeToString(), 863 self._hosts, 864 self._variables, 865 self._table_config) 866 867 def _retrieve_variables(self): 868 # Only retrieve the variables if we are: 869 # 1) Using TPU 870 # 2) Variables are created 871 # 3) Not in save context (except if running eagerly) 872 if self._using_tpu and self._built and not ( 873 not context.executing_eagerly() and save_context.in_save_context()): 874 _retrieve_variables_impl(self._config_proto.SerializeToString(), 875 self._hosts, 876 self._variables, 877 self._table_config) 878 879 def _gather_saveables_for_checkpoint( 880 self 881 ) -> Dict[Text, Callable[[Text], "TPUEmbeddingSaveable"]]: 882 """Overrides default Trackable implementation to add load/retrieve hook.""" 883 # This saveable should be here in both TPU and CPU checkpoints, so when on 884 # CPU, we add the hook with no functions. 885 # TODO(bfontain): Update restore logic in saver so that these hooks are 886 # always executed. Once that is done, we can output an empty list when on 887 # CPU. 888 889 def factory(name=_HOOK_KEY): 890 return TPUEmbeddingSaveable(name, self._load_variables, 891 self._retrieve_variables) 892 return {_HOOK_KEY: factory} 893 894 # Some helper functions for the below enqueue function. 895 def _add_data_for_tensor(self, tensor, weight, indices, values, weights, 896 int_zeros, float_zeros, path): 897 if weight is not None: 898 raise ValueError( 899 "Weight specified for dense input {}, which is not allowed. " 900 "Weight will always be 1 in this case.".format(path)) 901 # For tensors, there are no indices and no weights. 902 indices.append(int_zeros) 903 values.append(math_ops.cast(tensor, dtypes.int32)) 904 weights.append(float_zeros) 905 906 def _add_data_for_sparse_tensor(self, tensor, weight, indices, values, 907 weights, int_zeros, float_zeros, path): 908 indices.append(math_ops.cast(tensor.indices, dtypes.int32)) 909 values.append(math_ops.cast(tensor.values, dtypes.int32)) 910 # If we have weights they must be a SparseTensor. 911 if weight is not None: 912 if not isinstance(weight, sparse_tensor.SparseTensor): 913 raise ValueError("Weight for {} is type {} which does not match " 914 "type input which is SparseTensor.".format( 915 path, type(weight))) 916 weights.append(math_ops.cast(weight.values, dtypes.float32)) 917 else: 918 weights.append(float_zeros) 919 920 def _add_data_for_ragged_tensor(self, tensor, weight, indices, values, 921 weights, int_zeros, float_zeros, path): 922 indices.append(math_ops.cast(tensor.row_splits, dtypes.int32)) 923 values.append(math_ops.cast(tensor.values, dtypes.int32)) 924 # If we have weights they must be a RaggedTensor. 925 if weight is not None: 926 if not isinstance(weight, ragged_tensor.RaggedTensor): 927 raise ValueError("Weight for {} is type {} which does not match " 928 "type input which is RaggedTensor.".format( 929 path, type(weight))) 930 weights.append(math_ops.cast(weight.values, dtypes.float32)) 931 else: 932 weights.append(float_zeros) 933 934 def _generate_enqueue_op( 935 self, 936 flat_inputs: List[internal_types.NativeObject], 937 flat_weights: List[Optional[internal_types.NativeObject]], 938 flat_features: List[tpu_embedding_v2_utils.FeatureConfig], 939 device_ordinal: int, 940 mode_override: Text 941 ) -> ops.Operation: 942 """Outputs a the enqueue op given the inputs and weights. 943 944 Args: 945 flat_inputs: A list of input tensors. 946 flat_weights: A list of input weights (or None) of the same length as 947 flat_inputs. 948 flat_features: A list of FeatureConfigs of the same length as flat_inputs. 949 device_ordinal: The device to create the enqueue op for. 950 mode_override: A tensor containing the string "train" or "inference". 951 952 Returns: 953 The enqueue op. 954 """ 955 956 # First we need to understand which op to use. This depends on if sparse 957 # or ragged tensors are in the flat_inputs. 958 sparse = False 959 ragged = False 960 for inp in flat_inputs: 961 if isinstance(inp, sparse_tensor.SparseTensor): 962 sparse = True 963 elif isinstance(inp, ragged_tensor.RaggedTensor): 964 ragged = True 965 if sparse and ragged: 966 raise ValueError( 967 "Found both SparseTensors and RaggedTensors in the input to the " 968 "enqueue operation. Please ensure that your data does not include " 969 "both SparseTensors and RaggedTensors. It is ok to have Tensors in " 970 "combination with one of the previous types.") 971 972 # Combiners are per table, list in the same order as the table order. 973 combiners = [table.combiner for table in self._table_config] 974 975 # Reverse mapping of self._table_config, so that we can lookup the table 976 # index. 977 table_to_id = {table: i for i, table in enumerate(self._table_config)} 978 979 # These parallel arrays will be the inputs to the enqueue op. 980 indices = [] # sample_indices for sparse, sample_splits for ragged. 981 values = [] 982 weights = [] 983 table_ids = [] 984 max_sequence_lengths = [] 985 986 # We have to supply a empty/zero tensor in a list position where we don't 987 # have data (e.g. indices for standard Tensor input, weight when no weight 988 # is specified). We create one op here per call, so that we reduce the 989 # graph size. 990 int_zeros = array_ops.zeros((0,), dtype=dtypes.int32) 991 float_zeros = array_ops.zeros((0,), dtype=dtypes.float32) 992 993 # In the following loop we insert casts so that everything is either int32 994 # or float32. This is because op inputs which are lists of tensors must be 995 # of the same type within the list. Moreover the CPU implementations of 996 # these ops cast to these types anyway, so we don't lose any data by casting 997 # early. 998 for inp, weight, (path, feature) in zip( 999 flat_inputs, flat_weights, flat_features): 1000 table_ids.append(table_to_id[feature.table]) 1001 max_sequence_lengths.append(feature.max_sequence_length) 1002 if isinstance(inp, ops.Tensor): 1003 self._add_data_for_tensor(inp, weight, indices, values, weights, 1004 int_zeros, float_zeros, path) 1005 elif isinstance(inp, sparse_tensor.SparseTensor): 1006 self._add_data_for_sparse_tensor(inp, weight, indices, values, weights, 1007 int_zeros, float_zeros, path) 1008 elif isinstance(inp, ragged_tensor.RaggedTensor): 1009 self._add_data_for_ragged_tensor(inp, weight, indices, values, weights, 1010 int_zeros, float_zeros, path) 1011 else: 1012 raise ValueError("Input {} is of unknown type {}. Please only pass " 1013 "Tensor, SparseTensor or RaggedTensor as input to " 1014 "enqueue.".format(path, type(inp))) 1015 1016 if ragged: 1017 return tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch( 1018 sample_splits=indices, 1019 embedding_indices=values, 1020 aggregation_weights=weights, 1021 mode_override=mode_override, 1022 device_ordinal=device_ordinal, 1023 combiners=combiners, 1024 table_ids=table_ids, 1025 max_sequence_lengths=max_sequence_lengths) 1026 return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( 1027 sample_indices=indices, 1028 embedding_indices=values, 1029 aggregation_weights=weights, 1030 mode_override=mode_override, 1031 device_ordinal=device_ordinal, 1032 combiners=combiners, 1033 table_ids=table_ids, 1034 max_sequence_lengths=max_sequence_lengths) 1035 1036 def _raise_error_for_incorrect_control_flow_context(self): 1037 """Raises an error if we are not in the TPUReplicateContext.""" 1038 # Do not allow any XLA control flow (i.e. control flow in between a 1039 # TPUStrategy's run call and the call to this function), as we can't 1040 # extract the enqueue from the head when in XLA control flow. 1041 graph = ops.get_default_graph() 1042 in_tpu_ctx = False 1043 while graph is not None: 1044 ctx = graph._get_control_flow_context() # pylint: disable=protected-access 1045 while ctx is not None: 1046 if isinstance(ctx, tpu.TPUReplicateContext): 1047 in_tpu_ctx = True 1048 break 1049 ctx = ctx.outer_context 1050 if in_tpu_ctx: 1051 break 1052 graph = getattr(graph, "outer_graph", None) 1053 if graph != ops.get_default_graph() and in_tpu_ctx: 1054 raise RuntimeError( 1055 "Current graph {} does not match graph which contains " 1056 "TPUReplicateContext {}. This is most likely due to the fact that " 1057 "enqueueing embedding data is called inside control flow or a " 1058 "nested function inside `strategy.run`. This is not supported " 1059 "because outside compilation fails to extract the enqueue ops as " 1060 "head of computation.".format(ops.get_default_graph(), graph)) 1061 return in_tpu_ctx 1062 1063 def _raise_error_for_non_direct_inputs(self, features): 1064 """Checks all tensors in features to see if they are a direct input.""" 1065 1066 # expand_composites here is important: as composite tensors pass through 1067 # tpu.replicate, they get 'flattened' into their component tensors and then 1068 # repacked before being passed to the tpu function. In means that it is the 1069 # component tensors which are produced by an op with the 1070 # "_tpu_input_identity" attribute. 1071 for path, input_tensor in nest.flatten_with_joined_string_paths( 1072 features, expand_composites=True): 1073 if input_tensor.op.type == "Placeholder": 1074 continue 1075 try: 1076 is_input = input_tensor.op.get_attr("_tpu_input_identity") 1077 except ValueError: 1078 is_input = False 1079 if not is_input: 1080 raise ValueError( 1081 "Received input tensor {} which is the output of op {} (type {}) " 1082 "which does not have the `_tpu_input_identity` attr. Please " 1083 "ensure that the inputs to this layer are taken directly from " 1084 "the arguments of the function called by " 1085 "strategy.run. Two possible causes are: dynamic batch size " 1086 "support or you are using a keras layer and are not passing " 1087 "tensors which match the dtype of the `tf.keras.Input`s." 1088 "If you are triggering dynamic batch size support, you can " 1089 "disable it by passing tf.distribute.RunOptions(" 1090 "experimental_enable_dynamic_batch_size=False) to the options " 1091 "argument of strategy.run().".format(path, 1092 input_tensor.op.name, 1093 input_tensor.op.type)) 1094 1095 def _raise_error_for_inputs_not_on_cpu(self, features): 1096 """Checks all tensors in features to see are placed on the CPU.""" 1097 1098 def check_device(path, device_string): 1099 spec = tf_device.DeviceSpec.from_string(device_string) 1100 if spec.device_type == "TPU": 1101 raise ValueError( 1102 "Received input tensor {} which is on a TPU input device {}. Input " 1103 "tensors for TPU embeddings must be placed on the CPU. Please " 1104 "ensure that your dataset is prefetching tensors to the host by " 1105 "setting the 'experimental_prefetch_to_device' option of the " 1106 "dataset distribution function. See the documentation of the " 1107 "enqueue method for an example.".format( 1108 path, device_string)) 1109 1110 # expand_composites here is important, we need to check the device of each 1111 # underlying tensor. 1112 for path, input_tensor in nest.flatten_with_joined_string_paths( 1113 features, expand_composites=True): 1114 if (input_tensor.op.type == "Identity" and 1115 input_tensor.op.inputs[0].op.type == "TPUReplicatedInput"): 1116 for tensor in input_tensor.op.inputs[0].op.inputs: 1117 check_device(path, tensor.device) 1118 else: 1119 check_device(path, input_tensor.device) 1120 1121 def enqueue( 1122 self, 1123 features, 1124 weights=None, 1125 training: bool = True, 1126 name: Optional[Text] = None, 1127 device: Optional[Text] = None): 1128 """Enqueues id tensors for embedding lookup. 1129 1130 This function enqueues a structure of features to be looked up in the 1131 embedding tables. We expect that the batch size of each of the tensors in 1132 features matches the per core batch size. This will automatically happen if 1133 your input dataset is batched to the global batch size and you use 1134 `tf.distribute.TPUStrategy`'s `experimental_distribute_dataset` 1135 or if you use `distribute_datasets_from_function` and batch 1136 to the per core batch size computed by the context passed to your input 1137 function. 1138 1139 ```python 1140 strategy = tf.distribute.TPUStrategy(...) 1141 with strategy.scope(): 1142 embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) 1143 1144 distributed_dataset = ( 1145 strategy.distribute_datasets_from_function( 1146 dataset_fn=..., 1147 options=tf.distribute.InputOptions( 1148 experimental_prefetch_to_device=False)) 1149 dataset_iterator = iter(distributed_dataset) 1150 1151 @tf.function 1152 def training_step(): 1153 def tpu_step(tpu_features): 1154 with tf.GradientTape() as tape: 1155 activations = embedding.dequeue() 1156 tape.watch(activations) 1157 1158 loss = ... # some computation involving activations 1159 1160 embedding_gradients = tape.gradient(loss, activations) 1161 embedding.apply_gradients(embedding_gradients) 1162 1163 embedding_features, tpu_features = next(dataset_iterator) 1164 embedding.enqueue(embedding_features, training=True) 1165 strategy.run(tpu_step, args=(embedding_features,)) 1166 1167 training_step() 1168 ``` 1169 1170 NOTE: You should specify `training=True` when using 1171 `embedding.apply_gradients` as above and `training=False` when not using 1172 `embedding.apply_gradients` (e.g. for frozen embeddings or when doing 1173 evaluation). 1174 1175 For finer grained control, in the above example the line 1176 1177 ``` 1178 embedding.enqueue(embedding_features, training=True) 1179 ``` 1180 1181 may be replaced with 1182 1183 ``` 1184 per_core_embedding_features = self.strategy.experimental_local_results( 1185 embedding_features) 1186 1187 def per_core_enqueue(ctx): 1188 core_id = ctx.replica_id_in_sync_group 1189 device = strategy.extended.worker_devices[core_id] 1190 embedding.enqueue(per_core_embedding_features[core_id], 1191 device=device) 1192 1193 strategy.experimental_distribute_values_from_function( 1194 per_core_queue_inputs) 1195 ``` 1196 1197 Args: 1198 features: A nested structure of `tf.Tensor`s, `tf.SparseTensor`s or 1199 `tf.RaggedTensor`s, with the same structure as `feature_config`. Inputs 1200 will be downcast to `tf.int32`. Only one type out of `tf.SparseTensor` 1201 or `tf.RaggedTensor` is supported per call. 1202 weights: If not `None`, a nested structure of `tf.Tensor`s, 1203 `tf.SparseTensor`s or `tf.RaggedTensor`s, matching the above, except 1204 that the tensors should be of float type (and they will be downcast to 1205 `tf.float32`). For `tf.SparseTensor`s we assume the `indices` are the 1206 same for the parallel entries from `features` and similarly for 1207 `tf.RaggedTensor`s we assume the row_splits are the same. 1208 training: Defaults to `True`. If `False`, enqueue the batch as inference 1209 batch (forward pass only). Do not call `apply_gradients` when this is 1210 `False` as this may lead to a deadlock. 1211 name: A name for the underlying op. 1212 device: The device name (e.g. '/task:0/device:TPU:2') where this batch 1213 should be enqueued. This should be set if and only if features is not a 1214 `tf.distribute.DistributedValues` and enqueue is not being called 1215 inside a TPU context (e.g. inside `TPUStrategy.run`). 1216 1217 Raises: 1218 ValueError: When called inside a strategy.run call and input is not 1219 directly taken from the args of the `strategy.run` call. Also if 1220 the size of any sequence in `features` does not match corresponding 1221 sequence in `feature_config`. Similarly for `weights`, if not `None`. 1222 If batch size of features is unequal or different from a previous call. 1223 RuntimeError: When called inside a strategy.run call and inside XLA 1224 control flow. If batch_size is not able to be determined and build was 1225 not called. 1226 TypeError: If the type of any sequence in `features` does not match 1227 corresponding sequence in `feature_config`. Similarly for `weights`, if 1228 not `None`. 1229 """ 1230 if not self._using_tpu: 1231 raise RuntimeError("enqueue is not valid when TPUEmbedding object is not " 1232 "created under a TPUStrategy.") 1233 1234 in_tpu_context = self._raise_error_for_incorrect_control_flow_context() 1235 1236 # Should we also get batch_size from weights if they exist? 1237 # Since features is assumed to be batched at the per replica batch size 1238 # the returned batch size here is per replica an not global. 1239 batch_size = self._get_batch_size(features, in_tpu_context) 1240 if batch_size is None and not self._built: 1241 raise RuntimeError("Unable to determine batch size from input features." 1242 "Please call build() with global batch size to " 1243 "initialize the TPU for embeddings.") 1244 if batch_size is not None: 1245 self._maybe_build(batch_size) 1246 if self._batch_size != batch_size: 1247 raise ValueError("Multiple calls to enqueue with different batch sizes " 1248 "{} and {}.".format(self._batch_size, 1249 batch_size)) 1250 1251 nest.assert_same_structure(self._feature_config, features) 1252 1253 flat_inputs = nest.flatten(features) 1254 flat_weights = [None] * len(flat_inputs) 1255 if weights is not None: 1256 nest.assert_same_structure(self._feature_config, weights) 1257 flat_weights = nest.flatten(weights) 1258 flat_features = nest.flatten_with_joined_string_paths(self._feature_config) 1259 1260 self._raise_error_for_inputs_not_on_cpu(features) 1261 # If we are in a tpu_context, automatically apply outside compilation. 1262 if in_tpu_context: 1263 self._raise_error_for_non_direct_inputs(features) 1264 1265 def generate_enqueue_ops(): 1266 """Generate enqueue ops for outside compilation.""" 1267 # Note that we put array_ops.where_v2 rather than a python if so that 1268 # the op is explicitly create and the constant ops are both in the graph 1269 # even though we don't expect training to be a tensor (and thus generate 1270 # control flow automatically). This need to make it easier to re-write 1271 # the graph later if we need to fix which mode needs to be used. 1272 mode_override = array_ops.where_v2(training, 1273 constant_op.constant("train"), 1274 constant_op.constant("inference")) 1275 1276 # Device ordinal is -1 here, a later rewrite will fix this once the op 1277 # is expanded by outside compilation. 1278 enqueue_op = self._generate_enqueue_op( 1279 flat_inputs, flat_weights, flat_features, device_ordinal=-1, 1280 mode_override=mode_override) 1281 1282 # Apply the name tag to the op. 1283 if name is not None: 1284 _add_key_attr(enqueue_op, name) 1285 1286 # Ensure that this op has outbound control flow, otherwise it won't be 1287 # executed. 1288 ops.get_default_graph().control_outputs.append(enqueue_op) 1289 1290 tpu.outside_compilation(generate_enqueue_ops) 1291 1292 elif device is None: 1293 mode_override = "train" if training else "inference" 1294 # We generate enqueue ops per device, so we need to gather the all 1295 # features for a single device in to a dict. 1296 # We rely here on the fact that the devices in the PerReplica value occur 1297 # in the same (standard) order as self._strategy.extended.worker_devices. 1298 enqueue_ops = [] 1299 for replica_id in range(self._strategy.num_replicas_in_sync): 1300 replica_inputs = distribute_utils.select_replica(replica_id, 1301 flat_inputs) 1302 replica_weights = distribute_utils.select_replica(replica_id, 1303 flat_weights) 1304 tpu_device = self._strategy.extended.worker_devices[replica_id] 1305 # TPU devices string are like /job:worker/replica:0/task:0/device:TPU:0 1306 # the device ordinal is the last number 1307 device_ordinal = ( 1308 tf_device.DeviceSpec.from_string(tpu_device).device_index) 1309 with ops.device(device_util.get_host_for_device(tpu_device)): 1310 enqueue_op = self._generate_enqueue_op( 1311 replica_inputs, replica_weights, flat_features, 1312 device_ordinal=device_ordinal, mode_override=mode_override) 1313 1314 # Apply the name tag to the op. 1315 if name is not None: 1316 _add_key_attr(enqueue_op, name) 1317 enqueue_ops.append(enqueue_op) 1318 ops.get_default_graph().control_outputs.extend(enqueue_ops) 1319 else: 1320 mode_override = "train" if training else "inference" 1321 device_spec = tf_device.DeviceSpec.from_string(device) 1322 if device_spec.device_type != "TPU": 1323 raise ValueError( 1324 "Non-TPU device {} passed to enqueue.".format(device)) 1325 with ops.device(device_util.get_host_for_device(device)): 1326 enqueue_op = self._generate_enqueue_op( 1327 flat_inputs, flat_weights, flat_features, 1328 device_ordinal=device_spec.device_index, 1329 mode_override=mode_override) 1330 1331 # Apply the name tag to the op. 1332 if name is not None: 1333 _add_key_attr(enqueue_op, name) 1334 ops.get_default_graph().control_outputs.append(enqueue_op) 1335 1336 def _get_batch_size(self, tensors, in_tpu_context: bool): 1337 """Gets the batch size from a nested structure of features.""" 1338 batch_size = None 1339 for path, maybe_tensor in nest.flatten_with_joined_string_paths(tensors): 1340 tensor_list = [] 1341 if not in_tpu_context: 1342 # if we are not in a context, then this is PerReplica and we need to 1343 # check each replica's batch size. 1344 for replica_id in range(self._strategy.num_replicas_in_sync): 1345 tensor_list.append(distribute_utils.select_replica(replica_id, 1346 maybe_tensor)) 1347 else: 1348 tensor_list = [maybe_tensor] 1349 1350 for tensor in tensor_list: 1351 if tensor.shape.rank < 1: 1352 raise ValueError( 1353 "Input {} has rank 0, rank must be at least 1.".format(path)) 1354 shape = tensor.shape.as_list() 1355 if shape[0] is not None: 1356 if batch_size is None: 1357 batch_size = shape[0] 1358 elif batch_size != shape[0]: 1359 raise ValueError("Found multiple batch sizes {} and {}. All inputs " 1360 "must have the same batch dimensions size.".format( 1361 batch_size, shape[0])) 1362 return batch_size 1363 1364 1365@def_function.function 1366def _load_variables_impl( 1367 config: Text, 1368 hosts: List[Tuple[int, Text]], 1369 variables: Dict[Text, Dict[Text, tf_variables.Variable]], 1370 table_config: tpu_embedding_v2_utils.TableConfig): 1371 """Load embedding tables to onto TPU for each table and host. 1372 1373 Args: 1374 config: A serialized TPUEmbeddingConfiguration proto. 1375 hosts: A list of CPU devices, on per host. 1376 variables: A dictionary of dictionaries of TPUShardedVariables. First key is 1377 the table name, second key is 'parameters' or the optimizer slot name. 1378 table_config: A list of tf.tpu.experimental.embedding.TableConfig objects. 1379 """ 1380 def select_fn(host_id): 1381 1382 def select_or_zeros(x): 1383 if host_id >= len(x.variables): 1384 # In the edge case where we have more hosts than variables, due to using 1385 # a small number of rows, we load zeros for the later hosts. We copy 1386 # the shape of the first host's variables, which we assume is defined 1387 # because TableConfig guarantees at least one row. 1388 return array_ops.zeros_like(x.variables[0]) 1389 return x.variables[host_id] 1390 1391 return select_or_zeros 1392 1393 for host_id, host in enumerate(hosts): 1394 with ops.device(host): 1395 host_variables = nest.map_structure(select_fn(host_id), variables) 1396 for table in table_config: 1397 table.optimizer._load()( # pylint: disable=protected-access 1398 table_name=table.name, 1399 num_shards=len(hosts), 1400 shard_id=host_id, 1401 config=config, 1402 **host_variables[table.name]) 1403 # Ensure that only the first table/first host gets a config so that we 1404 # don't bloat graph by attaching this large string to each op. 1405 # We have num tables * num hosts of these so for models with a large 1406 # number of tables training on a large slice, this can be an issue. 1407 config = None 1408 1409 1410@def_function.function 1411def _retrieve_variables_impl( 1412 config: Text, 1413 hosts: List[Tuple[int, Text]], 1414 variables: Dict[Text, Dict[Text, tf_variables.Variable]], 1415 table_config: tpu_embedding_v2_utils.TableConfig): 1416 """Retrieve embedding tables from TPU to host memory. 1417 1418 Args: 1419 config: A serialized TPUEmbeddingConfiguration proto. 1420 hosts: A list of all the host CPU devices. 1421 variables: A dictionary of dictionaries of TPUShardedVariables. First key is 1422 the table name, second key is 'parameters' or the optimizer slot name. 1423 table_config: A list of tf.tpu.experimental.embedding.TableConfig objects. 1424 """ 1425 for host_id, host in enumerate(hosts): 1426 with ops.device(host): 1427 for table in table_config: 1428 retrieved = table.optimizer._retrieve()( # pylint: disable=protected-access 1429 table_name=table.name, 1430 num_shards=len(hosts), 1431 shard_id=host_id, 1432 config=config) 1433 # When there are no slot variables (e.g with SGD) this returns a 1434 # single tensor rather than a tuple. In this case we put the tensor in 1435 # a list to make the following code easier to write. 1436 if not isinstance(retrieved, tuple): 1437 retrieved = (retrieved,) 1438 1439 for i, slot in enumerate(["parameters"] + 1440 table.optimizer._slot_names()): # pylint: disable=protected-access 1441 # We must assign the CPU variables the values of tensors that were 1442 # returned from the TPU. 1443 sharded_var = variables[table.name][slot] 1444 if host_id < len(sharded_var.variables): 1445 # In the edge case where we have more hosts than variables, due to 1446 # using a small number of rows, we skip the later hosts. 1447 sharded_var.variables[host_id].assign(retrieved[i]) 1448 # Ensure that only the first table/first host gets a config so that we 1449 # don't bloat graph by attaching this large string to each op. 1450 # We have num tables * num hosts of these so for models with a large 1451 # number of tables training on a large slice, this can be an issue. 1452 config = None 1453 1454 1455class TPUEmbeddingSaveable(saveable_hook.SaveableHook): 1456 """Save/Restore hook to Retrieve/Load TPUEmbedding variables.""" 1457 1458 def __init__( 1459 self, 1460 name: Text, 1461 load: Callable[[], Any], 1462 retrieve: Callable[[], Any]): 1463 self._load = load 1464 self._retrieve = retrieve 1465 super(TPUEmbeddingSaveable, self).__init__(name=name) 1466 1467 def before_save(self): 1468 if self._retrieve is not None: 1469 self._retrieve() 1470 1471 def after_restore(self): 1472 if self._load is not None: 1473 self._load() 1474 1475 1476def _ragged_embedding_lookup_with_reduce( 1477 table: tf_variables.Variable, 1478 ragged: ragged_tensor.RaggedTensor, 1479 weights: ragged_tensor.RaggedTensor, 1480 combiner: Text) -> core.Tensor: 1481 """Compute a ragged lookup followed by a reduce on axis 1. 1482 1483 Args: 1484 table: The embedding table. 1485 ragged: A RaggedTensor of ids to look up. 1486 weights: A RaggedTensor of weights (or None). 1487 combiner: One of "mean", "sum", "sqrtn". 1488 1489 Returns: 1490 A Tensor. 1491 """ 1492 if weights is None: 1493 weights = array_ops.ones_like(ragged, dtype=table.dtype) 1494 weights = array_ops.expand_dims(weights, axis=2) 1495 ragged_result = embedding_ops.embedding_lookup_ragged(table, ragged) 1496 ragged_result = math_ops.reduce_sum(ragged_result * weights, axis=1) 1497 if combiner == "mean": 1498 ragged_result = ragged_result / math_ops.reduce_sum(weights, axis=1) 1499 elif combiner == "sqrtn": 1500 ragged_result = ragged_result, math_ops.sqrt(math_ops.reduce_sum( 1501 weights*weights, axis=1)) 1502 return ragged_result 1503 1504 1505@tf_export("tpu.experimental.embedding.serving_embedding_lookup") 1506def cpu_embedding_lookup(inputs, weights, tables, feature_config): 1507 """Apply standard lookup ops with `tf.tpu.experimental.embedding` configs. 1508 1509 This function is a utility which allows using the 1510 `tf.tpu.experimental.embedding` config objects with standard lookup functions. 1511 This can be used when exporting a model which uses 1512 `tf.tpu.experimental.embedding.TPUEmbedding` for serving on CPU. In particular 1513 `tf.tpu.experimental.embedding.TPUEmbedding` only supports lookups on TPUs and 1514 should not be part of your serving graph. 1515 1516 Note that TPU specific options (such as `max_sequence_length`) in the 1517 configuration objects will be ignored. 1518 1519 In the following example we take a trained model (see the documentation for 1520 `tf.tpu.experimental.embedding.TPUEmbedding` for the context) and create a 1521 saved model with a serving function that will perform the embedding lookup and 1522 pass the results to your model: 1523 1524 ```python 1525 model = model_fn(...) 1526 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 1527 feature_config=feature_config, 1528 batch_size=1024, 1529 optimizer=tf.tpu.experimental.embedding.SGD(0.1)) 1530 checkpoint = tf.train.Checkpoint(model=model, embedding=embedding) 1531 checkpoint.restore(...) 1532 1533 @tf.function(input_signature=[{'feature_one': tf.TensorSpec(...), 1534 'feature_two': tf.TensorSpec(...), 1535 'feature_three': tf.TensorSpec(...)}]) 1536 def serve_tensors(embedding_featurese): 1537 embedded_features = tf.tpu.experimental.embedding.serving_embedding_lookup( 1538 embedding_features, None, embedding.embedding_tables, 1539 feature_config) 1540 return model(embedded_features) 1541 1542 model.embedding_api = embedding 1543 tf.saved_model.save(model, 1544 export_dir=..., 1545 signatures={'serving_default': serve_tensors}) 1546 1547 ``` 1548 1549 NOTE: Its important to assign the embedding api object to a member of your 1550 model as `tf.saved_model.save` only supports saving variables one `Trackable` 1551 object. Since the model's weights are in `model` and the embedding table are 1552 managed by `embedding`, we assign `embedding` to and attribute of `model` so 1553 that tf.saved_model.save can find the embedding variables. 1554 1555 NOTE: The same `serve_tensors` function and `tf.saved_model.save` call will 1556 work directly from training. 1557 1558 Args: 1559 inputs: a nested structure of Tensors, SparseTensors or RaggedTensors. 1560 weights: a nested structure of Tensors, SparseTensors or RaggedTensors or 1561 None for no weights. If not None, structure must match that of inputs, but 1562 entries are allowed to be None. 1563 tables: a dict of mapping TableConfig objects to Variables. 1564 feature_config: a nested structure of FeatureConfig objects with the same 1565 structure as inputs. 1566 1567 Returns: 1568 A nested structure of Tensors with the same structure as inputs. 1569 """ 1570 1571 nest.assert_same_structure(inputs, feature_config) 1572 1573 flat_inputs = nest.flatten(inputs) 1574 flat_weights = [None] * len(flat_inputs) 1575 if weights is not None: 1576 nest.assert_same_structure(inputs, weights) 1577 flat_weights = nest.flatten(weights) 1578 flat_features = nest.flatten_with_joined_string_paths(feature_config) 1579 1580 outputs = [] 1581 for inp, weight, (path, feature) in zip( 1582 flat_inputs, flat_weights, flat_features): 1583 table = tables[feature.table] 1584 1585 if weight is not None: 1586 if isinstance(inp, ops.Tensor): 1587 raise ValueError( 1588 "Weight specified for {}, but input is dense.".format(path)) 1589 elif type(weight) is not type(inp): 1590 raise ValueError( 1591 "Weight for {} is of type {} but it does not match type of the " 1592 "input which is {}.".format(path, type(weight), type(inp))) 1593 elif feature.max_sequence_length > 0: 1594 raise ValueError("Weight specified for {}, but this is a sequence " 1595 "feature.".format(path)) 1596 1597 if isinstance(inp, ops.Tensor): 1598 if feature.max_sequence_length > 0: 1599 raise ValueError("Feature {} is a sequence feature but a dense tensor " 1600 "was passed.".format(path)) 1601 outputs.append(embedding_ops.embedding_lookup_v2(table, inp)) 1602 1603 elif isinstance(inp, sparse_tensor.SparseTensor): 1604 if feature.max_sequence_length > 0: 1605 batch_size = math_ops.cast(array_ops.shape(inp)[0], dtype=dtypes.int64) 1606 sparse_shape = array_ops.concat( 1607 [batch_size, feature.max_sequence_length], axis=0) 1608 # TPU Embedding truncates sequences to max_sequence_length, and if we 1609 # don't truncate, scatter_nd will error out if the index was out of 1610 # bounds. 1611 truncated_inp = sparse_ops.sparse_slice(inp, start=[0, 0], 1612 size=sparse_shape) 1613 1614 dense_output_shape = array_ops.concat( 1615 [batch_size, feature.max_sequence_length, feature.table.dim], 1616 axis=0) 1617 outputs.append( 1618 array_ops.scatter_nd( 1619 inp.indices, array_ops.gather(table, truncated_inp.values), 1620 dense_output_shape)) 1621 else: 1622 outputs.append(embedding_ops.safe_embedding_lookup_sparse_v2( 1623 table, inp, sparse_weights=weight, combiner=feature.table.combiner)) 1624 1625 elif isinstance(inp, ragged_tensor.RaggedTensor): 1626 if feature.max_sequence_length > 0: 1627 batch_size = inp.shape[0] 1628 dense_output_shape = [ 1629 batch_size, feature.max_sequence_length, feature.table.dim] 1630 ragged_lookup = embedding_ops.embedding_lookup_v2(table, inp) 1631 # Unlike scatter_nd, RaggedTensor.to_tensor truncates to the given 1632 # shape. 1633 outputs.append(ragged_lookup.to_tensor(shape=dense_output_shape)) 1634 else: 1635 outputs.append(_ragged_embedding_lookup_with_reduce( 1636 table, inp, weight, feature.table.combiner)) 1637 1638 else: 1639 raise ValueError("Input {} is type {}. Tensor, SparseTensor or " 1640 "RaggedTensor expected.".format(path, type(inp))) 1641 return nest.pack_sequence_as(feature_config, outputs) 1642 1643 1644def get_list_of_hosts(strategy: tpu_strategy.TPUStrategy) -> List[Text]: 1645 """Returns a sorted list of CPU devices for the remote jobs. 1646 1647 Args: 1648 strategy: A TPUStrategy object. 1649 1650 Returns: 1651 A sort list of device strings. 1652 """ 1653 list_of_hosts = [] 1654 # Assume this is sorted by task 1655 for tpu_device in strategy.extended.worker_devices: 1656 host = device_util.get_host_for_device(tpu_device) 1657 if host not in list_of_hosts: 1658 list_of_hosts.append(host) 1659 assert len(list_of_hosts) == strategy.extended.num_hosts 1660 return list_of_hosts 1661 1662 1663def extract_variable_info( 1664 kwargs) -> Tuple[Text, Tuple[int, ...], dtypes.DType, Callable[[], Any]]: 1665 """Extracts the variable creation attributes from the kwargs. 1666 1667 Args: 1668 kwargs: a dict of keyword arguments that were passed to a variable creator 1669 scope. 1670 1671 Returns: 1672 A tuple of variable name, shape, dtype, initialization function. 1673 """ 1674 if (isinstance(kwargs["initial_value"], functools.partial) and ( 1675 "shape" in kwargs["initial_value"].keywords or 1676 kwargs["initial_value"].args)): 1677 # Sometimes shape is passed positionally, sometimes it's passed as a kwarg. 1678 if "shape" in kwargs["initial_value"].keywords: 1679 shape = kwargs["initial_value"].keywords["shape"] 1680 else: 1681 shape = kwargs["initial_value"].args[0] 1682 return (kwargs["name"], shape, 1683 kwargs["initial_value"].keywords.get("dtype", kwargs["dtype"]), 1684 kwargs["initial_value"].func) 1685 elif "shape" not in kwargs or kwargs["shape"] is None or not callable( 1686 kwargs["initial_value"]): 1687 raise ValueError( 1688 "Unable to extract initializer function and shape from {}. Please " 1689 "either pass a function that expects a shape and dtype as the " 1690 "initial value for your variable or functools.partial object with " 1691 "the shape and dtype kwargs set. This is needed so that we can " 1692 "initialize the shards of the ShardedVariable locally.".format( 1693 kwargs["initial_value"])) 1694 else: 1695 return (kwargs["name"], kwargs["shape"], kwargs["dtype"], 1696 kwargs["initial_value"]) 1697 1698 1699def make_sharded_variable_creator( 1700 hosts: List[Text]) -> Callable[..., TPUShardedVariable]: 1701 """Makes a sharded variable creator given a list of hosts. 1702 1703 Args: 1704 hosts: a list of tensorflow devices on which to shard the tensors. 1705 1706 Returns: 1707 A variable creator function. 1708 """ 1709 1710 def sharded_variable_creator( 1711 next_creator: Callable[..., tf_variables.Variable], *args, **kwargs): 1712 """The sharded variable creator.""" 1713 kwargs["skip_mirrored_creator"] = True 1714 1715 num_hosts = len(hosts) 1716 name, shape, dtype, unwrapped_initial_value = extract_variable_info(kwargs) 1717 initial_value = kwargs["initial_value"] 1718 rows = shape[0] 1719 cols = shape[1] 1720 partial_partition = rows % num_hosts 1721 full_rows_per_host = rows // num_hosts 1722 # We partition as if we were using MOD sharding: at least 1723 # `full_rows_per_host` rows to `num_hosts` hosts, where the first 1724 # `partial_partition` hosts get an additional row when the number of rows 1725 # is not cleanly divisible. Note that `full_rows_per_host` may be zero. 1726 partitions = ( 1727 [full_rows_per_host + 1] * partial_partition 1728 + [full_rows_per_host] * (num_hosts - partial_partition)) 1729 variables = [] 1730 sharding_aware = "shard_info" in tf_inspect.getargspec(initial_value).args 1731 1732 # Keep track of offset for sharding aware initializers. 1733 offset = 0 1734 kwargs["dtype"] = dtype 1735 for i, p in enumerate(partitions): 1736 if p == 0: 1737 # Skip variable creation for empty partitions, resulting from the edge 1738 # case of 'rows < num_hosts'. This is safe because both load/restore 1739 # can handle the missing values. 1740 continue 1741 with ops.device(hosts[i]): 1742 kwargs["name"] = "{}_{}".format(name, i) 1743 kwargs["shape"] = (p, cols) 1744 if sharding_aware: 1745 shard_info = base.ShardInfo(kwargs["shape"], (offset, 0)) 1746 kwargs["initial_value"] = functools.partial( 1747 initial_value, shard_info=shard_info) 1748 offset += p 1749 else: 1750 kwargs["initial_value"] = functools.partial( 1751 unwrapped_initial_value, kwargs["shape"], dtype=dtype) 1752 variables.append(next_creator(*args, **kwargs)) 1753 return TPUShardedVariable(variables, name=name) 1754 return sharded_variable_creator 1755