• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Test async checkpointing."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23import numpy as np
24
25from tensorflow.python.compat import v2_compat
26from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.lib.io import file_io
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import metrics as metrics_lib
32from tensorflow.python.ops import variable_scope
33from tensorflow.python.ops.losses import losses
34from tensorflow.python.platform import flags
35from tensorflow.python.platform import test
36from tensorflow.python.platform import tf_logging as logging
37from tensorflow.python.tpu import async_checkpoint
38from tensorflow.python.tpu import tpu_config
39from tensorflow.python.tpu import tpu_estimator
40from tensorflow.python.tpu import tpu_optimizer
41from tensorflow.python.training import basic_session_run_hooks
42from tensorflow.python.training import training
43from tensorflow_estimator.python.estimator import estimator as estimator_lib
44from tensorflow_estimator.python.estimator import model_fn as model_fn_lib
45
46FLAGS = flags.FLAGS
47flags.DEFINE_string('tpu', '', 'TPU to use in this test.')
48flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.')
49flags.DEFINE_string('project', None, 'Name of GCP project with TPU.')
50flags.DEFINE_string(
51    'model_dir',
52    os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR'),
53    'GCS path to store model and checkpoints.')
54
55
56def input_fn(params):
57  """Return a dataset of source and target sequences for training."""
58  return (constant_op.constant(
59      np.random.randn(params['batch_size'], 1000), dtype=dtypes.float32),
60          constant_op.constant(
61              np.random.randint(0, 10, params['batch_size']),
62              dtype=dtypes.int32))
63
64
65def model_fn(features, labels, mode, params):
66  del params  # unused
67  with variable_scope.variable_scope('m', reuse=variable_scope.AUTO_REUSE):
68    w = variable_scope.get_variable('W', shape=[1000, 10])
69  logits = math_ops.matmul(features, w)
70  loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
71
72  if mode == model_fn_lib.ModeKeys.TRAIN:
73    optimizer = training.RMSPropOptimizer(learning_rate=0.01)
74    optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)
75    train_op = optimizer.minimize(loss, training.get_global_step())
76    return tpu_estimator.TPUEstimatorSpec(
77        mode=model_fn_lib.ModeKeys.TRAIN,
78        loss=loss,
79        train_op=train_op,
80    )
81  elif mode == model_fn_lib.ModeKeys.EVAL:
82
83    def metric_fn(labels, logits):
84      labels = math_ops.cast(labels, dtypes.int64)
85      logging.info('LABELS %s %s', labels, logits)
86      return {
87          'recall@1': metrics_lib.recall_at_k(labels, logits, 1),
88          'recall@5': metrics_lib.recall_at_k(labels, logits, 5),
89      }
90
91    loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
92    eval_metrics = (metric_fn, [labels, logits])
93    return tpu_estimator.TPUEstimatorSpec(
94        mode=model_fn_lib.ModeKeys.EVAL, loss=loss, eval_metrics=eval_metrics)
95
96
97class AsyncCheckpointingTest(test.TestCase):
98
99  def testAsyncCheckpointHookEnabled(self):
100    resolver = tpu_cluster_resolver.TPUClusterResolver(
101        tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project)
102
103    checkpoint_interval = 5
104    config = tpu_config.RunConfig(
105        master=resolver.master(),
106        model_dir=os.path.join(FLAGS.model_dir, 'runconfig'),
107        save_checkpoints_steps=1000,
108        keep_checkpoint_max=11,  # off by one
109        tpu_config=tpu_config.TPUConfig(
110            iterations_per_loop=checkpoint_interval,))
111
112    estimator = tpu_estimator.TPUEstimator(
113        use_tpu=True,
114        model_fn=model_fn,
115        config=config,
116        train_batch_size=32,
117        eval_batch_size=32,
118        predict_batch_size=1,
119        params={},
120    )
121
122    i = 10
123    mock_listener = test.mock.create_autospec(
124        basic_session_run_hooks.CheckpointSaverListener)
125    estimator.train(
126        input_fn=input_fn,
127        max_steps=i * 10,
128        hooks=[
129            async_checkpoint.AsyncCheckpointSaverHook(
130                FLAGS.model_dir,
131                save_steps=checkpoint_interval,
132                listeners=[mock_listener])
133        ])
134
135    current_step = estimator_lib._load_global_step_from_checkpoint_dir(
136        FLAGS.model_dir)  # pylint: disable=protected-access
137
138    # TODO(power) -- identify a better way to count the number of checkpoints.
139    checkpoints = file_io.get_matching_files(
140        FLAGS.model_dir + '/model.ckpt*.meta')
141    checkpoint_count = len(checkpoints)
142    logging.info('Found %d checkpoints: %s', checkpoint_count, checkpoints)
143    self.assertLessEqual(checkpoint_count, 10)
144    self.assertEqual(current_step, i * 10)
145    mock_listener.before_save.assert_called()
146    mock_listener.after_save.assert_called()
147
148  def testAsyncCheckpointHookWithoutListeners(self):
149    resolver = tpu_cluster_resolver.TPUClusterResolver(
150        tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project)
151
152    checkpoint_interval = 5
153    keep_checkpoint_max = 10
154    config = tpu_config.RunConfig(
155        master=resolver.master(),
156        model_dir=os.path.join(FLAGS.model_dir, 'runconfig'),
157        save_checkpoints_steps=1000,
158        keep_checkpoint_max=keep_checkpoint_max+1,  # off by one
159        tpu_config=tpu_config.TPUConfig(
160            iterations_per_loop=checkpoint_interval,))
161
162    estimator = tpu_estimator.TPUEstimator(
163        use_tpu=True,
164        model_fn=model_fn,
165        config=config,
166        train_batch_size=32,
167        eval_batch_size=32,
168        predict_batch_size=1,
169        params={},
170    )
171
172    max_steps = 100
173    estimator.train(
174        input_fn=input_fn,
175        max_steps=max_steps,
176        hooks=[
177            async_checkpoint.AsyncCheckpointSaverHook(
178                FLAGS.model_dir,
179                save_steps=checkpoint_interval)
180        ])
181
182    current_step = estimator_lib._load_global_step_from_checkpoint_dir(
183        FLAGS.model_dir)  # pylint: disable=protected-access
184
185    # TODO(power) -- identify a better way to count the number of checkpoints.
186    checkpoints = file_io.get_matching_files(
187        FLAGS.model_dir + '/model.ckpt*.meta')
188    checkpoint_count = len(checkpoints)
189    logging.info('Found %d checkpoints: %s', checkpoint_count, checkpoints)
190    self.assertLessEqual(checkpoint_count, keep_checkpoint_max)
191    self.assertEqual(current_step, max_steps)
192
193
194if __name__ == '__main__':
195  v2_compat.disable_v2_behavior()
196  test.main()
197