1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for TPU Embeddings mid level API on TPU.""" 16import itertools 17 18from absl.testing import parameterized 19import numpy as np 20 21from tensorflow.python.compat import v2_compat 22from tensorflow.python.distribute import distribute_lib 23from tensorflow.python.eager import backprop 24from tensorflow.python.eager import def_function 25from tensorflow.python.keras import optimizer_v2 26from tensorflow.python.platform import test 27from tensorflow.python.tpu import tpu_embedding_v1 28from tensorflow.python.tpu import tpu_embedding_v2_utils 29from tensorflow.python.tpu.tests import tpu_embedding_base_test 30 31 32_SLOT_NAME_MAPPING = { 33 # Slot names in Keras optimizer v2 are different compared to the slot names 34 # in our API. 35 optimizer_v2.adagrad.Adagrad: { 36 'accumulators': 'accumulator' 37 }, 38 optimizer_v2.adam.Adam: { 39 'momenta': 'm', 40 'velocities': 'v' 41 }, 42 optimizer_v2.ftrl.Ftrl: { 43 'accumulators': 'accumulator', 44 'linears': 'linear' 45 }, 46} 47 48 49class TPUEmbeddingV0CorrectnessTest(tpu_embedding_base_test.TPUEmbeddingBaseTest 50 ): 51 52 def _get_strategy(self): 53 if hasattr(self, 'strategy'): 54 return self.strategy 55 return super(TPUEmbeddingV0CorrectnessTest, self)._get_strategy() 56 57 def _create_mid_level(self, optimizer=None): 58 # Create `TPUEmbedding` object. 59 if optimizer is None: 60 optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1) 61 62 return tpu_embedding_v1.TPUEmbeddingV0( 63 feature_config=self.feature_config, optimizer=optimizer) 64 65 def _get_slot_variable_creation_fn(self, optimizer): 66 # This is needed so that the mid level API can create slots using a user 67 # passed optimizer rather than the built-in methods. This allows a user to 68 # train the same model on CPU and TPU. 69 def slot_variable_creation_fn(table, slot_names, slot_initializers): 70 slots = {} 71 for slot, initializer in zip(slot_names, slot_initializers): 72 slots[slot] = optimizer.add_slot( 73 table, _SLOT_NAME_MAPPING[type(optimizer)][slot], initializer) 74 return slots 75 76 return slot_variable_creation_fn 77 78 def _create_strategy_and_mid_level(self, optimizer_name): 79 strategy = self._get_strategy() 80 81 # Keras optimizers has to be translated to embedding optimizer with slot 82 # variable creation fn properly populated. 83 with strategy.scope(): 84 if optimizer_name == 'sgd': 85 optimizer = optimizer_v2.gradient_descent.SGD(learning_rate=0.1) 86 embedding_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1) 87 elif optimizer_name == 'adagrad': 88 optimizer = optimizer_v2.adagrad.Adagrad(learning_rate=0.1) 89 embedding_optimizer = tpu_embedding_v2_utils.Adagrad( 90 learning_rate=0.1, 91 slot_variable_creation_fn=self._get_slot_variable_creation_fn( 92 optimizer)) 93 elif optimizer_name == 'adam': 94 optimizer = optimizer_v2.adam.Adam(learning_rate=0.1) 95 embedding_optimizer = tpu_embedding_v2_utils.Adam( 96 learning_rate=0.1, 97 slot_variable_creation_fn=self._get_slot_variable_creation_fn( 98 optimizer)) 99 elif optimizer_name == 'ftrl': 100 optimizer = optimizer_v2.ftrl.Ftrl(learning_rate=0.1) 101 embedding_optimizer = tpu_embedding_v2_utils.FTRL( 102 learning_rate=0.1, 103 slot_variable_creation_fn=self._get_slot_variable_creation_fn( 104 optimizer)) 105 else: 106 raise ValueError('optimizer is not recognized: ', optimizer_name) 107 108 mid_level_api = self._create_mid_level(optimizer=embedding_optimizer) 109 110 return strategy, mid_level_api, optimizer 111 112 @parameterized.parameters( 113 *itertools.product(['sgd', 'adagrad', 'adam', 'ftrl'], [True, False], 114 [True, False], [True, False])) 115 def test_embedding(self, optimizer_name, training, sparse, 116 is_high_dimensional): 117 strategy, mid_level_api, optimizer = ( 118 self._create_strategy_and_mid_level(optimizer_name)) 119 120 if sparse: 121 if is_high_dimensional: 122 dataset = self._create_high_dimensional_sparse_dataset(strategy) 123 else: 124 dataset = self._create_sparse_dataset(strategy) 125 else: 126 if is_high_dimensional: 127 dataset = self._create_high_dimensional_sparse_dataset(strategy) 128 else: 129 dataset = self._create_ragged_dataset(strategy) 130 131 dist = strategy.experimental_distribute_dataset( 132 dataset, 133 options=distribute_lib.InputOptions(experimental_fetch_to_device=False)) 134 dist_iter = iter(dist) 135 136 @def_function.function 137 def test_fn(): 138 """Create and run computation that returns the embedding activations.""" 139 140 def step(data): 141 if not training: 142 activations = mid_level_api(data) 143 total_loss = self._get_total_loss_tensor(activations) 144 ret_val = [total_loss] + list(activations) 145 return ret_val 146 else: 147 with backprop.GradientTape() as tape: 148 tape.watch(mid_level_api.embedding_tables.values()) 149 activations = mid_level_api(data) 150 total_loss = self._get_total_loss_tensor(activations) 151 loss_per_replica = total_loss / strategy.num_replicas_in_sync 152 gradients = tape.gradient(loss_per_replica, 153 mid_level_api.embedding_tables.values()) 154 optimizer.apply_gradients( 155 list(zip(gradients, mid_level_api.embedding_tables.values()))) 156 ret_val = [total_loss] + list(activations) 157 return ret_val 158 159 return strategy.run(step, args=(next(dist_iter),)) 160 161 # Run model. 162 shard_out_val = test_fn() 163 164 # Compute sparse tensors for global batch. 165 if is_high_dimensional: 166 input_data = next( 167 iter(self._create_high_dimensional_sparse_dataset(strategy))) 168 else: 169 input_data = next(iter(self._create_sparse_dataset(strategy))) 170 171 # Check results. 172 self._check_results(strategy, shard_out_val, training, input_data, 173 mid_level_api._variables, optimizer, 174 is_high_dimensional) 175 176 def _check_embedding_and_slot_variables(self, embedding_table_user_before, 177 gradients_wrt_user, 178 embedding_table_video_before, 179 gradients_wrt_video, optimizer, 180 table_to_variable): 181 if isinstance(optimizer, optimizer_v2.gradient_descent.SGD): 182 check_fn = self._check_embedding_and_slot_variables_for_sgd 183 elif isinstance(optimizer, optimizer_v2.adagrad.Adagrad): 184 check_fn = self._check_embedding_and_slot_variables_for_adagrad 185 elif isinstance(optimizer, optimizer_v2.adam.Adam): 186 check_fn = self._check_embedding_and_slot_variables_for_adam 187 elif isinstance(optimizer, optimizer_v2.ftrl.Ftrl): 188 check_fn = self._check_embedding_and_slot_variables_for_ftrl 189 else: 190 raise ValueError('optimizer is not recognized: ', type(optimizer)) 191 check_fn(embedding_table_user_before, gradients_wrt_user, optimizer, 192 table_to_variable[self.table_user.name]) 193 check_fn(embedding_table_video_before, gradients_wrt_video, optimizer, 194 table_to_variable[self.table_video.name]) 195 196 def _check_embedding_and_slot_variables_for_sgd(self, embedding_table_before, 197 gradients, optimizer, 198 variables): 199 embedding_table = np.copy(embedding_table_before) 200 config = optimizer.get_config() 201 embedding_table -= config['learning_rate'] * np.sum(gradients, axis=0) 202 self.assertAllClose( 203 self._get_variable(variables['parameters']).numpy(), embedding_table) 204 205 def _check_embedding_and_slot_variables_for_adagrad(self, 206 embedding_table_before, 207 gradients, optimizer, 208 variables): 209 embedding_table = np.copy(embedding_table_before) 210 config = optimizer.get_config() 211 accumulator = ( 212 config['initial_accumulator_value'] + np.sum(gradients, axis=0)**2) 213 embedding_table -= ( 214 config['learning_rate'] * np.sum(gradients, axis=0) / 215 np.sqrt(accumulator)) 216 self.assertAllClose( 217 self._get_variable(variables['parameters']).numpy(), embedding_table) 218 self.assertAllClose( 219 self._get_variable(variables['accumulators']).numpy(), accumulator) 220 221 def _check_embedding_and_slot_variables_for_adam(self, embedding_table_before, 222 gradients, optimizer, 223 variables): 224 embedding_table = np.copy(embedding_table_before) 225 config = optimizer.get_config() 226 g = np.sum(gradients, axis=0) 227 v = g**2 * (1 - config['beta_2']) 228 m = g * (1 - config['beta_1']) 229 epsilon = config['epsilon'] 230 lr_modifier = np.sqrt(1 - config['beta_2']) / (1 - config['beta_1']) 231 embedding_table -= ( 232 m * config['learning_rate'] * lr_modifier / (np.sqrt(v) + epsilon)) 233 self.assertAllClose( 234 self._get_variable(variables['parameters']).numpy(), 235 embedding_table, 236 rtol=1e-3) 237 self.assertAllClose( 238 self._get_variable(variables['momenta']).numpy(), m, rtol=1e-4) 239 self.assertAllClose( 240 self._get_variable(variables['velocities']).numpy(), v, rtol=1e-4) 241 242 def _check_embedding_and_slot_variables_for_ftrl(self, embedding_table_before, 243 gradients, optimizer, 244 variables): 245 embedding_table = np.copy(embedding_table_before) 246 config = optimizer.get_config() 247 neg_lr_p = -config['learning_rate_power'] 248 accumulator = ( 249 config['initial_accumulator_value'] + np.sum(gradients, axis=0)**2) 250 sigma = (accumulator**neg_lr_p - config['initial_accumulator_value']** 251 neg_lr_p) / config['learning_rate'] 252 linear = np.sum(gradients, axis=0) - sigma * embedding_table 253 quadratic = accumulator**neg_lr_p / config['learning_rate'] 254 embedding_table = -linear / quadratic 255 actual_parameters = self._get_variable(variables['parameters']).numpy() 256 # For entries where `linear` == 0, it is not worth comparing since the 257 # initial values have not been touched yet and they will not agree with what 258 # the actual values should be. 259 actual_parameters *= (linear != 0.0) 260 # FTRL has a bit more precision diff on parameters. 261 self.assertAllClose(actual_parameters, embedding_table, rtol=5e-5) 262 self.assertAllClose( 263 self._get_variable(variables['linears']).numpy(), linear, rtol=5e-4) 264 self.assertAllClose( 265 self._get_variable(variables['accumulators']).numpy(), accumulator) 266 267 @parameterized.parameters(True, False) 268 def test_enqueue_with_weights(self, ragged): 269 strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd') 270 weight = 0.5 271 if ragged: 272 dataset = self._create_ragged_dataset( 273 strategy, include_weights=True, weight=weight) 274 else: 275 dataset = self._create_sparse_dataset( 276 strategy, include_weights=True, weight=weight) 277 278 dataset_iter = iter( 279 strategy.experimental_distribute_dataset( 280 dataset, 281 options=distribute_lib.InputOptions( 282 experimental_fetch_to_device=False))) 283 284 @def_function.function 285 def embedding_lookup(features, weights): 286 287 def step(features, weights): 288 return mid_level_api(features, weights) 289 290 return strategy.run(step, args=(features, weights)) 291 292 features, weights = next(dataset_iter) 293 # Replace the weight for the second feature by None to test. 294 weights = (weights[0], None, weights[2]) 295 296 no_weights_activations = embedding_lookup(features, weights=None) 297 weights_activations = embedding_lookup(features, weights=weights) 298 299 no_weights0 = (self._unpack(strategy, no_weights_activations[0]), 300 self._unpack(strategy, no_weights_activations[1]), 301 self._unpack(strategy, no_weights_activations[2])) 302 weights0 = (self._unpack(strategy, weights_activations[0]), 303 self._unpack(strategy, weights_activations[1]), 304 self._unpack(strategy, weights_activations[2])) 305 # videos table has sum combiner and users table has mean combiner. 306 # i.e. users table lookups isn't affected by the weights as all the weights 307 # are the same. 308 # Tuple entry 0 and 1 are the watched and favorited features from the videos 309 # table and entry 2 is the friends feature from the users table. 310 # Note that None was passed as a weight for entry 1 so weight should have no 311 # effect. 312 weight = (0.5, 1.0, 1.0) 313 golden = tuple([no_weight * w for no_weight, w in zip(no_weights0, weight)]) 314 315 self.assertAllClose(golden, weights0) 316 317if __name__ == '__main__': 318 v2_compat.enable_v2_behavior() 319 test.main() 320