1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A Python interface for creating dataset servers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23# pylint: disable=invalid-import-order,g-bad-import-order, unused-import 24from tensorflow.core.protobuf import service_config_pb2 25from tensorflow.python import pywrap_tensorflow 26from tensorflow.python.data.experimental.service import _pywrap_server_lib 27from tensorflow.python.data.experimental.service import _pywrap_utils 28from tensorflow.python.util.tf_export import tf_export 29 30 31@tf_export("data.experimental.service.DispatcherConfig") 32class DispatcherConfig( 33 collections.namedtuple("DispatcherConfig", [ 34 "port", "protocol", "work_dir", "fault_tolerant_mode", 35 "worker_addresses", "job_gc_check_interval_ms", "job_gc_timeout_ms" 36 ])): 37 """Configuration class for tf.data service dispatchers. 38 39 Fields: 40 port: Specifies the port to bind to. A value of 0 indicates that the server 41 may bind to any available port. 42 protocol: The protocol to use for communicating with the tf.data service, 43 e.g. "grpc". 44 work_dir: A directory to store dispatcher state in. This 45 argument is required for the dispatcher to be able to recover from 46 restarts. 47 fault_tolerant_mode: Whether the dispatcher should write its state to a 48 journal so that it can recover from restarts. Dispatcher state, including 49 registered datasets and created jobs, is synchronously written to the 50 journal before responding to RPCs. If `True`, `work_dir` must also be 51 specified. 52 worker_addresses: If the job uses auto-sharding, it needs to specify a fixed 53 list of worker addresses that will register with the dispatcher. The 54 worker addresses should be in the format `"host"` or `"host:port"`, where 55 `"port"` is an integer, named port, or `%port%` to match any port. 56 job_gc_check_interval_ms: How often the dispatcher should scan through to 57 delete old and unused jobs, in milliseconds. If not set, the runtime will 58 select a reasonable default. A higher value will reduce load on the 59 dispatcher, while a lower value will reduce the time it takes for the 60 dispatcher to garbage collect expired jobs. 61 job_gc_timeout_ms: How long a job needs to be unused before it becomes a 62 candidate for garbage collection, in milliseconds. A value of -1 indicates 63 that jobs should never be garbage collected. If not set, the runtime will 64 select a reasonable default. A higher value will cause jobs to stay around 65 longer with no consumers. This is useful if there is a large gap in 66 time between when consumers read from the job. A lower value will reduce 67 the time it takes to reclaim the resources from expired jobs. 68 """ 69 70 def __new__(cls, 71 port=0, 72 protocol=None, 73 work_dir=None, 74 fault_tolerant_mode=False, 75 worker_addresses=None, 76 job_gc_check_interval_ms=None, 77 job_gc_timeout_ms=None): 78 if protocol is None: 79 protocol = _pywrap_utils.TF_DATA_DefaultProtocol() 80 if job_gc_check_interval_ms is None: 81 job_gc_check_interval_ms = 10 * 60 * 1000 # 10 minutes. 82 if job_gc_timeout_ms is None: 83 job_gc_timeout_ms = 5 * 60 * 1000 # 5 minutes. 84 return super(DispatcherConfig, 85 cls).__new__(cls, port, protocol, work_dir, 86 fault_tolerant_mode, worker_addresses, 87 job_gc_check_interval_ms, job_gc_timeout_ms) 88 89 90@tf_export("data.experimental.service.DispatchServer", v1=[]) 91class DispatchServer(object): 92 """An in-process tf.data service dispatch server. 93 94 A `tf.data.experimental.service.DispatchServer` coordinates a cluster of 95 `tf.data.experimental.service.WorkerServer`s. When the workers start, they 96 register themselves with the dispatcher. 97 98 >>> dispatcher = tf.data.experimental.service.DispatchServer() 99 >>> dispatcher_address = dispatcher.target.split("://")[1] 100 >>> worker = tf.data.experimental.service.WorkerServer( 101 ... tf.data.experimental.service.WorkerConfig( 102 ... dispatcher_address=dispatcher_address)) 103 >>> dataset = tf.data.Dataset.range(10) 104 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 105 ... processing_mode="parallel_epochs", service=dispatcher.target)) 106 >>> print(list(dataset.as_numpy_iterator())) 107 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 108 109 When starting a dedicated tf.data dispatch process, use join() to block 110 indefinitely after starting up the server. 111 112 ``` 113 dispatcher = tf.data.experimental.service.DispatchServer( 114 tf.data.experimental.service.DispatcherConfig(port=5050)) 115 dispatcher.join() 116 ``` 117 118 To start a `DispatchServer` in fault-tolerant mode, set `work_dir` and 119 `fault_tolerant_mode` like below: 120 121 ``` 122 dispatcher = tf.data.experimental.service.DispatchServer( 123 tf.data.experimental.service.DispatcherConfig( 124 port=5050, 125 work_dir="gs://my-bucket/dispatcher/work_dir", 126 fault_tolerant_mode=True)) 127 ``` 128 """ 129 130 def __init__(self, config=None, start=True): 131 """Creates a new dispatch server. 132 133 Args: 134 config: (Optional.) A `tf.data.experimental.service.DispatcherConfig` 135 configration. If `None`, the dispatcher will use default 136 configuration values. 137 start: (Optional.) Boolean, indicating whether to start the server after 138 creating it. Defaults to True. 139 """ 140 config = config or DispatcherConfig() 141 if config.fault_tolerant_mode and not config.work_dir: 142 raise ValueError( 143 "Cannot enable fault tolerant mode without configuring a work_dir") 144 self._config = config 145 config_proto = service_config_pb2.DispatcherConfig( 146 port=config.port, 147 protocol=config.protocol, 148 work_dir=config.work_dir, 149 fault_tolerant_mode=config.fault_tolerant_mode, 150 worker_addresses=config.worker_addresses, 151 job_gc_check_interval_ms=config.job_gc_check_interval_ms, 152 job_gc_timeout_ms=config.job_gc_timeout_ms) 153 self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer( 154 config_proto.SerializeToString()) 155 if start: 156 self._server.start() 157 158 def start(self): 159 """Starts this server. 160 161 >>> dispatcher = tf.data.experimental.service.DispatchServer(start=False) 162 >>> dispatcher.start() 163 164 Raises: 165 tf.errors.OpError: Or one of its subclasses if an error occurs while 166 starting the server. 167 """ 168 self._server.start() 169 170 def join(self): 171 """Blocks until the server has shut down. 172 173 This is useful when starting a dedicated dispatch process. 174 175 ``` 176 dispatcher = tf.data.experimental.service.DispatchServer( 177 tf.data.experimental.service.DispatcherConfig(port=5050)) 178 dispatcher.join() 179 ``` 180 181 Raises: 182 tf.errors.OpError: Or one of its subclasses if an error occurs while 183 joining the server. 184 """ 185 self._server.join() 186 187 @property 188 def target(self): 189 """Returns a target that can be used to connect to the server. 190 191 >>> dispatcher = tf.data.experimental.service.DispatchServer() 192 >>> dataset = tf.data.Dataset.range(10) 193 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 194 ... processing_mode="parallel_epochs", service=dispatcher.target)) 195 196 The returned string will be in the form protocol://address, e.g. 197 "grpc://localhost:5050". 198 """ 199 return "{0}://localhost:{1}".format(self._config.protocol, 200 self._server.bound_port()) 201 202 def _stop(self): 203 """Stops the server. 204 205 Raises: 206 tf.errors.OpError: Or one of its subclasses if an error occurs while 207 stopping the server. 208 """ 209 self._server.stop() 210 211 def __del__(self): 212 self._stop() 213 214 @property 215 def _address(self): 216 """Returns the address of the server. 217 218 The returned string will be in the form address:port, e.g. "localhost:1000". 219 """ 220 return "localhost:{0}".format(self._server.bound_port()) 221 222 def _num_workers(self): 223 """Returns the number of workers registered with the dispatcher.""" 224 return self._server.num_workers() 225 226 227@tf_export("data.experimental.service.WorkerConfig") 228class WorkerConfig( 229 collections.namedtuple("WorkerConfig", [ 230 "dispatcher_address", "worker_address", "port", "protocol", 231 "heartbeat_interval_ms", "dispatcher_timeout_ms" 232 ])): 233 """Configuration class for tf.data service dispatchers. 234 235 Fields: 236 dispatcher_address: Specifies the address of the dispatcher. 237 worker_address: Specifies the address of the worker server. This address is 238 passed to the dispatcher so that the dispatcher can tell clients how to 239 connect to this worker. 240 port: Specifies the port to bind to. A value of 0 indicates that the worker 241 can bind to any available port. 242 protocol: (Optional.) Specifies the protocol to be used by the server, e.g. 243 "grpc". 244 heartbeat_interval_ms: How often the worker should heartbeat to the 245 dispatcher, in milliseconds. If not set, the runtime will select a 246 reasonable default. A higher value will reduce the load on the dispatcher, 247 while a lower value will reduce the time it takes to reclaim resources 248 from finished jobs. 249 dispatcher_timeout_ms: How long, in milliseconds, to retry requests to the 250 dispatcher before giving up and reporting an error. Defaults to 1 hour. 251 """ 252 253 def __new__(cls, 254 dispatcher_address, 255 worker_address=None, 256 port=0, 257 protocol=None, 258 heartbeat_interval_ms=None, 259 dispatcher_timeout_ms=None): 260 if worker_address is None: 261 worker_address = "localhost:%port%" 262 if protocol is None: 263 protocol = _pywrap_utils.TF_DATA_DefaultProtocol() 264 if heartbeat_interval_ms is None: 265 heartbeat_interval_ms = 30 * 1000 # 30 seconds 266 if dispatcher_timeout_ms is None: 267 dispatcher_timeout_ms = 60 * 60 * 1000 # 1 hour 268 269 return super(WorkerConfig, 270 cls).__new__(cls, dispatcher_address, worker_address, port, 271 protocol, heartbeat_interval_ms, 272 dispatcher_timeout_ms) 273 274 275@tf_export("data.experimental.service.WorkerServer", v1=[]) 276class WorkerServer(object): 277 """An in-process tf.data service worker server. 278 279 A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset` 280 processing for user-defined datasets, and provides the resulting elements over 281 RPC. A worker is associated with a single 282 `tf.data.experimental.service.DispatchServer`. 283 284 >>> dispatcher = tf.data.experimental.service.DispatchServer() 285 >>> dispatcher_address = dispatcher.target.split("://")[1] 286 >>> worker = tf.data.experimental.service.WorkerServer( 287 ... tf.data.experimental.service.WorkerConfig( 288 ... dispatcher_address=dispatcher_address)) 289 >>> dataset = tf.data.Dataset.range(10) 290 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 291 ... processing_mode="parallel_epochs", service=dispatcher.target)) 292 >>> print(list(dataset.as_numpy_iterator())) 293 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 294 295 When starting a dedicated tf.data worker process, use join() to block 296 indefinitely after starting up the server. 297 298 ``` 299 worker = tf.data.experimental.service.WorkerServer( 300 port=5051, dispatcher_address="localhost:5050") 301 worker.join() 302 ``` 303 """ 304 305 def __init__(self, config, start=True): 306 """Creates a new worker server. 307 308 Args: 309 config: A `tf.data.experimental.service.WorkerConfig` configration. 310 start: (Optional.) Boolean, indicating whether to start the server after 311 creating it. Defaults to True. 312 """ 313 if config.dispatcher_address is None: 314 raise ValueError("must specify a dispatcher_address") 315 if isinstance(config, service_config_pb2.WorkerConfig): 316 config_proto = config 317 else: 318 config_proto = service_config_pb2.WorkerConfig( 319 dispatcher_address=config.dispatcher_address, 320 worker_address=config.worker_address, 321 port=config.port, 322 protocol=config.protocol, 323 heartbeat_interval_ms=config.heartbeat_interval_ms, 324 dispatcher_timeout_ms=config.dispatcher_timeout_ms, 325 data_transfer_protocol=None) 326 self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer( 327 config_proto.SerializeToString()) 328 if start: 329 self._server.start() 330 331 def start(self): 332 """Starts this server. 333 334 Raises: 335 tf.errors.OpError: Or one of its subclasses if an error occurs while 336 starting the server. 337 """ 338 self._server.start() 339 340 def join(self): 341 """Blocks until the server has shut down. 342 343 This is useful when starting a dedicated worker process. 344 345 ``` 346 worker_server = tf.data.experimental.service.WorkerServer( 347 port=5051, dispatcher_address="localhost:5050") 348 worker_server.join() 349 ``` 350 351 This method currently blocks forever. 352 353 Raises: 354 tf.errors.OpError: Or one of its subclasses if an error occurs while 355 joining the server. 356 """ 357 self._server.join() 358 359 def _stop(self): 360 """Stops the server. 361 362 Raises: 363 tf.errors.OpError: Or one of its subclasses if an error occurs while 364 stopping the server. 365 """ 366 self._server.stop() 367 368 def __del__(self): 369 self._stop() 370 371 @property 372 def _address(self): 373 """Returns the address of the server. 374 375 The returned string will be in the form address:port, e.g. "localhost:1000". 376 """ 377 return "localhost:{0}".format(self._server.bound_port()) 378 379 def _num_tasks(self): 380 """Returns the number of tasks currently being executed on the worker.""" 381 return self._server.num_tasks() 382