• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""run_config.py tests."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import json
23
24from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
25from tensorflow.core.protobuf import config_pb2
26from tensorflow.python.estimator import run_config as core_run_config
27from tensorflow.python.platform import test
28from tensorflow.python.training import server_lib
29
30TEST_DIR = "test_dir"
31ANOTHER_TEST_DIR = "another_test_dir"
32MASTER = "master_"
33RANDOM_SEED = 123
34
35patch = test.mock.patch
36
37
38def _create_run_config_with_cluster_spec(tf_config_str):
39  with patch.dict("os.environ", {"TF_CONFIG": tf_config_str}):
40    return run_config_lib.RunConfig(
41        tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR)
42
43
44class RunConfigTest(test.TestCase):
45
46  def test_instance_of_core_run_config(self):
47    config = run_config_lib.RunConfig()
48    self.assertTrue(isinstance(config, core_run_config.RunConfig))
49
50  def test_defaults_with_no_tf_config(self):
51    config = run_config_lib.RunConfig()
52    self.assertEqual(config.master, "")
53    self.assertEqual(config.task_id, 0)
54    self.assertEqual(config.num_ps_replicas, 0)
55    self.assertEqual(config.cluster_spec, {})
56    self.assertIsNone(config.task_type)
57    self.assertTrue(config.is_chief)
58    self.assertEqual(config.evaluation_master, "")
59
60  def test_values_from_tf_config(self):
61    tf_config = {
62        "cluster": {
63            run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
64            run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"]
65        },
66        "task": {
67            "type": run_config_lib.TaskType.WORKER,
68            "index": 1
69        }
70    }
71    with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
72      config = run_config_lib.RunConfig()
73
74    self.assertEqual(config.master, "grpc://host4:4")
75    self.assertEqual(config.task_id, 1)
76    self.assertEqual(config.num_ps_replicas, 2)
77    self.assertEqual(config.num_worker_replicas, 3)
78    self.assertEqual(config.cluster_spec.as_dict(), tf_config["cluster"])
79    self.assertEqual(config.task_type, run_config_lib.TaskType.WORKER)
80    self.assertFalse(config.is_chief)
81    self.assertEqual(config.evaluation_master, "")
82
83  def test_explicitly_specified_values(self):
84    cluster_spec = {
85        run_config_lib.TaskType.PS: ["localhost:9990"],
86        "my_job_name": ["localhost:9991", "localhost:9992", "localhost:0"]
87    }
88    tf_config = {
89        "cluster": cluster_spec,
90        "task": {
91            "type": run_config_lib.TaskType.WORKER,
92            "index": 2
93        }
94    }
95    with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
96      config = run_config_lib.RunConfig(
97          master="localhost:0", evaluation_master="localhost:9991")
98
99    self.assertEqual(config.master, "localhost:0")
100    self.assertEqual(config.task_id, 2)
101    self.assertEqual(config.num_ps_replicas, 1)
102    self.assertEqual(config.num_worker_replicas, 0)
103    self.assertEqual(config.cluster_spec, server_lib.ClusterSpec(cluster_spec))
104    self.assertEqual(config.task_type, run_config_lib.TaskType.WORKER)
105    self.assertFalse(config.is_chief)
106    self.assertEqual(config.evaluation_master, "localhost:9991")
107
108  def test_single_node_in_cluster_spec_produces_empty_master(self):
109    tf_config = {"cluster": {run_config_lib.TaskType.WORKER: ["host1:1"]}}
110    with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
111      config = run_config_lib.RunConfig()
112      self.assertEqual(config.master, "")
113
114  def test_no_task_type_produces_empty_master(self):
115    tf_config = {
116        "cluster": {
117            run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
118            run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"]
119        },
120        # Omits "task": {"type": "worker}
121    }
122    with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
123      config = run_config_lib.RunConfig()
124      self.assertEqual(config.master, "")
125
126  def test_invalid_job_name_raises(self):
127    tf_config = {
128        "cluster": {
129            run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
130            run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"]
131        },
132        "task": {
133            "type": "not_in_cluster_spec"
134        }
135    }
136    expected_msg_regexp = "not_in_cluster_spec is not a valid task"
137    with patch.dict(
138        "os.environ",
139        {"TF_CONFIG": json.dumps(tf_config)}), self.assertRaisesRegexp(
140            ValueError, expected_msg_regexp):
141      run_config_lib.RunConfig()
142
143  def test_illegal_task_index_raises(self):
144    tf_config = {
145        "cluster": {
146            run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
147            run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"]
148        },
149        "task": {
150            "type": run_config_lib.TaskType.WORKER,
151            "index": 3
152        }
153    }
154    expected_msg_regexp = "3 is not a valid task_id"
155    with patch.dict(
156        "os.environ",
157        {"TF_CONFIG": json.dumps(tf_config)}), self.assertRaisesRegexp(
158            ValueError, expected_msg_regexp):
159      run_config_lib.RunConfig()
160
161  def test_is_chief_from_cloud_tf_config(self):
162    # is_chief should be true when ["task"]["type"] == "master" and
163    # index == 0 and ["task"]["environment"] == "cloud". Note that
164    # test_values_from_tf_config covers the non-master case.
165    tf_config = {
166        "cluster": {
167            run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
168            run_config_lib.TaskType.MASTER: ["host3:3"],
169            run_config_lib.TaskType.WORKER: ["host4:4", "host5:5", "host6:6"]
170        },
171        "task": {
172            "type": run_config_lib.TaskType.MASTER,
173            "index": 0
174        },
175        "environment": "cloud"
176    }
177    with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
178      config = run_config_lib.RunConfig()
179
180    self.assertTrue(config.is_chief)
181
182  def test_is_chief_from_noncloud_tf_config(self):
183    # is_chief should be true when ["task"]["type"] == "worker" and
184    # index == 0 if ["task"]["environment"] != "cloud".
185    tf_config = {
186        "cluster": {
187            run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
188            run_config_lib.TaskType.MASTER: ["host3:3"],
189            run_config_lib.TaskType.WORKER: ["host4:4", "host5:5", "host6:6"]
190        },
191        "task": {
192            "type": run_config_lib.TaskType.WORKER,
193            "index": 0
194        },
195        "environment": "random"
196    }
197    with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
198      config = run_config_lib.RunConfig()
199
200    self.assertTrue(config.is_chief)
201
202    # But task 0 for a job named "master" should not be.
203    tf_config = {
204        "cluster": {
205            run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
206            run_config_lib.TaskType.MASTER: ["host3:3"],
207            run_config_lib.TaskType.WORKER: ["host4:4", "host5:5", "host6:6"]
208        },
209        "task": {
210            "type": run_config_lib.TaskType.MASTER,
211            "index": 0
212        },
213        "environment": "random"
214    }
215    with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
216      config = run_config_lib.RunConfig()
217
218    self.assertFalse(config.is_chief)
219
220  def test_default_is_chief_from_tf_config_without_job_name(self):
221    tf_config = {"cluster": {}, "task": {}}
222    with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
223      config = run_config_lib.RunConfig()
224
225    self.assertTrue(config.is_chief)
226
227  def test_model_dir(self):
228    empty_config = run_config_lib.RunConfig()
229    self.assertIsNone(empty_config.model_dir)
230
231    config = run_config_lib.RunConfig(model_dir=TEST_DIR)
232    self.assertEqual(TEST_DIR, config.model_dir)
233
234  def test_model_dir_in_tf_config(self):
235    tf_config = {"model_dir": TEST_DIR}
236    with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
237      run_config = run_config_lib.RunConfig()
238    self.assertEqual(TEST_DIR, run_config.model_dir)
239
240  def test_model_dir_both_in_tf_config_and_constructor(self):
241    tf_config = {"model_dir": TEST_DIR}
242    with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
243      run_config = run_config_lib.RunConfig(model_dir=TEST_DIR)
244    self.assertEqual(TEST_DIR, run_config.model_dir)
245
246  def test_model_dir_fail_if_constructor_value_mismatch_tf_config(self):
247    tf_config = {"model_dir": TEST_DIR}
248    with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
249      with self.assertRaisesRegexp(
250          ValueError,
251          "`model_dir` provided in RunConfig .* must have "
252          "the same value .* in TF_CONFIG"):
253        run_config_lib.RunConfig(model_dir=TEST_DIR + "/sub_dir")
254
255  def test_replace(self):
256    config = run_config_lib.RunConfig(
257        tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR)
258    self.assertEqual(TEST_DIR, config.model_dir)
259    self.assertEqual(RANDOM_SEED, config.tf_random_seed)
260
261    new_config = config.replace(model_dir=ANOTHER_TEST_DIR)
262    self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir)
263    self.assertEqual(RANDOM_SEED, new_config.tf_random_seed)
264    self.assertEqual(RANDOM_SEED, config.tf_random_seed)
265
266  def test_uid_for_different_configs(self):
267    config = run_config_lib.RunConfig(
268        tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR)
269
270    expected_uid = config.uid()
271    # Check for 10 times, which should prove something.
272    for _ in range(10):
273      self.assertEqual(expected_uid, config.uid())
274
275    new_config = config.replace(model_dir=ANOTHER_TEST_DIR)
276    self.assertEqual(TEST_DIR, config.model_dir)
277    self.assertNotEqual(expected_uid, new_config.uid())
278    self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir)
279
280  def test_uid_for_whitelist(self):
281    whitelist = ["model_dir"]
282    config = run_config_lib.RunConfig(
283        tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR)
284
285    expected_uid = config.uid(whitelist)
286    self.assertEqual(expected_uid, config.uid(whitelist))
287
288    new_config = config.replace(model_dir=ANOTHER_TEST_DIR)
289    self.assertEqual(TEST_DIR, config.model_dir)
290    self.assertEqual(expected_uid, new_config.uid(whitelist))
291    self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir)
292
293  def test_uid_for_default_whitelist(self):
294    config = run_config_lib.RunConfig(
295        tf_random_seed=11,
296        save_summary_steps=12,
297        save_checkpoints_steps=13,
298        save_checkpoints_secs=14,
299        session_config=config_pb2.ConfigProto(allow_soft_placement=True),
300        keep_checkpoint_max=16,
301        keep_checkpoint_every_n_hours=17)
302    self.assertEqual(11, config.tf_random_seed)
303    self.assertEqual(12, config.save_summary_steps)
304    self.assertEqual(13, config.save_checkpoints_steps)
305    self.assertEqual(14, config.save_checkpoints_secs)
306    self.assertEqual(config_pb2.ConfigProto(allow_soft_placement=True),
307                     config.session_config)
308    self.assertEqual(16, config.keep_checkpoint_max)
309    self.assertEqual(17, config.keep_checkpoint_every_n_hours)
310
311    new_config = run_config_lib.RunConfig(
312        tf_random_seed=21,
313        save_summary_steps=22,
314        save_checkpoints_steps=23,
315        save_checkpoints_secs=24,
316        session_config=config_pb2.ConfigProto(allow_soft_placement=False),
317        keep_checkpoint_max=26,
318        keep_checkpoint_every_n_hours=27)
319    self.assertEqual(config.uid(), new_config.uid())
320    # model_dir is not on the default whitelist.
321    self.assertNotEqual(config.uid(whitelist=[]),
322                        new_config.uid(whitelist=[]))
323    new_config = new_config.replace(model_dir=ANOTHER_TEST_DIR)
324    self.assertNotEqual(config.uid(), new_config.uid())
325
326  def test_uid_for_deepcopy(self):
327    tf_config = {
328        "cluster": {
329            run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
330            run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"]
331        },
332        "task": {
333            "type": run_config_lib.TaskType.WORKER,
334            "index": 1
335        }
336    }
337
338    config = _create_run_config_with_cluster_spec(json.dumps(tf_config))
339    expected_uid = config.uid()
340    self.assertEqual(tf_config["cluster"], config.cluster_spec.as_dict())
341
342    new_config = copy.deepcopy(config)
343    self.assertEqual(tf_config["cluster"], new_config.cluster_spec.as_dict())
344    self.assertEqual(expected_uid, new_config.uid())
345
346  def test_uid_for_different_cluster_spec_order(self):
347    tf_config_1_str = (
348        "{\"cluster\": {\"ps\": [\"host1:1\", \"host2:2\"], "
349        "\"worker\": [\"host3:3\", \"host4:4\", \"host5:5\"]}}")
350
351    tf_config_2_str = (
352        "{\"cluster\": {\"worker\": [\"host3:3\", \"host4:4\", \"host5:5\"],"
353        "\"ps\": [\"host1:1\", \"host2:2\"]}}")
354
355    # Wraps in a loop to check flakiness.
356    for _ in range(100):
357      uid_1 = _create_run_config_with_cluster_spec(tf_config_1_str).uid()
358      uid_2 = _create_run_config_with_cluster_spec(tf_config_2_str).uid()
359      self.assertEqual(uid_1, uid_2)
360
361  def test_uid_for_different_cluster_specs(self):
362    tf_config_1 = {
363        "cluster": {
364            run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
365            run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"]
366        },
367    }
368
369    tf_config_2 = {
370        "cluster": {
371            run_config_lib.TaskType.PS: ["host1:1"],
372            run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"]
373        },
374    }
375
376    uid_1 = _create_run_config_with_cluster_spec(json.dumps(tf_config_1)).uid()
377    uid_2 = _create_run_config_with_cluster_spec(json.dumps(tf_config_2)).uid()
378    self.assertNotEqual(uid_1, uid_2)
379
380  def test_num_worker_replicas_counts_in_master_too(self):
381    tf_config = {
382        "cluster": {
383            run_config_lib.TaskType.PS: ["host1:1", "host2:2"],
384            run_config_lib.TaskType.MASTER: ["host6:6"],
385            run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"],
386        },
387        "task": {
388            "type": run_config_lib.TaskType.WORKER,
389            "index": 1
390        }
391    }
392
393    config = _create_run_config_with_cluster_spec(json.dumps(tf_config))
394    self.assertEqual(config.num_worker_replicas, 4)
395
396
397if __name__ == "__main__":
398  test.main()
399