1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TPU embedding APIs.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import copy 23import math 24import re 25 26import six 27 28from tensorflow.core.protobuf.tpu import optimization_parameters_pb2 29from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as elc 30from tensorflow.python.eager import context 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import ops 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import init_ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import partitioned_variables 38from tensorflow.python.ops import state_ops 39from tensorflow.python.ops import variable_scope 40from tensorflow.python.platform import tf_logging as logging 41from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib 42from tensorflow.python.tpu.ops import tpu_ops 43from tensorflow.python.util.tf_export import tf_export 44 45TRAINING = elc.TPUEmbeddingConfiguration.TRAINING 46INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE 47 48 49# TODO(shizhiw): a more future-proof way is to have optimization_parameter such 50# as AdagradParameters etc instead of learning_rate. 51class TableConfig( 52 collections.namedtuple('TableConfig', [ 53 'vocabulary_size', 'dimension', 'initializer', 'combiner', 54 'hot_id_replication', 'learning_rate', 'learning_rate_fn' 55 ])): 56 """Embedding table configuration.""" 57 58 def __new__(cls, 59 vocabulary_size, 60 dimension, 61 initializer=None, 62 combiner='mean', 63 hot_id_replication=False, 64 learning_rate=None, 65 learning_rate_fn=None): 66 """Embedding table configuration. 67 68 Args: 69 vocabulary_size: Number of vocabulary (/rows) in the table. 70 dimension: The embedding dimension. 71 initializer: A variable initializer function to be used in embedding 72 variable initialization. If not specified, defaults to 73 `tf.compat.v1.truncated_normal_initializer` with mean `0.0` and standard 74 deviation `1/sqrt(dimension)`. 75 combiner: A string specifying how to reduce if there are multiple entries 76 in a single row. Currently 'mean', 'sqrtn', 'sum' and None are 77 supported, with 'mean' the default. 'sqrtn' often achieves good 78 accuracy, in particular with bag-of-words columns. For more information, 79 see `tf.nn.embedding_lookup_sparse`. None is only valid for dense rather 80 than sparse tensors. 81 hot_id_replication: If true, enables hot id replication, which can make 82 embedding lookups faster if there are some hot rows in the table. 83 learning_rate: float, static learning rate for this table. If 84 learning_rate and learning_rate_fn are both `None`, global 85 static learning rate as specified in `optimization_parameters` in 86 `TPUEmbedding` constructor will be used. `learning_rate_fn` must be 87 `None` if `learning_rate` is not `None. 88 learning_rate_fn: string, use dynamic learning rate given by the function. 89 This function function will be passed the current global step. If 90 learning_rate and learning_rate_fn are both `None`, global static 91 learning rate as specified in `optimization_parameters` in 92 `TPUEmbedding` constructor will be used. `learning_rate` must be `None` 93 if `learning_rate_fn` is not `None. 94 95 Returns: 96 `TableConfig`. 97 98 Raises: 99 ValueError: if `vocabulary_size` is not positive integer. 100 ValueError: if `dimension` is not positive integer. 101 ValueError: if `initializer` is specified and is not callable. 102 ValueError: if `combiner` is not supported. 103 ValueError: if `learning_rate` and `learning_rate_fn` are both not 104 `None`. 105 """ 106 if not isinstance(vocabulary_size, int) or vocabulary_size < 1: 107 raise ValueError('Invalid vocabulary_size {}.'.format(vocabulary_size)) 108 109 if not isinstance(dimension, int) or dimension < 1: 110 raise ValueError('Invalid dimension {}.'.format(dimension)) 111 112 if (initializer is not None) and (not callable(initializer)): 113 raise ValueError('initializer must be callable if specified.') 114 if initializer is None: 115 initializer = init_ops.truncated_normal_initializer( 116 mean=0.0, stddev=1 / math.sqrt(dimension)) 117 118 if combiner not in ('mean', 'sum', 'sqrtn', None): 119 raise ValueError('Invalid combiner {}'.format(combiner)) 120 121 if learning_rate is not None and learning_rate_fn is not None: 122 raise ValueError('At most one of learning_rate and learning_rate_fn ' 123 'can be None; got {} and {}' 124 .format(learning_rate, learning_rate_fn)) 125 126 return super(TableConfig, cls).__new__( 127 cls, vocabulary_size, dimension, initializer, combiner, 128 hot_id_replication, learning_rate, learning_rate_fn) 129 130 131class FeatureConfig( 132 collections.namedtuple( 133 'FeatureConfig', 134 ['table_id', 'max_sequence_length', 'weight_key'])): 135 """Feature configuration.""" 136 137 def __new__(cls, 138 table_id, 139 max_sequence_length=0, 140 weight_key=None): 141 """Feature configuration. 142 143 Args: 144 table_id: Which table the feature is uses for embedding lookups. 145 max_sequence_length: If positive, the feature is a sequence feature with 146 the corresponding maximum sequence length. If the sequence is longer 147 than this, it will be truncated. If 0, the feature is not a sequence 148 feature. 149 weight_key: If using weights for the combiner, this key specifies which 150 input feature contains the weights. 151 152 Returns: 153 `FeatureConfig`. 154 155 Raises: 156 ValueError: if `max_sequence_length` non-negative. 157 """ 158 if not isinstance(max_sequence_length, int) or max_sequence_length < 0: 159 raise ValueError('Invalid max_sequence_length {}.'.format( 160 max_sequence_length)) 161 162 return super(FeatureConfig, cls).__new__(cls, table_id, max_sequence_length, 163 weight_key) 164 165 166class EnqueueData( 167 collections.namedtuple( 168 'EnqueueData', 169 ['embedding_indices', 'sample_indices', 'aggregation_weights'])): 170 """Data to be enqueued through generate_enqueue_ops().""" 171 172 def __new__(cls, 173 embedding_indices, 174 sample_indices=None, 175 aggregation_weights=None): 176 """Data to be enqueued through generate_enqueue_ops(). 177 178 Args: 179 embedding_indices: A rank 1 Tensors, indices into the embedding tables. It 180 corresponds to sp_ids.values in embedding_lookup_sparse(). Both int32 181 and int64 are allowed and will be converted to int32 internally. 182 sample_indices: A rank 2 Tensors specifying the training example to which 183 the corresponding embedding_indices and aggregation_weights values 184 belong. It corresponds to sp_ids.indices in embedding_lookup_sparse(). 185 If it is None, we assume each embedding_indices belongs to a different 186 sample. Both int32 and int64 are allowed and will be converted to int32 187 internally. 188 aggregation_weights: A rank 1 Tensors containing aggregation weights. 189 It corresponds to sp_weights.values in embedding_lookup_sparse(). If it 190 is None, we assume all weights are 1. Both float32 and float64 are 191 allowed and will be converted to float32 internally. 192 193 Returns: 194 An EnqueueData tuple. 195 196 """ 197 return super(EnqueueData, cls).__new__(cls, embedding_indices, 198 sample_indices, aggregation_weights) 199 200 @staticmethod 201 def from_sparse_tensor(sp_tensor, weights=None): 202 return EnqueueData( 203 sp_tensor.values, 204 sp_tensor.indices, 205 aggregation_weights=weights.values if weights is not None else None) 206 207 208def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list): 209 """Convenient function for generate_enqueue_ops(). 210 211 Args: 212 sp_tensors_list: a list of dictionary mapping from string of feature names 213 to SparseTensor. Each dictionary is for one TPU core. Dictionaries for the 214 same host should be contiguous on the list. 215 216 Returns: 217 enqueue_datas_list: a list of dictionary mapping from string 218 of feature names to EnqueueData. Each dictionary is for one 219 TPU core. Dictionaries for the same host should be contiguous 220 on the list. 221 222 """ 223 enqueue_datas_list = [] 224 for sp_tensors in sp_tensors_list: 225 enqueue_datas = collections.OrderedDict( 226 (k, EnqueueData.from_sparse_tensor(v)) 227 for k, v in six.iteritems(sp_tensors)) 228 enqueue_datas_list.append(enqueue_datas) 229 return enqueue_datas_list 230 231 232AdamSlotVariableNames = collections.namedtuple( 233 'AdamSlotVariableNames', ['m', 'v']) 234 235AdagradSlotVariableName = collections.namedtuple( 236 'AdagradSlotVariableName', ['accumulator']) 237 238FtrlSlotVariableName = collections.namedtuple( 239 'FtrlSlotVariableName', ['accumulator', 'linear']) 240 241AdamSlotVariables = collections.namedtuple( 242 'AdamSlotVariables', ['m', 'v']) 243 244AdagradSlotVariable = collections.namedtuple( 245 'AdagradSlotVariable', ['accumulator']) 246 247FtrlSlotVariable = collections.namedtuple( 248 'FtrlSlotVariable', ['accumulator', 'linear']) 249 250VariablesAndOps = collections.namedtuple( 251 'VariablesAndOps', 252 ['embedding_variables_by_table', 'slot_variables_by_table', 253 'load_ops', 'retrieve_ops'] 254) 255 256 257class _OptimizationParameters(object): 258 """Parameters common to all optimizations.""" 259 260 def __init__(self, learning_rate, use_gradient_accumulation, 261 clip_weight_min, clip_weight_max): 262 self.learning_rate = learning_rate 263 self.use_gradient_accumulation = use_gradient_accumulation 264 self.clip_weight_min = clip_weight_min 265 self.clip_weight_max = clip_weight_max 266 267 268@tf_export(v1=['tpu.experimental.AdagradParameters']) 269class AdagradParameters(_OptimizationParameters): 270 """Optimization parameters for Adagrad with TPU embeddings. 271 272 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 273 `optimization_parameters` argument to set the optimizer and its parameters. 274 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 275 for more details. 276 277 ``` 278 estimator = tf.estimator.tpu.TPUEstimator( 279 ... 280 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 281 ... 282 optimization_parameters=tf.tpu.experimental.AdagradParameters(0.1), 283 ...)) 284 ``` 285 286 """ 287 288 def __init__(self, 289 learning_rate, 290 initial_accumulator=0.1, 291 use_gradient_accumulation=True, 292 clip_weight_min=None, 293 clip_weight_max=None): 294 """Optimization parameters for Adagrad. 295 296 Args: 297 learning_rate: used for updating embedding table. 298 initial_accumulator: initial accumulator for Adagrad. 299 use_gradient_accumulation: setting this to `False` makes embedding 300 gradients calculation less accurate but faster. Please see 301 `optimization_parameters.proto` for details. 302 for details. 303 clip_weight_min: the minimum value to clip by; None means -infinity. 304 clip_weight_max: the maximum value to clip by; None means +infinity. 305 """ 306 super(AdagradParameters, 307 self).__init__(learning_rate, use_gradient_accumulation, 308 clip_weight_min, clip_weight_max) 309 if initial_accumulator <= 0: 310 raise ValueError('Adagrad initial_accumulator must be positive') 311 self.initial_accumulator = initial_accumulator 312 313 314@tf_export(v1=['tpu.experimental.AdamParameters']) 315class AdamParameters(_OptimizationParameters): 316 """Optimization parameters for Adam with TPU embeddings. 317 318 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 319 `optimization_parameters` argument to set the optimizer and its parameters. 320 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 321 for more details. 322 323 ``` 324 estimator = tf.estimator.tpu.TPUEstimator( 325 ... 326 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 327 ... 328 optimization_parameters=tf.tpu.experimental.AdamParameters(0.1), 329 ...)) 330 ``` 331 332 """ 333 334 def __init__(self, 335 learning_rate, 336 beta1=0.9, 337 beta2=0.999, 338 epsilon=1e-08, 339 lazy_adam=True, 340 sum_inside_sqrt=True, 341 use_gradient_accumulation=True, 342 clip_weight_min=None, 343 clip_weight_max=None): 344 """Optimization parameters for Adam. 345 346 Args: 347 learning_rate: a floating point value. The learning rate. 348 beta1: A float value. 349 The exponential decay rate for the 1st moment estimates. 350 beta2: A float value. 351 The exponential decay rate for the 2nd moment estimates. 352 epsilon: A small constant for numerical stability. 353 lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. 354 Please see `optimization_parameters.proto` for details. 355 sum_inside_sqrt: This improves training speed. Please see 356 `optimization_parameters.proto` for details. 357 use_gradient_accumulation: setting this to `False` makes embedding 358 gradients calculation less accurate but faster. Please see 359 `optimization_parameters.proto` for details. 360 for details. 361 clip_weight_min: the minimum value to clip by; None means -infinity. 362 clip_weight_max: the maximum value to clip by; None means +infinity. 363 """ 364 super(AdamParameters, 365 self).__init__(learning_rate, use_gradient_accumulation, 366 clip_weight_min, clip_weight_max) 367 if beta1 < 0. or beta1 >= 1.: 368 raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1)) 369 if beta2 < 0. or beta2 >= 1.: 370 raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2)) 371 if epsilon <= 0.: 372 raise ValueError('epsilon must be positive; got {}.'.format(epsilon)) 373 if not use_gradient_accumulation and not lazy_adam: 374 raise ValueError( 375 'When disabling Lazy Adam, gradient accumulation must be used.') 376 377 self.beta1 = beta1 378 self.beta2 = beta2 379 self.epsilon = epsilon 380 self.lazy_adam = lazy_adam 381 self.sum_inside_sqrt = sum_inside_sqrt 382 383 384@tf_export(v1=['tpu.experimental.FtrlParameters']) 385class FtrlParameters(_OptimizationParameters): 386 """Optimization parameters for Ftrl with TPU embeddings. 387 388 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 389 `optimization_parameters` argument to set the optimizer and its parameters. 390 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 391 for more details. 392 393 ``` 394 estimator = tf.estimator.tpu.TPUEstimator( 395 ... 396 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 397 ... 398 optimization_parameters=tf.tpu.experimental.FtrlParameters(0.1), 399 ...)) 400 ``` 401 402 """ 403 404 def __init__(self, 405 learning_rate, 406 learning_rate_power=-0.5, 407 initial_accumulator_value=0.1, 408 l1_regularization_strength=0.0, 409 l2_regularization_strength=0.0, 410 use_gradient_accumulation=True, 411 clip_weight_min=None, 412 clip_weight_max=None): 413 """Optimization parameters for Ftrl. 414 415 Args: 416 learning_rate: a floating point value. The learning rate. 417 learning_rate_power: A float value, must be less or equal to zero. 418 Controls how the learning rate decreases during training. Use zero for 419 a fixed learning rate. See section 3.1 in the 420 [paper](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf). 421 initial_accumulator_value: The starting value for accumulators. 422 Only zero or positive values are allowed. 423 l1_regularization_strength: A float value, must be greater than or 424 equal to zero. 425 l2_regularization_strength: A float value, must be greater than or 426 equal to zero. 427 use_gradient_accumulation: setting this to `False` makes embedding 428 gradients calculation less accurate but faster. Please see 429 `optimization_parameters.proto` for details. 430 for details. 431 clip_weight_min: the minimum value to clip by; None means -infinity. 432 clip_weight_max: the maximum value to clip by; None means +infinity. 433 """ 434 super(FtrlParameters, 435 self).__init__(learning_rate, use_gradient_accumulation, 436 clip_weight_min, clip_weight_max) 437 if learning_rate_power > 0.: 438 raise ValueError('learning_rate_power must be less than or equal to 0. ' 439 'got {}.'.format(learning_rate_power)) 440 441 if initial_accumulator_value < 0.: 442 raise ValueError('initial_accumulator_value must be greater than or equal' 443 ' to 0. got {}.'.format(initial_accumulator_value)) 444 445 if l1_regularization_strength < 0.: 446 raise ValueError('l1_regularization_strength must be greater than or ' 447 'equal to 0. got {}.'.format(l1_regularization_strength)) 448 449 if l2_regularization_strength < 0.: 450 raise ValueError('l2_regularization_strength must be greater than or ' 451 'equal to 0. got {}.'.format(l2_regularization_strength)) 452 453 self.learning_rate_power = learning_rate_power 454 self.initial_accumulator_value = initial_accumulator_value 455 self.initial_linear_value = 0.0 456 self.l1_regularization_strength = l1_regularization_strength 457 self.l2_regularization_strength = l2_regularization_strength 458 459 460@tf_export(v1=['tpu.experimental.StochasticGradientDescentParameters']) 461class StochasticGradientDescentParameters(_OptimizationParameters): 462 """Optimization parameters for stochastic gradient descent for TPU embeddings. 463 464 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 465 `optimization_parameters` argument to set the optimizer and its parameters. 466 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 467 for more details. 468 469 ``` 470 estimator = tf.estimator.tpu.TPUEstimator( 471 ... 472 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 473 ... 474 optimization_parameters=( 475 tf.tpu.experimental.StochasticGradientDescentParameters(0.1)))) 476 ``` 477 478 """ 479 480 def __init__(self, learning_rate, clip_weight_min=None, 481 clip_weight_max=None): 482 """Optimization parameters for stochastic gradient descent. 483 484 Args: 485 learning_rate: a floating point value. The learning rate. 486 clip_weight_min: the minimum value to clip by; None means -infinity. 487 clip_weight_max: the maximum value to clip by; None means +infinity. 488 """ 489 super(StochasticGradientDescentParameters, 490 self).__init__(learning_rate, False, clip_weight_min, clip_weight_max) 491 492 493DeviceConfig = collections.namedtuple('DeviceConfig', 494 ['num_hosts', 'num_cores', 'job_name']) 495 496 497class TPUEmbedding(object): 498 """API for using TPU for embedding. 499 500 Example: 501 ``` 502 table_config_user = tpu_embedding.TableConfig( 503 vocabulary_size=4, dimension=2, 504 initializer=initializer, combiner='mean') 505 table_to_config_dict = {'video': table_config_video, 506 'user': table_config_user} 507 feature_to_config_dict = {'watched': tpu_embedding.FeatureConfig('video'), 508 'favorited': tpu_embedding.FeatureConfig('video'), 509 'friends': tpu_embedding.FeatureConfig('user')} 510 batch_size = 4 511 num_hosts = 1 512 optimization_parameters = tpu_embedding.AdagradParameters(1., 1.) 513 mode = tpu_embedding.TRAINING 514 embedding = tpu_embedding.TPUEmbedding( 515 table_to_config_dict, feature_to_config_dict, 516 batch_size, num_hosts, mode, optimization_parameters) 517 518 batch_size_per_core = embedding.batch_size_per_core 519 sparse_features_list = [] 520 for host in hosts: 521 with ops.device(host): 522 for _ in range(embedding.num_cores_per_host): 523 sparse_features = {} 524 sparse_features['watched'] = sparse_tensor.SparseTensor(...) 525 sparse_features['favorited'] = sparse_tensor.SparseTensor(...) 526 sparse_features['friends'] = sparse_tensor.SparseTensor(...) 527 sparse_features_list.append(sparse_features) 528 529 enqueue_ops = embedding.generate_enqueue_ops(sparse_features_list) 530 embedding_variables_and_ops = embedding.create_variables_and_ops() 531 532 def computation(): 533 activations = embedding.get_activations() 534 loss = compute_loss(activations) 535 536 base_optimizer = gradient_descent.GradientDescentOptimizer( 537 learning_rate=1) 538 cross_shard_optimizer = tpu_optimizer.CrossShardOptimizer( 539 base_optimizer) 540 541 train_op = cross_shard_optimizer.minimize(loss) 542 gradients = ( 543 tpu_embedding_gradient.get_gradients_through_compute_gradients( 544 cross_shard_optimizer, loss, activations) 545 send_gradients_op = embedding.generate_send_gradients_op(gradients) 546 with ops.control_dependencies([train_op, send_gradients_op]): 547 loss = array_ops.identity(loss) 548 549 loss = tpu.shard(computation, 550 num_shards=embedding.num_cores) 551 552 with self.test_session() as sess: 553 sess.run(tpu.initialize_system(embedding_config= 554 embedding.config_proto)) 555 sess.run(variables.global_variables_initializer()) 556 sess.run(embedding_variables_and_ops.load_ops()) 557 sess.run(enqueue_ops) 558 loss_val = sess.run(loss) 559 ``` 560 """ 561 562 # TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that 563 # the feature should not be used to update embedding table (cr/204852758, 564 # cr/204940540). Also, this can support different combiners for different 565 # features within the same table. 566 # TODO(shizhiw, b/118512626): Remove `batch_size` from `__init__` and move it 567 # to `FeatureConfig`? 568 569 # TODO(shizhiw): will it be cleaner to make `table_to_config_dict` and 570 # `feature_to_config_dict` lists of `TableSpec` and `FeatureSpec` 571 # respectively? 572 573 # TODO(shizhiw): Consider adding `input_fn` as an option to remove boilerplate 574 # for-loops around construction of inputs. 575 576 # `optimization_parameter` applies to all tables. If the need arises, 577 # we can add `optimization_parameters` to `TableConfig` to override this 578 # global setting. 579 def __init__(self, 580 table_to_config_dict, 581 feature_to_config_dict, 582 batch_size, 583 mode, 584 master=None, 585 optimization_parameters=None, 586 cluster_def=None, 587 pipeline_execution_with_tensor_core=False, 588 partition_strategy='div', 589 device_config=None, 590 master_job_name=None): 591 """API for using TPU for embedding lookups. 592 593 Args: 594 table_to_config_dict: A dictionary mapping from string of table name to 595 `TableConfig`. Table refers to an embedding table, e.g. `params` 596 argument to `tf.nn.embedding_lookup_sparse()`. 597 feature_to_config_dict: A dictionary mapping from string of feature name 598 to `FeatureConfig`. Feature refers to ids to lookup in embedding table, 599 e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`. 600 batch_size: An `int` representing the global batch size. 601 mode: `TRAINING` or `INFERENCE`. 602 master: A `string` representing the TensorFlow master to use. 603 optimization_parameters: `AdagradParameters`, `AdamParameters`, 604 `Stochasticgradientdescentparameters`. Must be set in training and must 605 be `None` in inference. 606 cluster_def: A ClusterDef object describing the TPU cluster. 607 pipeline_execution_with_tensor_core: setting this to `True` makes training 608 faster, but trained model will be different if step N and step N+1 609 involve the same set of embedding IDs. Please see 610 `tpu_embedding_configuration.proto` for details. 611 partition_strategy: A string, either 'mod' or 'div', specifying how to map 612 the lookup id to the embedding tensor. For more information see 613 `tf.nn.embedding_lookup_sparse`. 614 device_config: A DeviceConfig instance, used when `master` and 615 `cluster_def` are both `None`. 616 master_job_name: if set, overrides the master job name used to schedule 617 embedding ops. 618 619 Raises: 620 ValueError: if any input is invalid. 621 """ 622 if partition_strategy not in ('div', 'mod'): 623 raise ValueError( 624 'Invalid partition_strategy {}'.format(partition_strategy)) 625 self._partition_strategy = partition_strategy 626 627 _validate_table_to_config_dict(table_to_config_dict) 628 # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`. 629 self._table_to_config_dict = _create_ordered_dict(table_to_config_dict) 630 631 _validate_feature_to_config_dict(table_to_config_dict, 632 feature_to_config_dict) 633 self._feature_to_config_dict = _create_ordered_dict(feature_to_config_dict) 634 self._table_to_features_dict, self._table_to_num_features_dict = ( 635 _create_table_to_features_and_num_features_dicts( 636 self._feature_to_config_dict)) 637 self._combiners = _create_combiners(self._table_to_config_dict, 638 self._table_to_features_dict) 639 640 self._batch_size = batch_size 641 642 if master is None and cluster_def is None: 643 if device_config is None: 644 raise ValueError('When master and cluster_def are both None,' 645 'device_config must be set but is not.') 646 if device_config.num_cores % device_config.num_hosts: 647 raise ValueError('num_hosts ({}) should divide num_cores ({}) ' 648 'but does not.'.format(device_config.num_cores, 649 device_config.num_hosts)) 650 self._num_hosts = device_config.num_hosts 651 self._num_cores = device_config.num_cores 652 self._num_cores_per_host = self._num_cores // self._num_hosts 653 self._hosts = [ 654 '{}/replica:0/task:{}/device:CPU:0'.format(device_config.job_name, i) 655 for i in range(self._num_hosts) 656 ] 657 else: 658 tpu_system_metadata = ( 659 tpu_system_metadata_lib._query_tpu_system_metadata( # pylint: disable=protected-access 660 master, 661 cluster_def=cluster_def)) 662 if tpu_system_metadata.num_cores == 0: 663 raise ValueError('TPUEmbedding needs TPUs, but master {} does not have ' 664 'TPUs.'.format(master)) 665 self._num_hosts = tpu_system_metadata.num_hosts 666 if master_job_name is None: 667 try: 668 master_job_name = tpu_system_metadata_lib.master_job(master, 669 cluster_def) 670 except ValueError as e: 671 raise ValueError(str(e) + ' Please specify a master_job_name.') 672 self._hosts = [] 673 for device in tpu_system_metadata.devices: 674 if 'device:CPU:' in device.name and ( 675 master_job_name is None or master_job_name in device.name): 676 self._hosts.append(device.name) 677 self._num_cores_per_host = tpu_system_metadata.num_of_cores_per_host 678 self._num_cores = tpu_system_metadata.num_cores 679 680 _validate_batch_size(self._batch_size, self._num_cores) 681 self._batch_size_per_core = self._batch_size // self._num_cores 682 683 # TODO(shizhiw): remove `mode`? 684 if mode == TRAINING: 685 _validate_optimization_parameters(optimization_parameters) 686 self._optimization_parameters = optimization_parameters 687 elif mode == INFERENCE: 688 if optimization_parameters is not None: 689 raise ValueError('`optimization_parameters` should be `None` ' 690 'for inference mode.') 691 self._optimization_parameters = ( 692 StochasticGradientDescentParameters(1.)) 693 else: 694 raise ValueError('`mode` only supports {} and {}; got {}.' 695 .format(TRAINING, INFERENCE, mode)) 696 self._mode = mode 697 698 # TODO(shizhiw): move `optimization_parameters` into `_optimizer_handler` 699 # and create special handler for inference that inherits from 700 # StochasticGradientDescentHandler with more user-friendly error message 701 # on get_slot(). 702 self._optimizer_handler = _get_optimization_handler( 703 self._optimization_parameters) 704 self._pipeline_execution_with_tensor_core = ( 705 pipeline_execution_with_tensor_core) 706 self._learning_rate_fn = list(set( 707 c.learning_rate_fn for c in self._table_to_config_dict.values() 708 if c.learning_rate_fn is not None)) 709 self._learning_rate_fn_to_tag = { 710 fn: id for id, fn in enumerate(self._learning_rate_fn)} 711 712 self._config_proto = self._create_config_proto() 713 714 @property 715 def hosts(self): 716 """A list of device names for CPU hosts. 717 718 Returns: 719 A list of device names for CPU hosts. 720 """ 721 return copy.copy(self._hosts) 722 723 # TODO(shizhiw): change to num_tensor_cores_per_host to be more explicit and 724 # to be consistent with `tpu_embedding_configuration.proto`. 725 @property 726 def num_cores_per_host(self): 727 """Number of TPU cores on a CPU host. 728 729 Returns: 730 Number of TPU cores on a CPU host. 731 """ 732 return self._num_cores_per_host 733 734 @property 735 def num_cores(self): 736 """Total number of TPU cores on all hosts. 737 738 Returns: 739 Total number of TPU cores on all hosts. 740 """ 741 return self._num_cores 742 743 @property 744 def batch_size_per_core(self): 745 """Batch size for each TPU core. 746 747 The sparse tensors in `sparse_features_list` to `generate_enqueue_ops` 748 must have batch dimension equal to this. 749 750 Returns: 751 Batch size for each TPU core. 752 """ 753 return self._batch_size_per_core 754 755 @property 756 def config_proto(self): 757 """Create embedding config proto for `tpu.initialize_system()`. 758 759 Returns: 760 an `TPUEmbeddingConfiguration` proto describing the desired 761 configuration of the hardware embedding lookup tables, which 762 is passed to `tpu.initialize_system()`. 763 """ 764 return self._config_proto 765 766 @property 767 def table_to_config_dict(self): 768 return copy.copy(self._table_to_config_dict) 769 770 @property 771 def feature_to_config_dict(self): 772 return copy.copy(self._feature_to_config_dict) 773 774 @property 775 def table_to_features_dict(self): 776 return copy.copy(self._table_to_features_dict) 777 778 @property 779 def optimization_parameters(self): 780 return self._optimization_parameters 781 782 def _create_config_proto(self): 783 """Create `TPUEmbeddingConfiguration`.""" 784 config_proto = elc.TPUEmbeddingConfiguration() 785 for table in self._table_to_config_dict: 786 table_descriptor = config_proto.table_descriptor.add() 787 table_descriptor.name = table 788 789 table_config = self._table_to_config_dict[table] 790 # For small tables, we pad to the number of hosts so that at least one 791 # id will be assigned to each host. 792 table_descriptor.vocabulary_size = max(table_config.vocabulary_size, 793 len(self.hosts)) 794 table_descriptor.dimension = table_config.dimension 795 796 table_descriptor.num_features = self._table_to_num_features_dict[table] 797 798 parameters = table_descriptor.optimization_parameters 799 if table_config.learning_rate: 800 parameters.learning_rate.constant = (table_config.learning_rate) 801 elif table_config.learning_rate_fn: 802 parameters.learning_rate.dynamic.tag = ( 803 self._learning_rate_fn_to_tag[table_config.learning_rate_fn]) 804 else: 805 parameters.learning_rate.constant = ( 806 self._optimization_parameters.learning_rate) 807 parameters.gradient_accumulation_status = ( 808 optimization_parameters_pb2.GradientAccumulationStatus.ENABLED 809 if self._optimization_parameters.use_gradient_accumulation else 810 optimization_parameters_pb2.GradientAccumulationStatus.DISABLED) 811 if self._optimization_parameters.clip_weight_min is not None: 812 parameters.clipping_limits.lower.value = ( 813 self._optimization_parameters.clip_weight_min) 814 if self._optimization_parameters.clip_weight_max is not None: 815 parameters.clipping_limits.upper.value = ( 816 self._optimization_parameters.clip_weight_max) 817 if table_config.hot_id_replication: 818 parameters.hot_id_replication_configuration.status = ( 819 optimization_parameters_pb2.HotIdReplicationConfiguration.ENABLED) 820 self._optimizer_handler.set_optimization_parameters(table_descriptor) 821 822 config_proto.mode = self._mode 823 config_proto.batch_size_per_tensor_core = self._batch_size_per_core 824 config_proto.num_hosts = self._num_hosts 825 config_proto.num_tensor_cores = self._num_cores 826 config_proto.sharding_strategy = ( 827 elc.TPUEmbeddingConfiguration.DIV_DEFAULT 828 if self._partition_strategy == 'div' else 829 elc.TPUEmbeddingConfiguration.MOD) 830 config_proto.pipeline_execution_with_tensor_core = ( 831 self._pipeline_execution_with_tensor_core) 832 833 return config_proto 834 835 def create_variables_and_ops(self, embedding_variable_name_by_table=None, 836 slot_variable_names_by_table=None): 837 """Create embedding and slot variables, with ops to load and retrieve them. 838 839 N.B.: the retrieve embedding variables (including slot variables) ops are 840 returned as lambda fn, as the call side might want to impose control 841 dependencies between the TPU computation and retrieving actions. For 842 example, the following code snippet ensures the TPU computation finishes 843 first, and then we pull the variables back from TPU to CPU. 844 845 ``` 846 updates_ops = [] 847 with ops.control_dependencies([loss]): 848 for op_fn in retrieve_parameters_op_fns: 849 update_ops.append(op_fn()) 850 ``` 851 852 Args: 853 embedding_variable_name_by_table: A dictionary mapping from string of 854 table name to string of embedding variable name. If `None`, 855 defaults from `get_default_slot_variable_names()` will be used. 856 slot_variable_names_by_table: A dictionary mapping from string of table 857 name to `AdamSlotVariableNames`, `AdagradSlotVariableNames` etc. If 858 `None`, defaults from `get_default_slot_variable_names()` will be used. 859 860 Returns: 861 `tpu_embedding.VariablesAndOps` with: 862 A dictionary mapping from string of table name to embedding variables, 863 A dictionary mapping from string of table name to AdagradSlotVariable, 864 AdamSlotVariables etc with slot variables, 865 A function which returns a list of ops to load embedding and slot 866 variables from CPU to TPU. 867 A function which returns a list of ops to retrieve embedding and slot 868 variables from TPU to CPU. 869 """ 870 embedding_variables_by_table = {} 871 slot_variables_by_table = {} 872 load_op_fns = [] 873 retrieve_op_fns = [] 874 875 for i, table in enumerate(self._table_to_config_dict): 876 if embedding_variable_name_by_table: 877 embedding_variable_name = embedding_variable_name_by_table[table] 878 else: 879 embedding_variable_name = table 880 if slot_variable_names_by_table: 881 slot_variable_names = slot_variable_names_by_table[table] 882 else: 883 slot_variable_names = ( 884 self._optimizer_handler.get_default_slot_variable_names(table)) 885 886 # TODO(b/139144091): Multi-host support for mid-level API in 887 # eager context (TF 2.0) 888 # Workaround below allows single-host use case in TF 2.0 889 if context.executing_eagerly(): 890 device = '' 891 else: 892 device = _create_device_fn(self._hosts) 893 894 with ops.device(device): 895 table_variables = _create_partitioned_variables( 896 name=embedding_variable_name, 897 num_hosts=self._num_hosts, 898 vocabulary_size=self._table_to_config_dict[table].vocabulary_size, 899 embedding_dimension=self._table_to_config_dict[table].dimension, 900 initializer=self._table_to_config_dict[table].initializer, 901 collections=[ops.GraphKeys.GLOBAL_VARIABLES]) 902 embedding_variables_by_table[table] = table_variables 903 904 # Only loads embedding config to load/retrieve nodes for the first table 905 # on the first host, other nodes would use config from the first node. 906 config = None if i else self.config_proto.SerializeToString() 907 slot_variables_for_table, load_ops_fn, retrieve_ops_fn = ( 908 self._optimizer_handler.create_variables_and_ops( 909 table, slot_variable_names, self._num_hosts, 910 self._table_to_config_dict[table], table_variables, config)) 911 slot_variables_by_table[table] = slot_variables_for_table 912 load_op_fns.append(load_ops_fn) 913 retrieve_op_fns.append(retrieve_ops_fn) 914 915 def load_ops(): 916 """Calls and returns the load ops for each embedding table. 917 918 Returns: 919 A list of ops to load embedding and slot variables from CPU to TPU. 920 """ 921 load_ops_list = [] 922 for load_op_fn in load_op_fns: 923 load_ops_list.extend(load_op_fn()) 924 return load_ops_list 925 926 def retrieve_ops(): 927 """Calls and returns the retrieve ops for each embedding table. 928 929 Returns: 930 A list of ops to retrieve embedding and slot variables from TPU to CPU. 931 """ 932 retrieve_ops_list = [] 933 for retrieve_op_fn in retrieve_op_fns: 934 retrieve_ops_list.extend(retrieve_op_fn()) 935 return retrieve_ops_list 936 937 return VariablesAndOps(embedding_variables_by_table, 938 slot_variables_by_table, 939 load_ops, retrieve_ops) 940 941 def generate_enqueue_ops(self, enqueue_datas_list, mode_override=None): 942 """Generate enqueue ops. 943 944 Args: 945 enqueue_datas_list: a list of dictionary mapping from string 946 of feature names to EnqueueData. Each dictionary is for one 947 TPU core. Dictionaries for the same host should be contiguous 948 on the list. 949 mode_override: A string input that overrides the mode specified in the 950 TPUEmbeddingConfiguration. Supported values are {'unspecified', 951 'inference', 'training', 'backward_pass_only'}. When set to 952 'unspecified', the mode set in TPUEmbeddingConfiguration is used, 953 otherwise mode_override is used (optional). 954 955 Returns: 956 Ops to enqueue to TPU for embedding. 957 """ 958 self._validate_generate_enqueue_ops_enqueue_datas_list(enqueue_datas_list) 959 return [ 960 self._generate_enqueue_op( # pylint: disable=g-complex-comprehension 961 enqueue_datas, 962 device_ordinal=i % self._num_cores_per_host, 963 mode_override=mode_override, 964 ) for i, enqueue_datas in enumerate(enqueue_datas_list) 965 ] 966 967 def _validate_generate_enqueue_ops_enqueue_datas_list(self, 968 enqueue_datas_list): 969 """Validate `enqueue_datas_list`.""" 970 feature_set = set(self._feature_to_config_dict.keys()) 971 contiguous_device = None 972 for i, enqueue_datas in enumerate(enqueue_datas_list): 973 used_feature_set = set(enqueue_datas.keys()) 974 975 # Check features are valid. 976 missing_feature_set = feature_set - used_feature_set 977 if missing_feature_set: 978 raise ValueError('`enqueue_datas_list[{}]` misses a feature that is ' 979 'in `feature_to_config_dict`: {}.'.format( 980 i, missing_feature_set)) 981 982 extra_feature_set = used_feature_set - feature_set 983 if extra_feature_set: 984 raise ValueError('`enqueue_datas_list[{}]` has a feature that is not ' 985 'in `feature_to_config_dict`: {}.'.format( 986 i, extra_feature_set)) 987 988 device = None 989 device_feature = None 990 for feature, enqueue_data in six.iteritems(enqueue_datas): 991 combiner = self._table_to_config_dict[ 992 self._feature_to_config_dict[feature].table_id].combiner 993 if not isinstance(enqueue_data, EnqueueData): 994 raise ValueError('`enqueue_datas_list[{}]` has a feature that is ' 995 'not mapped to `EnqueueData`. `feature`: {}'.format( 996 i, feature)) 997 998 if enqueue_data.sample_indices is None and combiner: 999 logging.warn('No sample indices set for features %f table %f but ' 1000 'combiner is set to %s.', feature, 1001 self._feature_to_config_dict[feature].table_id, combiner) 1002 1003 if (enqueue_data.sample_indices is not None and 1004 enqueue_data.sample_indices.device != 1005 enqueue_data.embedding_indices.device): 1006 raise ValueError( 1007 'Device of sample_indices does not agree with ' 1008 'that of emebdding_indices for feature {}.'.format(feature)) 1009 if (enqueue_data.aggregation_weights is not None and 1010 enqueue_data.aggregation_weights.device != 1011 enqueue_data.embedding_indices.device): 1012 raise ValueError( 1013 'Device of aggregation_weights does not agree with ' 1014 'that of emebdding_indices for feature {}.'.format(feature)) 1015 # Check all features are on the same device. 1016 if device is None: 1017 device = enqueue_data.embedding_indices.device 1018 device_feature = feature 1019 else: 1020 if device != enqueue_data.embedding_indices.device: 1021 raise ValueError('Devices are different between features in ' 1022 '`enqueue_datas_list[{}]`; ' 1023 'devices: {}, {}; features: {}, {}.'.format( 1024 i, device, 1025 enqueue_data.embedding_indices.device, feature, 1026 device_feature)) 1027 1028 if i % self._num_cores_per_host: 1029 if device != contiguous_device: 1030 raise ValueError('We expect the `enqueue_datas` which are on the ' 1031 'same host to be contiguous in ' 1032 '`enqueue_datas_list`, ' 1033 '`enqueue_datas_list[{}]` is on device {}, ' 1034 'but is expected to be on device {}.'.format( 1035 i, device, contiguous_device)) 1036 else: 1037 contiguous_device = device 1038 1039 def _generate_enqueue_op( 1040 self, enqueue_datas, device_ordinal, mode_override=None): 1041 enqueue_data0 = list(enqueue_datas.values())[0] 1042 with ops.colocate_with(enqueue_data0.embedding_indices): 1043 return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( 1044 device_ordinal=device_ordinal, 1045 combiners=self._combiners, 1046 mode_override=mode_override, 1047 **self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas) 1048 ) 1049 1050 def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas): 1051 """Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`. 1052 1053 Args: 1054 enqueue_datas: a `Dict` of tensors for embedding. Can be sparse or 1055 dense. 1056 1057 Returns: 1058 Dict of arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`. 1059 """ 1060 kwargs = { 1061 'sample_indices': [], 1062 'embedding_indices': [], 1063 'aggregation_weights': [], 1064 'table_ids': [], 1065 'max_sequence_lengths': [], 1066 } 1067 for table_id, table in enumerate(self._table_to_features_dict): 1068 features = self._table_to_features_dict[table] 1069 for feature in features: 1070 enqueue_data = enqueue_datas[feature] 1071 1072 kwargs['sample_indices'].append( 1073 enqueue_data.sample_indices 1074 if enqueue_data.sample_indices is not None else array_ops.zeros( 1075 (0,), dtype=dtypes.int64)) 1076 1077 kwargs['aggregation_weights'].append( 1078 enqueue_data.aggregation_weights if 1079 enqueue_data.aggregation_weights is not None else array_ops.zeros( 1080 (0,), dtype=dtypes.float32)) 1081 1082 kwargs['embedding_indices'].append(enqueue_data.embedding_indices) 1083 1084 kwargs['table_ids'].append(table_id) 1085 kwargs['max_sequence_lengths'].append( 1086 self._feature_to_config_dict[feature].max_sequence_length) 1087 1088 return kwargs 1089 1090 def get_activations(self): 1091 """Get activations for features. 1092 1093 This should be called within `computation` that is passed to 1094 `tpu.replicate` and friends. 1095 1096 Returns: 1097 A dictionary mapping from `String` of feature name to `Tensor` 1098 of activation. 1099 """ 1100 recv_activations = tpu_ops.recv_tpu_embedding_activations( 1101 num_outputs=len(self._table_to_config_dict), 1102 config=self._config_proto.SerializeToString()) 1103 1104 activations = collections.OrderedDict() 1105 for table_id, table in enumerate(self._table_to_features_dict): 1106 features = self._table_to_features_dict[table] 1107 num_features = self._table_to_num_features_dict[table] 1108 feature_index = 0 1109 table_activations = array_ops.reshape( 1110 recv_activations[table_id], 1111 [self.batch_size_per_core, num_features, -1]) 1112 for feature in features: 1113 seq_length = self._feature_to_config_dict[feature].max_sequence_length 1114 if not seq_length: 1115 activations[feature] = table_activations[:, feature_index, :] 1116 feature_index = feature_index + 1 1117 else: 1118 activations[feature] = ( 1119 table_activations[:, feature_index:(feature_index+seq_length), :]) 1120 feature_index = feature_index + seq_length 1121 1122 return activations 1123 1124 def generate_send_gradients_op(self, 1125 feature_to_gradient_dict, 1126 step=None): 1127 """Send gradient to TPU embedding. 1128 1129 Args: 1130 feature_to_gradient_dict: dict mapping feature names to gradient wrt 1131 activations. 1132 step: the current global step, used for dynamic learning rate. 1133 1134 Returns: 1135 SendTPUEmbeddingGradients Op. 1136 1137 Raises: 1138 RuntimeError: If `mode` is not `TRAINING`. 1139 """ 1140 if self._mode != TRAINING: 1141 raise RuntimeError('Only in training mode gradients need to ' 1142 'be sent to TPU embedding; got mode {}.' 1143 .format(self._mode)) 1144 if step is None and self._learning_rate_fn: 1145 raise ValueError('There are dynamic learning rates but step is None.') 1146 1147 gradients = [] 1148 for table in self._table_to_features_dict: 1149 features = self._table_to_features_dict[table] 1150 table_gradients = [] 1151 for feature in features: 1152 gradient = feature_to_gradient_dict[feature] 1153 # Expand dims for non-sequence feature to match sequence features. 1154 if gradient.shape.ndims == 2: 1155 gradient = array_ops.expand_dims(gradient, 1) 1156 table_gradients.append(gradient) 1157 interleaved_table_grads = array_ops.reshape( 1158 array_ops.concat(table_gradients, axis=1), 1159 [-1, array_ops.shape(table_gradients[0])[-1]]) 1160 gradients.append(interleaved_table_grads) 1161 1162 return tpu_ops.send_tpu_embedding_gradients( 1163 inputs=gradients, 1164 learning_rates=[math_ops.cast(fn(step), dtype=dtypes.float32) 1165 for fn in self._learning_rate_fn], 1166 config=self.config_proto.SerializeToString()) 1167 1168 1169def _validate_table_to_config_dict(table_to_config_dict): 1170 """Validate `table_to_config_dict`.""" 1171 for k, v in six.iteritems(table_to_config_dict): 1172 if not isinstance(v, TableConfig): 1173 raise ValueError('Value of `table_to_config_dict` must be of type ' 1174 '`TableConfig`, got {} for {}.'.format(type(v), k)) 1175 1176 1177def _validate_feature_to_config_dict(table_to_config_dict, 1178 feature_to_config_dict): 1179 """Validate `feature_to_config_dict`.""" 1180 used_table_set = set([feature.table_id 1181 for feature in feature_to_config_dict.values()]) 1182 table_set = set(table_to_config_dict.keys()) 1183 1184 unused_table_set = table_set - used_table_set 1185 if unused_table_set: 1186 raise ValueError('`table_to_config_dict` specifies table that is not ' 1187 'used in `feature_to_config_dict`: {}.' 1188 .format(unused_table_set)) 1189 1190 extra_table_set = used_table_set - table_set 1191 if extra_table_set: 1192 raise ValueError('`feature_to_config_dict` refers to a table that is not ' 1193 'specified in `table_to_config_dict`: {}.' 1194 .format(extra_table_set)) 1195 1196 1197def _validate_batch_size(batch_size, num_cores): 1198 if batch_size % num_cores: 1199 raise ValueError('`batch_size` is not a multiple of number of ' 1200 'cores. `batch_size`={}, `_num_cores`={}.'.format( 1201 batch_size, num_cores)) 1202 1203 1204def _validate_optimization_parameters(optimization_parameters): 1205 if not isinstance(optimization_parameters, _OptimizationParameters): 1206 raise ValueError('`optimization_parameters` must inherit from ' 1207 '`_OptimizationPramaters`. ' 1208 '`type(optimization_parameters)`={}'.format( 1209 type(optimization_parameters))) 1210 1211 1212class _OptimizerHandler(object): 1213 """Interface class for handling optimizer specific logic.""" 1214 1215 def __init__(self, optimization_parameters): 1216 self._optimization_parameters = optimization_parameters 1217 1218 def set_optimization_parameters(self, table_descriptor): 1219 raise NotImplementedError() 1220 1221 def get_default_slot_variable_names(self, table): 1222 raise NotImplementedError() 1223 1224 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 1225 table_config, table_variables, config_proto): 1226 raise NotImplementedError() 1227 1228 1229class _AdagradHandler(_OptimizerHandler): 1230 """Handles Adagrad specific logic.""" 1231 1232 def __init__(self, optimization_parameters): 1233 super(_AdagradHandler, self).__init__(optimization_parameters) 1234 self._table_to_accumulator_variables_dict = {} 1235 1236 def set_optimization_parameters(self, table_descriptor): 1237 table_descriptor.optimization_parameters.adagrad.SetInParent() 1238 1239 def get_default_slot_variable_names(self, table): 1240 return AdagradSlotVariableName('{}/{}'.format(table, 'Adagrad')) 1241 1242 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 1243 table_config, table_variables, config_proto): 1244 accumulator_initializer = init_ops.constant_initializer( 1245 self._optimization_parameters.initial_accumulator) 1246 accumulator_variables = _create_partitioned_variables( 1247 name=slot_variable_names.accumulator, 1248 num_hosts=num_hosts, 1249 vocabulary_size=table_config.vocabulary_size, 1250 embedding_dimension=table_config.dimension, 1251 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 1252 initializer=accumulator_initializer) 1253 slot_variables = AdagradSlotVariable(accumulator_variables) 1254 1255 def load_ops_fn(): 1256 """Returns the retrieve ops for AdaGrad embedding tables. 1257 1258 Returns: 1259 A list of ops to load embedding and slot variables from CPU to TPU. 1260 """ 1261 config = config_proto 1262 load_op_list = [] 1263 for host_id, table_variable, accumulator_variable in zip( 1264 range(num_hosts), table_variables, accumulator_variables): 1265 with ops.colocate_with(table_variable): 1266 load_parameters_op = ( 1267 tpu_ops.load_tpu_embedding_adagrad_parameters( 1268 parameters=table_variable, 1269 accumulators=accumulator_variable, 1270 table_name=table, 1271 num_shards=num_hosts, 1272 shard_id=host_id, 1273 config=config)) 1274 config = None 1275 load_op_list.append(load_parameters_op) 1276 return load_op_list 1277 1278 def retrieve_ops_fn(): 1279 """Returns the retrieve ops for AdaGrad embedding tables. 1280 1281 Returns: 1282 A list of ops to retrieve embedding and slot variables from TPU to CPU. 1283 """ 1284 config = config_proto 1285 retrieve_op_list = [] 1286 for host_id, table_variable, accumulator_variable in (zip( 1287 range(num_hosts), table_variables, accumulator_variables)): 1288 with ops.colocate_with(table_variable): 1289 retrieved_table, retrieved_accumulator = ( 1290 tpu_ops.retrieve_tpu_embedding_adagrad_parameters( 1291 table_name=table, 1292 num_shards=num_hosts, 1293 shard_id=host_id, 1294 config=config)) 1295 retrieve_parameters_op = control_flow_ops.group( 1296 state_ops.assign(table_variable, retrieved_table), 1297 state_ops.assign(accumulator_variable, retrieved_accumulator)) 1298 config = None 1299 retrieve_op_list.append(retrieve_parameters_op) 1300 return retrieve_op_list 1301 1302 return slot_variables, load_ops_fn, retrieve_ops_fn 1303 1304 1305class _AdamHandler(_OptimizerHandler): 1306 """Handles Adam specific logic.""" 1307 1308 def __init__(self, optimization_parameters): 1309 super(_AdamHandler, self).__init__(optimization_parameters) 1310 self._table_to_m_variables_dict = {} 1311 self._table_to_v_variables_dict = {} 1312 1313 def set_optimization_parameters(self, table_descriptor): 1314 table_descriptor.optimization_parameters.adam.beta1 = ( 1315 self._optimization_parameters.beta1) 1316 table_descriptor.optimization_parameters.adam.beta2 = ( 1317 self._optimization_parameters.beta2) 1318 table_descriptor.optimization_parameters.adam.epsilon = ( 1319 self._optimization_parameters.epsilon) 1320 table_descriptor.optimization_parameters.adam.use_non_lazy_adam = ( 1321 not self._optimization_parameters.lazy_adam) 1322 table_descriptor.optimization_parameters.adam.use_sum_inside_sqrt = ( 1323 self._optimization_parameters.sum_inside_sqrt) 1324 1325 def get_default_slot_variable_names(self, table): 1326 return AdamSlotVariableNames('{}/{}/m'.format(table, 'Adam'), 1327 '{}/{}/v'.format(table, 'Adam')) 1328 1329 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 1330 table_config, table_variables, config_proto): 1331 m_initializer = init_ops.zeros_initializer() 1332 m_variables = _create_partitioned_variables( 1333 name=slot_variable_names.m, 1334 num_hosts=num_hosts, 1335 vocabulary_size=table_config.vocabulary_size, 1336 embedding_dimension=table_config.dimension, 1337 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 1338 initializer=m_initializer) 1339 v_initializer = init_ops.zeros_initializer() 1340 v_variables = _create_partitioned_variables( 1341 name=slot_variable_names.v, 1342 num_hosts=num_hosts, 1343 vocabulary_size=table_config.vocabulary_size, 1344 embedding_dimension=table_config.dimension, 1345 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 1346 initializer=v_initializer) 1347 slot_variables = AdamSlotVariables(m_variables, v_variables) 1348 1349 def load_ops_fn(): 1350 """Returns the retrieve ops for AdaGrad embedding tables. 1351 1352 Returns: 1353 A list of ops to load embedding and slot variables from CPU to TPU. 1354 """ 1355 load_op_list = [] 1356 config = config_proto 1357 for host_id, table_variable, m_variable, v_variable in (zip( 1358 range(num_hosts), table_variables, 1359 m_variables, v_variables)): 1360 with ops.colocate_with(table_variable): 1361 load_parameters_op = ( 1362 tpu_ops.load_tpu_embedding_adam_parameters( 1363 parameters=table_variable, 1364 momenta=m_variable, 1365 velocities=v_variable, 1366 table_name=table, 1367 num_shards=num_hosts, 1368 shard_id=host_id, 1369 config=config)) 1370 # Set config to None to enforce that config is only loaded to the first 1371 # table. 1372 config = None 1373 load_op_list.append(load_parameters_op) 1374 return load_op_list 1375 1376 def retrieve_ops_fn(): 1377 """Returns the retrieve ops for Adam embedding tables. 1378 1379 Returns: 1380 A list of ops to retrieve embedding and slot variables from TPU to CPU. 1381 """ 1382 retrieve_op_list = [] 1383 config = config_proto 1384 for host_id, table_variable, m_variable, v_variable in (zip( 1385 range(num_hosts), table_variables, 1386 m_variables, v_variables)): 1387 with ops.colocate_with(table_variable): 1388 retrieved_table, retrieved_m, retrieved_v = ( 1389 tpu_ops.retrieve_tpu_embedding_adam_parameters( 1390 table_name=table, 1391 num_shards=num_hosts, 1392 shard_id=host_id, 1393 config=config)) 1394 retrieve_parameters_op = control_flow_ops.group( 1395 state_ops.assign(table_variable, retrieved_table), 1396 state_ops.assign(m_variable, retrieved_m), 1397 state_ops.assign(v_variable, retrieved_v)) 1398 config = None 1399 retrieve_op_list.append(retrieve_parameters_op) 1400 return retrieve_op_list 1401 1402 return slot_variables, load_ops_fn, retrieve_ops_fn 1403 1404 1405class _FtrlHandler(_OptimizerHandler): 1406 """Handles Ftrl specific logic.""" 1407 1408 def __init__(self, optimization_parameters): 1409 super(_FtrlHandler, self).__init__(optimization_parameters) 1410 self._table_to_accumulator_variables_dict = {} 1411 self._table_to_linear_variables_dict = {} 1412 1413 def set_optimization_parameters(self, table_descriptor): 1414 table_descriptor.optimization_parameters.ftrl.lr_power = ( 1415 self._optimization_parameters.learning_rate_power) 1416 table_descriptor.optimization_parameters.ftrl.l1 = ( 1417 self._optimization_parameters.l1_regularization_strength) 1418 table_descriptor.optimization_parameters.ftrl.l2 = ( 1419 self._optimization_parameters.l2_regularization_strength) 1420 table_descriptor.optimization_parameters.ftrl.initial_accum = ( 1421 self._optimization_parameters.initial_accumulator_value) 1422 table_descriptor.optimization_parameters.ftrl.initial_linear = ( 1423 self._optimization_parameters.initial_linear_value) 1424 1425 def get_default_slot_variable_names(self, table): 1426 # These match the default slot variable names created by 1427 # tf.train.FtrlOptimizer. 1428 return FtrlSlotVariableName('{}/{}'.format(table, 'Ftrl'), # accumulator 1429 '{}/{}'.format(table, 'Ftrl_1')) # linear 1430 1431 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 1432 table_config, table_variables, config_proto): 1433 accumulator_initializer = init_ops.constant_initializer( 1434 self._optimization_parameters.initial_accumulator_value) 1435 accumulator_variables = _create_partitioned_variables( 1436 name=slot_variable_names.accumulator, 1437 num_hosts=num_hosts, 1438 vocabulary_size=table_config.vocabulary_size, 1439 embedding_dimension=table_config.dimension, 1440 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 1441 initializer=accumulator_initializer) 1442 linear_initializer = init_ops.constant_initializer( 1443 self._optimization_parameters.initial_linear_value) 1444 linear_variables = _create_partitioned_variables( 1445 name=slot_variable_names.linear, 1446 num_hosts=num_hosts, 1447 vocabulary_size=table_config.vocabulary_size, 1448 embedding_dimension=table_config.dimension, 1449 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 1450 initializer=linear_initializer) 1451 slot_variables = FtrlSlotVariable(accumulator_variables, 1452 linear_variables) 1453 1454 def load_ops_fn(): 1455 """Returns the retrieve ops for Ftrl embedding tables. 1456 1457 Returns: 1458 A list of ops to load embedding and slot variables from CPU to TPU. 1459 """ 1460 config = config_proto 1461 load_op_list = [] 1462 for host_id, table_variable, accumulator_variable, linear_variable in zip( 1463 range(num_hosts), table_variables, accumulator_variables, 1464 linear_variables): 1465 with ops.colocate_with(table_variable): 1466 load_parameters_op = ( 1467 tpu_ops.load_tpu_embedding_ftrl_parameters( 1468 parameters=table_variable, 1469 accumulators=accumulator_variable, 1470 linears=linear_variable, 1471 table_name=table, 1472 num_shards=num_hosts, 1473 shard_id=host_id, 1474 config=config)) 1475 config = None 1476 load_op_list.append(load_parameters_op) 1477 return load_op_list 1478 1479 def retrieve_ops_fn(): 1480 """Returns the retrieve ops for Ftrl embedding tables. 1481 1482 Returns: 1483 A list of ops to retrieve embedding and slot variables from TPU to CPU. 1484 """ 1485 config = config_proto 1486 retrieve_op_list = [] 1487 for host_id, table_variable, accumulator_variable, linear_variable in zip( 1488 range(num_hosts), table_variables, accumulator_variables, 1489 linear_variables): 1490 with ops.colocate_with(table_variable): 1491 retrieved_table, retrieved_accumulator, retrieved_linear = ( 1492 tpu_ops.retrieve_tpu_embedding_ftrl_parameters( 1493 table_name=table, 1494 num_shards=num_hosts, 1495 shard_id=host_id, 1496 config=config)) 1497 retrieve_parameters_op = control_flow_ops.group( 1498 state_ops.assign(table_variable, retrieved_table), 1499 state_ops.assign(accumulator_variable, retrieved_accumulator), 1500 state_ops.assign(linear_variable, retrieved_linear)) 1501 config = None 1502 retrieve_op_list.append(retrieve_parameters_op) 1503 return retrieve_op_list 1504 1505 return slot_variables, load_ops_fn, retrieve_ops_fn 1506 1507 1508class _StochasticGradientDescentHandler(_OptimizerHandler): 1509 """Handles stochastic gradient descent specific logic.""" 1510 1511 def set_optimization_parameters(self, table_descriptor): 1512 (table_descriptor.optimization_parameters.stochastic_gradient_descent 1513 .SetInParent()) 1514 1515 def get_default_slot_variable_names(self, table): 1516 return None 1517 1518 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 1519 table_config, table_variables, config_proto): 1520 del table_config 1521 1522 def load_ops_fn(): 1523 """Returns the retrieve ops for AdaGrad embedding tables. 1524 1525 Returns: 1526 A list of ops to load embedding and slot variables from CPU to TPU. 1527 """ 1528 load_op_list = [] 1529 config = config_proto 1530 for host_id, table_variable in (zip( 1531 range(num_hosts), table_variables)): 1532 with ops.colocate_with(table_variable): 1533 load_parameters_op = ( 1534 tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters( 1535 parameters=table_variable, 1536 table_name=table, 1537 num_shards=num_hosts, 1538 shard_id=host_id, 1539 config=config)) 1540 config = None 1541 load_op_list.append(load_parameters_op) 1542 return load_op_list 1543 1544 def retrieve_ops_fn(): 1545 """Returns the retrieve ops for SGD embedding tables. 1546 1547 Returns: 1548 A list of ops to retrieve embedding and slot variables from TPU to CPU. 1549 """ 1550 retrieve_op_list = [] 1551 config = config_proto 1552 for host_id, table_variable in (zip( 1553 range(num_hosts), table_variables)): 1554 with ops.colocate_with(table_variable): 1555 retrieved_table = ( 1556 tpu_ops 1557 .retrieve_tpu_embedding_stochastic_gradient_descent_parameters( 1558 table_name=table, 1559 num_shards=num_hosts, 1560 shard_id=host_id, 1561 config=config)) 1562 retrieve_parameters_op = control_flow_ops.group( 1563 state_ops.assign(table_variable, retrieved_table)) 1564 config = None 1565 retrieve_op_list.append(retrieve_parameters_op) 1566 return retrieve_op_list 1567 1568 return None, load_ops_fn, retrieve_ops_fn 1569 1570 1571def _get_optimization_handler(optimization_parameters): 1572 """Gets the optimization handler given the parameter type.""" 1573 if isinstance(optimization_parameters, AdagradParameters): 1574 return _AdagradHandler(optimization_parameters) 1575 elif isinstance(optimization_parameters, AdamParameters): 1576 return _AdamHandler(optimization_parameters) 1577 elif isinstance(optimization_parameters, FtrlParameters): 1578 return _FtrlHandler(optimization_parameters) 1579 elif isinstance(optimization_parameters, StochasticGradientDescentParameters): 1580 return _StochasticGradientDescentHandler(optimization_parameters) 1581 else: 1582 return NotImplementedError() 1583 1584 1585def _create_ordered_dict(d): 1586 """Create an OrderedDict from Dict.""" 1587 return collections.OrderedDict((k, d[k]) for k in sorted(d)) 1588 1589 1590def _create_combiners(table_to_config_dict, table_to_features_dict): 1591 """Create a per feature list of combiners, ordered by table.""" 1592 combiners = [] 1593 for table in table_to_config_dict: 1594 combiner = table_to_config_dict[table].combiner or 'sum' 1595 combiners.extend([combiner] * len(table_to_features_dict[table])) 1596 return combiners 1597 1598 1599def _create_table_to_features_and_num_features_dicts(feature_to_config_dict): 1600 """Create mapping from table to a list of its features.""" 1601 table_to_features_dict_tmp = {} 1602 table_to_num_features_dict_tmp = {} 1603 for feature, feature_config in six.iteritems(feature_to_config_dict): 1604 if feature_config.table_id in table_to_features_dict_tmp: 1605 table_to_features_dict_tmp[feature_config.table_id].append(feature) 1606 else: 1607 table_to_features_dict_tmp[feature_config.table_id] = [feature] 1608 table_to_num_features_dict_tmp[feature_config.table_id] = 0 1609 if feature_config.max_sequence_length == 0: 1610 table_to_num_features_dict_tmp[feature_config.table_id] = ( 1611 table_to_num_features_dict_tmp[feature_config.table_id] + 1) 1612 else: 1613 table_to_num_features_dict_tmp[feature_config.table_id] = ( 1614 table_to_num_features_dict_tmp[feature_config.table_id] + 1615 feature_config.max_sequence_length) 1616 1617 table_to_features_dict = collections.OrderedDict() 1618 table_to_num_features_dict = collections.OrderedDict() 1619 for table in sorted(table_to_features_dict_tmp): 1620 table_to_features_dict[table] = sorted(table_to_features_dict_tmp[table]) 1621 table_to_num_features_dict[table] = table_to_num_features_dict_tmp[table] 1622 return table_to_features_dict, table_to_num_features_dict 1623 1624 1625def _create_device_fn(hosts): 1626 """Create device_fn() to use with _create_partitioned_variables().""" 1627 1628 def device_fn(op): 1629 """Returns the `device` for `op`.""" 1630 part_match = re.match(r'.*/part_(\d+)(/|$)', op.name) 1631 dummy_match = re.match(r'.*dummy_(\d+).*', op.name) 1632 if not part_match and not dummy_match: 1633 raise RuntimeError( 1634 'Internal Error: Expected {} to contain /part_* or dummy_*'.format( 1635 op.name)) 1636 1637 if part_match: 1638 idx = int(part_match.group(1)) 1639 else: 1640 idx = int(dummy_match.group(1)) 1641 1642 device = hosts[idx] 1643 logging.debug('assigning {} to {}.', op, device) 1644 return device 1645 1646 return device_fn 1647 1648 1649def _create_partitioned_variables(name, 1650 num_hosts, 1651 vocabulary_size, 1652 embedding_dimension, 1653 initializer, 1654 collections=None): # pylint: disable=redefined-outer-name 1655 """Creates ParitionedVariables based on `num_hosts` for `table`.""" 1656 1657 num_slices = min(vocabulary_size, num_hosts) 1658 1659 var_list = list( 1660 variable_scope.get_variable( 1661 name, 1662 shape=(vocabulary_size, embedding_dimension), 1663 partitioner=partitioned_variables.fixed_size_partitioner(num_slices), 1664 dtype=dtypes.float32, 1665 initializer=initializer, 1666 collections=collections, 1667 trainable=False)) 1668 1669 if vocabulary_size >= num_hosts: 1670 return var_list 1671 1672 # For padded part, define the dummy variable to be loaded into TPU system. 1673 for idx in range(num_hosts - vocabulary_size): 1674 var_list.append( 1675 variable_scope.get_variable( 1676 'dummy_{}_{}'.format(vocabulary_size + idx, name), 1677 shape=(1, embedding_dimension), 1678 dtype=dtypes.float32, 1679 initializer=initializer, 1680 collections=[ops.GraphKeys.LOCAL_VARIABLES], 1681 trainable=False)) 1682 1683 return var_list 1684