1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""learn_main tests.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import json 22import os 23 24from tensorflow.contrib.learn.python.learn import evaluable # pylint: disable=g-import-not-at-top 25from tensorflow.contrib.learn.python.learn import experiment 26from tensorflow.contrib.learn.python.learn import learn_runner 27from tensorflow.contrib.learn.python.learn import trainable 28 29from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib 30from tensorflow.contrib.training.python.training import hparam as hparam_lib 31from tensorflow.python.estimator import run_config as core_run_config_lib 32from tensorflow.python.platform import test 33from tensorflow.python.platform import tf_logging 34 35patch = test.mock.patch 36 37_MODIR_DIR = "/tmp" 38_HPARAMS = hparam_lib.HParams(learning_rate=0.01) 39_MUST_SPECIFY_OUTPUT_DIR_MSG = "Must specify an output directory" 40_MISSING_MODEL_DIR_ERR_MSG = ( 41 "Must specify a model directory `model_dir` in `run_config`.") 42_EXP_NOT_CALLABLE_MSG = "Experiment builder .* is not callable" 43_INVALID_HPARAMS_ERR_MSG = "`hparams` must be `HParams` instance" 44_NOT_EXP_TYPE_MSG = "Experiment builder did not return an Experiment" 45_NON_EXIST_TASK_MSG = "Schedule references non-existent task" 46_NON_CALLABLE_MSG = "Schedule references non-callable member" 47_MUST_SPECIFY_OUTPUT_DIR_OR_CONFIG_MSG = ( 48 "Must set value for `output_dir` or `run_config`") 49_HPARAMS_CANNOT_BE_SET_FOR_OUTPUT_DIR_MSG = ( 50 "Must set `hparams` as None for `experiment_fn` with `output_dir`.") 51_CANNOT_SET_BOTH_OUTPUT_DIR_AND_CONFIG_MSG = ( 52 "Cannot provide both `output_dir` and `run_config`") 53_INVALID_RUN_CONFIG_TYPE_MSG = ( 54 "`run_config` must be `tf.contrib.learn.RunConfig` instance") 55_RUN_CONFIG_UID_CHECK_ERR_MSG = ( 56 "`RunConfig` instance is expected to be used by the `Estimator`") 57_MISSING_RUN_CONFIG_UID_ERR_MSG = ( 58 "Pass `run_config` argument of the `experiment_fn` to the Estimator") 59 60 61class TestExperiment(experiment.Experiment): 62 63 def __init__(self, default=None, config=None, model_dir=None): 64 self.default = default 65 self.config = config 66 internal_model_dir = model_dir or config.model_dir 67 self._model_dir = internal_model_dir 68 69 class Estimator(evaluable.Evaluable, trainable.Trainable): 70 config = self.config 71 72 @property 73 def model_dir(self): 74 return internal_model_dir 75 76 def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None, 77 monitors=None, max_steps=None): 78 raise NotImplementedError 79 80 def evaluate(self, x=None, y=None, input_fn=None, feed_fn=None, 81 batch_size=None, steps=None, metrics=None, name=None, 82 checkpoint_path=None, hooks=None): 83 raise NotImplementedError 84 85 super(TestExperiment, self).__init__(Estimator(), None, None) 86 87 def local_run(self): 88 return "local_run-{}".format(self._model_dir) 89 90 def train(self): 91 return "train-{}".format(self._model_dir) 92 93 def run_std_server(self): 94 return "run_std_server-{}".format(self._model_dir) 95 96 def train_and_evaluate(self): 97 return "train_and_evaluate-{}".format(self._model_dir) 98 99 def simple_task(self): 100 return "simple_task, default=%s." % self.default 101 102 103# pylint: disable=unused-argument 104def build_experiment(output_dir): 105 tf_logging.info("In default build_experiment.") 106 return TestExperiment(model_dir=output_dir) 107 108 109def build_experiment_fn_for_output_dir(run_config=None): 110 def _build_experiment(output_dir): 111 tf_logging.info("In default build_experiment.") 112 return TestExperiment(config=run_config, model_dir=output_dir) 113 return _build_experiment 114 115 116def build_experiment_for_run_config(run_config, hparams): 117 if hparams is not None and hparams != _HPARAMS: 118 raise ValueError("hparams is not set correctly") 119 return TestExperiment(config=run_config) 120 121 122def build_non_experiment(output_dir): 123 return "Ceci n'est pas un Experiment." 124 125 126# pylint: enable=unused-argument 127 128 129def build_distributed_cluster_spec(): 130 return { 131 run_config_lib.TaskType.PS: ["localhost:1234", "localhost:1235"], 132 run_config_lib.TaskType.WORKER: ["localhost:1236", "localhost:1237"], 133 run_config_lib.TaskType.MASTER: ["localhost:1238"], 134 "foo_has_no_default_schedule": ["localhost:1239"] 135 } 136 137 138def build_non_distributed_cluster_spec(): 139 return {"foo": ["localhost:1234"]} 140 141 142class LearnRunnerRunWithOutputDirTest(test.TestCase): 143 144 def setUp(self): 145 # Ensure the TF_CONFIG environment variable is unset for all tests. 146 os.environ.pop("TF_CONFIG", None) 147 148 def test_run_with_custom_schedule(self): 149 self.assertEqual( 150 "simple_task, default=None.", 151 learn_runner.run(build_experiment, 152 output_dir=_MODIR_DIR, 153 schedule="simple_task")) 154 155 def test_run_with_explicit_local_run(self): 156 self.assertEqual( 157 "local_run-" + _MODIR_DIR, 158 learn_runner.run(build_experiment, 159 output_dir=_MODIR_DIR, 160 schedule="local_run")) 161 162 def test_fail_output_dir_and_run_config_are_both_set(self): 163 with self.assertRaisesRegexp( 164 ValueError, _CANNOT_SET_BOTH_OUTPUT_DIR_AND_CONFIG_MSG): 165 learn_runner.run(build_experiment, 166 output_dir=_MODIR_DIR, 167 schedule="simple_task", 168 run_config=run_config_lib.RunConfig()) 169 170 def test_fail_empty_output_dir(self): 171 with self.assertRaisesRegexp(ValueError, _MUST_SPECIFY_OUTPUT_DIR_MSG): 172 learn_runner.run(build_experiment, output_dir="", schedule="simple_task") 173 174 def test_fail_no_output_dir(self): 175 with self.assertRaisesRegexp( 176 ValueError, _MUST_SPECIFY_OUTPUT_DIR_OR_CONFIG_MSG): 177 learn_runner.run(build_experiment, None, "simple_task") 178 179 def test_fail_hparams_are_set(self): 180 hparams = _HPARAMS 181 with self.assertRaisesRegexp( 182 ValueError, _HPARAMS_CANNOT_BE_SET_FOR_OUTPUT_DIR_MSG): 183 learn_runner.run( 184 build_experiment, _MODIR_DIR, schedule="simple_task", hparams=hparams) 185 186 def test_fail_non_callable(self): 187 with self.assertRaisesRegexp(TypeError, _EXP_NOT_CALLABLE_MSG): 188 learn_runner.run("not callable", _MODIR_DIR, "simple_test") 189 190 def test_fail_not_experiment(self): 191 with self.assertRaisesRegexp(TypeError, _NOT_EXP_TYPE_MSG): 192 learn_runner.run(build_non_experiment, _MODIR_DIR, "simple_test") 193 194 def test_fail_non_existent_task(self): 195 with self.assertRaisesRegexp(ValueError, _NON_EXIST_TASK_MSG): 196 learn_runner.run(build_experiment, _MODIR_DIR, "mirage") 197 198 def test_fail_non_callable_task(self): 199 with self.assertRaisesRegexp(TypeError, _NON_CALLABLE_MSG): 200 learn_runner.run(build_experiment, _MODIR_DIR, "default") 201 202 203class LearnRunnerRunWithRunConfigTest(test.TestCase): 204 205 def setUp(self): 206 # Ensure the TF_CONFIG environment variable is unset for all tests. 207 os.environ.pop("TF_CONFIG", None) 208 209 def test_run_with_custom_schedule(self): 210 run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR) 211 self.assertEqual( 212 "simple_task, default=None.", 213 learn_runner.run(build_experiment_for_run_config, 214 run_config=run_config, 215 schedule="simple_task")) 216 217 def test_run_with_hparams(self): 218 run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR) 219 self.assertEqual( 220 "simple_task, default=None.", 221 learn_runner.run(build_experiment_for_run_config, 222 run_config=run_config, 223 schedule="simple_task", 224 hparams=_HPARAMS)) 225 226 def test_run_with_explicit_local_run(self): 227 run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR) 228 self.assertEqual( 229 "local_run-" + _MODIR_DIR, 230 learn_runner.run(build_experiment_for_run_config, 231 run_config=run_config, 232 schedule="local_run")) 233 234 def test_fail_empty_output_dir(self): 235 run_config = run_config_lib.RunConfig(model_dir="") 236 with self.assertRaisesRegexp(ValueError, _MISSING_MODEL_DIR_ERR_MSG): 237 learn_runner.run(build_experiment_for_run_config, 238 run_config=run_config, 239 schedule="local_run") 240 241 def test_fail_no_output_dir(self): 242 run_config = run_config_lib.RunConfig() 243 with self.assertRaisesRegexp(ValueError, _MISSING_MODEL_DIR_ERR_MSG): 244 learn_runner.run(build_experiment_for_run_config, 245 run_config=run_config, 246 schedule="local_run") 247 248 def test_fail_invalid_run_config_type(self): 249 run_config = "invalid_run_config" 250 with self.assertRaisesRegexp(ValueError, _INVALID_RUN_CONFIG_TYPE_MSG): 251 learn_runner.run(build_experiment_for_run_config, 252 run_config=run_config, 253 schedule="local_run") 254 255 def test_fail_invalid_hparams_type(self): 256 run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR) 257 with self.assertRaisesRegexp(ValueError, _INVALID_HPARAMS_ERR_MSG): 258 learn_runner.run(build_experiment_for_run_config, 259 run_config=run_config, 260 schedule="local_run", 261 hparams=["hparams"]) 262 263 def test_fail_non_callable(self): 264 run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR) 265 with self.assertRaisesRegexp(TypeError, _EXP_NOT_CALLABLE_MSG): 266 learn_runner.run("not callable", 267 run_config=run_config, 268 schedule="simple_task") 269 270 def test_fail_not_experiment(self): 271 def _experiment_fn(run_config, hparams): 272 del run_config, hparams # unused. 273 return "not experiment" 274 275 run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR) 276 with self.assertRaisesRegexp(TypeError, _NOT_EXP_TYPE_MSG): 277 learn_runner.run(_experiment_fn, 278 run_config=run_config, 279 schedule="simple_task") 280 281 def test_fail_non_existent_task(self): 282 run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR) 283 with self.assertRaisesRegexp(ValueError, _NON_EXIST_TASK_MSG): 284 learn_runner.run(build_experiment_for_run_config, 285 run_config=run_config, 286 schedule="mirage") 287 288 def test_fail_non_callable_task(self): 289 run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR) 290 with self.assertRaisesRegexp(TypeError, _NON_CALLABLE_MSG): 291 learn_runner.run(build_experiment_for_run_config, 292 run_config=run_config, 293 schedule="default") 294 295 def test_basic_run_config_uid_check(self): 296 expected_run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR) 297 298 def _experiment_fn(run_config, hparams): 299 del run_config, hparams # unused. 300 # Explicitly use a new run_config. 301 new_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR + "/123") 302 303 return TestExperiment(config=new_config) 304 305 with self.assertRaisesRegexp(RuntimeError, _RUN_CONFIG_UID_CHECK_ERR_MSG): 306 learn_runner.run(experiment_fn=_experiment_fn, 307 run_config=expected_run_config) 308 309 def test_fail_invalid_experiment_config_type(self): 310 expected_run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR) 311 312 def _experiment_fn(run_config, hparams): 313 del run_config, hparams # unused. 314 # Explicitly use a new run_config without `uid` method. 315 new_config = core_run_config_lib.RunConfig( 316 model_dir=_MODIR_DIR + "/123") 317 318 return TestExperiment(config=new_config) 319 320 with self.assertRaisesRegexp(RuntimeError, 321 _MISSING_RUN_CONFIG_UID_ERR_MSG): 322 learn_runner.run(experiment_fn=_experiment_fn, 323 run_config=expected_run_config) 324 325 326class LearnRunnerDefaultScheduleTest(test.TestCase): 327 328 def setUp(self): 329 # Ensure the TF_CONFIG environment variable is unset for all tests. 330 os.environ.pop("TF_CONFIG", None) 331 332 def test_schedule_from_tf_config_runs_train_on_worker(self): 333 os.environ["TF_CONFIG"] = json.dumps({ 334 "cluster": build_distributed_cluster_spec(), 335 "task": { 336 "type": run_config_lib.TaskType.WORKER 337 } 338 }) 339 # RunConfig constructor will set job_name from TF_CONFIG. 340 config = run_config_lib.RunConfig() 341 self.assertEqual( 342 "train-" + _MODIR_DIR, 343 learn_runner.run( 344 build_experiment_fn_for_output_dir(config), 345 output_dir=_MODIR_DIR)) 346 347 def test_schedule_from_tf_config_runs_train_and_evaluate_on_master(self): 348 tf_config = { 349 "cluster": build_distributed_cluster_spec(), 350 "task": { 351 "type": run_config_lib.TaskType.MASTER 352 } 353 } 354 with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): 355 config = run_config_lib.RunConfig() 356 self.assertEqual( 357 "train_and_evaluate-" + _MODIR_DIR, 358 learn_runner.run( 359 build_experiment_fn_for_output_dir(config), 360 output_dir=_MODIR_DIR)) 361 362 def test_schedule_from_tf_config_runs_serve_on_ps(self): 363 tf_config = { 364 "cluster": build_distributed_cluster_spec(), 365 "task": { 366 "type": run_config_lib.TaskType.PS 367 } 368 } 369 with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): 370 config = run_config_lib.RunConfig() 371 self.assertEqual( 372 "run_std_server-" + _MODIR_DIR, 373 learn_runner.run( 374 build_experiment_fn_for_output_dir(config), 375 output_dir=_MODIR_DIR)) 376 377 def test_no_schedule_and_no_config_runs_train_and_evaluate(self): 378 self.assertEqual( 379 "train_and_evaluate-" + _MODIR_DIR, 380 learn_runner.run(build_experiment, output_dir=_MODIR_DIR)) 381 382 def test_no_schedule_and_non_distributed_runs_train_and_evaluate(self): 383 tf_config = {"cluster": build_non_distributed_cluster_spec()} 384 with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): 385 config = run_config_lib.RunConfig() 386 self.assertEqual( 387 "train_and_evaluate-" + _MODIR_DIR, 388 learn_runner.run( 389 build_experiment_fn_for_output_dir(config), 390 output_dir=_MODIR_DIR)) 391 392 def test_fail_task_type_with_no_default_schedule(self): 393 tf_config = { 394 "cluster": build_distributed_cluster_spec(), 395 "task": { 396 "type": "foo_has_no_default_schedule" 397 } 398 } 399 with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): 400 config = run_config_lib.RunConfig() 401 create_experiment_fn = lambda output_dir: TestExperiment(config=config) 402 self.assertRaisesRegexp(ValueError, 403 "No default schedule", 404 learn_runner.run, 405 create_experiment_fn, 406 _MODIR_DIR) 407 408 def test_fail_schedule_from_config_with_no_task_type(self): 409 tf_config = {"cluster": build_distributed_cluster_spec()} 410 with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): 411 config = run_config_lib.RunConfig() 412 self.assertRaisesRegexp( 413 ValueError, 414 "Must specify a schedule", 415 learn_runner.run, 416 lambda output_dir: TestExperiment(config=config), 417 output_dir=_MODIR_DIR) 418 419 420if __name__ == "__main__": 421 test.main() 422