1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Companion classes for mid level API for TPU Embeddings in TF2.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20from __future__ import unicode_literals 21 22import abc 23import math 24import typing 25from typing import Any, Dict, Callable, List, Optional, Text, Tuple, TypeVar, Union 26 27from absl import logging 28import six 29 30from tensorflow.core.protobuf.tpu import optimization_parameters_pb2 31from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 32from tensorflow.python.distribute import sharded_variable 33from tensorflow.python.framework import ops 34from tensorflow.python.ops import init_ops_v2 35from tensorflow.python.ops import variables as tf_variables 36from tensorflow.python.tpu.ops import tpu_ops 37from tensorflow.python.types import core 38from tensorflow.python.util.tf_export import tf_export 39 40 41TableVariable = TypeVar("TableVariable", sharded_variable.ShardedVariable, 42 tf_variables.Variable) 43SlotVarCreationFnType = Callable[ 44 [TableVariable, List[Text], List[init_ops_v2.Initializer]], 45 Dict[Text, TableVariable]] 46ClipValueType = Union[Tuple[float, float], float] 47 48 49@six.add_metaclass(abc.ABCMeta) 50class _Optimizer(object): 51 """Base class for all optimizers, with common parameters.""" 52 53 def __init__( 54 self, 55 learning_rate: Union[float, Callable[[], float]], 56 use_gradient_accumulation: bool, 57 clip_weight_min: Optional[float], 58 clip_weight_max: Optional[float], 59 weight_decay_factor: Optional[float], 60 multiply_weight_decay_factor_by_learning_rate: bool, 61 clipvalue: Optional[ClipValueType] = None, 62 slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None): 63 self.learning_rate = learning_rate 64 self.use_gradient_accumulation = use_gradient_accumulation 65 self.clip_weight_min = clip_weight_min 66 self.clip_weight_max = clip_weight_max 67 if not use_gradient_accumulation and clipvalue is not None: 68 raise ValueError( 69 f"When `use_gradient_accumulation` is False, gradient clipping " 70 f"cannot be used and `clipvalue` should be left as None. " 71 f"Received value {clipvalue} for argument `clipvalue`.") 72 if clipvalue is None: 73 clipvalue = (None, None) 74 elif not isinstance(clipvalue, tuple): 75 clipvalue = (-1. * clipvalue, clipvalue) 76 self.clip_gradient_min, self.clip_gradient_max = clipvalue 77 78 self.weight_decay_factor = weight_decay_factor 79 self.multiply_weight_decay_factor_by_learning_rate = ( 80 multiply_weight_decay_factor_by_learning_rate) 81 82 if (slot_variable_creation_fn is not None and 83 not callable(slot_variable_creation_fn)): 84 raise ValueError( 85 f"Argument `slot_variable_creation_fn` must be either None or a " 86 f"callable. Received: {slot_variable_creation_fn}") 87 self.slot_variable_creation_fn = slot_variable_creation_fn 88 89 @abc.abstractmethod 90 def _slot_names(self) -> List[Text]: 91 """Returns the name of all the slot variables. 92 93 This does not include the 'parameters' variable and these names must match 94 the names of the slots variables as used in the corresponding 95 `tpu_ops.load_tpu_embedding_*` ops. 96 """ 97 raise NotImplementedError 98 99 @abc.abstractmethod 100 def _slot_initializers(self) -> List[init_ops_v2.Initializer]: 101 """Returns initializers for slot variables. 102 103 This returns a parallel list to self._slot_names(). 104 """ 105 raise NotImplementedError 106 107 def _set_optimization_parameters( 108 self, parameters: optimization_parameters_pb2.OptimizationParameters): 109 """Sets the optimizer fields in the OptimizationParameters.""" 110 if self.use_gradient_accumulation: 111 parameters.gradient_accumulation_status = ( 112 optimization_parameters_pb2.GradientAccumulationStatus.ENABLED) 113 else: 114 parameters.gradient_accumulation_status = ( 115 optimization_parameters_pb2.GradientAccumulationStatus.DISABLED) 116 117 if self.clip_weight_min is not None: 118 parameters.clipping_limits.lower.value = self.clip_weight_min 119 120 if self.clip_weight_max is not None: 121 parameters.clipping_limits.upper.value = self.clip_weight_max 122 123 if self.clip_gradient_min is not None: 124 parameters.gradient_clipping_limits.lower.value = self.clip_gradient_min 125 126 if self.clip_gradient_max is not None: 127 parameters.gradient_clipping_limits.upper.value = self.clip_gradient_max 128 129 if self.weight_decay_factor: 130 parameters.weight_decay_factor = self.weight_decay_factor 131 if self.multiply_weight_decay_factor_by_learning_rate: 132 parameters.multiply_weight_decay_factor_by_learning_rate = True 133 134 @abc.abstractmethod 135 def _load(self) -> Callable[..., ops.Operation]: 136 """Returns the load function for the optimizer.""" 137 raise NotImplementedError 138 139 @abc.abstractmethod 140 def _retrieve(self) -> Callable[..., core.Tensor]: 141 """Returns the retrieve function for the optimizer.""" 142 raise NotImplementedError 143 144 def _create_slots( 145 self, table: "TableConfig", 146 variable_creator: Callable[[Text, init_ops_v2.Initializer], 147 tf_variables.Variable] 148 ) -> Dict[Text, tf_variables.Variable]: 149 """Creates slot variables for table. 150 151 Args: 152 table: The table variable to create slots for. 153 variable_creator: A function which creates variables. Takes parameters 154 'name', 'initializer'. 155 156 Returns: 157 A dict of variables, keyed by self._slot_names(). 158 """ 159 if self.slot_variable_creation_fn is not None: 160 return self.slot_variable_creation_fn(table, self._slot_names(), 161 self._slot_initializers()) 162 else: 163 slots = {} 164 for slot, initializer in zip(self._slot_names(), 165 self._slot_initializers()): 166 slots[slot] = variable_creator(slot, initializer) 167 return slots 168 169 170@tf_export("tpu.experimental.embedding.SGD") 171class SGD(_Optimizer): 172 """Optimization parameters for stochastic gradient descent for TPU embeddings. 173 174 Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer` 175 argument to set the global optimizer and its parameters: 176 177 ``` 178 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 179 ... 180 optimizer=tf.tpu.experimental.embedding.SGD(0.1)) 181 ``` 182 183 This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the 184 optimizer parameter to set a table specific optimizer. This will override the 185 optimizer and parameters for global embedding optimizer defined above: 186 187 ``` 188 table_one = tf.tpu.experimental.embedding.TableConfig( 189 vocabulary_size=..., 190 dim=..., 191 optimizer=tf.tpu.experimental.embedding.SGD(0.2)) 192 table_two = tf.tpu.experimental.embedding.TableConfig( 193 vocabulary_size=..., 194 dim=...) 195 196 feature_config = ( 197 tf.tpu.experimental.embedding.FeatureConfig( 198 table=table_one), 199 tf.tpu.experimental.embedding.FeatureConfig( 200 table=table_two)) 201 202 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 203 feature_config=feature_config, 204 batch_size=... 205 optimizer=tf.tpu.experimental.embedding.SGD(0.1)) 206 ``` 207 208 In the above example, the first feature will be looked up in a table that has 209 a learning rate of 0.2 while the second feature will be looked up in a table 210 that has a learning rate of 0.1. 211 212 See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a 213 complete description of these parameters and their impacts on the optimizer 214 algorithm. 215 """ 216 217 def __init__(self, 218 learning_rate: Union[float, Callable[[], float]] = 0.01, 219 clip_weight_min: Optional[float] = None, 220 clip_weight_max: Optional[float] = None, 221 weight_decay_factor: Optional[float] = None, 222 multiply_weight_decay_factor_by_learning_rate: bool = None, 223 clipvalue: Optional[ClipValueType] = None): 224 """Optimization parameters for stochastic gradient descent. 225 226 Args: 227 learning_rate: The learning rate. It should be a floating point value or a 228 callable taking no arguments for a dynamic learning rate. 229 clip_weight_min: the minimum value to clip by; None means -infinity. 230 clip_weight_max: the maximum value to clip by; None means +infinity. 231 weight_decay_factor: amount of weight decay to apply; None means that the 232 weights are not decayed. Weights are decayed by multiplying the weight 233 by this factor each step. 234 multiply_weight_decay_factor_by_learning_rate: if true, 235 `weight_decay_factor` is multiplied by the current learning rate. 236 clipvalue: Controls clipping of the gradient. Set to either a single 237 positive scalar value to get clipping or a tiple of scalar values (min, 238 max) to set a separate maximum or minimum. If one of the two entries is 239 None, then there will be no clipping that direction. Note if this is 240 set, you may see a decrease in performance as gradient accumulation 241 will be enabled (it is normally off for SGD as it has no affect on 242 accuracy). See 243 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for more 244 information on gradient accumulation and its impact on tpu embeddings. 245 """ 246 use_gradient_accumulation = clipvalue is not None 247 248 super(SGD, self).__init__( 249 learning_rate, use_gradient_accumulation, clip_weight_min, 250 clip_weight_max, weight_decay_factor, 251 multiply_weight_decay_factor_by_learning_rate, clipvalue) 252 253 def _slot_names(self) -> List[Text]: 254 return [] 255 256 def _slot_initializers(self) -> List[init_ops_v2.Initializer]: 257 return [] 258 259 def _set_optimization_parameters( 260 self, parameters: optimization_parameters_pb2.OptimizationParameters): 261 super(SGD, self)._set_optimization_parameters(parameters) 262 parameters.stochastic_gradient_descent.SetInParent() 263 264 def _load(self) -> Callable[..., ops.Operation]: 265 return tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters 266 267 def _retrieve(self) -> Callable[..., core.Tensor]: 268 return tpu_ops.retrieve_tpu_embedding_stochastic_gradient_descent_parameters 269 270 271@tf_export("tpu.experimental.embedding.Adagrad") 272class Adagrad(_Optimizer): 273 """Optimization parameters for Adagrad with TPU embeddings. 274 275 Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer` 276 argument to set the global optimizer and its parameters: 277 278 ```python 279 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 280 ... 281 optimizer=tf.tpu.experimental.embedding.Adagrad(0.1)) 282 ``` 283 284 This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the 285 optimizer parameter to set a table specific optimizer. This will override the 286 optimizer and parameters for global embedding optimizer defined above: 287 288 ```python 289 table_one = tf.tpu.experimental.embedding.TableConfig( 290 vocabulary_size=..., 291 dim=..., 292 optimizer=tf.tpu.experimental.embedding.Adagrad(0.2)) 293 table_two = tf.tpu.experimental.embedding.TableConfig( 294 vocabulary_size=..., 295 dim=...) 296 297 feature_config = ( 298 tf.tpu.experimental.embedding.FeatureConfig( 299 table=table_one), 300 tf.tpu.experimental.embedding.FeatureConfig( 301 table=table_two)) 302 303 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 304 feature_config=feature_config, 305 batch_size=... 306 optimizer=tf.tpu.experimental.embedding.Adagrad(0.1)) 307 ``` 308 309 In the above example, the first feature will be looked up in a table that has 310 a learning rate of 0.2 while the second feature will be looked up in a table 311 that has a learning rate of 0.1. 312 313 See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a 314 complete description of these parameters and their impacts on the optimizer 315 algorithm. 316 """ 317 318 def __init__( 319 self, 320 learning_rate: Union[float, Callable[[], float]] = 0.001, 321 initial_accumulator_value: float = 0.1, 322 use_gradient_accumulation: bool = True, 323 clip_weight_min: Optional[float] = None, 324 clip_weight_max: Optional[float] = None, 325 weight_decay_factor: Optional[float] = None, 326 multiply_weight_decay_factor_by_learning_rate: bool = None, 327 slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None, 328 clipvalue: Optional[ClipValueType] = None): 329 """Optimization parameters for Adagrad. 330 331 Args: 332 learning_rate: The learning rate. It should be a floating point value or a 333 callable taking no arguments for a dynamic learning rate. 334 initial_accumulator_value: initial accumulator for Adagrad. 335 use_gradient_accumulation: setting this to `False` makes embedding 336 gradients calculation less accurate but faster. 337 clip_weight_min: the minimum value to clip by; None means -infinity. 338 clip_weight_max: the maximum value to clip by; None means +infinity. 339 weight_decay_factor: amount of weight decay to apply; None means that the 340 weights are not decayed. 341 multiply_weight_decay_factor_by_learning_rate: if true, 342 `weight_decay_factor` is multiplied by the current learning rate. 343 slot_variable_creation_fn: If you wish do directly control the creation of 344 the slot variables, set this to a callable taking three parameters: a 345 table variable, a list of slot names to create for it, and a list of 346 initializers. This function should return a dict with the slot names 347 as keys and the created variables as values with types matching the 348 table variable. When set to None (the default), uses the built-in 349 variable creation. 350 clipvalue: Controls clipping of the gradient. Set to either a single 351 positive scalar value to get clipping or a tuple of scalar values (min, 352 max) to set a separate maximum or minimum. If one of the two entries is 353 None, then there will be no clipping that direction. 354 """ 355 super(Adagrad, self).__init__( 356 learning_rate, use_gradient_accumulation, clip_weight_min, 357 clip_weight_max, weight_decay_factor, 358 multiply_weight_decay_factor_by_learning_rate, clipvalue, 359 slot_variable_creation_fn) 360 if initial_accumulator_value <= 0: 361 raise ValueError( 362 f"Argument `initial_accumulator_value` must be a positive float. " 363 f"Received: {initial_accumulator_value}") 364 self.initial_accumulator_value = initial_accumulator_value 365 366 def _slot_names(self) -> List[Text]: 367 return ["accumulators"] 368 369 def _slot_initializers(self) -> List[init_ops_v2.Initializer]: 370 return [init_ops_v2.Constant(self.initial_accumulator_value)] 371 372 def _set_optimization_parameters( 373 self, parameters: optimization_parameters_pb2.OptimizationParameters): 374 super(Adagrad, self)._set_optimization_parameters(parameters) 375 parameters.adagrad.SetInParent() 376 377 def _load(self) -> Callable[..., ops.Operation]: 378 return tpu_ops.load_tpu_embedding_adagrad_parameters 379 380 def _retrieve(self) -> Callable[..., core.Tensor]: 381 return tpu_ops.retrieve_tpu_embedding_adagrad_parameters 382 383 384@tf_export("tpu.experimental.embedding.FTRL") 385class FTRL(_Optimizer): 386 """Optimization parameters for FTRL with TPU embeddings. 387 388 See Algorithm 1 of this 389 [paper](https://research.google.com/pubs/archive/41159.pdf). 390 391 Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer` 392 argument to set the global optimizer and its parameters: 393 394 ```python 395 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 396 ... 397 optimizer=tf.tpu.experimental.embedding.FTRL(0.1)) 398 ``` 399 400 This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the 401 optimizer parameter to set a table specific optimizer. This will override the 402 optimizer and parameters for global embedding optimizer defined above: 403 404 ```python 405 table_one = tf.tpu.experimental.embedding.TableConfig( 406 vocabulary_size=..., 407 dim=..., 408 optimizer=tf.tpu.experimental.embedding.FTRL(0.2)) 409 table_two = tf.tpu.experimental.embedding.TableConfig( 410 vocabulary_size=..., 411 dim=...) 412 413 feature_config = ( 414 tf.tpu.experimental.embedding.FeatureConfig( 415 table=table_one), 416 tf.tpu.experimental.embedding.FeatureConfig( 417 table=table_two)) 418 419 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 420 feature_config=feature_config, 421 batch_size=... 422 optimizer=tf.tpu.experimental.embedding.FTRL(0.1)) 423 ``` 424 425 In the above example, the first feature will be looked up in a table that has 426 a learning rate of 0.2 while the second feature will be looked up in a table 427 that has a learning rate of 0.1. 428 429 See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a 430 complete description of these parameters and their impacts on the optimizer 431 algorithm. 432 """ 433 434 def __init__( 435 self, 436 learning_rate: Union[float, Callable[[], float]] = 0.001, 437 learning_rate_power: float = -0.5, 438 l1_regularization_strength: float = 0.0, 439 l2_regularization_strength: float = 0.0, 440 beta: float = 0.0, 441 initial_accumulator_value: float = 0.1, 442 use_gradient_accumulation: bool = True, 443 clip_weight_min: Optional[float] = None, 444 clip_weight_max: Optional[float] = None, 445 weight_decay_factor: Optional[float] = None, 446 multiply_weight_decay_factor_by_learning_rate: bool = None, 447 slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None, 448 clipvalue: Optional[ClipValueType] = None, 449 multiply_linear_by_learning_rate: bool = False, 450 allow_zero_accumulator: bool = False): 451 """Optimization parameters for Adagrad. 452 453 Args: 454 learning_rate: The learning rate. It should be a floating point value or a 455 callable taking no arguments for a dynamic learning rate. 456 learning_rate_power: A float value, must be less or equal to zero. 457 Controls how the learning rate decreases during training. Use zero for a 458 fixed learning rate. 459 l1_regularization_strength: A float value, must be greater than or equal 460 to zero. 461 l2_regularization_strength: A float value, must be greater than or equal 462 to zero. 463 beta: A float value, representing the beta value from the paper. 464 initial_accumulator_value: The starting value for accumulators. Only zero 465 or positive values are allowed. 466 use_gradient_accumulation: setting this to `False` makes embedding 467 gradients calculation less accurate but faster. 468 clip_weight_min: the minimum value to clip by; None means -infinity. 469 clip_weight_max: the maximum value to clip by; None means +infinity. 470 weight_decay_factor: amount of weight decay to apply; None means that the 471 weights are not decayed. 472 multiply_weight_decay_factor_by_learning_rate: if true, 473 `weight_decay_factor` is multiplied by the current learning rate. 474 slot_variable_creation_fn: If you wish do directly control the creation of 475 the slot variables, set this to a callable taking three parameters: a 476 table variable, a list of slot names to create for it, and a list of 477 initializers. This function should return a dict with the slot names 478 as keys and the created variables as values with types matching the 479 table variable. When set to None (the default), uses the built-in 480 variable creation. 481 clipvalue: Controls clipping of the gradient. Set to either a single 482 positive scalar value to get clipping or a tuple of scalar values (min, 483 max) to set a separate maximum or minimum. If one of the two entries is 484 None, then there will be no clipping that direction. 485 multiply_linear_by_learning_rate: If set to True, a modified formula is 486 used for FTRL that treats the "linear" accumulator as being 487 pre-multiplied by the learning rate (i.e., the accumulator named 488 "linear" actually stores "linear * learning_rate"). Other than 489 checkpoint compatibility, this is mathematically equivalent for a static 490 learning rate; for a dynamic learning rate, it is nearly the same as 491 long as the learning rate does not change quickly. The benefit of this 492 is that the modified formula handles zero and near-zero learning rates 493 without producing NaNs, improving flexibility for learning rate ramp-up. 494 allow_zero_accumulator: If set to True, changes some internal formulas to 495 allow zero and near-zero accumulator values at the cost of some 496 performance; this only needs to be set if you are using an initial 497 accumulator value of zero, which is uncommon. 498 """ 499 super().__init__(learning_rate, use_gradient_accumulation, clip_weight_min, 500 clip_weight_max, weight_decay_factor, 501 multiply_weight_decay_factor_by_learning_rate, clipvalue, 502 slot_variable_creation_fn) 503 if initial_accumulator_value <= 0: 504 raise ValueError( 505 f"Argument `initial_accumulator_value` must be a positive float. " 506 f"Received: {initial_accumulator_value}") 507 self.initial_accumulator_value = initial_accumulator_value 508 self.learning_rate_power = learning_rate_power 509 self.l1_regularization_strength = l1_regularization_strength 510 self.l2_regularization_strength = l2_regularization_strength 511 self.beta = beta 512 self.multiply_linear_by_learning_rate = multiply_linear_by_learning_rate 513 self.allow_zero_accumulator = allow_zero_accumulator 514 515 def _slot_names(self) -> List[Text]: 516 return ["accumulators", "linears"] 517 518 def _slot_initializers(self) -> List[init_ops_v2.Initializer]: 519 return [ 520 init_ops_v2.Constant(self.initial_accumulator_value), 521 init_ops_v2.Constant() 522 ] 523 524 def _set_optimization_parameters( 525 self, parameters: optimization_parameters_pb2.OptimizationParameters): 526 super()._set_optimization_parameters(parameters) 527 ftrl = parameters.ftrl 528 ftrl.l1 = self.l1_regularization_strength 529 ftrl.l2 = self.l2_regularization_strength 530 ftrl.lr_power = self.learning_rate_power 531 ftrl.beta = self.beta 532 ftrl.multiply_linear_by_lr = self.multiply_linear_by_learning_rate 533 ftrl.allow_zero_accumulator = self.allow_zero_accumulator 534 535 def _load(self) -> Callable[..., ops.Operation]: 536 return tpu_ops.load_tpu_embedding_ftrl_parameters 537 538 def _retrieve(self) -> Callable[..., core.Tensor]: 539 return tpu_ops.retrieve_tpu_embedding_ftrl_parameters 540 541 542@tf_export("tpu.experimental.embedding.Adam") 543class Adam(_Optimizer): 544 """Optimization parameters for Adam with TPU embeddings. 545 546 Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer` 547 argument to set the global optimizer and its parameters: 548 549 NOTE: By default this optimizer is lazy, i.e. it will not apply the gradient 550 update of zero to rows that were not looked up. You can change this behavior 551 by setting `lazy_adam` to `False`. 552 553 ```python 554 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 555 ... 556 optimizer=tf.tpu.experimental.embedding.Adam(0.1)) 557 ``` 558 559 This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the 560 optimizer parameter to set a table specific optimizer. This will override the 561 optimizer and parameters for global embedding optimizer defined above: 562 563 ```python 564 table_one = tf.tpu.experimental.embedding.TableConfig( 565 vocabulary_size=..., 566 dim=..., 567 optimizer=tf.tpu.experimental.embedding.Adam(0.2)) 568 table_two = tf.tpu.experimental.embedding.TableConfig( 569 vocabulary_size=..., 570 dim=...) 571 572 feature_config = ( 573 tf.tpu.experimental.embedding.FeatureConfig( 574 table=table_one), 575 tf.tpu.experimental.embedding.FeatureConfig( 576 table=table_two)) 577 578 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 579 feature_config=feature_config, 580 batch_size=... 581 optimizer=tf.tpu.experimental.embedding.Adam(0.1)) 582 ``` 583 584 In the above example, the first feature will be looked up in a table that has 585 a learning rate of 0.2 while the second feature will be looked up in a table 586 that has a learning rate of 0.1. 587 588 See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a 589 complete description of these parameters and their impacts on the optimizer 590 algorithm. 591 """ 592 593 def __init__( 594 self, 595 learning_rate: Union[float, Callable[[], float]] = 0.001, 596 beta_1: float = 0.9, 597 beta_2: float = 0.999, 598 epsilon: float = 1e-07, 599 lazy_adam: bool = True, 600 sum_inside_sqrt: bool = True, 601 use_gradient_accumulation: bool = True, 602 clip_weight_min: Optional[float] = None, 603 clip_weight_max: Optional[float] = None, 604 weight_decay_factor: Optional[float] = None, 605 multiply_weight_decay_factor_by_learning_rate: bool = None, 606 slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None, 607 clipvalue: Optional[ClipValueType] = None): 608 """Optimization parameters for Adam. 609 610 See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a 611 complete description of these parameters and their impacts on the optimizer 612 algorithm. 613 614 Args: 615 learning_rate: The learning rate. It should be a floating point value or a 616 callable taking no arguments for a dynamic learning rate. 617 beta_1: A float value. The exponential decay rate for the 1st moment 618 estimates. 619 beta_2: A float value. The exponential decay rate for the 2nd moment 620 estimates. 621 epsilon: A small constant for numerical stability. 622 lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. 623 sum_inside_sqrt: When this is true, the Adam update formula is changed 624 from `m / (sqrt(v) + epsilon)` to `m / sqrt(v + epsilon**2)`. This 625 option improves the performance of TPU training and is not expected to 626 harm model quality. 627 use_gradient_accumulation: Setting this to `False` makes embedding 628 gradients calculation less accurate but faster. 629 clip_weight_min: the minimum value to clip by; None means -infinity. 630 clip_weight_max: the maximum value to clip by; None means +infinity. 631 weight_decay_factor: amount of weight decay to apply; None means that the 632 weights are not decayed. 633 multiply_weight_decay_factor_by_learning_rate: if true, 634 `weight_decay_factor` is multiplied by the current learning rate. 635 slot_variable_creation_fn: If you wish do directly control the creation of 636 the slot variables, set this to a callable taking three parameters: a 637 table variable, a list of slot names to create for it, and a list of 638 initializers. This function should return a dict with the slot names 639 as keys and the created variables as values with types matching the 640 table variable. When set to None (the default), uses the built-in 641 variable creation. 642 clipvalue: Controls clipping of the gradient. Set to either a single 643 positive scalar value to get clipping or a tiple of scalar values (min, 644 max) to set a separate maximum or minimum. If one of the two entries is 645 None, then there will be no clipping that direction. 646 """ 647 super(Adam, self).__init__( 648 learning_rate, use_gradient_accumulation, clip_weight_min, 649 clip_weight_max, weight_decay_factor, 650 multiply_weight_decay_factor_by_learning_rate, clipvalue, 651 slot_variable_creation_fn) 652 if beta_1 < 0. or beta_1 >= 1.: 653 raise ValueError( 654 f"Argument `beta_1` must be >= 0 and < 1. Received: {beta_1}.") 655 if beta_2 < 0. or beta_2 >= 1.: 656 raise ValueError( 657 f"Argument `beta_2` must be >= 0 and < 1. Received: {beta_1}.") 658 if epsilon <= 0.: 659 raise ValueError("epsilon must be positive; got {}.".format(epsilon)) 660 if not use_gradient_accumulation and not lazy_adam: 661 raise ValueError( 662 "When disabling lazy Adam (`lazy_adam=False`), " 663 "gradient accumulation must be used. " 664 "Set `use_gradient_accumulation` to False.") 665 666 self.beta_1 = beta_1 667 self.beta_2 = beta_2 668 self.epsilon = epsilon 669 self.lazy_adam = lazy_adam 670 self.sum_inside_sqrt = sum_inside_sqrt 671 672 def _slot_names(self) -> List[Text]: 673 return ["momenta", "velocities"] 674 675 def _slot_initializers(self) -> List[init_ops_v2.Initializer]: 676 return [init_ops_v2.Constant(), init_ops_v2.Constant()] 677 678 def _set_optimization_parameters( 679 self, parameters: optimization_parameters_pb2.OptimizationParameters): 680 super(Adam, self)._set_optimization_parameters(parameters) 681 parameters.adam.beta1 = self.beta_1 682 parameters.adam.beta2 = self.beta_2 683 parameters.adam.epsilon = self.epsilon 684 parameters.adam.use_non_lazy_adam = not self.lazy_adam 685 parameters.adam.use_sum_inside_sqrt = self.sum_inside_sqrt 686 687 def _load(self) -> Callable[..., ops.Operation]: 688 return tpu_ops.load_tpu_embedding_adam_parameters 689 690 def _retrieve(self) -> Callable[..., core.Tensor]: 691 return tpu_ops.retrieve_tpu_embedding_adam_parameters 692 693 694@tf_export("tpu.experimental.embedding.TableConfig") 695class TableConfig(object): 696 """Configuration data for one embedding table. 697 698 This class holds the configuration data for a single embedding table. It is 699 used as the `table` parameter of a 700 `tf.tpu.experimental.embedding.FeatureConfig`. Multiple 701 `tf.tpu.experimental.embedding.FeatureConfig` objects can use the same 702 `tf.tpu.experimental.embedding.TableConfig` object. In this case a shared 703 table will be created for those feature lookups. 704 705 ```python 706 table_config_one = tf.tpu.experimental.embedding.TableConfig( 707 vocabulary_size=..., 708 dim=...) 709 table_config_two = tf.tpu.experimental.embedding.TableConfig( 710 vocabulary_size=..., 711 dim=...) 712 feature_config = { 713 'feature_one': tf.tpu.experimental.embedding.FeatureConfig( 714 table=table_config_one), 715 'feature_two': tf.tpu.experimental.embedding.FeatureConfig( 716 table=table_config_one), 717 'feature_three': tf.tpu.experimental.embedding.FeatureConfig( 718 table=table_config_two)} 719 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 720 feature_config=feature_config, 721 batch_size=... 722 optimizer=tf.tpu.experimental.embedding.Adam(0.1)) 723 ``` 724 725 The above configuration has 2 tables, and three features. The first two 726 features will be looked up in the first table and the third feature will be 727 looked up in the second table. 728 729 """ 730 731 def __init__(self, 732 vocabulary_size: int, 733 dim: int, 734 initializer: Optional[Callable[[Any], None]], 735 optimizer: Optional[_Optimizer] = None, 736 combiner: Text = "mean", 737 name: Optional[Text] = None): 738 """Embedding table configuration. 739 740 Args: 741 vocabulary_size: Size of the table's vocabulary (number of rows). 742 dim: The embedding dimension (width) of the table. 743 initializer: A callable initializer taking one parameter, the shape of the 744 variable that will be initialized. Will be called once per task, to 745 initialize that task's shard of the embedding table. If not specified, 746 defaults to `truncated_normal_initializer` with mean `0.0` and standard 747 deviation `1/sqrt(dim)`. 748 optimizer: An optional instance of an optimizer parameters class, instance 749 of one of `tf.tpu.experimental.embedding.SGD`, 750 `tf.tpu.experimental.embedding.Adagrad` or 751 `tf.tpu.experimental.embedding.Adam`. It set will override the global 752 optimizer passed to `tf.tpu.experimental.embedding.TPUEmbedding`. 753 combiner: A string specifying how to reduce if there are multiple entries 754 in a single row. Currently 'mean', 'sqrtn', 'sum' are supported, with 755 'mean' the default. 'sqrtn' often achieves good accuracy, in particular 756 with bag-of-words columns. For more information, see 757 `tf.nn.embedding_lookup_sparse`. 758 name: An optional string used to name the table. Useful for debugging. 759 760 Returns: 761 `TableConfig`. 762 763 Raises: 764 ValueError: if `vocabulary_size` is not a positive integer. 765 ValueError: if `dim` is not a positive integer. 766 ValueError: if `initializer` is specified and is not callable. 767 ValueError: if `combiner` is not supported. 768 """ 769 if not isinstance(vocabulary_size, int) or vocabulary_size < 1: 770 raise ValueError( 771 f"Argument `vocabulary_size` must be an int and must be >= 1. " 772 f"Received: {vocabulary_size}") 773 774 if not isinstance(dim, int) or dim < 1: 775 raise ValueError( 776 f"Argument `dim` (embedding dimension) " 777 f"must be an int and must be >= 1. Received: {dim}") 778 779 if (initializer is not None) and (not callable(initializer)): 780 raise ValueError( 781 f"Argument `initializer` must be a callable (or None). " 782 f"Received: {initializer}") 783 if initializer is None: 784 initializer = init_ops_v2.TruncatedNormal(mean=0.0, 785 stddev=1/math.sqrt(dim)) 786 accepted_combiners = ("mean", "sum", "sqrtn") 787 if combiner not in accepted_combiners: 788 raise ValueError( 789 f"Argument `combiner` must be one of {accepted_combiners}. " 790 f"Received: {combiner}") 791 792 self.vocabulary_size = vocabulary_size 793 self.dim = dim 794 self.initializer = initializer 795 self.optimizer = optimizer 796 self.combiner = combiner 797 self.name = name 798 799 def __repr__(self): 800 # If using the default initializer, just print "None" for clarity. 801 initializer = self.initializer 802 803 if isinstance(initializer, init_ops_v2.TruncatedNormal): 804 # PY2 type checking can't infer type of initializer even after if. 805 initializer = typing.cast(init_ops_v2.TruncatedNormal, initializer) 806 if (initializer.mean == 0.0 807 and math.isclose(initializer.stddev, 1/math.sqrt(self.dim))): # pytype: disable=module-attr (math.isclose not in PY2) 808 initializer = None 809 810 return ( 811 "TableConfig(vocabulary_size={vocabulary_size!r}, dim={dim!r}, " 812 "initializer={initializer!r}, optimizer={optimizer!r}, " 813 "combiner={combiner!r}, name={name!r})".format( 814 vocabulary_size=self.vocabulary_size, 815 dim=self.dim, 816 initializer=initializer, 817 optimizer=self.optimizer, 818 combiner=self.combiner, 819 name=self.name,) 820 ) 821 822 823@tf_export("tpu.experimental.embedding.FeatureConfig") 824class FeatureConfig(object): 825 """Configuration data for one embedding feature. 826 827 This class holds the configuration data for a single embedding feature. The 828 main use is to assign features to `tf.tpu.experimental.embedding.TableConfig`s 829 via the table parameter: 830 831 ```python 832 table_config_one = tf.tpu.experimental.embedding.TableConfig( 833 vocabulary_size=..., 834 dim=...) 835 table_config_two = tf.tpu.experimental.embedding.TableConfig( 836 vocabulary_size=..., 837 dim=...) 838 feature_config = { 839 'feature_one': tf.tpu.experimental.embedding.FeatureConfig( 840 table=table_config_one), 841 'feature_two': tf.tpu.experimental.embedding.FeatureConfig( 842 table=table_config_one), 843 'feature_three': tf.tpu.experimental.embedding.FeatureConfig( 844 table=table_config_two)} 845 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 846 feature_config=feature_config, 847 batch_size=... 848 optimizer=tf.tpu.experimental.embedding.Adam(0.1)) 849 ``` 850 851 The above configuration has 2 tables, and three features. The first two 852 features will be looked up in the first table and the third feature will be 853 looked up in the second table. 854 855 When feeding features into `embedding.enqueue` they can be `tf.Tensor`s, 856 `tf.SparseTensor`s or `tf.RaggedTensor`s. When the argument 857 `max_sequence_length` is 0, the default, you should expect a output of 858 `embedding.dequeue` for this feature of shape `(batch_size, dim)`. If 859 `max_sequence_length` is greater than 0, the feature is embedded as a sequence 860 and padded up to the given length. The shape of the output for this feature 861 will be `(batch_size, max_sequence_length, dim)`. 862 """ 863 864 def __init__(self, 865 table: TableConfig, 866 max_sequence_length: int = 0, 867 validate_weights_and_indices: bool = True, 868 name: Optional[Text] = None): 869 """Feature configuration. 870 871 Args: 872 table: An instance of `tf.tpu.experimental.embedding.TableConfig`, 873 describing the table in which this feature should be looked up. 874 max_sequence_length: If positive, the feature is a sequence feature with 875 the corresponding maximum sequence length. If the sequence is longer 876 than this, it will be truncated. If 0, the feature is not a sequence 877 feature. 878 validate_weights_and_indices: If true, uses safe_embedding_lookup 879 during serving which ensures there are no empty rows and all weights 880 and ids are positive at the expense of extra compute cost. 881 name: An optional name for the feature, useful for debugging. 882 883 Returns: 884 `FeatureConfig`. 885 886 Raises: 887 ValueError: if `table` is not an instance of 888 `tf.tpu.experimental.embedding.TableConfig`. 889 ValueError: if `max_sequence_length` not an integer or is negative. 890 """ 891 if not isinstance(table, TableConfig): 892 raise ValueError(f"Argument `table` has invalid type {type(table)}. " 893 "Expected `tf.tpu.experimental.embedding.TableConfig`.") 894 895 if not isinstance(max_sequence_length, int) or max_sequence_length < 0: 896 raise ValueError( 897 f"Argument `max_sequence_length` must be an int and must be >= 0. " 898 f"Received: {max_sequence_length}") 899 900 self.table = table 901 self.max_sequence_length = max_sequence_length 902 self.name = name 903 904 if not isinstance( 905 validate_weights_and_indices, bool): 906 raise ValueError( 907 f"Argument `validate_weights_and_indices` must be a boolean. " 908 f"Received: {validate_weights_and_indices}") 909 910 self.validate_weights_and_indices = validate_weights_and_indices 911 912 def __repr__(self): 913 return ("FeatureConfig(table={table!r}, " 914 "max_sequence_length={max_sequence_length!r}, " 915 "validate_weights_and_indices={" 916 "validate_weights_and_indices!r}, name={name!r})".format( 917 table=self.table, 918 max_sequence_length=self.max_sequence_length, 919 validate_weights_and_indices=self.validate_weights_and_indices, 920 name=self.name)) 921 922 923def log_tpu_embedding_configuration( 924 config: tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration) -> None: 925 """Logs a TPUEmbeddingConfiguration proto across multiple statements. 926 927 Args: 928 config: TPUEmbeddingConfiguration proto to log. Necessary because 929 logging.info has a maximum length to each log statement, which 930 particularly large configs can exceed. 931 """ 932 logging.info("Beginning log of TPUEmbeddingConfiguration.") 933 for line in str(config).splitlines(): 934 logging.info(line) 935 logging.info("Done with log of TPUEmbeddingConfiguration.") 936