1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Test async checkpointing.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23import numpy as np 24 25from tensorflow.python.compat import v2_compat 26from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.lib.io import file_io 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import metrics as metrics_lib 32from tensorflow.python.ops import variable_scope 33from tensorflow.python.ops.losses import losses 34from tensorflow.python.platform import flags 35from tensorflow.python.platform import test 36from tensorflow.python.platform import tf_logging as logging 37from tensorflow.python.tpu import async_checkpoint 38from tensorflow.python.tpu import tpu_config 39from tensorflow.python.tpu import tpu_estimator 40from tensorflow.python.tpu import tpu_optimizer 41from tensorflow.python.training import basic_session_run_hooks 42from tensorflow.python.training import training 43from tensorflow_estimator.python.estimator import estimator as estimator_lib 44from tensorflow_estimator.python.estimator import model_fn as model_fn_lib 45 46FLAGS = flags.FLAGS 47flags.DEFINE_string('tpu', '', 'TPU to use in this test.') 48flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.') 49flags.DEFINE_string('project', None, 'Name of GCP project with TPU.') 50flags.DEFINE_string( 51 'model_dir', 52 os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR'), 53 'GCS path to store model and checkpoints.') 54 55 56def input_fn(params): 57 """Return a dataset of source and target sequences for training.""" 58 return (constant_op.constant( 59 np.random.randn(params['batch_size'], 1000), dtype=dtypes.float32), 60 constant_op.constant( 61 np.random.randint(0, 10, params['batch_size']), 62 dtype=dtypes.int32)) 63 64 65def model_fn(features, labels, mode, params): 66 del params # unused 67 with variable_scope.variable_scope('m', reuse=variable_scope.AUTO_REUSE): 68 w = variable_scope.get_variable('W', shape=[1000, 10]) 69 logits = math_ops.matmul(features, w) 70 loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) 71 72 if mode == model_fn_lib.ModeKeys.TRAIN: 73 optimizer = training.RMSPropOptimizer(learning_rate=0.01) 74 optimizer = tpu_optimizer.CrossShardOptimizer(optimizer) 75 train_op = optimizer.minimize(loss, training.get_global_step()) 76 return tpu_estimator.TPUEstimatorSpec( 77 mode=model_fn_lib.ModeKeys.TRAIN, 78 loss=loss, 79 train_op=train_op, 80 ) 81 elif mode == model_fn_lib.ModeKeys.EVAL: 82 83 def metric_fn(labels, logits): 84 labels = math_ops.cast(labels, dtypes.int64) 85 logging.info('LABELS %s %s', labels, logits) 86 return { 87 'recall@1': metrics_lib.recall_at_k(labels, logits, 1), 88 'recall@5': metrics_lib.recall_at_k(labels, logits, 5), 89 } 90 91 loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) 92 eval_metrics = (metric_fn, [labels, logits]) 93 return tpu_estimator.TPUEstimatorSpec( 94 mode=model_fn_lib.ModeKeys.EVAL, loss=loss, eval_metrics=eval_metrics) 95 96 97class AsyncCheckpointingTest(test.TestCase): 98 99 def testAsyncCheckpointHookEnabled(self): 100 resolver = tpu_cluster_resolver.TPUClusterResolver( 101 tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project) 102 103 checkpoint_interval = 5 104 config = tpu_config.RunConfig( 105 master=resolver.master(), 106 model_dir=os.path.join(FLAGS.model_dir, 'runconfig'), 107 save_checkpoints_steps=1000, 108 keep_checkpoint_max=11, # off by one 109 tpu_config=tpu_config.TPUConfig( 110 iterations_per_loop=checkpoint_interval,)) 111 112 estimator = tpu_estimator.TPUEstimator( 113 use_tpu=True, 114 model_fn=model_fn, 115 config=config, 116 train_batch_size=32, 117 eval_batch_size=32, 118 predict_batch_size=1, 119 params={}, 120 ) 121 122 i = 10 123 mock_listener = test.mock.create_autospec( 124 basic_session_run_hooks.CheckpointSaverListener) 125 estimator.train( 126 input_fn=input_fn, 127 max_steps=i * 10, 128 hooks=[ 129 async_checkpoint.AsyncCheckpointSaverHook( 130 FLAGS.model_dir, 131 save_steps=checkpoint_interval, 132 listeners=[mock_listener]) 133 ]) 134 135 current_step = estimator_lib._load_global_step_from_checkpoint_dir( 136 FLAGS.model_dir) # pylint: disable=protected-access 137 138 # TODO(power) -- identify a better way to count the number of checkpoints. 139 checkpoints = file_io.get_matching_files( 140 FLAGS.model_dir + '/model.ckpt*.meta') 141 checkpoint_count = len(checkpoints) 142 logging.info('Found %d checkpoints: %s', checkpoint_count, checkpoints) 143 self.assertLessEqual(checkpoint_count, 10) 144 self.assertEqual(current_step, i * 10) 145 mock_listener.before_save.assert_called() 146 mock_listener.after_save.assert_called() 147 148 def testAsyncCheckpointHookWithoutListeners(self): 149 resolver = tpu_cluster_resolver.TPUClusterResolver( 150 tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project) 151 152 checkpoint_interval = 5 153 keep_checkpoint_max = 10 154 config = tpu_config.RunConfig( 155 master=resolver.master(), 156 model_dir=os.path.join(FLAGS.model_dir, 'runconfig'), 157 save_checkpoints_steps=1000, 158 keep_checkpoint_max=keep_checkpoint_max+1, # off by one 159 tpu_config=tpu_config.TPUConfig( 160 iterations_per_loop=checkpoint_interval,)) 161 162 estimator = tpu_estimator.TPUEstimator( 163 use_tpu=True, 164 model_fn=model_fn, 165 config=config, 166 train_batch_size=32, 167 eval_batch_size=32, 168 predict_batch_size=1, 169 params={}, 170 ) 171 172 max_steps = 100 173 estimator.train( 174 input_fn=input_fn, 175 max_steps=max_steps, 176 hooks=[ 177 async_checkpoint.AsyncCheckpointSaverHook( 178 FLAGS.model_dir, 179 save_steps=checkpoint_interval) 180 ]) 181 182 current_step = estimator_lib._load_global_step_from_checkpoint_dir( 183 FLAGS.model_dir) # pylint: disable=protected-access 184 185 # TODO(power) -- identify a better way to count the number of checkpoints. 186 checkpoints = file_io.get_matching_files( 187 FLAGS.model_dir + '/model.ckpt*.meta') 188 checkpoint_count = len(checkpoints) 189 logging.info('Found %d checkpoints: %s', checkpoint_count, checkpoints) 190 self.assertLessEqual(checkpoint_count, keep_checkpoint_max) 191 self.assertEqual(current_step, max_steps) 192 193 194if __name__ == '__main__': 195 v2_compat.disable_v2_behavior() 196 test.main() 197