• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Run Config (deprecated, use tf.estimator.RunConfig instead).
16
17This module and all its submodules are deprecated. See
18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
19for migration instructions.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import collections
27import json
28import os
29
30import six
31
32from tensorflow.contrib.framework.python.framework import experimental
33from tensorflow.core.protobuf import config_pb2
34from tensorflow.python.estimator import run_config as core_run_config
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.training import server_lib
37from tensorflow.python.util.deprecation import deprecated
38
39
40# A list of the property names in RunConfig user allows to change. They will
41# not affect the execution framework, so when execution framework checks the
42# `uid` of the RunConfig, it should be ignored.
43_DEFAULT_UID_WHITE_LIST = [
44    'tf_random_seed',
45    'save_summary_steps',
46    'save_checkpoints_steps',
47    'save_checkpoints_secs',
48    'session_config',
49    'keep_checkpoint_max',
50    'keep_checkpoint_every_n_hours',
51    'log_step_count_steps',
52]
53
54
55class Environment(object):
56  """DEPRECATED CLASS."""
57  # For running general distributed training.
58  CLOUD = 'cloud'
59  # For running Google-internal distributed training.
60  GOOGLE = 'google'
61  # For running on local desktop.
62  LOCAL = 'local'
63
64
65class TaskType(object):
66  """DEPRECATED CLASS."""
67  MASTER = 'master'
68  PS = 'ps'
69  WORKER = 'worker'
70
71
72class ClusterConfig(object):
73  """This class specifies the configurations for a distributed run.
74
75  THIS CLASS IS DEPRECATED. Use tf.estimator.RunConfig instead.
76
77  If you're using an `Estimator`, you should probably use the subclass
78  RunConfig instead.
79  """
80
81  def __init__(self, master=None, evaluation_master=None):
82    """Constructor.
83
84    Sets the properties `cluster_spec`, `is_chief`, `master` (if `None` in the
85    args), `num_ps_replicas`, `task_id`, and `task_type` based on the
86    `TF_CONFIG` environment variable, if the pertinent information is
87    present. The `TF_CONFIG` environment variable is a JSON object with
88    attributes: `cluster`, `environment`, and `task`.
89
90    `cluster` is a JSON serialized version of `ClusterSpec`'s Python dict from
91    `server_lib.py`, mapping task types (usually one of the TaskType enums) to a
92    list of task addresses.
93
94    `environment` specifies the runtime environment for the job (usually one of
95    the `Environment` enums). Defaults to `LOCAL`.
96
97    `task` has two attributes: `type` and `index`, where `type` can be any of
98    the task types in `cluster`. When `TF_CONFIG` contains said information, the
99    following properties are set on this class:
100
101    * `task_type` is set to `TF_CONFIG['task']['type']`. Defaults to `None`.
102    * `task_id` is set to `TF_CONFIG['task']['index']`. Defaults to 0.
103    * `cluster_spec` is parsed from `TF_CONFIG['cluster']`. Defaults to {}.
104    * `master` is determined by looking up `task_type` and `task_id` in the
105      `cluster_spec`. Defaults to ''.
106    * `num_ps_replicas` is set by counting the number of nodes listed
107      in the `ps` attribute of `cluster_spec`. Defaults to 0.
108    * `num_worker_replicas` is set by counting the number of nodes listed
109      in the `worker` attribute of `cluster_spec`. Defaults to 0.
110    * `is_chief` is deteremined based on `task_type`, `type_id`, and
111      `environment`.
112
113    Example:
114    ```
115      cluster = {'ps': ['host1:2222', 'host2:2222'],
116                 'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
117      os.environ['TF_CONFIG'] = json.dumps(
118          {'cluster': cluster,
119           'task': {'type': 'worker', 'index': 1}})
120      config = ClusterConfig()
121      assert config.master == 'host4:2222'
122      assert config.task_id == 1
123      assert config.num_ps_replicas == 2
124      assert config.num_worker_replicas == 3
125      assert config.cluster_spec == server_lib.ClusterSpec(cluster)
126      assert config.task_type == 'worker'
127      assert not config.is_chief
128    ```
129
130    Args:
131      master: TensorFlow master. Defaults to empty string for local.
132      evaluation_master: The master on which to perform evaluation.
133    """
134    # If not explicitly specified in the constructor and the TF_CONFIG
135    # environment variable is present, load cluster_spec from TF_CONFIG.
136    config = json.loads(os.environ.get('TF_CONFIG') or '{}')
137
138    # Set task_type and task_id if the TF_CONFIG environment variable is
139    # present.  Otherwise, use the respective default (None / 0).
140    task_env = config.get('task', {})
141    self._task_type = task_env.get('type', None)
142    self._task_id = self.get_task_id()
143
144    self._cluster_spec = server_lib.ClusterSpec(config.get('cluster', {}))
145    self._master = (master if master is not None else
146                    _get_master(self._cluster_spec, self._task_type,
147                                self._task_id) or '')
148    self._num_ps_replicas = _count_ps(self._cluster_spec) or 0
149    self._num_worker_replicas = _count_worker(self._cluster_spec) or 0
150
151    # Set is_chief.
152    self._environment = config.get('environment', Environment.LOCAL)
153    self._is_chief = None
154    if self._task_type is None:
155      self._is_chief = (self._task_id == 0)
156    elif self._environment == Environment.CLOUD:
157      # When the TF_CONFIG environment variable is set, we can set the
158      # default of is_chief to 0 when task_type is "master" and task_id is 0.
159      self._is_chief = (self._task_type == TaskType.MASTER and
160                        self._task_id == 0)
161    else:
162      # Legacy behavior is that is_chief is None if task_id == 0.
163      self._is_chief = (self._task_type == TaskType.WORKER and
164                        self._task_id == 0)
165
166    self._evaluation_master = evaluation_master or ''
167
168  @property
169  def cluster_spec(self):
170    return self._cluster_spec
171
172  @property
173  def environment(self):
174    return self._environment
175
176  @property
177  def evaluation_master(self):
178    return self._evaluation_master
179
180  @property
181  def is_chief(self):
182    return self._is_chief
183
184  @property
185  def master(self):
186    return self._master
187
188  @property
189  def num_ps_replicas(self):
190    return self._num_ps_replicas
191
192  @property
193  def num_worker_replicas(self):
194    return self._num_worker_replicas
195
196  @property
197  def task_id(self):
198    return self._task_id
199
200  @property
201  def task_type(self):
202    return self._task_type
203
204  @staticmethod
205  def get_task_id():
206    """Returns task index from `TF_CONFIG` environmental variable.
207
208    If you have a ClusterConfig instance, you can just access its task_id
209    property instead of calling this function and re-parsing the environmental
210    variable.
211
212    Returns:
213      `TF_CONFIG['task']['index']`. Defaults to 0.
214    """
215    config = json.loads(os.environ.get('TF_CONFIG') or '{}')
216    task_env = config.get('task', {})
217    task_index = task_env.get('index')
218    return int(task_index) if task_index else 0
219
220
221class RunConfig(ClusterConfig, core_run_config.RunConfig):
222  """This class specifies the configurations for an `Estimator` run.
223
224  This class is a deprecated implementation of `tf.estimator.RunConfig`
225  interface.
226  """
227  _USE_DEFAULT = 0
228
229  @deprecated(None, 'When switching to tf.estimator.Estimator, use'
230              ' tf.estimator.RunConfig instead.')
231  def __init__(self,
232               master=None,
233               num_cores=0,
234               log_device_placement=False,
235               gpu_memory_fraction=1,
236               tf_random_seed=None,
237               save_summary_steps=100,
238               save_checkpoints_secs=_USE_DEFAULT,
239               save_checkpoints_steps=None,
240               keep_checkpoint_max=5,
241               keep_checkpoint_every_n_hours=10000,
242               log_step_count_steps=100,
243               protocol=None,
244               evaluation_master='',
245               model_dir=None,
246               session_config=None):
247    """Constructor.
248
249    The superclass `ClusterConfig` may set properties like `cluster_spec`,
250    `is_chief`, `master` (if `None` in the args), `num_ps_replicas`, `task_id`,
251    and `task_type` based on the `TF_CONFIG` environment variable. See
252    `ClusterConfig` for more details.
253
254    N.B.: If `save_checkpoints_steps` or `save_checkpoints_secs` is set,
255    `keep_checkpoint_max` might need to be adjusted accordingly, especially in
256    distributed training. For example, setting `save_checkpoints_secs` as 60
257    without adjusting `keep_checkpoint_max` (defaults to 5) leads to situation
258    that checkpoint would be garbage collected after 5 minutes. In distributed
259    training, the evaluation job starts asynchronously and might fail to load or
260    find the checkpoint due to race condition.
261
262    Args:
263      master: TensorFlow master. Defaults to empty string for local.
264      num_cores: Number of cores to be used. If 0, the system picks an
265        appropriate number (default: 0).
266      log_device_placement: Log the op placement to devices (default: False).
267      gpu_memory_fraction: Fraction of GPU memory used by the process on
268        each GPU uniformly on the same machine.
269      tf_random_seed: Random seed for TensorFlow initializers.
270        Setting this value allows consistency between reruns.
271      save_summary_steps: Save summaries every this many steps.
272      save_checkpoints_secs: Save checkpoints every this many seconds. Can not
273          be specified with `save_checkpoints_steps`.
274      save_checkpoints_steps: Save checkpoints every this many steps. Can not be
275          specified with `save_checkpoints_secs`.
276      keep_checkpoint_max: The maximum number of recent checkpoint files to
277        keep. As new files are created, older files are deleted. If None or 0,
278        all checkpoint files are kept. Defaults to 5 (that is, the 5 most recent
279        checkpoint files are kept.)
280      keep_checkpoint_every_n_hours: Number of hours between each checkpoint
281        to be saved. The default value of 10,000 hours effectively disables
282        the feature.
283      log_step_count_steps: The frequency, in number of global steps, that the
284        global step/sec will be logged during training.
285      evaluation_master: the master on which to perform evaluation.
286      model_dir: directory where model parameters, graph etc are saved. If
287        `None`, will use `model_dir` property in `TF_CONFIG` environment
288        variable. If both are set, must have same value. If both are `None`, see
289        `Estimator` about where the model will be saved.
290      session_config: a ConfigProto used to set session parameters, or None.
291        Note - using this argument, it is easy to provide settings which break
292        otherwise perfectly good models. Use with care.
293      protocol: An optional argument which specifies the protocol used when
294        starting server. None means default to grpc.
295    """
296    # Neither parent class calls super().__init__(), so here we have to
297    # manually call their __init__() methods.
298    ClusterConfig.__init__(
299        self, master=master, evaluation_master=evaluation_master)
300    # For too long this code didn't call:
301    #   core_run_config.RunConfig.__init__(self)
302    # so instead of breaking compatibility with that assumption, we
303    # just manually initialize this field:
304    self._train_distribute = None
305    self._eval_distribute = None
306    self._device_fn = None
307
308    gpu_options = config_pb2.GPUOptions(
309        per_process_gpu_memory_fraction=gpu_memory_fraction)
310    self._tf_config = config_pb2.ConfigProto(
311        log_device_placement=log_device_placement,
312        inter_op_parallelism_threads=num_cores,
313        intra_op_parallelism_threads=num_cores,
314        gpu_options=gpu_options)
315
316    self._tf_random_seed = tf_random_seed
317    self._save_summary_steps = save_summary_steps
318    self._save_checkpoints_secs = save_checkpoints_secs
319    self._log_step_count_steps = log_step_count_steps
320    self._protocol = protocol
321    self._session_config = session_config
322    if save_checkpoints_secs == RunConfig._USE_DEFAULT:
323      if save_checkpoints_steps is None:
324        self._save_checkpoints_secs = 600
325      else:
326        self._save_checkpoints_secs = None
327    self._save_checkpoints_steps = save_checkpoints_steps
328
329    # TODO(weiho): Remove these after ModelFn refactoring, when users can
330    # create Scaffold and Saver in their model_fn to set these.
331    self._keep_checkpoint_max = keep_checkpoint_max
332    self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
333    self._model_dir = _get_model_dir(model_dir)
334
335  @experimental
336  def uid(self, whitelist=None):
337    """Generates a 'Unique Identifier' based on all internal fields.
338
339    Caller should use the uid string to check `RunConfig` instance integrity
340    in one session use, but should not rely on the implementation details, which
341    is subject to change.
342
343    Args:
344      whitelist: A list of the string names of the properties uid should not
345        include. If `None`, defaults to `_DEFAULT_UID_WHITE_LIST`, which
346        includes most properties user allowes to change.
347
348    Returns:
349      A uid string.
350    """
351    if whitelist is None:
352      whitelist = _DEFAULT_UID_WHITE_LIST
353
354    state = {k: v for k, v in self.__dict__.items() if not k.startswith('__')}
355    # Pop out the keys in whitelist.
356    for k in whitelist:
357      state.pop('_' + k, None)
358
359    ordered_state = collections.OrderedDict(
360        sorted(state.items(), key=lambda t: t[0]))
361    # For class instance without __repr__, some special cares are required.
362    # Otherwise, the object address will be used.
363    if '_cluster_spec' in ordered_state:
364      ordered_state['_cluster_spec'] = collections.OrderedDict(
365          sorted(ordered_state['_cluster_spec'].as_dict().items(),
366                 key=lambda t: t[0]))
367    return ', '.join(
368        '%s=%r' % (k, v) for (k, v) in six.iteritems(ordered_state))
369
370  @property
371  def model_dir(self):
372    return self._model_dir
373
374  @property
375  def tf_config(self):
376    return self._tf_config
377
378  @property
379  def tf_random_seed(self):
380    return self._tf_random_seed
381
382  @property
383  def save_summary_steps(self):
384    return self._save_summary_steps
385
386  @property
387  def save_checkpoints_secs(self):
388    return self._save_checkpoints_secs
389
390  @property
391  def save_checkpoints_steps(self):
392    return self._save_checkpoints_steps
393
394  @property
395  def session_config(self):
396    return self._session_config
397
398  @property
399  def keep_checkpoint_max(self):
400    return self._keep_checkpoint_max
401
402  @property
403  def keep_checkpoint_every_n_hours(self):
404    return self._keep_checkpoint_every_n_hours
405
406  @property
407  def log_step_count_steps(self):
408    return self._log_step_count_steps
409
410
411def _count_ps(cluster_spec):
412  """Counts the number of parameter servers in cluster_spec."""
413  return len(cluster_spec.as_dict().get('ps', [])) if cluster_spec else 0
414
415
416def _count_worker(cluster_spec):
417  """Counts the number of workers in cluster_spec.
418
419  Workers with TaskType.WORKER and TaskType.MASTER are included in the return
420  value.
421
422  Args:
423    cluster_spec: a ClusterSpec instance that describes current deployment.
424
425  Returns:
426    The total number of eligible workers.
427
428    If 'cluster_spec' was None, then 0 is returned.
429  """
430  return (len(cluster_spec.as_dict().get('worker', [])) +
431          len(cluster_spec.as_dict().get('master', []))) if cluster_spec else 0
432
433
434def _get_master(cluster_spec, task_type, task_id):
435  """Returns the appropriate string for the TensorFlow master."""
436  if not cluster_spec:
437    return ''
438
439  # If there is only one node in the cluster, do things locally.
440  jobs = cluster_spec.jobs
441  if len(jobs) == 1 and len(cluster_spec.job_tasks(jobs[0])) == 1:
442    return ''
443
444  # Lookup the master in cluster_spec using task_type and task_id,
445  # if possible.
446  if task_type:
447    if task_type not in jobs:
448      raise ValueError(
449          '%s is not a valid task_type in the cluster_spec:\n'
450          '%s\n\n'
451          'Note that these values may be coming from the TF_CONFIG environment '
452          'variable.' % (task_type, cluster_spec))
453    addresses = cluster_spec.job_tasks(task_type)
454    if task_id >= len(addresses) or task_id < 0:
455      raise ValueError(
456          '%d is not a valid task_id for task_type %s in the '
457          'cluster_spec:\n'
458          '%s\n\n'
459          'Note that these value may be coming from the TF_CONFIG environment '
460          'variable.' % (task_id, task_type, cluster_spec))
461    return 'grpc://' + addresses[task_id]
462
463  # For backwards compatibility, we return empty string if task_type was
464  # not set (task_type did not previously exist).
465  return ''
466
467
468def _get_model_dir(model_dir):
469  """Returns `model_dir` based user provided `model_dir` or `TF_CONFIG`."""
470
471  model_dir_in_tf_config = json.loads(
472      os.environ.get('TF_CONFIG') or '{}').get('model_dir', None)
473  if model_dir_in_tf_config is not None:
474    if model_dir is not None and model_dir_in_tf_config != model_dir:
475      raise ValueError(
476          '`model_dir` provided in RunConfig construct, if set, '
477          'must have the same value as the model_dir in TF_CONFIG. '
478          'model_dir: {}\nTF_CONFIG["model_dir"]: {}.\n'.format(
479              model_dir, model_dir_in_tf_config))
480
481    logging.info('Using model_dir in TF_CONFIG: %s', model_dir_in_tf_config)
482
483  return model_dir or model_dir_in_tf_config
484