• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base testing class for strategies that require multiple nodes."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import contextlib
23import copy
24import json
25import os
26import threading
27import numpy as np
28
29_portpicker_import_error = None
30try:
31  import portpicker  # pylint: disable=g-import-not-at-top
32except ImportError as _error:  # pylint: disable=invalid-name
33  _portpicker_import_error = _error
34  portpicker = None
35
36# pylint: disable=g-import-not-at-top
37from tensorflow.core.protobuf import config_pb2
38from tensorflow.core.protobuf import rewriter_config_pb2
39from tensorflow.python.client import session
40from tensorflow.python.distribute import distribute_coordinator as dc
41from tensorflow.python.estimator import run_config
42from tensorflow.python.platform import test
43from tensorflow.python.platform import tf_logging as logging
44from tensorflow.python.training import coordinator
45from tensorflow.python.training import server_lib
46
47
48original_run_std_server = dc._run_std_server  # pylint: disable=protected-access
49
50ASSIGNED_PORTS = set()
51lock = threading.Lock()
52
53
54def pick_unused_port():
55  """Returns an unused and unassigned local port."""
56  if _portpicker_import_error:
57    raise _portpicker_import_error  # pylint: disable=raising-bad-type
58
59  global ASSIGNED_PORTS
60  with lock:
61    while True:
62      port = portpicker.pick_unused_port()
63      if port > 10000 and port not in ASSIGNED_PORTS:
64        ASSIGNED_PORTS.add(port)
65        logging.info('Using local port %r', port)
66        return port
67
68
69def _create_cluster(num_workers,
70                    num_ps,
71                    has_chief=False,
72                    has_eval=False,
73                    protocol='grpc',
74                    worker_config=None,
75                    ps_config=None):
76  """Creates and starts local servers and returns the cluster_spec dict."""
77  if _portpicker_import_error:
78    raise _portpicker_import_error  # pylint: disable=raising-bad-type
79  worker_ports = [pick_unused_port() for _ in range(num_workers)]
80  ps_ports = [pick_unused_port() for _ in range(num_ps)]
81
82  cluster_dict = {}
83  if num_workers > 0:
84    cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports]
85  if num_ps > 0:
86    cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
87  if has_eval:
88    cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()]
89  if has_chief:
90    cluster_dict['chief'] = ['localhost:%s' % pick_unused_port()]
91
92  cs = server_lib.ClusterSpec(cluster_dict)
93
94  for i in range(num_workers):
95    server_lib.Server(
96        cs,
97        job_name='worker',
98        protocol=protocol,
99        task_index=i,
100        config=worker_config,
101        start=True)
102
103  for i in range(num_ps):
104    server_lib.Server(
105        cs,
106        job_name='ps',
107        protocol=protocol,
108        task_index=i,
109        config=ps_config,
110        start=True)
111
112  if has_chief:
113    server_lib.Server(
114        cs,
115        job_name='chief',
116        protocol=protocol,
117        task_index=0,
118        config=worker_config,
119        start=True)
120
121  if has_eval:
122    server_lib.Server(
123        cs,
124        job_name='evaluator',
125        protocol=protocol,
126        task_index=0,
127        config=worker_config,
128        start=True)
129
130  return cluster_dict
131
132
133def create_in_process_cluster(num_workers,
134                              num_ps,
135                              has_chief=False,
136                              has_eval=False):
137  """Create an in-process cluster that consists of only standard server."""
138  # Leave some memory for cuda runtime.
139  gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval))
140  worker_config = config_pb2.ConfigProto()
141  worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
142
143  # Enable collective ops which has no impact on non-collective ops.
144  # TODO(yuefengz, tucker): removing this after we move the initialization of
145  # collective mgr to the session level.
146  if has_chief:
147    worker_config.experimental.collective_group_leader = (
148        '/job:chief/replica:0/task:0')
149  else:
150    worker_config.experimental.collective_group_leader = (
151        '/job:worker/replica:0/task:0')
152
153  ps_config = config_pb2.ConfigProto()
154  ps_config.device_count['GPU'] = 0
155
156  # Create in-process servers. Once an in-process tensorflow server is created,
157  # there is no way to terminate it. So we create one cluster per test process.
158  # We could've started the server in another process, we could then kill that
159  # process to terminate the server. The reasons why we don't want multiple
160  # processes are
161  # 1) it is more difficult to manage these processes;
162  # 2) there is something global in CUDA such that if we initialize CUDA in the
163  # parent process, the child process cannot initialize it again and thus cannot
164  # use GPUs (https://stackoverflow.com/questions/22950047).
165  return _create_cluster(
166      num_workers,
167      num_ps=num_ps,
168      has_chief=has_chief,
169      has_eval=has_eval,
170      worker_config=worker_config,
171      ps_config=ps_config,
172      protocol='grpc')
173
174
175def create_cluster_spec(has_chief=False,
176                        num_workers=1,
177                        num_ps=0,
178                        has_eval=False):
179  """Create a cluster spec with tasks with unused local ports."""
180  if _portpicker_import_error:
181    raise _portpicker_import_error  # pylint: disable=raising-bad-type
182
183  cluster_spec = {}
184  if has_chief:
185    cluster_spec['chief'] = ['localhost:%s' % pick_unused_port()]
186  if num_workers:
187    cluster_spec['worker'] = [
188        'localhost:%s' % pick_unused_port() for _ in range(num_workers)
189    ]
190  if num_ps:
191    cluster_spec['ps'] = [
192        'localhost:%s' % pick_unused_port() for _ in range(num_ps)
193    ]
194  if has_eval:
195    cluster_spec['evaluator'] = ['localhost:%s' % pick_unused_port()]
196  return cluster_spec
197
198
199class MultiWorkerTestBase(test.TestCase):
200  """Base class for testing multi node strategy and dataset."""
201
202  @classmethod
203  def setUpClass(cls):
204    """Create a local cluster with 2 workers."""
205    cls._cluster_spec = create_in_process_cluster(num_workers=2, num_ps=0)
206    cls._default_target = 'grpc://' + cls._cluster_spec['worker'][0]
207
208  def setUp(self):
209    # We only cache the session in one test because another test may have a
210    # different session config or master target.
211    self._thread_local = threading.local()
212    self._thread_local.cached_session = None
213    self._result = 0
214    self._lock = threading.Lock()
215
216  @contextlib.contextmanager
217  def session(self, graph=None, config=None, target=None):
218    """Create a test session with master target set to the testing cluster.
219
220    Creates a test session that connects to the local testing cluster.
221
222    Args:
223      graph: Optional graph to use during the returned session.
224      config: An optional config_pb2.ConfigProto to use to configure the
225        session.
226      target: the target of session to connect to.
227
228    Yields:
229      A Session object that should be used as a context manager to surround
230      the graph building and execution code in a test case.
231    """
232    config = self._create_config(config)
233
234    if target is None:
235      target = self._default_target
236    with session.Session(graph=graph, config=config, target=target) as sess:
237      yield sess
238
239  @contextlib.contextmanager
240  # TODO(b/117573461): Overwrite self.evaluate() to use this function.
241  def cached_session(self, graph=None, config=None, target=None):
242    """Create a test session with master target set to the testing cluster.
243
244    Creates a test session that connects to the local testing cluster.
245    The session is only created once per test and then reused.
246
247    Args:
248      graph: Optional graph to use during the returned session.
249      config: An optional config_pb2.ConfigProto to use to configure the
250        session.
251      target: the target of session to connect to.
252
253    Yields:
254      A Session object that should be used as a context manager to surround
255      the graph building and execution code in a test case. Note that the
256      session will live until the end of the test.
257    """
258    config = self._create_config(config)
259
260    if target is None:
261      target = self._default_target
262    if getattr(self._thread_local, 'cached_session', None) is None:
263      self._thread_local.cached_session = session.Session(
264          graph=None, config=config, target=target)
265    sess = self._thread_local.cached_session
266    with sess.graph.as_default(), sess.as_default():
267      yield sess
268
269  def _create_config(self, config):
270    if config is None:
271      config = config_pb2.ConfigProto(allow_soft_placement=True)
272    else:
273      config = copy.deepcopy(config)
274    # Don't perform optimizations for tests so we don't inadvertently run
275    # gpu ops on cpu
276    config.graph_options.optimizer_options.opt_level = -1
277    config.graph_options.rewrite_options.constant_folding = (
278        rewriter_config_pb2.RewriterConfig.OFF)
279
280    return config
281
282  def _run_client(self, client_fn, task_type, task_id, num_gpus, *args,
283                  **kwargs):
284    result = client_fn(task_type, task_id, num_gpus, *args, **kwargs)
285    if np.all(result):
286      with self._lock:
287        self._result += 1
288
289  def _run_between_graph_clients(self, client_fn, cluster_spec, num_gpus, *args,
290                                 **kwargs):
291    """Runs several clients for between-graph replication.
292
293    Args:
294      client_fn: a function that needs to accept `task_type`, `task_id`,
295        `num_gpus` and returns True if it succeeds.
296      cluster_spec: a dict specifying jobs in a cluster.
297      num_gpus: number of GPUs per worker.
298      *args: will be passed to `client_fn`.
299      **kwargs: will be passed to `client_fn`.
300    """
301    threads = []
302    for task_type in [run_config.TaskType.CHIEF, run_config.TaskType.WORKER]:
303      for task_id in range(len(cluster_spec.get(task_type, []))):
304        t = threading.Thread(
305            target=self._run_client,
306            args=(client_fn, task_type, task_id, num_gpus) + args,
307            kwargs=kwargs)
308        t.start()
309        threads.append(t)
310    for t in threads:
311      t.join()
312    self.assertEqual(self._result, len(threads))
313
314
315class MockOsEnv(collections.Mapping):
316  """A class that allows per-thread TF_CONFIG."""
317
318  def __init__(self, *args):
319    self._dict = dict()
320    self._thread_local = threading.local()
321    super(MockOsEnv, self).__init__(*args)
322
323  def get(self, key, default=None):
324    if not hasattr(self._thread_local, 'dict'):
325      self._thread_local.dict = dict()
326    if key == 'TF_CONFIG':
327      return dict.get(self._thread_local.dict, key, default)
328    else:
329      return dict.get(self._dict, key, default)
330
331  def __getitem__(self, key):
332    if not hasattr(self._thread_local, 'dict'):
333      self._thread_local.dict = dict()
334    if key == 'TF_CONFIG':
335      return dict.__getitem__(self._thread_local.dict, key)
336    else:
337      return dict.__getitem__(self._dict, key)
338
339  def __setitem__(self, key, val):
340    if not hasattr(self._thread_local, 'dict'):
341      self._thread_local.dict = dict()
342    if key == 'TF_CONFIG':
343      return dict.__setitem__(self._thread_local.dict, key, val)
344    else:
345      return dict.__setitem__(self._dict, key, val)
346
347  def __iter__(self):
348    if not hasattr(self._thread_local, 'dict'):
349      self._thread_local.dict = dict()
350    for x in self._thread_local.dict:
351      yield x
352    for x in self._dict:
353      yield x
354
355  def __len__(self):
356    if not hasattr(self._thread_local, 'dict'):
357      self._thread_local.dict = dict()
358    return self._thread_local.dict.__len__() + self._dict.__len__()
359
360
361class IndependentWorkerTestBase(test.TestCase):
362  """Testing infra for independent workers."""
363
364  def _make_mock_run_std_server(self):
365    thread_local = threading.local()
366
367    def _mock_run_std_server(*args, **kwargs):
368      ret = original_run_std_server(*args, **kwargs)
369      # Wait for all std servers to be brought up in order to reduce the chance
370      # of remote sessions taking local ports that have been assigned to std
371      # servers. Only call this barrier the first time this function is run for
372      # each thread.
373      if not getattr(thread_local, 'server_started', False):
374        self._barrier.wait()
375      thread_local.server_started = True
376      return ret
377
378    return _mock_run_std_server
379
380  def setUp(self):
381    self._mock_os_env = MockOsEnv()
382    self._mock_context = test.mock.patch.object(os, 'environ',
383                                                self._mock_os_env)
384    self._coord = coordinator.Coordinator()
385    super(IndependentWorkerTestBase, self).setUp()
386    self._mock_context.__enter__()
387
388  def tearDown(self):
389    self._mock_context.__exit__(None, None, None)
390    super(IndependentWorkerTestBase, self).tearDown()
391
392  def _task_thread(self, task_fn, tf_config, *args, **kwargs):
393    with self._coord.stop_on_exception():
394      os.environ['TF_CONFIG'] = json.dumps(tf_config)
395      task_fn(*args, **kwargs)
396
397  def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id,
398                          *args, **kwargs):
399    if task_type:
400      tf_config = {
401          'cluster': cluster_spec,
402          'task': {
403              'type': task_type,
404              'index': task_id
405          }
406      }
407    else:
408      tf_config = {
409          'cluster': cluster_spec,
410      }
411    t = threading.Thread(
412        target=self._task_thread,
413        args=(task_fn, tf_config) + args,
414        kwargs=kwargs)
415    t.start()
416    return t
417
418  def run_multiple_tasks_in_threads(self, task_fn, cluster_spec, *args,
419                                    **kwargs):
420    # The task_fn should create std_server by itself.
421    threads = {}
422    for task_type in cluster_spec.keys():
423      threads[task_type] = []
424      for task_id in range(len(cluster_spec[task_type])):
425        t = self._run_task_in_thread(task_fn, cluster_spec, task_type, task_id,
426                                     *args, **kwargs)
427        threads[task_type].append(t)
428    return threads
429
430  def join_independent_workers(self, worker_threads):
431    self._coord.join(worker_threads)
432
433
434def get_tf_config_task():
435  return json.loads(os.environ['TF_CONFIG'])['task']
436
437
438def get_tf_config_cluster_spec():
439  return json.loads(os.environ['TF_CONFIG'])['cluster']
440
441
442def get_task_type():
443  return get_tf_config_task()['type']
444
445
446def get_task_index():
447  return get_tf_config_task()['index']
448
449
450def is_chief():
451  return ('chief' not in get_tf_config_cluster_spec()
452          and get_task_type() == 'worker'
453          and get_task_index() == 0)
454