• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Correctness tests for tf.keras using DistributionStrategy."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import functools
21
22from absl.testing import parameterized
23import numpy as np
24import six
25from tensorflow.python import keras
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.distribute import strategy_combinations
28from tensorflow.python.distribute import tpu_strategy
29from tensorflow.python.eager import context
30from tensorflow.python.framework import random_seed
31from tensorflow.python.framework import test_combinations as combinations
32from tensorflow.python.framework import test_util
33from tensorflow.python.keras.distribute import distributed_training_utils
34from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
35from tensorflow.python.keras.distribute.strategy_combinations import multi_worker_mirrored_strategies
36from tensorflow.python.keras.distribute.strategy_combinations import strategies_minus_tpu
37from tensorflow.python.keras.mixed_precision import policy
38from tensorflow.python.keras.preprocessing import sequence
39from tensorflow.python.platform import test
40from tensorflow.python.util import nest
41
42_RANDOM_SEED = 1337
43_EVAL_STEPS = 20
44_GLOBAL_BATCH_SIZE = 64
45
46# Note: Please make sure the tests in this file are also covered in
47# keras_backward_compat_test for features that are supported with both APIs.
48
49
50def eager_mode_test_configuration():
51  return combinations.combine(
52      mode='eager', use_numpy=[True, False], use_validation_data=[True, False])
53
54
55def graph_mode_test_configuration():
56  return combinations.combine(
57      mode='graph', use_numpy=[True, False], use_validation_data=[True, False])
58
59
60def all_strategy_and_input_config_combinations():
61  return (combinations.times(
62      combinations.combine(distribution=all_strategies),
63      eager_mode_test_configuration() + graph_mode_test_configuration()))
64
65
66def all_strategy_and_input_config_combinations_eager():
67  return (combinations.times(
68      combinations.combine(distribution=all_strategies),
69      eager_mode_test_configuration()))
70
71
72def strategy_minus_tpu_and_input_config_combinations_eager():
73  return (combinations.times(
74      combinations.combine(distribution=strategies_minus_tpu),
75      eager_mode_test_configuration()))
76
77
78def strategies_for_embedding_models():
79  """Returns distribution strategies to test for embedding models.
80
81  Since embedding models take longer to train, we disregard DefaultStrategy
82  in order to prevent testing timeouts.
83  """
84
85  return [
86      s for s in all_strategies if s.required_tpu or s.required_gpus or
87      s is strategy_combinations.one_device_strategy
88  ]
89
90
91def test_combinations_for_embedding_model():
92  # TODO(sourabhbajaj): Enable tests for eager mode
93  eager_mode_strategies = [
94      s for s in strategies_for_embedding_models() if not s.required_tpu
95  ]
96
97  return (combinations.times(
98      combinations.combine(
99          distribution=strategies_for_embedding_models()),
100      (graph_mode_test_configuration())) + combinations.times(
101          combinations.combine(
102              distribution=eager_mode_strategies),
103          (eager_mode_test_configuration())))
104
105
106def test_combinations_with_tpu_strategies_graph():
107  tpu_strategies = [
108      strategy_combinations.tpu_strategy,
109  ]
110
111  return (combinations.times(
112      combinations.combine(distribution=tpu_strategies),
113      graph_mode_test_configuration()))
114
115
116def multi_worker_mirrored_eager():
117  return combinations.times(
118      combinations.combine(distribution=multi_worker_mirrored_strategies),
119      eager_mode_test_configuration())
120
121
122def multi_worker_mirrored_eager_and_graph():
123  return combinations.times(
124      combinations.combine(distribution=multi_worker_mirrored_strategies),
125      eager_mode_test_configuration() + graph_mode_test_configuration())
126
127
128class MaybeDistributionScope(object):
129  """Provides a context allowing no distribution strategy."""
130
131  def __init__(self, distribution):
132    self._distribution = distribution
133    self._scope = None
134
135  def __enter__(self):
136    if self._distribution:
137      self._scope = self._distribution.scope()
138      self._scope.__enter__()
139
140  def __exit__(self, exc_type, value, traceback):
141    if self._distribution:
142      self._scope.__exit__(exc_type, value, traceback)
143      self._scope = None
144
145
146def batch_wrapper(dataset, batch_size, repeat=None):
147  if repeat:
148    dataset = dataset.repeat(repeat)
149  return dataset.batch(batch_size)
150
151
152def get_batch_size(global_batch_size, distribution):
153  batch_size = global_batch_size
154  # TODO(b/118776054): Use global batch size for Keras/DS support.
155  use_per_core_batch_size = (
156      distribution and
157      not distributed_training_utils.global_batch_size_supported(distribution))
158  if use_per_core_batch_size:
159    batch_size //= distribution.num_replicas_in_sync
160  return batch_size
161
162
163def get_data_size(data):
164  """Gets the size of data in list, tuple, dict, or a numpy array."""
165  assert isinstance(data, (np.ndarray, list, dict, tuple))
166
167  if isinstance(data, np.ndarray):
168    return len(data)
169
170  if isinstance(data, (list, tuple)):
171    return len(data[0])
172
173  return len(six.next(six.itervalues(data)))
174
175
176def get_shapes(data):
177  shapes = None
178  if all(hasattr(x, 'shape') for x in nest.flatten(data)):
179    shapes = nest.map_structure(lambda x: x.shape, data)
180  return shapes
181
182
183def get_correctness_test_inputs(use_numpy, use_validation_data,
184                                with_distribution, x_train, y_train, x_eval,
185                                y_eval, x_predict, training_epochs):
186  """Generates the inputs for correctness check when enable Keras with DS."""
187  global_batch_size = _GLOBAL_BATCH_SIZE
188  batch_size = get_batch_size(global_batch_size, with_distribution)
189
190  if use_numpy:
191    training_inputs = {
192        'batch_size': batch_size,
193        'x': x_train,
194        'y': y_train,
195        'epochs': training_epochs,
196        'shuffle': False,
197    }
198
199    if use_validation_data:
200      eval_inputs = None
201      training_inputs['validation_data'] = (x_eval, y_eval)
202    else:
203      eval_inputs = {
204          'batch_size': batch_size,
205          'x': x_eval,
206          'y': y_eval,
207      }
208    predict_inputs = {'x': x_predict}
209  else:
210    training_data_size = get_data_size(x_train)
211    # For dataset inputs, we do not pass batch_size to
212    # keras.fit/evaluate/predict. The batch size is part of the dataset.
213    train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train))
214    x = batch_wrapper(train_dataset, batch_size, repeat=training_epochs)
215
216    steps_per_epoch = int(np.ceil(1.0 * training_data_size / global_batch_size))
217    training_inputs = {
218        'batch_size': None,
219        'x': x,
220        'y': None,
221        'epochs': training_epochs,
222        'shuffle': False,
223        'steps_per_epoch': steps_per_epoch
224    }
225    if use_validation_data:
226      eval_inputs = None  # Remove the eval_inputs
227      eval_dataset = dataset_ops.Dataset.from_tensor_slices((x_eval, y_eval))
228      x = batch_wrapper(eval_dataset, batch_size)
229      training_inputs['validation_data'] = x
230      training_inputs['validation_steps'] = 5
231    else:
232      eval_dataset = dataset_ops.Dataset.from_tensor_slices((x_eval, y_eval))
233      x = batch_wrapper(eval_dataset, batch_size)
234      eval_steps = int(np.ceil(1.0 * get_data_size(x_eval) / global_batch_size))
235      eval_inputs = {
236          'batch_size': None,
237          'x': x,
238          'y': None,
239          'steps': eval_steps,
240      }
241
242    predict_batch_size = get_batch_size(
243        get_data_size(x_predict), with_distribution)
244    predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict)
245    predict_dataset = batch_wrapper(predict_dataset, predict_batch_size)
246    predict_inputs = {
247        'steps': 1,
248        'x': predict_dataset,
249    }
250
251  return training_inputs, eval_inputs, predict_inputs
252
253
254def fit_eval_and_predict(initial_weights,
255                         input_fn,
256                         model_fn,
257                         distribution=None,
258                         is_stateful_model=False):
259  """Generates results for fit/predict/evaluate for given model."""
260  training_inputs, eval_inputs, predict_inputs = input_fn()
261  model = model_fn(
262      initial_weights=initial_weights,
263      distribution=distribution,
264      input_shapes=get_shapes(training_inputs['x']))
265
266  result = {}
267  result['training_history_1'] = model.fit(**training_inputs).history
268
269  if eval_inputs is not None:
270    result['eval_result_1'] = model.evaluate(**eval_inputs)
271
272  result['weights_1'] = model.get_weights()
273
274  if predict_inputs is not None:
275    # Check correctness of the result of predict() invoked
276    # multiple times -- as for stateful models, result of
277    # predict may differ for each batch.
278    predict_length = 1
279    if is_stateful_model:
280      predict_length = 3
281    for i in range(predict_length):
282      result_key = 'predict_result_{}'.format(i)
283      result[result_key] = model.predict(**predict_inputs)
284
285  # Train and eval again to mimic user's flow.
286
287  result['training_history_2'] = model.fit(**training_inputs).history
288
289  if eval_inputs is not None:
290    result['eval_result_2'] = model.evaluate(**eval_inputs)
291
292  result['weights_2'] = model.get_weights()
293
294  return result
295
296
297def compare_results(results_with_ds,
298                    results_without_ds,
299                    distribution,
300                    testcase,
301                    partial_last_batch=None):
302  """Compares results of model compiled with/without distribution strategy."""
303  if policy.global_policy().compute_dtype in ('float16', 'bfloat16'):
304    default_tolerance = 1e-2
305    relaxed_tolerance = 1e-2
306  elif partial_last_batch == 'train_and_eval':
307    # We relax the tolerance a lot in the partial last batch case as
308    #   1. the examples in uneven batches may have different weights when
309    #      applying the gradients in the distributed case.
310    #   2. TF Keras and TF Keras DS have different ways to handle the case when
311    #      training with epochs > 1 with numpy inputs. In TF Keras, every epoch
312    #      may have a partial batch. While in TF Keras DS, as we convert
313    #      numpy inputs into dataset, it will do a repeat() first and calculate
314    #      steps_per_epoch, so it will at most have one partial batch. This
315    #      makes the 1-CPU result even different.
316    default_tolerance = 1e-3
317    relaxed_tolerance = 1e-3
318  else:
319    default_tolerance = 4e-5
320    relaxed_tolerance = 1e-4
321
322  def _get_compare_result_tolerance(key):
323    """Returns tolerance to compare results."""
324    # See b/119257215 for more details. DS test run on GPU could have larger
325    # variance then test on CPU.
326    if (test_util.is_gpu_available() and
327        key.startswith(('weights_1', 'weights_2', 'predict_result'))):
328      return relaxed_tolerance
329
330    return default_tolerance
331
332  for key in sorted(results_with_ds.keys()):
333    if (key.startswith('training_history') and
334        isinstance(distribution,
335                   (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)) and
336        distribution.extended.steps_per_run > 1):
337      # TODO(b/119894254): Enable this test for all cases once the
338      # underlying bug is fixed.
339      continue
340
341    tolerance = _get_compare_result_tolerance(key)
342
343    # We don't compare the loss as loss is currently not computed as metric
344    # in Keras, the loss value is inaccurate for last partial batch due to
345    # more weights for the last batch samples.
346    if partial_last_batch is not None:
347      if key.startswith('eval_result'):
348        results_with_ds[key] = results_with_ds[key][1:]
349        results_without_ds[key] = results_without_ds[key][1:]
350      if key.startswith('training_history'):
351        results_with_ds[key]['val_loss'] = 0
352        results_without_ds[key]['val_loss'] = 0
353
354    testcase.assertAllClose(
355        results_with_ds[key],
356        results_without_ds[key],
357        atol=tolerance,
358        rtol=tolerance,
359        msg='Fail to assert {}.'.format(key))
360
361
362def should_skip_tpu_with_eager(distribution):
363  return (context.executing_eagerly() and
364          isinstance(distribution,
365                     (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)))
366
367
368class LearningRateBatchScheduler(keras.callbacks.Callback):
369  """Scheduler that dynamically sets the learning rate of model."""
370
371  def __init__(self, update_freq=None):
372    self._update_freq = update_freq
373
374  def on_batch_begin(self, batch, logs=None):
375    if self._update_freq and batch % self._update_freq != 0:
376      return
377
378    # To avoid divergence, limit the value range.
379    lr = 0.001 * (batch % 10)
380    keras.backend.set_value(self.model.optimizer.lr, lr)
381
382
383class TestDistributionStrategyCorrectnessBase(test.TestCase,
384                                              parameterized.TestCase):
385  """Model agnostic testing infra to test correctness of Keras models."""
386
387  def set_up_test_config(self,
388                         use_numpy=False,
389                         use_validation_data=False,
390                         with_batch_norm=None):
391    self.use_numpy = use_numpy
392    self.use_validation_data = use_validation_data
393    self.with_batch_norm = with_batch_norm
394
395    keras.backend.set_image_data_format('channels_last')
396    np.random.seed(_RANDOM_SEED)
397    random_seed.set_random_seed(_RANDOM_SEED)
398
399  def get_data(self):
400    num_samples = 10000
401    x_train = np.random.randint(0, 2, num_samples)
402    x_train = np.reshape(x_train, (num_samples, 1))
403    y_train = x_train
404    return (x_train.astype('float32'), y_train.astype('float32'), None)
405
406  def get_data_with_partial_last_batch(self):
407    raise NotImplementedError
408
409  def get_data_with_partial_last_batch_eval(self):
410    raise NotImplementedError
411
412  def get_input_for_correctness_test(self, **kwargs):
413    """Generates inputs that are dictionaries.
414
415    We only provide a default implementation of this method here. If you need
416    more customized way of providing input to your model, overwrite this method.
417
418    Args:
419      **kwargs: key word arguments about how to create the input dictionaries
420
421    Returns:
422      Three dictionaries representing the input for fit(), evaluate() and
423      predict()
424    """
425
426    return get_correctness_test_inputs(**kwargs)
427
428  def get_model(self,
429                distribution=None,
430                input_shapes=None):
431    raise NotImplementedError
432
433  def run_correctness_test(self,
434                           distribution,
435                           use_numpy,
436                           use_validation_data,
437                           with_batch_norm=None,
438                           is_stateful_model=False,
439                           partial_last_batch=None,
440                           training_epochs=2):
441    with self.cached_session():
442      self.set_up_test_config(use_numpy, use_validation_data, with_batch_norm)
443
444      if partial_last_batch == 'eval':
445        x_train, y_train, x_eval, y_eval, x_predict = (
446            self.get_data_with_partial_last_batch_eval())
447      elif partial_last_batch == 'train_and_eval':
448        x_train, y_train, x_eval, y_eval, x_predict = (
449            self.get_data_with_partial_last_batch())
450      else:
451        x_train, y_train, x_predict = self.get_data()
452        x_eval = x_train
453        y_eval = y_train
454
455      # The model is built once and the initial weights are saved.
456      # This is used to initialize the model for both the distribution and
457      # non-distribution run.
458      model = self.get_model(
459          input_shapes=get_shapes(x_train))
460      initial_weights = model.get_weights()
461
462      ds_input_fn = functools.partial(
463          self.get_input_for_correctness_test,
464          use_numpy=use_numpy,
465          use_validation_data=use_validation_data,
466          with_distribution=distribution,
467          x_train=x_train,
468          y_train=y_train,
469          x_eval=x_eval,
470          y_eval=y_eval,
471          x_predict=x_predict,
472          training_epochs=training_epochs)
473
474      nods_input_fn = functools.partial(
475          self.get_input_for_correctness_test,
476          use_numpy=use_numpy,
477          use_validation_data=use_validation_data,
478          with_distribution=None,
479          x_train=x_train,
480          y_train=y_train,
481          x_eval=x_eval,
482          y_eval=y_eval,
483          x_predict=x_predict,
484          training_epochs=training_epochs)
485
486      results_with_ds = fit_eval_and_predict(
487          initial_weights,
488          input_fn=ds_input_fn,
489          model_fn=self.get_model,
490          distribution=distribution,
491          is_stateful_model=is_stateful_model)
492      results_without_ds = fit_eval_and_predict(
493          initial_weights,
494          input_fn=nods_input_fn,
495          model_fn=self.get_model,
496          distribution=None,
497          is_stateful_model=is_stateful_model)
498
499      # First, special case, for multi-replica distributed training, batch
500      # norm is not aggregated globally. So it is expected to have different
501      # weights.
502      if (self.with_batch_norm == 'regular' and
503          distribution.num_replicas_in_sync > 1):
504        with self.assertRaises(AssertionError):
505          compare_results(
506              results_with_ds,
507              results_without_ds,
508              distribution,
509              testcase=self,
510              partial_last_batch=partial_last_batch)
511      else:
512        compare_results(
513            results_with_ds,
514            results_without_ds,
515            distribution,
516            testcase=self,
517            partial_last_batch=partial_last_batch)
518
519  def get_input_for_dynamic_lr_test(self, **kwargs):
520    """Generates inputs that are dictionaries.
521
522    We only provide a default implementation of this method here. If you need
523    more customized way of providing input to your model, overwrite this method.
524
525    Args:
526      **kwargs: key word arguments about how to create the input dictionaries
527
528    Returns:
529      Three dictionaries representing the input for fit(), evaluate() and
530      predict()
531    """
532
533    training_input = kwargs
534    return training_input, None, None
535
536  def run_dynamic_lr_test(self,
537                          distribution):
538    with self.cached_session():
539      self.set_up_test_config()
540
541      x_train, y_train, _ = self.get_data()
542      model = self.get_model(
543          input_shapes=get_shapes(x_train))
544      initial_weights = model.get_weights()
545      update_freq = None
546
547      if (isinstance(distribution, tpu_strategy.TPUStrategyV1) and
548          distribution.extended.steps_per_run > 1):
549        # For TPUStrategy with steps_per_run > 1, the callback is not invoked
550        # every step. So, to compare the CPU/TPU, we let the CPU to behave the
551        # same as TPU.
552        update_freq = distribution.extended.steps_per_run
553
554      training_epochs = 2
555      global_batch_size = 64
556
557      ds_batch_size = get_batch_size(global_batch_size, distribution)
558      nods_batch_size = get_batch_size(global_batch_size, None)
559
560      ds_input_fn = functools.partial(
561          self.get_input_for_dynamic_lr_test,
562          x=x_train,
563          y=y_train,
564          batch_size=ds_batch_size,
565          shuffle=False,
566          epochs=training_epochs,
567          callbacks=[LearningRateBatchScheduler(update_freq)],
568          validation_data=(x_train, y_train))
569
570      nods_input_fn = functools.partial(
571          self.get_input_for_dynamic_lr_test,
572          x=x_train,
573          y=y_train,
574          batch_size=nods_batch_size,
575          shuffle=False,
576          epochs=training_epochs,
577          callbacks=[LearningRateBatchScheduler(update_freq)],
578          validation_data=(x_train, y_train))
579
580      results_with_ds = fit_eval_and_predict(
581          initial_weights,
582          input_fn=ds_input_fn,
583          model_fn=self.get_model,
584          distribution=distribution)
585      results_without_ds = fit_eval_and_predict(
586          initial_weights,
587          input_fn=nods_input_fn,
588          model_fn=self.get_model,
589          distribution=None)
590      compare_results(
591          results_with_ds, results_without_ds, distribution, testcase=self)
592
593
594class TestDistributionStrategyEmbeddingModelCorrectnessBase(
595    TestDistributionStrategyCorrectnessBase):
596  """Base class to test correctness of Keras models with embedding layers."""
597
598  def get_data(self,
599               count=(_GLOBAL_BATCH_SIZE * _EVAL_STEPS),
600               min_words=5,
601               max_words=10,
602               max_word_id=19,
603               num_classes=2):
604    distribution = []
605    for _ in range(num_classes):
606      dist = np.abs(np.random.randn(max_word_id))
607      dist /= np.sum(dist)
608      distribution.append(dist)
609
610    features = []
611    labels = []
612    for _ in range(count):
613      label = np.random.randint(0, num_classes, size=1)[0]
614      num_words = np.random.randint(min_words, max_words, size=1)[0]
615      word_ids = np.random.choice(
616          max_word_id, size=num_words, replace=True, p=distribution[label])
617      word_ids = word_ids
618      labels.append(label)
619      features.append(word_ids)
620
621    features = sequence.pad_sequences(
622        features, maxlen=max_words)
623    x_train = np.asarray(features, dtype=np.float32)
624    y_train = np.asarray(labels, dtype=np.int32).reshape((count, 1))
625    x_predict = x_train[:_GLOBAL_BATCH_SIZE]
626    return x_train, y_train, x_predict
627
628
629if __name__ == '__main__':
630  test.main()
631