• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for TPU Embeddings mid level API on TPU."""
16import itertools
17
18from absl.testing import parameterized
19import numpy as np
20
21from tensorflow.python.compat import v2_compat
22from tensorflow.python.distribute import distribute_lib
23from tensorflow.python.eager import backprop
24from tensorflow.python.eager import def_function
25from tensorflow.python.keras import optimizer_v2
26from tensorflow.python.platform import test
27from tensorflow.python.tpu import tpu_embedding_v1
28from tensorflow.python.tpu import tpu_embedding_v2_utils
29from tensorflow.python.tpu.tests import tpu_embedding_base_test
30
31
32_SLOT_NAME_MAPPING = {
33    # Slot names in Keras optimizer v2 are different compared to the slot names
34    # in our API.
35    optimizer_v2.adagrad.Adagrad: {
36        'accumulators': 'accumulator'
37    },
38    optimizer_v2.adam.Adam: {
39        'momenta': 'm',
40        'velocities': 'v'
41    },
42    optimizer_v2.ftrl.Ftrl: {
43        'accumulators': 'accumulator',
44        'linears': 'linear'
45    },
46}
47
48
49class TPUEmbeddingV0CorrectnessTest(tpu_embedding_base_test.TPUEmbeddingBaseTest
50                                   ):
51
52  def _get_strategy(self):
53    if hasattr(self, 'strategy'):
54      return self.strategy
55    return super(TPUEmbeddingV0CorrectnessTest, self)._get_strategy()
56
57  def _create_mid_level(self, optimizer=None):
58    # Create `TPUEmbedding` object.
59    if optimizer is None:
60      optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
61
62    return tpu_embedding_v1.TPUEmbeddingV0(
63        feature_config=self.feature_config, optimizer=optimizer)
64
65  def _get_slot_variable_creation_fn(self, optimizer):
66    # This is needed so that the mid level API can create slots using a user
67    # passed optimizer rather than the built-in methods. This allows a user to
68    # train the same model on CPU and TPU.
69    def slot_variable_creation_fn(table, slot_names, slot_initializers):
70      slots = {}
71      for slot, initializer in zip(slot_names, slot_initializers):
72        slots[slot] = optimizer.add_slot(
73            table, _SLOT_NAME_MAPPING[type(optimizer)][slot], initializer)
74      return slots
75
76    return slot_variable_creation_fn
77
78  def _create_strategy_and_mid_level(self, optimizer_name):
79    strategy = self._get_strategy()
80
81    # Keras optimizers has to be translated to embedding optimizer with slot
82    # variable creation fn properly populated.
83    with strategy.scope():
84      if optimizer_name == 'sgd':
85        optimizer = optimizer_v2.gradient_descent.SGD(learning_rate=0.1)
86        embedding_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
87      elif optimizer_name == 'adagrad':
88        optimizer = optimizer_v2.adagrad.Adagrad(learning_rate=0.1)
89        embedding_optimizer = tpu_embedding_v2_utils.Adagrad(
90            learning_rate=0.1,
91            slot_variable_creation_fn=self._get_slot_variable_creation_fn(
92                optimizer))
93      elif optimizer_name == 'adam':
94        optimizer = optimizer_v2.adam.Adam(learning_rate=0.1)
95        embedding_optimizer = tpu_embedding_v2_utils.Adam(
96            learning_rate=0.1,
97            slot_variable_creation_fn=self._get_slot_variable_creation_fn(
98                optimizer))
99      elif optimizer_name == 'ftrl':
100        optimizer = optimizer_v2.ftrl.Ftrl(learning_rate=0.1)
101        embedding_optimizer = tpu_embedding_v2_utils.FTRL(
102            learning_rate=0.1,
103            slot_variable_creation_fn=self._get_slot_variable_creation_fn(
104                optimizer))
105      else:
106        raise ValueError('optimizer is not recognized: ', optimizer_name)
107
108      mid_level_api = self._create_mid_level(optimizer=embedding_optimizer)
109
110    return strategy, mid_level_api, optimizer
111
112  @parameterized.parameters(
113      *itertools.product(['sgd', 'adagrad', 'adam', 'ftrl'], [True, False],
114                         [True, False], [True, False]))
115  def test_embedding(self, optimizer_name, training, sparse,
116                     is_high_dimensional):
117    strategy, mid_level_api, optimizer = (
118        self._create_strategy_and_mid_level(optimizer_name))
119
120    if sparse:
121      if is_high_dimensional:
122        dataset = self._create_high_dimensional_sparse_dataset(strategy)
123      else:
124        dataset = self._create_sparse_dataset(strategy)
125    else:
126      if is_high_dimensional:
127        dataset = self._create_high_dimensional_sparse_dataset(strategy)
128      else:
129        dataset = self._create_ragged_dataset(strategy)
130
131    dist = strategy.experimental_distribute_dataset(
132        dataset,
133        options=distribute_lib.InputOptions(experimental_fetch_to_device=False))
134    dist_iter = iter(dist)
135
136    @def_function.function
137    def test_fn():
138      """Create and run computation that returns the embedding activations."""
139
140      def step(data):
141        if not training:
142          activations = mid_level_api(data)
143          total_loss = self._get_total_loss_tensor(activations)
144          ret_val = [total_loss] + list(activations)
145          return ret_val
146        else:
147          with backprop.GradientTape() as tape:
148            tape.watch(mid_level_api.embedding_tables.values())
149            activations = mid_level_api(data)
150            total_loss = self._get_total_loss_tensor(activations)
151            loss_per_replica = total_loss / strategy.num_replicas_in_sync
152          gradients = tape.gradient(loss_per_replica,
153                                    mid_level_api.embedding_tables.values())
154          optimizer.apply_gradients(
155              list(zip(gradients, mid_level_api.embedding_tables.values())))
156        ret_val = [total_loss] + list(activations)
157        return ret_val
158
159      return strategy.run(step, args=(next(dist_iter),))
160
161    # Run model.
162    shard_out_val = test_fn()
163
164    # Compute sparse tensors for global batch.
165    if is_high_dimensional:
166      input_data = next(
167          iter(self._create_high_dimensional_sparse_dataset(strategy)))
168    else:
169      input_data = next(iter(self._create_sparse_dataset(strategy)))
170
171    # Check results.
172    self._check_results(strategy, shard_out_val, training, input_data,
173                        mid_level_api._variables, optimizer,
174                        is_high_dimensional)
175
176  def _check_embedding_and_slot_variables(self, embedding_table_user_before,
177                                          gradients_wrt_user,
178                                          embedding_table_video_before,
179                                          gradients_wrt_video, optimizer,
180                                          table_to_variable):
181    if isinstance(optimizer, optimizer_v2.gradient_descent.SGD):
182      check_fn = self._check_embedding_and_slot_variables_for_sgd
183    elif isinstance(optimizer, optimizer_v2.adagrad.Adagrad):
184      check_fn = self._check_embedding_and_slot_variables_for_adagrad
185    elif isinstance(optimizer, optimizer_v2.adam.Adam):
186      check_fn = self._check_embedding_and_slot_variables_for_adam
187    elif isinstance(optimizer, optimizer_v2.ftrl.Ftrl):
188      check_fn = self._check_embedding_and_slot_variables_for_ftrl
189    else:
190      raise ValueError('optimizer is not recognized: ', type(optimizer))
191    check_fn(embedding_table_user_before, gradients_wrt_user, optimizer,
192             table_to_variable[self.table_user.name])
193    check_fn(embedding_table_video_before, gradients_wrt_video, optimizer,
194             table_to_variable[self.table_video.name])
195
196  def _check_embedding_and_slot_variables_for_sgd(self, embedding_table_before,
197                                                  gradients, optimizer,
198                                                  variables):
199    embedding_table = np.copy(embedding_table_before)
200    config = optimizer.get_config()
201    embedding_table -= config['learning_rate'] * np.sum(gradients, axis=0)
202    self.assertAllClose(
203        self._get_variable(variables['parameters']).numpy(), embedding_table)
204
205  def _check_embedding_and_slot_variables_for_adagrad(self,
206                                                      embedding_table_before,
207                                                      gradients, optimizer,
208                                                      variables):
209    embedding_table = np.copy(embedding_table_before)
210    config = optimizer.get_config()
211    accumulator = (
212        config['initial_accumulator_value'] + np.sum(gradients, axis=0)**2)
213    embedding_table -= (
214        config['learning_rate'] * np.sum(gradients, axis=0) /
215        np.sqrt(accumulator))
216    self.assertAllClose(
217        self._get_variable(variables['parameters']).numpy(), embedding_table)
218    self.assertAllClose(
219        self._get_variable(variables['accumulators']).numpy(), accumulator)
220
221  def _check_embedding_and_slot_variables_for_adam(self, embedding_table_before,
222                                                   gradients, optimizer,
223                                                   variables):
224    embedding_table = np.copy(embedding_table_before)
225    config = optimizer.get_config()
226    g = np.sum(gradients, axis=0)
227    v = g**2 * (1 - config['beta_2'])
228    m = g * (1 - config['beta_1'])
229    epsilon = config['epsilon']
230    lr_modifier = np.sqrt(1 - config['beta_2']) / (1 - config['beta_1'])
231    embedding_table -= (
232        m * config['learning_rate'] * lr_modifier / (np.sqrt(v) + epsilon))
233    self.assertAllClose(
234        self._get_variable(variables['parameters']).numpy(),
235        embedding_table,
236        rtol=1e-3)
237    self.assertAllClose(
238        self._get_variable(variables['momenta']).numpy(), m, rtol=1e-4)
239    self.assertAllClose(
240        self._get_variable(variables['velocities']).numpy(), v, rtol=1e-4)
241
242  def _check_embedding_and_slot_variables_for_ftrl(self, embedding_table_before,
243                                                   gradients, optimizer,
244                                                   variables):
245    embedding_table = np.copy(embedding_table_before)
246    config = optimizer.get_config()
247    neg_lr_p = -config['learning_rate_power']
248    accumulator = (
249        config['initial_accumulator_value'] + np.sum(gradients, axis=0)**2)
250    sigma = (accumulator**neg_lr_p - config['initial_accumulator_value']**
251             neg_lr_p) / config['learning_rate']
252    linear = np.sum(gradients, axis=0) - sigma * embedding_table
253    quadratic = accumulator**neg_lr_p / config['learning_rate']
254    embedding_table = -linear / quadratic
255    actual_parameters = self._get_variable(variables['parameters']).numpy()
256    # For entries where `linear` == 0, it is not worth comparing since the
257    # initial values have not been touched yet and they will not agree with what
258    # the actual values should be.
259    actual_parameters *= (linear != 0.0)
260    # FTRL has a bit more precision diff on parameters.
261    self.assertAllClose(actual_parameters, embedding_table, rtol=5e-5)
262    self.assertAllClose(
263        self._get_variable(variables['linears']).numpy(), linear, rtol=5e-4)
264    self.assertAllClose(
265        self._get_variable(variables['accumulators']).numpy(), accumulator)
266
267  @parameterized.parameters(True, False)
268  def test_enqueue_with_weights(self, ragged):
269    strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
270    weight = 0.5
271    if ragged:
272      dataset = self._create_ragged_dataset(
273          strategy, include_weights=True, weight=weight)
274    else:
275      dataset = self._create_sparse_dataset(
276          strategy, include_weights=True, weight=weight)
277
278    dataset_iter = iter(
279        strategy.experimental_distribute_dataset(
280            dataset,
281            options=distribute_lib.InputOptions(
282                experimental_fetch_to_device=False)))
283
284    @def_function.function
285    def embedding_lookup(features, weights):
286
287      def step(features, weights):
288        return mid_level_api(features, weights)
289
290      return strategy.run(step, args=(features, weights))
291
292    features, weights = next(dataset_iter)
293    # Replace the weight for the second feature by None to test.
294    weights = (weights[0], None, weights[2])
295
296    no_weights_activations = embedding_lookup(features, weights=None)
297    weights_activations = embedding_lookup(features, weights=weights)
298
299    no_weights0 = (self._unpack(strategy, no_weights_activations[0]),
300                   self._unpack(strategy, no_weights_activations[1]),
301                   self._unpack(strategy, no_weights_activations[2]))
302    weights0 = (self._unpack(strategy, weights_activations[0]),
303                self._unpack(strategy, weights_activations[1]),
304                self._unpack(strategy, weights_activations[2]))
305    # videos table has sum combiner and users table has mean combiner.
306    # i.e. users table lookups isn't affected by the weights as all the weights
307    # are the same.
308    # Tuple entry 0 and 1 are the watched and favorited features from the videos
309    # table and entry 2 is the friends feature from the users table.
310    # Note that None was passed as a weight for entry 1 so weight should have no
311    # effect.
312    weight = (0.5, 1.0, 1.0)
313    golden = tuple([no_weight * w for no_weight, w in zip(no_weights0, weight)])
314
315    self.assertAllClose(golden, weights0)
316
317if __name__ == '__main__':
318  v2_compat.enable_v2_behavior()
319  test.main()
320