• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Correctness test for tf.keras Embedding models using DistributionStrategy."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python import keras
23from tensorflow.python.distribute import combinations as ds_combinations
24from tensorflow.python.distribute import multi_process_runner
25from tensorflow.python.keras.distribute import keras_correctness_test_base
26from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
27
28
29class DistributionStrategyEmbeddingModelCorrectnessTest(
30    keras_correctness_test_base
31    .TestDistributionStrategyEmbeddingModelCorrectnessBase):
32
33  def get_model(self,
34                max_words=10,
35                initial_weights=None,
36                distribution=None,
37                input_shapes=None):
38    del input_shapes
39    with keras_correctness_test_base.MaybeDistributionScope(distribution):
40      word_ids = keras.layers.Input(
41          shape=(max_words,), dtype=np.int32, name='words')
42      word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids)
43      if self.use_distributed_dense:
44        word_embed = keras.layers.TimeDistributed(keras.layers.Dense(4))(
45            word_embed)
46      avg = keras.layers.GlobalAveragePooling1D()(word_embed)
47      preds = keras.layers.Dense(2, activation='softmax')(avg)
48      model = keras.Model(inputs=[word_ids], outputs=[preds])
49
50      if initial_weights:
51        model.set_weights(initial_weights)
52
53      model.compile(
54          optimizer=gradient_descent_keras.SGD(learning_rate=0.1),
55          loss='sparse_categorical_crossentropy',
56          metrics=['sparse_categorical_accuracy'])
57    return model
58
59  @ds_combinations.generate(
60      keras_correctness_test_base.test_combinations_for_embedding_model() +
61      keras_correctness_test_base.multi_worker_mirrored_eager())
62  def test_embedding_model_correctness(self, distribution, use_numpy,
63                                       use_validation_data):
64
65    self.use_distributed_dense = False
66    self.run_correctness_test(distribution, use_numpy, use_validation_data)
67
68  @ds_combinations.generate(
69      keras_correctness_test_base.test_combinations_for_embedding_model() +
70      keras_correctness_test_base.multi_worker_mirrored_eager())
71  def test_embedding_time_distributed_model_correctness(
72      self, distribution, use_numpy, use_validation_data):
73    self.use_distributed_dense = True
74    self.run_correctness_test(distribution, use_numpy, use_validation_data)
75
76
77class DistributionStrategySiameseEmbeddingModelCorrectnessTest(
78    keras_correctness_test_base
79    .TestDistributionStrategyEmbeddingModelCorrectnessBase):
80
81  def get_model(self,
82                max_words=10,
83                initial_weights=None,
84                distribution=None,
85                input_shapes=None):
86    del input_shapes
87    with keras_correctness_test_base.MaybeDistributionScope(distribution):
88      word_ids_a = keras.layers.Input(
89          shape=(max_words,), dtype=np.int32, name='words_a')
90      word_ids_b = keras.layers.Input(
91          shape=(max_words,), dtype=np.int32, name='words_b')
92
93      def submodel(embedding, word_ids):
94        word_embed = embedding(word_ids)
95        rep = keras.layers.GlobalAveragePooling1D()(word_embed)
96        return keras.Model(inputs=[word_ids], outputs=[rep])
97
98      word_embed = keras.layers.Embedding(
99          input_dim=20,
100          output_dim=10,
101          input_length=max_words,
102          embeddings_initializer=keras.initializers.RandomUniform(0, 1))
103
104      a_rep = submodel(word_embed, word_ids_a).outputs[0]
105      b_rep = submodel(word_embed, word_ids_b).outputs[0]
106      sim = keras.layers.Dot(axes=1, normalize=True)([a_rep, b_rep])
107
108      model = keras.Model(inputs=[word_ids_a, word_ids_b], outputs=[sim])
109
110      if initial_weights:
111        model.set_weights(initial_weights)
112
113      # TODO(b/130808953): Switch back to the V1 optimizer after global_step
114      # is made mirrored.
115      model.compile(
116          optimizer=gradient_descent_keras.SGD(learning_rate=0.1),
117          loss='mse',
118          metrics=['mse'])
119    return model
120
121  def get_data(self,
122               count=(keras_correctness_test_base._GLOBAL_BATCH_SIZE *
123                      keras_correctness_test_base._EVAL_STEPS),
124               min_words=5,
125               max_words=10,
126               max_word_id=19,
127               num_classes=2):
128    features_a, labels_a, _ = (
129        super(DistributionStrategySiameseEmbeddingModelCorrectnessTest,
130              self).get_data(count, min_words, max_words, max_word_id,
131                             num_classes))
132
133    features_b, labels_b, _ = (
134        super(DistributionStrategySiameseEmbeddingModelCorrectnessTest,
135              self).get_data(count, min_words, max_words, max_word_id,
136                             num_classes))
137
138    y_train = np.zeros((count, 1), dtype=np.float32)
139    y_train[labels_a == labels_b] = 1.0
140    y_train[labels_a != labels_b] = -1.0
141    # TODO(b/123360757): Add tests for using list as inputs for multi-input
142    # models.
143    x_train = {
144        'words_a': features_a,
145        'words_b': features_b,
146    }
147    x_predict = x_train
148
149    return x_train, y_train, x_predict
150
151  @ds_combinations.generate(
152      keras_correctness_test_base.test_combinations_for_embedding_model() +
153      keras_correctness_test_base.multi_worker_mirrored_eager())
154  def test_siamese_embedding_model_correctness(self, distribution, use_numpy,
155                                               use_validation_data):
156    self.run_correctness_test(distribution, use_numpy, use_validation_data)
157
158
159if __name__ == '__main__':
160  multi_process_runner.test_main()
161