1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Correctness test for tf.keras Embedding models using DistributionStrategy.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.python import keras 23from tensorflow.python.distribute import combinations as ds_combinations 24from tensorflow.python.distribute import multi_process_runner 25from tensorflow.python.keras.distribute import keras_correctness_test_base 26from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras 27 28 29class DistributionStrategyEmbeddingModelCorrectnessTest( 30 keras_correctness_test_base 31 .TestDistributionStrategyEmbeddingModelCorrectnessBase): 32 33 def get_model(self, 34 max_words=10, 35 initial_weights=None, 36 distribution=None, 37 input_shapes=None): 38 del input_shapes 39 with keras_correctness_test_base.MaybeDistributionScope(distribution): 40 word_ids = keras.layers.Input( 41 shape=(max_words,), dtype=np.int32, name='words') 42 word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids) 43 if self.use_distributed_dense: 44 word_embed = keras.layers.TimeDistributed(keras.layers.Dense(4))( 45 word_embed) 46 avg = keras.layers.GlobalAveragePooling1D()(word_embed) 47 preds = keras.layers.Dense(2, activation='softmax')(avg) 48 model = keras.Model(inputs=[word_ids], outputs=[preds]) 49 50 if initial_weights: 51 model.set_weights(initial_weights) 52 53 model.compile( 54 optimizer=gradient_descent_keras.SGD(learning_rate=0.1), 55 loss='sparse_categorical_crossentropy', 56 metrics=['sparse_categorical_accuracy']) 57 return model 58 59 @ds_combinations.generate( 60 keras_correctness_test_base.test_combinations_for_embedding_model() + 61 keras_correctness_test_base.multi_worker_mirrored_eager()) 62 def test_embedding_model_correctness(self, distribution, use_numpy, 63 use_validation_data): 64 65 self.use_distributed_dense = False 66 self.run_correctness_test(distribution, use_numpy, use_validation_data) 67 68 @ds_combinations.generate( 69 keras_correctness_test_base.test_combinations_for_embedding_model() + 70 keras_correctness_test_base.multi_worker_mirrored_eager()) 71 def test_embedding_time_distributed_model_correctness( 72 self, distribution, use_numpy, use_validation_data): 73 self.use_distributed_dense = True 74 self.run_correctness_test(distribution, use_numpy, use_validation_data) 75 76 77class DistributionStrategySiameseEmbeddingModelCorrectnessTest( 78 keras_correctness_test_base 79 .TestDistributionStrategyEmbeddingModelCorrectnessBase): 80 81 def get_model(self, 82 max_words=10, 83 initial_weights=None, 84 distribution=None, 85 input_shapes=None): 86 del input_shapes 87 with keras_correctness_test_base.MaybeDistributionScope(distribution): 88 word_ids_a = keras.layers.Input( 89 shape=(max_words,), dtype=np.int32, name='words_a') 90 word_ids_b = keras.layers.Input( 91 shape=(max_words,), dtype=np.int32, name='words_b') 92 93 def submodel(embedding, word_ids): 94 word_embed = embedding(word_ids) 95 rep = keras.layers.GlobalAveragePooling1D()(word_embed) 96 return keras.Model(inputs=[word_ids], outputs=[rep]) 97 98 word_embed = keras.layers.Embedding( 99 input_dim=20, 100 output_dim=10, 101 input_length=max_words, 102 embeddings_initializer=keras.initializers.RandomUniform(0, 1)) 103 104 a_rep = submodel(word_embed, word_ids_a).outputs[0] 105 b_rep = submodel(word_embed, word_ids_b).outputs[0] 106 sim = keras.layers.Dot(axes=1, normalize=True)([a_rep, b_rep]) 107 108 model = keras.Model(inputs=[word_ids_a, word_ids_b], outputs=[sim]) 109 110 if initial_weights: 111 model.set_weights(initial_weights) 112 113 # TODO(b/130808953): Switch back to the V1 optimizer after global_step 114 # is made mirrored. 115 model.compile( 116 optimizer=gradient_descent_keras.SGD(learning_rate=0.1), 117 loss='mse', 118 metrics=['mse']) 119 return model 120 121 def get_data(self, 122 count=(keras_correctness_test_base._GLOBAL_BATCH_SIZE * 123 keras_correctness_test_base._EVAL_STEPS), 124 min_words=5, 125 max_words=10, 126 max_word_id=19, 127 num_classes=2): 128 features_a, labels_a, _ = ( 129 super(DistributionStrategySiameseEmbeddingModelCorrectnessTest, 130 self).get_data(count, min_words, max_words, max_word_id, 131 num_classes)) 132 133 features_b, labels_b, _ = ( 134 super(DistributionStrategySiameseEmbeddingModelCorrectnessTest, 135 self).get_data(count, min_words, max_words, max_word_id, 136 num_classes)) 137 138 y_train = np.zeros((count, 1), dtype=np.float32) 139 y_train[labels_a == labels_b] = 1.0 140 y_train[labels_a != labels_b] = -1.0 141 # TODO(b/123360757): Add tests for using list as inputs for multi-input 142 # models. 143 x_train = { 144 'words_a': features_a, 145 'words_b': features_b, 146 } 147 x_predict = x_train 148 149 return x_train, y_train, x_predict 150 151 @ds_combinations.generate( 152 keras_correctness_test_base.test_combinations_for_embedding_model() + 153 keras_correctness_test_base.multi_worker_mirrored_eager()) 154 def test_siamese_embedding_model_correctness(self, distribution, use_numpy, 155 use_validation_data): 156 self.run_correctness_test(distribution, use_numpy, use_validation_data) 157 158 159if __name__ == '__main__': 160 multi_process_runner.test_main() 161