1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Correctness tests for tf.keras using DistributionStrategy.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import functools 21 22from absl.testing import parameterized 23import numpy as np 24import six 25from tensorflow.python import keras 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.distribute import strategy_combinations 28from tensorflow.python.distribute import tpu_strategy 29from tensorflow.python.eager import context 30from tensorflow.python.framework import random_seed 31from tensorflow.python.framework import test_combinations as combinations 32from tensorflow.python.framework import test_util 33from tensorflow.python.keras.distribute import distributed_training_utils 34from tensorflow.python.keras.distribute.strategy_combinations import all_strategies 35from tensorflow.python.keras.distribute.strategy_combinations import multi_worker_mirrored_strategies 36from tensorflow.python.keras.distribute.strategy_combinations import strategies_minus_tpu 37from tensorflow.python.keras.mixed_precision import policy 38from tensorflow.python.keras.preprocessing import sequence 39from tensorflow.python.platform import test 40from tensorflow.python.util import nest 41 42_RANDOM_SEED = 1337 43_EVAL_STEPS = 20 44_GLOBAL_BATCH_SIZE = 64 45 46# Note: Please make sure the tests in this file are also covered in 47# keras_backward_compat_test for features that are supported with both APIs. 48 49 50def eager_mode_test_configuration(): 51 return combinations.combine( 52 mode='eager', use_numpy=[True, False], use_validation_data=[True, False]) 53 54 55def graph_mode_test_configuration(): 56 return combinations.combine( 57 mode='graph', use_numpy=[True, False], use_validation_data=[True, False]) 58 59 60def all_strategy_and_input_config_combinations(): 61 return (combinations.times( 62 combinations.combine(distribution=all_strategies), 63 eager_mode_test_configuration() + graph_mode_test_configuration())) 64 65 66def all_strategy_and_input_config_combinations_eager(): 67 return (combinations.times( 68 combinations.combine(distribution=all_strategies), 69 eager_mode_test_configuration())) 70 71 72def strategy_minus_tpu_and_input_config_combinations_eager(): 73 return (combinations.times( 74 combinations.combine(distribution=strategies_minus_tpu), 75 eager_mode_test_configuration())) 76 77 78def strategies_for_embedding_models(): 79 """Returns distribution strategies to test for embedding models. 80 81 Since embedding models take longer to train, we disregard DefaultStrategy 82 in order to prevent testing timeouts. 83 """ 84 85 return [ 86 s for s in all_strategies if s.required_tpu or s.required_gpus or 87 s is strategy_combinations.one_device_strategy 88 ] 89 90 91def test_combinations_for_embedding_model(): 92 # TODO(sourabhbajaj): Enable tests for eager mode 93 eager_mode_strategies = [ 94 s for s in strategies_for_embedding_models() if not s.required_tpu 95 ] 96 97 return (combinations.times( 98 combinations.combine( 99 distribution=strategies_for_embedding_models()), 100 (graph_mode_test_configuration())) + combinations.times( 101 combinations.combine( 102 distribution=eager_mode_strategies), 103 (eager_mode_test_configuration()))) 104 105 106def test_combinations_with_tpu_strategies_graph(): 107 tpu_strategies = [ 108 strategy_combinations.tpu_strategy, 109 ] 110 111 return (combinations.times( 112 combinations.combine(distribution=tpu_strategies), 113 graph_mode_test_configuration())) 114 115 116def multi_worker_mirrored_eager(): 117 return combinations.times( 118 combinations.combine(distribution=multi_worker_mirrored_strategies), 119 eager_mode_test_configuration()) 120 121 122def multi_worker_mirrored_eager_and_graph(): 123 return combinations.times( 124 combinations.combine(distribution=multi_worker_mirrored_strategies), 125 eager_mode_test_configuration() + graph_mode_test_configuration()) 126 127 128class MaybeDistributionScope(object): 129 """Provides a context allowing no distribution strategy.""" 130 131 def __init__(self, distribution): 132 self._distribution = distribution 133 self._scope = None 134 135 def __enter__(self): 136 if self._distribution: 137 self._scope = self._distribution.scope() 138 self._scope.__enter__() 139 140 def __exit__(self, exc_type, value, traceback): 141 if self._distribution: 142 self._scope.__exit__(exc_type, value, traceback) 143 self._scope = None 144 145 146def batch_wrapper(dataset, batch_size, repeat=None): 147 if repeat: 148 dataset = dataset.repeat(repeat) 149 return dataset.batch(batch_size) 150 151 152def get_batch_size(global_batch_size, distribution): 153 batch_size = global_batch_size 154 # TODO(b/118776054): Use global batch size for Keras/DS support. 155 use_per_core_batch_size = ( 156 distribution and 157 not distributed_training_utils.global_batch_size_supported(distribution)) 158 if use_per_core_batch_size: 159 batch_size //= distribution.num_replicas_in_sync 160 return batch_size 161 162 163def get_data_size(data): 164 """Gets the size of data in list, tuple, dict, or a numpy array.""" 165 assert isinstance(data, (np.ndarray, list, dict, tuple)) 166 167 if isinstance(data, np.ndarray): 168 return len(data) 169 170 if isinstance(data, (list, tuple)): 171 return len(data[0]) 172 173 return len(six.next(six.itervalues(data))) 174 175 176def get_shapes(data): 177 shapes = None 178 if all(hasattr(x, 'shape') for x in nest.flatten(data)): 179 shapes = nest.map_structure(lambda x: x.shape, data) 180 return shapes 181 182 183def get_correctness_test_inputs(use_numpy, use_validation_data, 184 with_distribution, x_train, y_train, x_eval, 185 y_eval, x_predict, training_epochs): 186 """Generates the inputs for correctness check when enable Keras with DS.""" 187 global_batch_size = _GLOBAL_BATCH_SIZE 188 batch_size = get_batch_size(global_batch_size, with_distribution) 189 190 if use_numpy: 191 training_inputs = { 192 'batch_size': batch_size, 193 'x': x_train, 194 'y': y_train, 195 'epochs': training_epochs, 196 'shuffle': False, 197 } 198 199 if use_validation_data: 200 eval_inputs = None 201 training_inputs['validation_data'] = (x_eval, y_eval) 202 else: 203 eval_inputs = { 204 'batch_size': batch_size, 205 'x': x_eval, 206 'y': y_eval, 207 } 208 predict_inputs = {'x': x_predict} 209 else: 210 training_data_size = get_data_size(x_train) 211 # For dataset inputs, we do not pass batch_size to 212 # keras.fit/evaluate/predict. The batch size is part of the dataset. 213 train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) 214 x = batch_wrapper(train_dataset, batch_size, repeat=training_epochs) 215 216 steps_per_epoch = int(np.ceil(1.0 * training_data_size / global_batch_size)) 217 training_inputs = { 218 'batch_size': None, 219 'x': x, 220 'y': None, 221 'epochs': training_epochs, 222 'shuffle': False, 223 'steps_per_epoch': steps_per_epoch 224 } 225 if use_validation_data: 226 eval_inputs = None # Remove the eval_inputs 227 eval_dataset = dataset_ops.Dataset.from_tensor_slices((x_eval, y_eval)) 228 x = batch_wrapper(eval_dataset, batch_size) 229 training_inputs['validation_data'] = x 230 training_inputs['validation_steps'] = 5 231 else: 232 eval_dataset = dataset_ops.Dataset.from_tensor_slices((x_eval, y_eval)) 233 x = batch_wrapper(eval_dataset, batch_size) 234 eval_steps = int(np.ceil(1.0 * get_data_size(x_eval) / global_batch_size)) 235 eval_inputs = { 236 'batch_size': None, 237 'x': x, 238 'y': None, 239 'steps': eval_steps, 240 } 241 242 predict_batch_size = get_batch_size( 243 get_data_size(x_predict), with_distribution) 244 predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) 245 predict_dataset = batch_wrapper(predict_dataset, predict_batch_size) 246 predict_inputs = { 247 'steps': 1, 248 'x': predict_dataset, 249 } 250 251 return training_inputs, eval_inputs, predict_inputs 252 253 254def fit_eval_and_predict(initial_weights, 255 input_fn, 256 model_fn, 257 distribution=None, 258 is_stateful_model=False): 259 """Generates results for fit/predict/evaluate for given model.""" 260 training_inputs, eval_inputs, predict_inputs = input_fn() 261 model = model_fn( 262 initial_weights=initial_weights, 263 distribution=distribution, 264 input_shapes=get_shapes(training_inputs['x'])) 265 266 result = {} 267 result['training_history_1'] = model.fit(**training_inputs).history 268 269 if eval_inputs is not None: 270 result['eval_result_1'] = model.evaluate(**eval_inputs) 271 272 result['weights_1'] = model.get_weights() 273 274 if predict_inputs is not None: 275 # Check correctness of the result of predict() invoked 276 # multiple times -- as for stateful models, result of 277 # predict may differ for each batch. 278 predict_length = 1 279 if is_stateful_model: 280 predict_length = 3 281 for i in range(predict_length): 282 result_key = 'predict_result_{}'.format(i) 283 result[result_key] = model.predict(**predict_inputs) 284 285 # Train and eval again to mimic user's flow. 286 287 result['training_history_2'] = model.fit(**training_inputs).history 288 289 if eval_inputs is not None: 290 result['eval_result_2'] = model.evaluate(**eval_inputs) 291 292 result['weights_2'] = model.get_weights() 293 294 return result 295 296 297def compare_results(results_with_ds, 298 results_without_ds, 299 distribution, 300 testcase, 301 partial_last_batch=None): 302 """Compares results of model compiled with/without distribution strategy.""" 303 if policy.global_policy().compute_dtype in ('float16', 'bfloat16'): 304 default_tolerance = 1e-2 305 relaxed_tolerance = 1e-2 306 elif partial_last_batch == 'train_and_eval': 307 # We relax the tolerance a lot in the partial last batch case as 308 # 1. the examples in uneven batches may have different weights when 309 # applying the gradients in the distributed case. 310 # 2. TF Keras and TF Keras DS have different ways to handle the case when 311 # training with epochs > 1 with numpy inputs. In TF Keras, every epoch 312 # may have a partial batch. While in TF Keras DS, as we convert 313 # numpy inputs into dataset, it will do a repeat() first and calculate 314 # steps_per_epoch, so it will at most have one partial batch. This 315 # makes the 1-CPU result even different. 316 default_tolerance = 1e-3 317 relaxed_tolerance = 1e-3 318 else: 319 default_tolerance = 4e-5 320 relaxed_tolerance = 1e-4 321 322 def _get_compare_result_tolerance(key): 323 """Returns tolerance to compare results.""" 324 # See b/119257215 for more details. DS test run on GPU could have larger 325 # variance then test on CPU. 326 if (test_util.is_gpu_available() and 327 key.startswith(('weights_1', 'weights_2', 'predict_result'))): 328 return relaxed_tolerance 329 330 return default_tolerance 331 332 for key in sorted(results_with_ds.keys()): 333 if (key.startswith('training_history') and 334 isinstance(distribution, 335 (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)) and 336 distribution.extended.steps_per_run > 1): 337 # TODO(b/119894254): Enable this test for all cases once the 338 # underlying bug is fixed. 339 continue 340 341 tolerance = _get_compare_result_tolerance(key) 342 343 # We don't compare the loss as loss is currently not computed as metric 344 # in Keras, the loss value is inaccurate for last partial batch due to 345 # more weights for the last batch samples. 346 if partial_last_batch is not None: 347 if key.startswith('eval_result'): 348 results_with_ds[key] = results_with_ds[key][1:] 349 results_without_ds[key] = results_without_ds[key][1:] 350 if key.startswith('training_history'): 351 results_with_ds[key]['val_loss'] = 0 352 results_without_ds[key]['val_loss'] = 0 353 354 testcase.assertAllClose( 355 results_with_ds[key], 356 results_without_ds[key], 357 atol=tolerance, 358 rtol=tolerance, 359 msg='Fail to assert {}.'.format(key)) 360 361 362def should_skip_tpu_with_eager(distribution): 363 return (context.executing_eagerly() and 364 isinstance(distribution, 365 (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1))) 366 367 368class LearningRateBatchScheduler(keras.callbacks.Callback): 369 """Scheduler that dynamically sets the learning rate of model.""" 370 371 def __init__(self, update_freq=None): 372 self._update_freq = update_freq 373 374 def on_batch_begin(self, batch, logs=None): 375 if self._update_freq and batch % self._update_freq != 0: 376 return 377 378 # To avoid divergence, limit the value range. 379 lr = 0.001 * (batch % 10) 380 keras.backend.set_value(self.model.optimizer.lr, lr) 381 382 383class TestDistributionStrategyCorrectnessBase(test.TestCase, 384 parameterized.TestCase): 385 """Model agnostic testing infra to test correctness of Keras models.""" 386 387 def set_up_test_config(self, 388 use_numpy=False, 389 use_validation_data=False, 390 with_batch_norm=None): 391 self.use_numpy = use_numpy 392 self.use_validation_data = use_validation_data 393 self.with_batch_norm = with_batch_norm 394 395 keras.backend.set_image_data_format('channels_last') 396 np.random.seed(_RANDOM_SEED) 397 random_seed.set_random_seed(_RANDOM_SEED) 398 399 def get_data(self): 400 num_samples = 10000 401 x_train = np.random.randint(0, 2, num_samples) 402 x_train = np.reshape(x_train, (num_samples, 1)) 403 y_train = x_train 404 return (x_train.astype('float32'), y_train.astype('float32'), None) 405 406 def get_data_with_partial_last_batch(self): 407 raise NotImplementedError 408 409 def get_data_with_partial_last_batch_eval(self): 410 raise NotImplementedError 411 412 def get_input_for_correctness_test(self, **kwargs): 413 """Generates inputs that are dictionaries. 414 415 We only provide a default implementation of this method here. If you need 416 more customized way of providing input to your model, overwrite this method. 417 418 Args: 419 **kwargs: key word arguments about how to create the input dictionaries 420 421 Returns: 422 Three dictionaries representing the input for fit(), evaluate() and 423 predict() 424 """ 425 426 return get_correctness_test_inputs(**kwargs) 427 428 def get_model(self, 429 distribution=None, 430 input_shapes=None): 431 raise NotImplementedError 432 433 def run_correctness_test(self, 434 distribution, 435 use_numpy, 436 use_validation_data, 437 with_batch_norm=None, 438 is_stateful_model=False, 439 partial_last_batch=None, 440 training_epochs=2): 441 with self.cached_session(): 442 self.set_up_test_config(use_numpy, use_validation_data, with_batch_norm) 443 444 if partial_last_batch == 'eval': 445 x_train, y_train, x_eval, y_eval, x_predict = ( 446 self.get_data_with_partial_last_batch_eval()) 447 elif partial_last_batch == 'train_and_eval': 448 x_train, y_train, x_eval, y_eval, x_predict = ( 449 self.get_data_with_partial_last_batch()) 450 else: 451 x_train, y_train, x_predict = self.get_data() 452 x_eval = x_train 453 y_eval = y_train 454 455 # The model is built once and the initial weights are saved. 456 # This is used to initialize the model for both the distribution and 457 # non-distribution run. 458 model = self.get_model( 459 input_shapes=get_shapes(x_train)) 460 initial_weights = model.get_weights() 461 462 ds_input_fn = functools.partial( 463 self.get_input_for_correctness_test, 464 use_numpy=use_numpy, 465 use_validation_data=use_validation_data, 466 with_distribution=distribution, 467 x_train=x_train, 468 y_train=y_train, 469 x_eval=x_eval, 470 y_eval=y_eval, 471 x_predict=x_predict, 472 training_epochs=training_epochs) 473 474 nods_input_fn = functools.partial( 475 self.get_input_for_correctness_test, 476 use_numpy=use_numpy, 477 use_validation_data=use_validation_data, 478 with_distribution=None, 479 x_train=x_train, 480 y_train=y_train, 481 x_eval=x_eval, 482 y_eval=y_eval, 483 x_predict=x_predict, 484 training_epochs=training_epochs) 485 486 results_with_ds = fit_eval_and_predict( 487 initial_weights, 488 input_fn=ds_input_fn, 489 model_fn=self.get_model, 490 distribution=distribution, 491 is_stateful_model=is_stateful_model) 492 results_without_ds = fit_eval_and_predict( 493 initial_weights, 494 input_fn=nods_input_fn, 495 model_fn=self.get_model, 496 distribution=None, 497 is_stateful_model=is_stateful_model) 498 499 # First, special case, for multi-replica distributed training, batch 500 # norm is not aggregated globally. So it is expected to have different 501 # weights. 502 if (self.with_batch_norm == 'regular' and 503 distribution.num_replicas_in_sync > 1): 504 with self.assertRaises(AssertionError): 505 compare_results( 506 results_with_ds, 507 results_without_ds, 508 distribution, 509 testcase=self, 510 partial_last_batch=partial_last_batch) 511 else: 512 compare_results( 513 results_with_ds, 514 results_without_ds, 515 distribution, 516 testcase=self, 517 partial_last_batch=partial_last_batch) 518 519 def get_input_for_dynamic_lr_test(self, **kwargs): 520 """Generates inputs that are dictionaries. 521 522 We only provide a default implementation of this method here. If you need 523 more customized way of providing input to your model, overwrite this method. 524 525 Args: 526 **kwargs: key word arguments about how to create the input dictionaries 527 528 Returns: 529 Three dictionaries representing the input for fit(), evaluate() and 530 predict() 531 """ 532 533 training_input = kwargs 534 return training_input, None, None 535 536 def run_dynamic_lr_test(self, 537 distribution): 538 with self.cached_session(): 539 self.set_up_test_config() 540 541 x_train, y_train, _ = self.get_data() 542 model = self.get_model( 543 input_shapes=get_shapes(x_train)) 544 initial_weights = model.get_weights() 545 update_freq = None 546 547 if (isinstance(distribution, tpu_strategy.TPUStrategyV1) and 548 distribution.extended.steps_per_run > 1): 549 # For TPUStrategy with steps_per_run > 1, the callback is not invoked 550 # every step. So, to compare the CPU/TPU, we let the CPU to behave the 551 # same as TPU. 552 update_freq = distribution.extended.steps_per_run 553 554 training_epochs = 2 555 global_batch_size = 64 556 557 ds_batch_size = get_batch_size(global_batch_size, distribution) 558 nods_batch_size = get_batch_size(global_batch_size, None) 559 560 ds_input_fn = functools.partial( 561 self.get_input_for_dynamic_lr_test, 562 x=x_train, 563 y=y_train, 564 batch_size=ds_batch_size, 565 shuffle=False, 566 epochs=training_epochs, 567 callbacks=[LearningRateBatchScheduler(update_freq)], 568 validation_data=(x_train, y_train)) 569 570 nods_input_fn = functools.partial( 571 self.get_input_for_dynamic_lr_test, 572 x=x_train, 573 y=y_train, 574 batch_size=nods_batch_size, 575 shuffle=False, 576 epochs=training_epochs, 577 callbacks=[LearningRateBatchScheduler(update_freq)], 578 validation_data=(x_train, y_train)) 579 580 results_with_ds = fit_eval_and_predict( 581 initial_weights, 582 input_fn=ds_input_fn, 583 model_fn=self.get_model, 584 distribution=distribution) 585 results_without_ds = fit_eval_and_predict( 586 initial_weights, 587 input_fn=nods_input_fn, 588 model_fn=self.get_model, 589 distribution=None) 590 compare_results( 591 results_with_ds, results_without_ds, distribution, testcase=self) 592 593 594class TestDistributionStrategyEmbeddingModelCorrectnessBase( 595 TestDistributionStrategyCorrectnessBase): 596 """Base class to test correctness of Keras models with embedding layers.""" 597 598 def get_data(self, 599 count=(_GLOBAL_BATCH_SIZE * _EVAL_STEPS), 600 min_words=5, 601 max_words=10, 602 max_word_id=19, 603 num_classes=2): 604 distribution = [] 605 for _ in range(num_classes): 606 dist = np.abs(np.random.randn(max_word_id)) 607 dist /= np.sum(dist) 608 distribution.append(dist) 609 610 features = [] 611 labels = [] 612 for _ in range(count): 613 label = np.random.randint(0, num_classes, size=1)[0] 614 num_words = np.random.randint(min_words, max_words, size=1)[0] 615 word_ids = np.random.choice( 616 max_word_id, size=num_words, replace=True, p=distribution[label]) 617 word_ids = word_ids 618 labels.append(label) 619 features.append(word_ids) 620 621 features = sequence.pad_sequences( 622 features, maxlen=max_words) 623 x_train = np.asarray(features, dtype=np.float32) 624 y_train = np.asarray(labels, dtype=np.int32).reshape((count, 1)) 625 x_predict = x_train[:_GLOBAL_BATCH_SIZE] 626 return x_train, y_train, x_predict 627 628 629if __name__ == '__main__': 630 test.main() 631