1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TPU embedding APIs.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import copy 23import math 24import re 25import six 26 27from tensorflow.core.protobuf.tpu import optimization_parameters_pb2 28from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as elc 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import sparse_tensor 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import control_flow_ops 34from tensorflow.python.ops import init_ops 35from tensorflow.python.ops import partitioned_variables 36from tensorflow.python.ops import state_ops 37from tensorflow.python.ops import variable_scope 38from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib 39from tensorflow.python.tpu.ops import tpu_ops 40 41TRAINING = elc.TPUEmbeddingConfiguration.TRAINING 42INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE 43 44 45class TableConfig( 46 collections.namedtuple( 47 'TableConfig', 48 ['vocabulary_size', 'dimension', 'initializer', 'combiner'])): 49 """Embedding table configuration.""" 50 51 def __new__(cls, 52 vocabulary_size, 53 dimension, 54 initializer=None, 55 combiner='mean'): 56 """Embedding table configuration. 57 58 Args: 59 vocabulary_size: Number of vocabulary (/rows) in the table. 60 dimension: The embedding dimension. 61 initializer: A variable initializer function to be used in embedding 62 variable initialization. If not specified, defaults to 63 `tf.truncated_normal_initializer` with mean `0.0` and standard deviation 64 `1/sqrt(dimension)`. 65 combiner: A string specifying how to reduce if there are multiple entries 66 in a single row. Currently 'mean', 'sqrtn', 'sum' and None are 67 supported, with 'mean' the default. 'sqrtn' often achieves good 68 accuracy, in particular with bag-of-words columns. For more information, 69 see `tf.nn.embedding_lookup_sparse`. None is only valid for dense rather 70 than sparse tensors. 71 72 Returns: 73 `TableConfig`. 74 75 Raises: 76 ValueError: if `vocabulary_size` is not positive integer. 77 ValueError: if `dimension` is not positive integer. 78 ValueError: if `initializer` is specified and is not callable. 79 ValueError: if `combiner` is not supported. 80 """ 81 if not isinstance(vocabulary_size, int) or vocabulary_size < 1: 82 raise ValueError('Invalid vocabulary_size {}.'.format(vocabulary_size)) 83 84 if not isinstance(dimension, int) or dimension < 1: 85 raise ValueError('Invalid dimension {}.'.format(dimension)) 86 87 if (initializer is not None) and (not callable(initializer)): 88 raise ValueError('initializer must be callable if specified.') 89 if initializer is None: 90 initializer = init_ops.truncated_normal_initializer( 91 mean=0.0, stddev=1 / math.sqrt(dimension)) 92 93 if combiner not in ('mean', 'sum', 'sqrtn', None): 94 raise ValueError('Invalid combiner {}'.format(combiner)) 95 96 return super(TableConfig, cls).__new__(cls, vocabulary_size, dimension, 97 initializer, combiner) 98 99 100AdamSlotVariableNames = collections.namedtuple( 101 'AdamSlotVariableNames', ['m', 'v']) 102 103AdagradSlotVariableName = collections.namedtuple( 104 'AdagradSlotVariableName', ['accumulator']) 105 106AdamSlotVariables = collections.namedtuple( 107 'AdamSlotVariables', ['m', 'v']) 108 109AdagradSlotVariable = collections.namedtuple( 110 'AdagradSlotVariable', ['accumulator']) 111 112VariablesAndOps = collections.namedtuple( 113 'VariablesAndOps', 114 ['embedding_variables_by_table', 'slot_variables_by_table', 115 'load_ops', 'retrieve_ops'] 116) 117 118 119class _OptimizationParameters(object): 120 """Parameters common to all optimizations.""" 121 122 def __init__(self, learning_rate, use_gradient_accumulation): 123 self.learning_rate = learning_rate 124 self.use_gradient_accumulation = use_gradient_accumulation 125 126 127class AdagradParameters(_OptimizationParameters): 128 """Optimization parameters for Adagrad.""" 129 130 def __init__(self, learning_rate, initial_accumulator=0.1, 131 use_gradient_accumulation=True): 132 """Optimization parameters for Adagrad. 133 134 Args: 135 learning_rate: used for updating embedding table. 136 initial_accumulator: initial accumulator for Adagrad. 137 use_gradient_accumulation: setting this to `False` makes embedding 138 gradients calculation less accurate but faster. Please see 139 `optimization_parameters.proto` for details. 140 for details. 141 """ 142 super(AdagradParameters, self).__init__(learning_rate, 143 use_gradient_accumulation) 144 if initial_accumulator <= 0: 145 raise ValueError('Adagrad initial_accumulator must be positive') 146 self.initial_accumulator = initial_accumulator 147 148 149class AdamParameters(_OptimizationParameters): 150 """Optimization parameters for Adam.""" 151 152 def __init__(self, learning_rate, 153 beta1=0.9, 154 beta2=0.999, 155 epsilon=1e-08, 156 lazy_adam=True, 157 sum_inside_sqrt=True, 158 use_gradient_accumulation=True): 159 """Optimization parameters for Adam. 160 161 Args: 162 learning_rate: a floating point value. The learning rate. 163 beta1: A float value. 164 The exponential decay rate for the 1st moment estimates. 165 beta2: A float value. 166 The exponential decay rate for the 2nd moment estimates. 167 epsilon: A small constant for numerical stability. 168 lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. 169 Please see `optimization_parameters.proto` for details. 170 sum_inside_sqrt: This improves training speed. Please see 171 `optimization_parameters.proto` for details. 172 use_gradient_accumulation: setting this to `False` makes embedding 173 gradients calculation less accurate but faster. Please see 174 `optimization_parameters.proto` for details. 175 for details. 176 """ 177 super(AdamParameters, self).__init__(learning_rate, 178 use_gradient_accumulation) 179 if beta1 < 0. or beta1 >= 1.: 180 raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1)) 181 if beta2 < 0. or beta2 >= 1.: 182 raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2)) 183 if epsilon <= 0.: 184 raise ValueError('epsilon must be positive; got {}.'.format(epsilon)) 185 if not use_gradient_accumulation and not lazy_adam: 186 raise ValueError( 187 'When disabling Lazy Adam, gradient accumulation must be used.') 188 189 self.beta1 = beta1 190 self.beta2 = beta2 191 self.epsilon = epsilon 192 self.lazy_adam = lazy_adam 193 self.sum_inside_sqrt = sum_inside_sqrt 194 195 196class StochasticGradientDescentParameters(_OptimizationParameters): 197 """Optimization parameters for stochastic gradient descent. 198 199 Args: 200 learning_rate: a floating point value. The learning rate. 201 """ 202 203 def __init__(self, learning_rate): 204 super(StochasticGradientDescentParameters, self).__init__( 205 learning_rate, False) 206 207 208class TPUEmbedding(object): 209 """API for using TPU for embedding. 210 211 Example: 212 ``` 213 table_config_user = tpu_embedding.TableConfig( 214 vocabulary_size=4, dimension=2, 215 initializer=initializer, combiner='mean') 216 table_to_config_dict = {'video': table_config_video, 217 'user': table_config_user} 218 feature_to_table_dict = {'watched': 'video', 219 'favorited': 'video', 220 'friends': 'user'} 221 batch_size = 4 222 num_hosts = 1 223 optimization_parameters = tpu_embedding.AdagradParameters(1., 1.) 224 mode = tpu_embedding.TRAINING 225 embedding = tpu_embedding.TPUEmbedding( 226 table_to_config_dict, feature_to_table_dict, 227 batch_size, num_hosts, mode, optimization_parameters) 228 229 batch_size_per_core = embedding.batch_size_per_core 230 sparse_features_list = [] 231 for host in hosts: 232 with ops.device(host): 233 for _ in range(embedding.num_cores_per_host): 234 sparse_features = {} 235 sparse_features['watched'] = sparse_tensor.SparseTensor(...) 236 sparse_features['favorited'] = sparse_tensor.SparseTensor(...) 237 sparse_features['friends'] = sparse_tensor.SparseTensor(...) 238 sparse_features_list.append(sparse_features) 239 240 enqueue_ops = embedding.generate_enqueue_ops(sparse_features_list) 241 embedding_variables_and_ops = embedding.create_variables_and_ops() 242 243 def computation(): 244 activations = embedding.get_activations() 245 loss = compute_loss(activations) 246 247 base_optimizer = gradient_descent.GradientDescentOptimizer( 248 learning_rate=1) 249 cross_shard_optimizer = tpu_optimizer.CrossShardOptimizer( 250 base_optimizer) 251 252 train_op = cross_shard_optimizer.minimize(loss) 253 gradients = ( 254 tpu_embedding_gradient.get_gradients_through_compute_gradients( 255 cross_shard_optimizer, loss, activations) 256 send_gradients_op = embedding.generate_send_gradients_op(gradients) 257 with ops.control_dependencies([train_op, send_gradients_op]): 258 loss = array_ops.identity(loss) 259 260 loss = tpu.shard(computation, 261 num_shards=embedding.num_cores) 262 263 with self.test_session() as sess: 264 sess.run(tpu.initialize_system(embedding_config= 265 embedding.config_proto)) 266 sess.run(variables.global_variables_initializer()) 267 sess.run(embedding_variables_and_ops.load_ops()) 268 sess.run(enqueue_ops) 269 loss_val = sess.run(loss) 270 ``` 271 """ 272 273 # TODO(shizhiw): Instead of `feature_to_table_dict` which maps to table 274 # name, consider `feature_to_config_dict` which maps to `FeatureConfig`. 275 # `FeatureConfig` could have fields other than table name. For example, it 276 # could have a field to indicate that the feature should not be used to 277 # update embedding table (cr/204852758, cr/204940540). Also, this can support 278 # different combiners for different features within the same table. 279 # TODO(shizhiw, b/118512626): Remove `batch_size` from `__init__` and move it 280 # to `FeatureConfig`? 281 282 # TODO(shizhiw): will it be cleaner to make `table_to_config_dict` and 283 # `feature_to_table_dict` lists of `TableSpec` and `FeatureSpec` respectively? 284 285 # TODO(shizhiw): Consider adding `input_fn` as an option to remove boilerplate 286 # for-loops around construction of inputs. 287 288 # `optimization_parameter` applies to all tables. If the need arises, 289 # we can add `optimization_parameters` to `TableConfig` to override this 290 # global setting. 291 def __init__(self, 292 table_to_config_dict, 293 feature_to_table_dict, 294 batch_size, 295 mode, 296 master, 297 optimization_parameters=None, 298 cluster_def=None, 299 pipeline_execution_with_tensor_core=True): 300 """API for using TPU for embedding lookups. 301 302 Args: 303 table_to_config_dict: A dictionary mapping from string of table name to 304 `TableConfig`. Table refers to an embedding table, e.g. `params` 305 argument to `tf.nn.embedding_lookup_sparse()`. 306 feature_to_table_dict: A dictionary mapping from string of feature name 307 to string of table name. Feature refers to ids to lookup in embedding 308 table, e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`. 309 batch_size: An `int` representing the global batch size. 310 mode: `TRAINING` or `INFERENCE`. 311 master: A `string` representing the TensorFlow master to use. 312 optimization_parameters: `AdagradParameters`, `AdamParameters`, 313 `Stochasticgradientdescentparameters`. Must be set in training and must 314 be `None` in inference. 315 cluster_def: A ClusterDef object describing the TPU cluster. 316 pipeline_execution_with_tensor_core: setting this to `True` makes training 317 faster, but trained model will be different if step N and step N+1 318 involve the same set of embedding ID. Please see 319 `tpu_embedding_configuration.proto` for details. 320 321 Raises: 322 ValueError: if any input is invalid. 323 """ 324 _validate_table_to_config_dict(table_to_config_dict) 325 # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`. 326 self._table_to_config_dict = _create_ordered_dict(table_to_config_dict) 327 328 _validate_feature_to_table_dict(table_to_config_dict, feature_to_table_dict) 329 self._feature_to_table_dict = _create_ordered_dict(feature_to_table_dict) 330 self._table_to_features_dict = _create_table_to_features_dict( 331 self._feature_to_table_dict) 332 self._combiners = _create_combiners(self._table_to_config_dict, 333 self._table_to_features_dict) 334 335 self._batch_size = batch_size 336 337 self._master = master 338 self._cluster_def = cluster_def 339 self._tpu_system_metadata = ( 340 tpu_system_metadata_lib._query_tpu_system_metadata( # pylint: disable=protected-access 341 self._master, cluster_def=self._cluster_def)) 342 if self._tpu_system_metadata.num_cores == 0: 343 raise ValueError('TPUEmbedding needs TPUs, but master {} does not have ' 344 'TPUs.'.format(self._master)) 345 self._num_hosts = self._tpu_system_metadata.num_hosts 346 master_job_name = tpu_system_metadata_lib.master_job(self._master, 347 self._cluster_def) 348 self._hosts = sorted([ 349 device.name for device in self._tpu_system_metadata.devices 350 if 'device:CPU:' in device.name and (master_job_name is None or 351 master_job_name in device.name)]) 352 self._num_cores_per_host = self._tpu_system_metadata.num_of_cores_per_host 353 self._num_cores = self._tpu_system_metadata.num_cores 354 355 _validate_batch_size(self._batch_size, self._num_cores) 356 self._batch_size_per_core = self._batch_size // self._num_cores 357 358 # TODO(shizhiw): remove `mode`? 359 if mode == TRAINING: 360 _validate_optimization_parameters(optimization_parameters) 361 self._optimization_parameters = optimization_parameters 362 elif mode == INFERENCE: 363 if optimization_parameters is not None: 364 raise ValueError('`optimization_parameters` should be `None` ' 365 'for inference mode.') 366 self._optimization_parameters = ( 367 StochasticGradientDescentParameters(1.)) 368 else: 369 raise ValueError('`mode` only supports {} and {}; got {}.' 370 .format(TRAINING, INFERENCE, mode)) 371 self._mode = mode 372 373 # TODO(shizhiw): move `optimization_parameters` into `_optimizer_handler` 374 # and create special handler for inference that inherits from 375 # StochasticGradientDescentHandler with more user-friendly error message 376 # on get_slot(). 377 self._optimizer_handler = _get_optimization_handler( 378 self._optimization_parameters) 379 self._pipeline_execution_with_tensor_core = ( 380 pipeline_execution_with_tensor_core) 381 382 self._config_proto = self._create_config_proto() 383 384 @property 385 def hosts(self): 386 """A list of device names for CPU hosts. 387 388 Returns: 389 A list of device names for CPU hosts. 390 """ 391 return copy.copy(self._hosts) 392 393 # TODO(shizhiw): change to num_tensor_cores_per_host to be more explicit and 394 # to be consistent with `tpu_embedding_configuration.proto`. 395 @property 396 def num_cores_per_host(self): 397 """Number of TPU cores on a CPU host. 398 399 Returns: 400 Number of TPU cores on a CPU host. 401 """ 402 return self._num_cores_per_host 403 404 @property 405 def num_cores(self): 406 """Total number of TPU cores on all hosts. 407 408 Returns: 409 Total number of TPU cores on all hosts. 410 """ 411 return self._num_cores 412 413 @property 414 def batch_size_per_core(self): 415 """Batch size for each TPU core. 416 417 The sparse tensors in `sparse_features_list` to `generate_enqueue_ops` 418 must have batch dimension equal to this. 419 420 Returns: 421 Batch size for each TPU core. 422 """ 423 return self._batch_size_per_core 424 425 @property 426 def config_proto(self): 427 """Create embedding config proto for `tpu.initialize_system()`. 428 429 Returns: 430 an `TPUEmbeddingConfiguration` proto describing the desired 431 configuration of the hardware embedding lookup tables, which 432 is passed to `tpu.initialize_system()`. 433 """ 434 return self._config_proto 435 436 @property 437 def table_to_config_dict(self): 438 return copy.copy(self._table_to_config_dict) 439 440 @property 441 def feature_to_table_dict(self): 442 return copy.copy(self._feature_to_table_dict) 443 444 @property 445 def table_to_features_dict(self): 446 return copy.copy(self._table_to_features_dict) 447 448 @property 449 def optimization_parameters(self): 450 return self._optimization_parameters 451 452 def _create_config_proto(self): 453 """Create `TPUEmbeddingConfiguration`.""" 454 config_proto = elc.TPUEmbeddingConfiguration() 455 for table in self._table_to_config_dict: 456 table_descriptor = config_proto.table_descriptor.add() 457 table_descriptor.name = table 458 459 table_config = self._table_to_config_dict[table] 460 table_descriptor.vocabulary_size = table_config.vocabulary_size 461 table_descriptor.dimension = table_config.dimension 462 463 features_for_table = self._table_to_features_dict[table] 464 table_descriptor.num_features = len(features_for_table) 465 466 table_descriptor.optimization_parameters.learning_rate.constant = ( 467 self._optimization_parameters.learning_rate) 468 table_descriptor.optimization_parameters.gradient_accumulation_status = ( 469 optimization_parameters_pb2.GradientAccumulationStatus.ENABLED 470 if self._optimization_parameters.use_gradient_accumulation else 471 optimization_parameters_pb2.GradientAccumulationStatus.DISABLED) 472 self._optimizer_handler.set_optimization_parameters(table_descriptor) 473 474 config_proto.mode = self._mode 475 config_proto.batch_size_per_tensor_core = self._batch_size_per_core 476 config_proto.num_hosts = self._num_hosts 477 config_proto.num_tensor_cores = self._num_cores 478 config_proto.sharding_strategy = elc.TPUEmbeddingConfiguration.DIV_DEFAULT 479 config_proto.pipeline_execution_with_tensor_core = ( 480 self._pipeline_execution_with_tensor_core) 481 482 return config_proto 483 484 def create_variables_and_ops(self, embedding_variable_name_by_table=None, 485 slot_variable_names_by_table=None): 486 """Create embedding and slot variables, with ops to load and retrieve them. 487 488 Args: 489 embedding_variable_name_by_table: A dictionary mapping from string of 490 table name to string of embedding variable name. If `None`, 491 defaults from `get_default_slot_variable_names()` will be used. 492 slot_variable_names_by_table: A dictionary mapping from string of table 493 name to `AdamSlotVariableNames`, `AdagradSlotVariableNames` etc. If 494 `None`, defaults from `get_default_slot_variable_names()` will be used. 495 496 Returns: 497 `tpu_embedding.VariablesAndOps` with: 498 A dictionary mapping from string of table name to embedding variables, 499 A dictionary mapping from string of table name to AdagradSlotVariable, 500 AdamSlotVariables etc with slot variables, 501 A function which returns a list of ops to load embedding and slot 502 variables from TPU to CPU. 503 A function which returns a list of ops to retrieve embedding and slot 504 variables from TPU to CPU. 505 """ 506 embedding_variables_by_table = {} 507 slot_variables_by_table = {} 508 load_op_fns = [] 509 retrieve_op_fns = [] 510 for table in self._table_to_config_dict: 511 if embedding_variable_name_by_table: 512 embedding_variable_name = embedding_variable_name_by_table[table] 513 else: 514 embedding_variable_name = table 515 if slot_variable_names_by_table: 516 slot_variable_names = slot_variable_names_by_table[table] 517 else: 518 slot_variable_names = ( 519 self._optimizer_handler.get_default_slot_variable_names(table)) 520 521 device_fn = _create_device_fn(self._hosts) 522 with ops.device(device_fn): 523 table_variables = _create_partitioned_variables( 524 name=embedding_variable_name, 525 num_hosts=self._num_hosts, 526 vocabulary_size=self._table_to_config_dict[table].vocabulary_size, 527 embedding_dimension=self._table_to_config_dict[table].dimension, 528 initializer=self._table_to_config_dict[table].initializer, 529 collections=[ops.GraphKeys.GLOBAL_VARIABLES]) 530 embedding_variables_by_table[table] = table_variables 531 532 slot_variables_for_table, load_ops_fn, retrieve_ops_fn = ( 533 self._optimizer_handler.create_variables_and_ops( 534 table, slot_variable_names, self._num_hosts, 535 self._table_to_config_dict[table], table_variables) 536 ) 537 slot_variables_by_table[table] = slot_variables_for_table 538 load_op_fns.append(load_ops_fn) 539 retrieve_op_fns.append(retrieve_ops_fn) 540 541 def load_ops(): 542 """Calls and returns the load ops for each embedding table. 543 544 Returns: 545 A list of ops to load embedding and slot variables from CPU to TPU. 546 """ 547 load_ops_list = [] 548 for load_op_fn in load_op_fns: 549 load_ops_list.extend(load_op_fn()) 550 return load_ops_list 551 552 def retrieve_ops(): 553 """Calls and returns the retrieve ops for each embedding table. 554 555 Returns: 556 A list of ops to retrieve embedding and slot variables from TPU to CPU. 557 """ 558 retrieve_ops_list = [] 559 for retrieve_op_fn in retrieve_op_fns: 560 retrieve_ops_list.extend(retrieve_op_fn()) 561 return retrieve_ops_list 562 563 return VariablesAndOps(embedding_variables_by_table, 564 slot_variables_by_table, 565 load_ops, retrieve_ops) 566 567 def generate_enqueue_ops(self, sparse_features_list): 568 """Generate enqueue ops. 569 570 Args: 571 sparse_features_list: a list of dictionary mapping from string 572 of feature names to sparse tensor. Each dictionary is for one 573 TPU core. Dictionaries for the same host should be contiguous 574 on the list. 575 576 Returns: 577 Ops to enqueue to TPU for embedding. 578 """ 579 self._validate_generate_enqueue_ops_sparse_features_list( 580 sparse_features_list) 581 return [ 582 self._generate_enqueue_op( 583 sparse_features, device_ordinal=i % self._num_cores_per_host) 584 for i, sparse_features in enumerate(sparse_features_list) 585 ] 586 587 def _validate_generate_enqueue_ops_sparse_features_list( 588 self, sparse_features_list): 589 """Validate `sparse_features_list`.""" 590 if len(sparse_features_list) != self._num_cores: 591 raise ValueError('Length of `sparse_features_list` should match the ' 592 'number of cores; ' 593 '`len(sparse_features_list)` is {}, ' 594 'number of cores is {}.'.format( 595 len(sparse_features_list), self._num_cores)) 596 597 feature_set = set(self._feature_to_table_dict.keys()) 598 contiguous_device = None 599 for i, sparse_features in enumerate(sparse_features_list): 600 used_feature_set = set(sparse_features.keys()) 601 602 # Check features are valid. 603 missing_feature_set = feature_set - used_feature_set 604 if missing_feature_set: 605 raise ValueError('`sparse_features_list[{}]` misses a feature that is ' 606 'in `feature_to_config_dict`: {}.'.format( 607 i, missing_feature_set)) 608 609 extra_feature_set = used_feature_set - feature_set 610 if extra_feature_set: 611 raise ValueError('`sparse_features_list[{}]` has a feature that is not ' 612 'in `feature_to_config_dict`: {}.'.format( 613 i, extra_feature_set)) 614 615 device = None 616 device_feature = None 617 for feature, tensor in six.iteritems(sparse_features): 618 combiner = self._table_to_config_dict[ 619 self._feature_to_table_dict[feature]].combiner 620 if not isinstance(tensor, sparse_tensor.SparseTensor) and combiner: 621 raise ValueError('`sparse_features_list[{}]` has a feature that is ' 622 'not mapped to `SparseTensor` and has a combiner. ' 623 '`feature`: {}, combiner: {}'.format( 624 i, feature, combiner)) 625 626 # Check all features are on the same device. 627 if device is None: 628 device = tensor.op.device 629 device_feature = feature 630 else: 631 if device != tensor.op.device: 632 raise ValueError('Devices are different between features in ' 633 '`sparse_features_list[{}]`; ' 634 'devices: {}, {}; features: {}, {}.'.format( 635 i, device, tensor.op.device, feature, 636 device_feature)) 637 638 if i % self._num_cores_per_host: 639 if device != contiguous_device: 640 raise ValueError('We expect the `sparse_features` which are on the ' 641 'same host to be contiguous in ' 642 '`sparse_features_list`, ' 643 '`sparse_features_list[{}]` is on device {}, ' 644 'but is expected to be on device {}.'.format( 645 i, device, contiguous_device)) 646 else: 647 contiguous_device = device 648 649 def _generate_enqueue_op(self, sparse_features, device_ordinal): 650 with ops.colocate_with(list(sparse_features.values())[0]): 651 sample_idcs, embedding_idcs, aggregation_weights, table_ids = ( 652 self._format_for_tpu_embedding_sparse_tensor_batch(sparse_features)) 653 return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( 654 sample_idcs, 655 embedding_idcs, 656 aggregation_weights, 657 table_ids, 658 device_ordinal=device_ordinal, 659 combiners=self._combiners) 660 661 def _format_for_tpu_embedding_sparse_tensor_batch(self, sparse_features): 662 """Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`. 663 664 Args: 665 sparse_features: a `Dict` of tensors for embedding. Can be sparse or 666 dense. 667 668 Returns: 669 Arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`. 670 """ 671 672 sample_idcs, embedding_idcs, aggregation_weights, table_ids = ( 673 list(), list(), list(), list()) 674 for table_id, table in enumerate(self._table_to_features_dict): 675 features = self._table_to_features_dict[table] 676 for feature in features: 677 tensor = sparse_features[feature] 678 if not isinstance(tensor, sparse_tensor.SparseTensor): 679 sample_idcs.append(array_ops.zeros([0], dtype=dtypes.int32)) 680 embedding_idcs.append(tensor) 681 else: 682 sample_idcs.append(tensor.indices) 683 embedding_idcs.append(tensor.values) 684 aggregation_weights.append(array_ops.zeros([0])) 685 table_ids.append(table_id) 686 687 return sample_idcs, embedding_idcs, aggregation_weights, table_ids 688 689 def get_activations(self): 690 """Get activations for features. 691 692 This should be called within `computation` that is passed to 693 `tpu.replicate` and friends. 694 695 Returns: 696 A dictionary mapping from `String` of feature name to `Tensor` 697 of activation. 698 """ 699 recv_activations = tpu_ops.recv_tpu_embedding_activations( 700 num_outputs=len(self._table_to_config_dict), 701 config=self._config_proto.SerializeToString()) 702 703 activations = collections.OrderedDict() 704 for table_id, table in enumerate(self._table_to_features_dict): 705 features = self._table_to_features_dict[table] 706 for lookup_id, feature in enumerate(features): 707 stride = len(self._table_to_features_dict[table]) 708 activations[feature] = recv_activations[table_id][lookup_id::stride, :] 709 return activations 710 711 def generate_send_gradients_op(self, feature_to_gradient_dict): 712 """Send gradient to TPU embedding. 713 714 Args: 715 feature_to_gradient_dict: dict mapping feature names to gradient wrt 716 activations. 717 718 Returns: 719 SendTPUEmbeddingGradients Op. 720 721 Raises: 722 RuntimeError: If `mode` is not `TRAINING`. 723 """ 724 if self._mode != TRAINING: 725 raise RuntimeError('Only in training mode gradients need to ' 726 'be sent to TPU embedding; got mode {}.' 727 .format(self._mode)) 728 gradients = [] 729 for table in self._table_to_features_dict: 730 features = self._table_to_features_dict[table] 731 table_gradients = [ 732 feature_to_gradient_dict[feature] for feature in features 733 ] 734 interleaved_table_grads = array_ops.reshape( 735 array_ops.stack(table_gradients, axis=1), 736 [-1, table_gradients[0].shape[1]]) 737 gradients.append(interleaved_table_grads) 738 return tpu_ops.send_tpu_embedding_gradients( 739 inputs=gradients, config=self.config_proto.SerializeToString()) 740 741 742def _validate_table_to_config_dict(table_to_config_dict): 743 """Validate `table_to_config_dict`.""" 744 for k, v in six.iteritems(table_to_config_dict): 745 if not isinstance(v, TableConfig): 746 raise ValueError('Value of `table_to_config_dict` must be of type ' 747 '`TableConfig`, got {} for {}.'.format(type(v), k)) 748 749 750def _validate_feature_to_table_dict(table_to_config_dict, 751 feature_to_table_dict): 752 """Validate `feature_to_table_dict`.""" 753 used_table_set = set(feature_to_table_dict.values()) 754 table_set = set(table_to_config_dict.keys()) 755 756 unused_table_set = table_set - used_table_set 757 if unused_table_set: 758 raise ValueError('`table_to_config_dict` specifies table that is not ' 759 'used in `feature_to_table_dict`: {}.' 760 .format(unused_table_set)) 761 762 extra_table_set = used_table_set - table_set 763 if extra_table_set: 764 raise ValueError('`feature_to_table_dict` refers to a table that is not ' 765 'specified in `table_to_config_dict`: {}.' 766 .format(extra_table_set)) 767 768 769def _validate_batch_size(batch_size, num_cores): 770 if batch_size % num_cores: 771 raise ValueError('`batch_size` is not a multiple of number of ' 772 'cores. `batch_size`={}, `_num_cores`={}.'.format( 773 batch_size, num_cores)) 774 775 776def _validate_optimization_parameters(optimization_parameters): 777 if not isinstance(optimization_parameters, _OptimizationParameters): 778 raise ValueError('`optimization_parameters` must inherit from ' 779 '`_OptimizationPramaters`. ' 780 '`type(optimization_parameters)`={}'.format( 781 type(optimization_parameters))) 782 783 784class _OptimizerHandler(object): 785 """Interface class for handling optimizer specific logic.""" 786 787 def __init__(self, optimization_parameters): 788 self._optimization_parameters = optimization_parameters 789 790 def set_optimization_parameters(self, table_descriptor): 791 raise NotImplementedError() 792 793 def get_default_slot_variable_names(self, table): 794 raise NotImplementedError() 795 796 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 797 table_config, table_variables): 798 raise NotImplementedError() 799 800 801class _AdagradHandler(_OptimizerHandler): 802 """Handles Adagrad specific logic.""" 803 804 def __init__(self, optimization_parameters): 805 super(_AdagradHandler, self).__init__(optimization_parameters) 806 self._table_to_accumulator_variables_dict = {} 807 808 def set_optimization_parameters(self, table_descriptor): 809 table_descriptor.optimization_parameters.adagrad.SetInParent() 810 811 def get_default_slot_variable_names(self, table): 812 return AdagradSlotVariableName('{}/{}'.format(table, 'Adagrad')) 813 814 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 815 table_config, table_variables): 816 accumulator_initializer = init_ops.constant_initializer( 817 self._optimization_parameters.initial_accumulator) 818 accumulator_variables = _create_partitioned_variables( 819 name=slot_variable_names.accumulator, 820 num_hosts=num_hosts, 821 vocabulary_size=table_config.vocabulary_size, 822 embedding_dimension=table_config.dimension, 823 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 824 initializer=accumulator_initializer) 825 slot_variables = AdagradSlotVariable(accumulator_variables) 826 827 def load_ops_fn(): 828 """Returns the retrieve ops for AdaGrad embedding tables. 829 830 Returns: 831 A list of ops to load embedding and slot variables from CPU to TPU. 832 """ 833 load_op_list = [] 834 for host_id, table_variable, accumulator_variable in (zip( 835 range(num_hosts), table_variables, accumulator_variables)): 836 with ops.colocate_with(table_variable): 837 load_parameters_op = ( 838 tpu_ops.load_tpu_embedding_adagrad_parameters( 839 parameters=table_variable, 840 accumulators=accumulator_variable, 841 table_name=table, 842 num_shards=num_hosts, 843 shard_id=host_id)) 844 load_op_list.append(load_parameters_op) 845 return load_op_list 846 847 def retrieve_ops_fn(): 848 """Returns the retrieve ops for AdaGrad embedding tables. 849 850 Returns: 851 A list of ops to retrieve embedding and slot variables from TPU to CPU. 852 """ 853 retrieve_op_list = [] 854 for host_id, table_variable, accumulator_variable in (zip( 855 range(num_hosts), table_variables, accumulator_variables)): 856 with ops.colocate_with(table_variable): 857 retrieved_table, retrieved_accumulator = ( 858 tpu_ops.retrieve_tpu_embedding_adagrad_parameters( 859 table_name=table, 860 num_shards=num_hosts, 861 shard_id=host_id)) 862 retrieve_parameters_op = control_flow_ops.group( 863 state_ops.assign(table_variable, retrieved_table), 864 state_ops.assign(accumulator_variable, retrieved_accumulator)) 865 retrieve_op_list.append(retrieve_parameters_op) 866 return retrieve_op_list 867 868 return slot_variables, load_ops_fn, retrieve_ops_fn 869 870 871class _AdamHandler(_OptimizerHandler): 872 """Handles Adam specific logic.""" 873 874 def __init__(self, optimization_parameters): 875 super(_AdamHandler, self).__init__(optimization_parameters) 876 self._table_to_m_variables_dict = {} 877 self._table_to_v_variables_dict = {} 878 879 def set_optimization_parameters(self, table_descriptor): 880 table_descriptor.optimization_parameters.adam.beta1 = ( 881 self._optimization_parameters.beta1) 882 table_descriptor.optimization_parameters.adam.beta2 = ( 883 self._optimization_parameters.beta2) 884 table_descriptor.optimization_parameters.adam.epsilon = ( 885 self._optimization_parameters.epsilon) 886 table_descriptor.optimization_parameters.adam.use_non_lazy_adam = ( 887 not self._optimization_parameters.lazy_adam) 888 table_descriptor.optimization_parameters.adam.use_sum_inside_sqrt = ( 889 self._optimization_parameters.sum_inside_sqrt) 890 891 def get_default_slot_variable_names(self, table): 892 return AdamSlotVariableNames('{}/{}/m'.format(table, 'Adam'), 893 '{}/{}/v'.format(table, 'Adam')) 894 895 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 896 table_config, table_variables): 897 m_initializer = init_ops.zeros_initializer() 898 m_variables = _create_partitioned_variables( 899 name=slot_variable_names.m, 900 num_hosts=num_hosts, 901 vocabulary_size=table_config.vocabulary_size, 902 embedding_dimension=table_config.dimension, 903 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 904 initializer=m_initializer) 905 v_initializer = init_ops.zeros_initializer() 906 v_variables = _create_partitioned_variables( 907 name=slot_variable_names.v, 908 num_hosts=num_hosts, 909 vocabulary_size=table_config.vocabulary_size, 910 embedding_dimension=table_config.dimension, 911 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 912 initializer=v_initializer) 913 slot_variables = AdamSlotVariables(m_variables, v_variables) 914 915 def load_ops_fn(): 916 """Returns the retrieve ops for AdaGrad embedding tables. 917 918 Returns: 919 A list of ops to load embedding and slot variables from CPU to TPU. 920 """ 921 load_op_list = [] 922 for host_id, table_variable, m_variable, v_variable in (zip( 923 range(num_hosts), table_variables, 924 m_variables, v_variables)): 925 with ops.colocate_with(table_variable): 926 load_parameters_op = ( 927 tpu_ops.load_tpu_embedding_adam_parameters( 928 parameters=table_variable, 929 momenta=m_variable, 930 velocities=v_variable, 931 table_name=table, 932 num_shards=num_hosts, 933 shard_id=host_id)) 934 load_op_list.append(load_parameters_op) 935 return load_op_list 936 937 def retrieve_ops_fn(): 938 """Returns the retrieve ops for Adam embedding tables. 939 940 Returns: 941 A list of ops to retrieve embedding and slot variables from TPU to CPU. 942 """ 943 944 retrieve_op_list = [] 945 for host_id, table_variable, m_variable, v_variable in (zip( 946 range(num_hosts), table_variables, 947 m_variables, v_variables)): 948 with ops.colocate_with(table_variable): 949 retrieved_table, retrieved_m, retrieved_v = ( 950 tpu_ops.retrieve_tpu_embedding_adam_parameters( 951 table_name=table, 952 num_shards=num_hosts, 953 shard_id=host_id)) 954 retrieve_parameters_op = control_flow_ops.group( 955 state_ops.assign(table_variable, retrieved_table), 956 state_ops.assign(m_variable, retrieved_m), 957 state_ops.assign(v_variable, retrieved_v)) 958 959 retrieve_op_list.append(retrieve_parameters_op) 960 return retrieve_op_list 961 962 return slot_variables, load_ops_fn, retrieve_ops_fn 963 964 965class _StochasticGradientDescentHandler(_OptimizerHandler): 966 """Handles stochastic gradient descent specific logic.""" 967 968 def set_optimization_parameters(self, table_descriptor): 969 (table_descriptor.optimization_parameters.stochastic_gradient_descent 970 .SetInParent()) 971 972 def get_default_slot_variable_names(self, table): 973 return None 974 975 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 976 table_config, table_variables): 977 del table_config 978 979 def load_ops_fn(): 980 """Returns the retrieve ops for AdaGrad embedding tables. 981 982 Returns: 983 A list of ops to load embedding and slot variables from CPU to TPU. 984 """ 985 load_op_list = [] 986 for host_id, table_variable in (zip( 987 range(num_hosts), table_variables)): 988 with ops.colocate_with(table_variable): 989 load_parameters_op = ( 990 tpu_ops 991 .load_tpu_embedding_stochastic_gradient_descent_parameters( 992 parameters=table_variable, 993 table_name=table, 994 num_shards=num_hosts, 995 shard_id=host_id)) 996 997 load_op_list.append(load_parameters_op) 998 return load_op_list 999 1000 def retrieve_ops_fn(): 1001 """Returns the retrieve ops for SGD embedding tables. 1002 1003 Returns: 1004 A list of ops to retrieve embedding and slot variables from TPU to CPU. 1005 """ 1006 1007 retrieve_op_list = [] 1008 for host_id, table_variable in (zip( 1009 range(num_hosts), table_variables)): 1010 with ops.colocate_with(table_variable): 1011 retrieved_table = ( 1012 tpu_ops 1013 .retrieve_tpu_embedding_stochastic_gradient_descent_parameters( 1014 table_name=table, 1015 num_shards=num_hosts, 1016 shard_id=host_id)) 1017 retrieve_parameters_op = control_flow_ops.group( 1018 state_ops.assign(table_variable, retrieved_table)) 1019 1020 retrieve_op_list.append(retrieve_parameters_op) 1021 return retrieve_op_list 1022 1023 return None, load_ops_fn, retrieve_ops_fn 1024 1025 1026def _get_optimization_handler(optimization_parameters): 1027 if isinstance(optimization_parameters, AdagradParameters): 1028 return _AdagradHandler(optimization_parameters) 1029 elif isinstance(optimization_parameters, AdamParameters): 1030 return _AdamHandler(optimization_parameters) 1031 elif isinstance(optimization_parameters, StochasticGradientDescentParameters): 1032 return _StochasticGradientDescentHandler(optimization_parameters) 1033 else: 1034 return NotImplementedError() 1035 1036 1037def _create_ordered_dict(d): 1038 """Create an OrderedDict from Dict.""" 1039 return collections.OrderedDict((k, d[k]) for k in sorted(d)) 1040 1041 1042def _create_combiners(table_to_config_dict, table_to_features_dict): 1043 """Create a per feature list of combiners, ordered by table.""" 1044 combiners = [] 1045 for table in table_to_config_dict: 1046 combiner = table_to_config_dict[table].combiner or 'sum' 1047 combiners.extend([combiner] * len(table_to_features_dict[table])) 1048 return combiners 1049 1050 1051def _create_table_to_features_dict(feature_to_table_dict): 1052 """Create mapping from table to a list of its features.""" 1053 table_to_features_dict_tmp = {} 1054 for feature, table in six.iteritems(feature_to_table_dict): 1055 if table in table_to_features_dict_tmp: 1056 table_to_features_dict_tmp[table].append(feature) 1057 else: 1058 table_to_features_dict_tmp[table] = [feature] 1059 1060 table_to_features_dict = collections.OrderedDict() 1061 for table in sorted(table_to_features_dict_tmp): 1062 table_to_features_dict[table] = sorted(table_to_features_dict_tmp[table]) 1063 return table_to_features_dict 1064 1065 1066def _create_device_fn(hosts): 1067 """Create device_fn() to use with _create_partitioned_variables().""" 1068 1069 def device_fn(op): 1070 """Returns the `device` for `op`.""" 1071 part_match = re.match(r'.*/part_(\d+)(/|$)', op.name) 1072 1073 if part_match: 1074 idx = int(part_match.group(1)) 1075 else: 1076 raise RuntimeError('Internal Error: ' 1077 'Expected %s to contain /part_*.' % op.name) 1078 1079 device = hosts[idx] 1080 return device 1081 1082 return device_fn 1083 1084 1085def _create_partitioned_variables(name, 1086 num_hosts, 1087 vocabulary_size, 1088 embedding_dimension, 1089 initializer, 1090 collections=None): # pylint: disable=redefined-outer-name 1091 """Creates ParitionedVariables based on `num_hosts` for `table`.""" 1092 # TODO(shizhiw): automatically place embedding lookup elsewhere? 1093 if vocabulary_size < num_hosts: 1094 raise ValueError('`vocabulary_size`({}) is smaller than `num_hosts`({}). ' 1095 'As TPU embedding is not optimized for small tables, ' 1096 'please consider other ways for this embedding lookup.') 1097 1098 return list(variable_scope.get_variable( 1099 name, 1100 shape=(vocabulary_size, embedding_dimension), 1101 partitioner=partitioned_variables.fixed_size_partitioner(num_hosts), 1102 dtype=dtypes.float32, 1103 initializer=initializer, 1104 collections=collections, 1105 trainable=False)) 1106