• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""learn_main tests."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import json
22import os
23
24from tensorflow.contrib.learn.python.learn import evaluable  # pylint: disable=g-import-not-at-top
25from tensorflow.contrib.learn.python.learn import experiment
26from tensorflow.contrib.learn.python.learn import learn_runner
27from tensorflow.contrib.learn.python.learn import trainable
28
29from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
30from tensorflow.contrib.training.python.training import hparam as hparam_lib
31from tensorflow.python.estimator import run_config as core_run_config_lib
32from tensorflow.python.platform import test
33from tensorflow.python.platform import tf_logging
34
35patch = test.mock.patch
36
37_MODIR_DIR = "/tmp"
38_HPARAMS = hparam_lib.HParams(learning_rate=0.01)
39_MUST_SPECIFY_OUTPUT_DIR_MSG = "Must specify an output directory"
40_MISSING_MODEL_DIR_ERR_MSG = (
41    "Must specify a model directory `model_dir` in `run_config`.")
42_EXP_NOT_CALLABLE_MSG = "Experiment builder .* is not callable"
43_INVALID_HPARAMS_ERR_MSG = "`hparams` must be `HParams` instance"
44_NOT_EXP_TYPE_MSG = "Experiment builder did not return an Experiment"
45_NON_EXIST_TASK_MSG = "Schedule references non-existent task"
46_NON_CALLABLE_MSG = "Schedule references non-callable member"
47_MUST_SPECIFY_OUTPUT_DIR_OR_CONFIG_MSG = (
48    "Must set value for `output_dir` or `run_config`")
49_HPARAMS_CANNOT_BE_SET_FOR_OUTPUT_DIR_MSG = (
50    "Must set `hparams` as None for `experiment_fn` with `output_dir`.")
51_CANNOT_SET_BOTH_OUTPUT_DIR_AND_CONFIG_MSG = (
52    "Cannot provide both `output_dir` and `run_config`")
53_INVALID_RUN_CONFIG_TYPE_MSG = (
54    "`run_config` must be `tf.contrib.learn.RunConfig` instance")
55_RUN_CONFIG_UID_CHECK_ERR_MSG = (
56    "`RunConfig` instance is expected to be used by the `Estimator`")
57_MISSING_RUN_CONFIG_UID_ERR_MSG = (
58    "Pass `run_config` argument of the `experiment_fn` to the Estimator")
59
60
61class TestExperiment(experiment.Experiment):
62
63  def __init__(self, default=None, config=None, model_dir=None):
64    self.default = default
65    self.config = config
66    internal_model_dir = model_dir or config.model_dir
67    self._model_dir = internal_model_dir
68
69    class Estimator(evaluable.Evaluable, trainable.Trainable):
70      config = self.config
71
72      @property
73      def model_dir(self):
74        return internal_model_dir
75
76      def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
77              monitors=None, max_steps=None):
78        raise NotImplementedError
79
80      def evaluate(self, x=None, y=None, input_fn=None, feed_fn=None,
81                   batch_size=None, steps=None, metrics=None, name=None,
82                   checkpoint_path=None, hooks=None):
83        raise NotImplementedError
84
85    super(TestExperiment, self).__init__(Estimator(), None, None)
86
87  def local_run(self):
88    return "local_run-{}".format(self._model_dir)
89
90  def train(self):
91    return "train-{}".format(self._model_dir)
92
93  def run_std_server(self):
94    return "run_std_server-{}".format(self._model_dir)
95
96  def train_and_evaluate(self):
97    return "train_and_evaluate-{}".format(self._model_dir)
98
99  def simple_task(self):
100    return "simple_task, default=%s." % self.default
101
102
103# pylint: disable=unused-argument
104def build_experiment(output_dir):
105  tf_logging.info("In default build_experiment.")
106  return TestExperiment(model_dir=output_dir)
107
108
109def build_experiment_fn_for_output_dir(run_config=None):
110  def _build_experiment(output_dir):
111    tf_logging.info("In default build_experiment.")
112    return TestExperiment(config=run_config, model_dir=output_dir)
113  return _build_experiment
114
115
116def build_experiment_for_run_config(run_config, hparams):
117  if hparams is not None and hparams != _HPARAMS:
118    raise ValueError("hparams is not set correctly")
119  return TestExperiment(config=run_config)
120
121
122def build_non_experiment(output_dir):
123  return "Ceci n'est pas un Experiment."
124
125
126# pylint: enable=unused-argument
127
128
129def build_distributed_cluster_spec():
130  return {
131      run_config_lib.TaskType.PS: ["localhost:1234", "localhost:1235"],
132      run_config_lib.TaskType.WORKER: ["localhost:1236", "localhost:1237"],
133      run_config_lib.TaskType.MASTER: ["localhost:1238"],
134      "foo_has_no_default_schedule": ["localhost:1239"]
135  }
136
137
138def build_non_distributed_cluster_spec():
139  return {"foo": ["localhost:1234"]}
140
141
142class LearnRunnerRunWithOutputDirTest(test.TestCase):
143
144  def setUp(self):
145    # Ensure the TF_CONFIG environment variable is unset for all tests.
146    os.environ.pop("TF_CONFIG", None)
147
148  def test_run_with_custom_schedule(self):
149    self.assertEqual(
150        "simple_task, default=None.",
151        learn_runner.run(build_experiment,
152                         output_dir=_MODIR_DIR,
153                         schedule="simple_task"))
154
155  def test_run_with_explicit_local_run(self):
156    self.assertEqual(
157        "local_run-" + _MODIR_DIR,
158        learn_runner.run(build_experiment,
159                         output_dir=_MODIR_DIR,
160                         schedule="local_run"))
161
162  def test_fail_output_dir_and_run_config_are_both_set(self):
163    with self.assertRaisesRegexp(
164        ValueError, _CANNOT_SET_BOTH_OUTPUT_DIR_AND_CONFIG_MSG):
165      learn_runner.run(build_experiment,
166                       output_dir=_MODIR_DIR,
167                       schedule="simple_task",
168                       run_config=run_config_lib.RunConfig())
169
170  def test_fail_empty_output_dir(self):
171    with self.assertRaisesRegexp(ValueError, _MUST_SPECIFY_OUTPUT_DIR_MSG):
172      learn_runner.run(build_experiment, output_dir="", schedule="simple_task")
173
174  def test_fail_no_output_dir(self):
175    with self.assertRaisesRegexp(
176        ValueError, _MUST_SPECIFY_OUTPUT_DIR_OR_CONFIG_MSG):
177      learn_runner.run(build_experiment, None, "simple_task")
178
179  def test_fail_hparams_are_set(self):
180    hparams = _HPARAMS
181    with self.assertRaisesRegexp(
182        ValueError, _HPARAMS_CANNOT_BE_SET_FOR_OUTPUT_DIR_MSG):
183      learn_runner.run(
184          build_experiment, _MODIR_DIR, schedule="simple_task", hparams=hparams)
185
186  def test_fail_non_callable(self):
187    with self.assertRaisesRegexp(TypeError, _EXP_NOT_CALLABLE_MSG):
188      learn_runner.run("not callable", _MODIR_DIR, "simple_test")
189
190  def test_fail_not_experiment(self):
191    with self.assertRaisesRegexp(TypeError, _NOT_EXP_TYPE_MSG):
192      learn_runner.run(build_non_experiment, _MODIR_DIR, "simple_test")
193
194  def test_fail_non_existent_task(self):
195    with self.assertRaisesRegexp(ValueError, _NON_EXIST_TASK_MSG):
196      learn_runner.run(build_experiment, _MODIR_DIR, "mirage")
197
198  def test_fail_non_callable_task(self):
199    with self.assertRaisesRegexp(TypeError, _NON_CALLABLE_MSG):
200      learn_runner.run(build_experiment, _MODIR_DIR, "default")
201
202
203class LearnRunnerRunWithRunConfigTest(test.TestCase):
204
205  def setUp(self):
206    # Ensure the TF_CONFIG environment variable is unset for all tests.
207    os.environ.pop("TF_CONFIG", None)
208
209  def test_run_with_custom_schedule(self):
210    run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)
211    self.assertEqual(
212        "simple_task, default=None.",
213        learn_runner.run(build_experiment_for_run_config,
214                         run_config=run_config,
215                         schedule="simple_task"))
216
217  def test_run_with_hparams(self):
218    run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)
219    self.assertEqual(
220        "simple_task, default=None.",
221        learn_runner.run(build_experiment_for_run_config,
222                         run_config=run_config,
223                         schedule="simple_task",
224                         hparams=_HPARAMS))
225
226  def test_run_with_explicit_local_run(self):
227    run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)
228    self.assertEqual(
229        "local_run-" + _MODIR_DIR,
230        learn_runner.run(build_experiment_for_run_config,
231                         run_config=run_config,
232                         schedule="local_run"))
233
234  def test_fail_empty_output_dir(self):
235    run_config = run_config_lib.RunConfig(model_dir="")
236    with self.assertRaisesRegexp(ValueError, _MISSING_MODEL_DIR_ERR_MSG):
237      learn_runner.run(build_experiment_for_run_config,
238                       run_config=run_config,
239                       schedule="local_run")
240
241  def test_fail_no_output_dir(self):
242    run_config = run_config_lib.RunConfig()
243    with self.assertRaisesRegexp(ValueError, _MISSING_MODEL_DIR_ERR_MSG):
244      learn_runner.run(build_experiment_for_run_config,
245                       run_config=run_config,
246                       schedule="local_run")
247
248  def test_fail_invalid_run_config_type(self):
249    run_config = "invalid_run_config"
250    with self.assertRaisesRegexp(ValueError, _INVALID_RUN_CONFIG_TYPE_MSG):
251      learn_runner.run(build_experiment_for_run_config,
252                       run_config=run_config,
253                       schedule="local_run")
254
255  def test_fail_invalid_hparams_type(self):
256    run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)
257    with self.assertRaisesRegexp(ValueError, _INVALID_HPARAMS_ERR_MSG):
258      learn_runner.run(build_experiment_for_run_config,
259                       run_config=run_config,
260                       schedule="local_run",
261                       hparams=["hparams"])
262
263  def test_fail_non_callable(self):
264    run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)
265    with self.assertRaisesRegexp(TypeError, _EXP_NOT_CALLABLE_MSG):
266      learn_runner.run("not callable",
267                       run_config=run_config,
268                       schedule="simple_task")
269
270  def test_fail_not_experiment(self):
271    def _experiment_fn(run_config, hparams):
272      del run_config, hparams  # unused.
273      return "not experiment"
274
275    run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)
276    with self.assertRaisesRegexp(TypeError, _NOT_EXP_TYPE_MSG):
277      learn_runner.run(_experiment_fn,
278                       run_config=run_config,
279                       schedule="simple_task")
280
281  def test_fail_non_existent_task(self):
282    run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)
283    with self.assertRaisesRegexp(ValueError, _NON_EXIST_TASK_MSG):
284      learn_runner.run(build_experiment_for_run_config,
285                       run_config=run_config,
286                       schedule="mirage")
287
288  def test_fail_non_callable_task(self):
289    run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)
290    with self.assertRaisesRegexp(TypeError, _NON_CALLABLE_MSG):
291      learn_runner.run(build_experiment_for_run_config,
292                       run_config=run_config,
293                       schedule="default")
294
295  def test_basic_run_config_uid_check(self):
296    expected_run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)
297
298    def _experiment_fn(run_config, hparams):
299      del run_config, hparams  # unused.
300      # Explicitly use a new run_config.
301      new_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR + "/123")
302
303      return TestExperiment(config=new_config)
304
305    with self.assertRaisesRegexp(RuntimeError, _RUN_CONFIG_UID_CHECK_ERR_MSG):
306      learn_runner.run(experiment_fn=_experiment_fn,
307                       run_config=expected_run_config)
308
309  def test_fail_invalid_experiment_config_type(self):
310    expected_run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)
311
312    def _experiment_fn(run_config, hparams):
313      del run_config, hparams  # unused.
314      # Explicitly use a new run_config without `uid` method.
315      new_config = core_run_config_lib.RunConfig(
316          model_dir=_MODIR_DIR + "/123")
317
318      return TestExperiment(config=new_config)
319
320    with self.assertRaisesRegexp(RuntimeError,
321                                 _MISSING_RUN_CONFIG_UID_ERR_MSG):
322      learn_runner.run(experiment_fn=_experiment_fn,
323                       run_config=expected_run_config)
324
325
326class LearnRunnerDefaultScheduleTest(test.TestCase):
327
328  def setUp(self):
329    # Ensure the TF_CONFIG environment variable is unset for all tests.
330    os.environ.pop("TF_CONFIG", None)
331
332  def test_schedule_from_tf_config_runs_train_on_worker(self):
333    os.environ["TF_CONFIG"] = json.dumps({
334        "cluster": build_distributed_cluster_spec(),
335        "task": {
336            "type": run_config_lib.TaskType.WORKER
337        }
338    })
339    # RunConfig constructor will set job_name from TF_CONFIG.
340    config = run_config_lib.RunConfig()
341    self.assertEqual(
342        "train-" + _MODIR_DIR,
343        learn_runner.run(
344            build_experiment_fn_for_output_dir(config),
345            output_dir=_MODIR_DIR))
346
347  def test_schedule_from_tf_config_runs_train_and_evaluate_on_master(self):
348    tf_config = {
349        "cluster": build_distributed_cluster_spec(),
350        "task": {
351            "type": run_config_lib.TaskType.MASTER
352        }
353    }
354    with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
355      config = run_config_lib.RunConfig()
356      self.assertEqual(
357          "train_and_evaluate-" + _MODIR_DIR,
358          learn_runner.run(
359              build_experiment_fn_for_output_dir(config),
360              output_dir=_MODIR_DIR))
361
362  def test_schedule_from_tf_config_runs_serve_on_ps(self):
363    tf_config = {
364        "cluster": build_distributed_cluster_spec(),
365        "task": {
366            "type": run_config_lib.TaskType.PS
367        }
368    }
369    with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
370      config = run_config_lib.RunConfig()
371      self.assertEqual(
372          "run_std_server-" + _MODIR_DIR,
373          learn_runner.run(
374              build_experiment_fn_for_output_dir(config),
375              output_dir=_MODIR_DIR))
376
377  def test_no_schedule_and_no_config_runs_train_and_evaluate(self):
378    self.assertEqual(
379        "train_and_evaluate-" + _MODIR_DIR,
380        learn_runner.run(build_experiment, output_dir=_MODIR_DIR))
381
382  def test_no_schedule_and_non_distributed_runs_train_and_evaluate(self):
383    tf_config = {"cluster": build_non_distributed_cluster_spec()}
384    with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
385      config = run_config_lib.RunConfig()
386      self.assertEqual(
387          "train_and_evaluate-" + _MODIR_DIR,
388          learn_runner.run(
389              build_experiment_fn_for_output_dir(config),
390              output_dir=_MODIR_DIR))
391
392  def test_fail_task_type_with_no_default_schedule(self):
393    tf_config = {
394        "cluster": build_distributed_cluster_spec(),
395        "task": {
396            "type": "foo_has_no_default_schedule"
397        }
398    }
399    with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
400      config = run_config_lib.RunConfig()
401      create_experiment_fn = lambda output_dir: TestExperiment(config=config)
402      self.assertRaisesRegexp(ValueError,
403                              "No default schedule",
404                              learn_runner.run,
405                              create_experiment_fn,
406                              _MODIR_DIR)
407
408  def test_fail_schedule_from_config_with_no_task_type(self):
409    tf_config = {"cluster": build_distributed_cluster_spec()}
410    with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
411      config = run_config_lib.RunConfig()
412      self.assertRaisesRegexp(
413          ValueError,
414          "Must specify a schedule",
415          learn_runner.run,
416          lambda output_dir: TestExperiment(config=config),
417          output_dir=_MODIR_DIR)
418
419
420if __name__ == "__main__":
421  test.main()
422