1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base testing class for strategies that require multiple nodes.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import contextlib 23import copy 24import json 25import os 26import threading 27import numpy as np 28 29_portpicker_import_error = None 30try: 31 import portpicker # pylint: disable=g-import-not-at-top 32except ImportError as _error: # pylint: disable=invalid-name 33 _portpicker_import_error = _error 34 portpicker = None 35 36# pylint: disable=g-import-not-at-top 37from tensorflow.core.protobuf import config_pb2 38from tensorflow.core.protobuf import rewriter_config_pb2 39from tensorflow.python.client import session 40from tensorflow.python.distribute import distribute_coordinator as dc 41from tensorflow.python.estimator import run_config 42from tensorflow.python.platform import test 43from tensorflow.python.platform import tf_logging as logging 44from tensorflow.python.training import coordinator 45from tensorflow.python.training import server_lib 46 47 48original_run_std_server = dc._run_std_server # pylint: disable=protected-access 49 50ASSIGNED_PORTS = set() 51lock = threading.Lock() 52 53 54def pick_unused_port(): 55 """Returns an unused and unassigned local port.""" 56 if _portpicker_import_error: 57 raise _portpicker_import_error # pylint: disable=raising-bad-type 58 59 global ASSIGNED_PORTS 60 with lock: 61 while True: 62 port = portpicker.pick_unused_port() 63 if port > 10000 and port not in ASSIGNED_PORTS: 64 ASSIGNED_PORTS.add(port) 65 logging.info('Using local port %r', port) 66 return port 67 68 69def _create_cluster(num_workers, 70 num_ps, 71 has_chief=False, 72 has_eval=False, 73 protocol='grpc', 74 worker_config=None, 75 ps_config=None): 76 """Creates and starts local servers and returns the cluster_spec dict.""" 77 if _portpicker_import_error: 78 raise _portpicker_import_error # pylint: disable=raising-bad-type 79 worker_ports = [pick_unused_port() for _ in range(num_workers)] 80 ps_ports = [pick_unused_port() for _ in range(num_ps)] 81 82 cluster_dict = {} 83 if num_workers > 0: 84 cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports] 85 if num_ps > 0: 86 cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports] 87 if has_eval: 88 cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()] 89 if has_chief: 90 cluster_dict['chief'] = ['localhost:%s' % pick_unused_port()] 91 92 cs = server_lib.ClusterSpec(cluster_dict) 93 94 for i in range(num_workers): 95 server_lib.Server( 96 cs, 97 job_name='worker', 98 protocol=protocol, 99 task_index=i, 100 config=worker_config, 101 start=True) 102 103 for i in range(num_ps): 104 server_lib.Server( 105 cs, 106 job_name='ps', 107 protocol=protocol, 108 task_index=i, 109 config=ps_config, 110 start=True) 111 112 if has_chief: 113 server_lib.Server( 114 cs, 115 job_name='chief', 116 protocol=protocol, 117 task_index=0, 118 config=worker_config, 119 start=True) 120 121 if has_eval: 122 server_lib.Server( 123 cs, 124 job_name='evaluator', 125 protocol=protocol, 126 task_index=0, 127 config=worker_config, 128 start=True) 129 130 return cluster_dict 131 132 133def create_in_process_cluster(num_workers, 134 num_ps, 135 has_chief=False, 136 has_eval=False): 137 """Create an in-process cluster that consists of only standard server.""" 138 # Leave some memory for cuda runtime. 139 gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval)) 140 worker_config = config_pb2.ConfigProto() 141 worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac 142 143 # Enable collective ops which has no impact on non-collective ops. 144 # TODO(yuefengz, tucker): removing this after we move the initialization of 145 # collective mgr to the session level. 146 if has_chief: 147 worker_config.experimental.collective_group_leader = ( 148 '/job:chief/replica:0/task:0') 149 else: 150 worker_config.experimental.collective_group_leader = ( 151 '/job:worker/replica:0/task:0') 152 153 ps_config = config_pb2.ConfigProto() 154 ps_config.device_count['GPU'] = 0 155 156 # Create in-process servers. Once an in-process tensorflow server is created, 157 # there is no way to terminate it. So we create one cluster per test process. 158 # We could've started the server in another process, we could then kill that 159 # process to terminate the server. The reasons why we don't want multiple 160 # processes are 161 # 1) it is more difficult to manage these processes; 162 # 2) there is something global in CUDA such that if we initialize CUDA in the 163 # parent process, the child process cannot initialize it again and thus cannot 164 # use GPUs (https://stackoverflow.com/questions/22950047). 165 return _create_cluster( 166 num_workers, 167 num_ps=num_ps, 168 has_chief=has_chief, 169 has_eval=has_eval, 170 worker_config=worker_config, 171 ps_config=ps_config, 172 protocol='grpc') 173 174 175def create_cluster_spec(has_chief=False, 176 num_workers=1, 177 num_ps=0, 178 has_eval=False): 179 """Create a cluster spec with tasks with unused local ports.""" 180 if _portpicker_import_error: 181 raise _portpicker_import_error # pylint: disable=raising-bad-type 182 183 cluster_spec = {} 184 if has_chief: 185 cluster_spec['chief'] = ['localhost:%s' % pick_unused_port()] 186 if num_workers: 187 cluster_spec['worker'] = [ 188 'localhost:%s' % pick_unused_port() for _ in range(num_workers) 189 ] 190 if num_ps: 191 cluster_spec['ps'] = [ 192 'localhost:%s' % pick_unused_port() for _ in range(num_ps) 193 ] 194 if has_eval: 195 cluster_spec['evaluator'] = ['localhost:%s' % pick_unused_port()] 196 return cluster_spec 197 198 199class MultiWorkerTestBase(test.TestCase): 200 """Base class for testing multi node strategy and dataset.""" 201 202 @classmethod 203 def setUpClass(cls): 204 """Create a local cluster with 2 workers.""" 205 cls._cluster_spec = create_in_process_cluster(num_workers=2, num_ps=0) 206 cls._default_target = 'grpc://' + cls._cluster_spec['worker'][0] 207 208 def setUp(self): 209 # We only cache the session in one test because another test may have a 210 # different session config or master target. 211 self._thread_local = threading.local() 212 self._thread_local.cached_session = None 213 self._result = 0 214 self._lock = threading.Lock() 215 216 @contextlib.contextmanager 217 def session(self, graph=None, config=None, target=None): 218 """Create a test session with master target set to the testing cluster. 219 220 Creates a test session that connects to the local testing cluster. 221 222 Args: 223 graph: Optional graph to use during the returned session. 224 config: An optional config_pb2.ConfigProto to use to configure the 225 session. 226 target: the target of session to connect to. 227 228 Yields: 229 A Session object that should be used as a context manager to surround 230 the graph building and execution code in a test case. 231 """ 232 config = self._create_config(config) 233 234 if target is None: 235 target = self._default_target 236 with session.Session(graph=graph, config=config, target=target) as sess: 237 yield sess 238 239 @contextlib.contextmanager 240 # TODO(b/117573461): Overwrite self.evaluate() to use this function. 241 def cached_session(self, graph=None, config=None, target=None): 242 """Create a test session with master target set to the testing cluster. 243 244 Creates a test session that connects to the local testing cluster. 245 The session is only created once per test and then reused. 246 247 Args: 248 graph: Optional graph to use during the returned session. 249 config: An optional config_pb2.ConfigProto to use to configure the 250 session. 251 target: the target of session to connect to. 252 253 Yields: 254 A Session object that should be used as a context manager to surround 255 the graph building and execution code in a test case. Note that the 256 session will live until the end of the test. 257 """ 258 config = self._create_config(config) 259 260 if target is None: 261 target = self._default_target 262 if getattr(self._thread_local, 'cached_session', None) is None: 263 self._thread_local.cached_session = session.Session( 264 graph=None, config=config, target=target) 265 sess = self._thread_local.cached_session 266 with sess.graph.as_default(), sess.as_default(): 267 yield sess 268 269 def _create_config(self, config): 270 if config is None: 271 config = config_pb2.ConfigProto(allow_soft_placement=True) 272 else: 273 config = copy.deepcopy(config) 274 # Don't perform optimizations for tests so we don't inadvertently run 275 # gpu ops on cpu 276 config.graph_options.optimizer_options.opt_level = -1 277 config.graph_options.rewrite_options.constant_folding = ( 278 rewriter_config_pb2.RewriterConfig.OFF) 279 280 return config 281 282 def _run_client(self, client_fn, task_type, task_id, num_gpus, *args, 283 **kwargs): 284 result = client_fn(task_type, task_id, num_gpus, *args, **kwargs) 285 if np.all(result): 286 with self._lock: 287 self._result += 1 288 289 def _run_between_graph_clients(self, client_fn, cluster_spec, num_gpus, *args, 290 **kwargs): 291 """Runs several clients for between-graph replication. 292 293 Args: 294 client_fn: a function that needs to accept `task_type`, `task_id`, 295 `num_gpus` and returns True if it succeeds. 296 cluster_spec: a dict specifying jobs in a cluster. 297 num_gpus: number of GPUs per worker. 298 *args: will be passed to `client_fn`. 299 **kwargs: will be passed to `client_fn`. 300 """ 301 threads = [] 302 for task_type in [run_config.TaskType.CHIEF, run_config.TaskType.WORKER]: 303 for task_id in range(len(cluster_spec.get(task_type, []))): 304 t = threading.Thread( 305 target=self._run_client, 306 args=(client_fn, task_type, task_id, num_gpus) + args, 307 kwargs=kwargs) 308 t.start() 309 threads.append(t) 310 for t in threads: 311 t.join() 312 self.assertEqual(self._result, len(threads)) 313 314 315class MockOsEnv(collections.Mapping): 316 """A class that allows per-thread TF_CONFIG.""" 317 318 def __init__(self, *args): 319 self._dict = dict() 320 self._thread_local = threading.local() 321 super(MockOsEnv, self).__init__(*args) 322 323 def get(self, key, default=None): 324 if not hasattr(self._thread_local, 'dict'): 325 self._thread_local.dict = dict() 326 if key == 'TF_CONFIG': 327 return dict.get(self._thread_local.dict, key, default) 328 else: 329 return dict.get(self._dict, key, default) 330 331 def __getitem__(self, key): 332 if not hasattr(self._thread_local, 'dict'): 333 self._thread_local.dict = dict() 334 if key == 'TF_CONFIG': 335 return dict.__getitem__(self._thread_local.dict, key) 336 else: 337 return dict.__getitem__(self._dict, key) 338 339 def __setitem__(self, key, val): 340 if not hasattr(self._thread_local, 'dict'): 341 self._thread_local.dict = dict() 342 if key == 'TF_CONFIG': 343 return dict.__setitem__(self._thread_local.dict, key, val) 344 else: 345 return dict.__setitem__(self._dict, key, val) 346 347 def __iter__(self): 348 if not hasattr(self._thread_local, 'dict'): 349 self._thread_local.dict = dict() 350 for x in self._thread_local.dict: 351 yield x 352 for x in self._dict: 353 yield x 354 355 def __len__(self): 356 if not hasattr(self._thread_local, 'dict'): 357 self._thread_local.dict = dict() 358 return self._thread_local.dict.__len__() + self._dict.__len__() 359 360 361class IndependentWorkerTestBase(test.TestCase): 362 """Testing infra for independent workers.""" 363 364 def _make_mock_run_std_server(self): 365 thread_local = threading.local() 366 367 def _mock_run_std_server(*args, **kwargs): 368 ret = original_run_std_server(*args, **kwargs) 369 # Wait for all std servers to be brought up in order to reduce the chance 370 # of remote sessions taking local ports that have been assigned to std 371 # servers. Only call this barrier the first time this function is run for 372 # each thread. 373 if not getattr(thread_local, 'server_started', False): 374 self._barrier.wait() 375 thread_local.server_started = True 376 return ret 377 378 return _mock_run_std_server 379 380 def setUp(self): 381 self._mock_os_env = MockOsEnv() 382 self._mock_context = test.mock.patch.object(os, 'environ', 383 self._mock_os_env) 384 self._coord = coordinator.Coordinator() 385 super(IndependentWorkerTestBase, self).setUp() 386 self._mock_context.__enter__() 387 388 def tearDown(self): 389 self._mock_context.__exit__(None, None, None) 390 super(IndependentWorkerTestBase, self).tearDown() 391 392 def _task_thread(self, task_fn, tf_config, *args, **kwargs): 393 with self._coord.stop_on_exception(): 394 os.environ['TF_CONFIG'] = json.dumps(tf_config) 395 task_fn(*args, **kwargs) 396 397 def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id, 398 *args, **kwargs): 399 if task_type: 400 tf_config = { 401 'cluster': cluster_spec, 402 'task': { 403 'type': task_type, 404 'index': task_id 405 } 406 } 407 else: 408 tf_config = { 409 'cluster': cluster_spec, 410 } 411 t = threading.Thread( 412 target=self._task_thread, 413 args=(task_fn, tf_config) + args, 414 kwargs=kwargs) 415 t.start() 416 return t 417 418 def run_multiple_tasks_in_threads(self, task_fn, cluster_spec, *args, 419 **kwargs): 420 # The task_fn should create std_server by itself. 421 threads = {} 422 for task_type in cluster_spec.keys(): 423 threads[task_type] = [] 424 for task_id in range(len(cluster_spec[task_type])): 425 t = self._run_task_in_thread(task_fn, cluster_spec, task_type, task_id, 426 *args, **kwargs) 427 threads[task_type].append(t) 428 return threads 429 430 def join_independent_workers(self, worker_threads): 431 self._coord.join(worker_threads) 432 433 434def get_tf_config_task(): 435 return json.loads(os.environ['TF_CONFIG'])['task'] 436 437 438def get_tf_config_cluster_spec(): 439 return json.loads(os.environ['TF_CONFIG'])['cluster'] 440 441 442def get_task_type(): 443 return get_tf_config_task()['type'] 444 445 446def get_task_index(): 447 return get_tf_config_task()['index'] 448 449 450def is_chief(): 451 return ('chief' not in get_tf_config_cluster_spec() 452 and get_task_type() == 'worker' 453 and get_task_index() == 0) 454