1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A Python interface for creating dataset servers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23# pylint: disable=invalid-import-order,g-bad-import-order, unused-import 24from tensorflow.core.protobuf import service_config_pb2 25from tensorflow.python import pywrap_tensorflow 26from tensorflow.python.data.experimental.service import _pywrap_server_lib 27from tensorflow.python.util.tf_export import tf_export 28 29 30@tf_export("data.experimental.service.DispatcherConfig") 31class DispatcherConfig( 32 collections.namedtuple("DispatcherConfig", [ 33 "port", "protocol", "work_dir", "fault_tolerant_mode", 34 "job_gc_check_interval_ms", "job_gc_timeout_ms" 35 ])): 36 """Configuration class for tf.data service dispatchers. 37 38 Fields: 39 port: Specifies the port to bind to. A value of 0 indicates that the server 40 may bind to any available port. 41 protocol: The protocol to use for communicating with the tf.data service. 42 Defaults to `"grpc"`. 43 work_dir: A directory to store dispatcher state in. This 44 argument is required for the dispatcher to be able to recover from 45 restarts. 46 fault_tolerant_mode: Whether the dispatcher should write its state to a 47 journal so that it can recover from restarts. Dispatcher state, including 48 registered datasets and created jobs, is synchronously written to the 49 journal before responding to RPCs. If `True`, `work_dir` must also be 50 specified. 51 job_gc_check_interval_ms: How often the dispatcher should scan through to 52 delete old and unused jobs, in milliseconds. If not set, the runtime will 53 select a reasonable default. A higher value will reduce load on the 54 dispatcher, while a lower value will reduce the time it takes for the 55 dispatcher to garbage collect expired jobs. 56 job_gc_timeout_ms: How long a job needs to be unused before it becomes a 57 candidate for garbage collection, in milliseconds. If not set, the runtime 58 will select a reasonable default. A higher value will cause jobs to stay 59 around longer with no consumers. This is useful if there is a large gap in 60 time between when consumers read from the job. A lower value will reduce 61 the time it takes to reclaim the resources from expired jobs. 62 """ 63 64 def __new__(cls, 65 port=0, 66 protocol="grpc", 67 work_dir=None, 68 fault_tolerant_mode=False, 69 job_gc_check_interval_ms=None, 70 job_gc_timeout_ms=None): 71 if job_gc_check_interval_ms is None: 72 job_gc_check_interval_ms = 10 * 60 * 1000 # 10 minutes. 73 if job_gc_timeout_ms is None: 74 job_gc_timeout_ms = 5 * 60 * 1000 # 5 minutes. 75 return super(DispatcherConfig, 76 cls).__new__(cls, port, protocol, work_dir, 77 fault_tolerant_mode, job_gc_check_interval_ms, 78 job_gc_timeout_ms) 79 80 81@tf_export("data.experimental.service.DispatchServer", v1=[]) 82class DispatchServer(object): 83 """An in-process tf.data service dispatch server. 84 85 A `tf.data.experimental.service.DispatchServer` coordinates a cluster of 86 `tf.data.experimental.service.WorkerServer`s. When the workers start, they 87 register themselves with the dispatcher. 88 89 >>> dispatcher = tf.data.experimental.service.DispatchServer() 90 >>> dispatcher_address = dispatcher.target.split("://")[1] 91 >>> worker = tf.data.experimental.service.WorkerServer( 92 ... tf.data.experimental.service.WorkerConfig( 93 ... dispatcher_address=dispatcher_address)) 94 >>> dataset = tf.data.Dataset.range(10) 95 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 96 ... processing_mode="parallel_epochs", service=dispatcher.target)) 97 >>> print(list(dataset.as_numpy_iterator())) 98 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 99 100 When starting a dedicated tf.data dispatch process, use join() to block 101 indefinitely after starting up the server. 102 103 ``` 104 dispatcher = tf.data.experimental.service.DispatchServer( 105 tf.data.experimental.service.DispatcherConfig(port=5050)) 106 dispatcher.join() 107 ``` 108 109 To start a `DispatchServer` in fault-tolerant mode, set `work_dir` and 110 `fault_tolerant_mode` like below: 111 112 ``` 113 dispatcher = tf.data.experimental.service.DispatchServer( 114 tf.data.experimental.service.DispatcherConfig( 115 port=5050, 116 work_dir="gs://my-bucket/dispatcher/work_dir", 117 fault_tolerant_mode=True)) 118 ``` 119 """ 120 121 def __init__(self, config=None, start=True): 122 """Creates a new dispatch server. 123 124 Args: 125 config: (Optional.) A `tf.data.experimental.service.DispatcherConfig` 126 configration. If `None`, the dispatcher will use default 127 configuration values. 128 start: (Optional.) Boolean, indicating whether to start the server after 129 creating it. Defaults to True. 130 """ 131 config = config or DispatcherConfig() 132 if config.fault_tolerant_mode and not config.work_dir: 133 raise ValueError( 134 "Cannot enable fault tolerant mode without configuring a work_dir") 135 self._config = config 136 config_proto = service_config_pb2.DispatcherConfig( 137 port=config.port, 138 protocol=config.protocol, 139 work_dir=config.work_dir, 140 fault_tolerant_mode=config.fault_tolerant_mode, 141 job_gc_check_interval_ms=config.job_gc_check_interval_ms, 142 job_gc_timeout_ms=config.job_gc_timeout_ms) 143 self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer( 144 config_proto.SerializeToString()) 145 if start: 146 self._server.start() 147 148 def start(self): 149 """Starts this server. 150 151 >>> dispatcher = tf.data.experimental.service.DispatchServer(start=False) 152 >>> dispatcher.start() 153 154 Raises: 155 tf.errors.OpError: Or one of its subclasses if an error occurs while 156 starting the server. 157 """ 158 self._server.start() 159 160 def join(self): 161 """Blocks until the server has shut down. 162 163 This is useful when starting a dedicated dispatch process. 164 165 ``` 166 dispatcher = tf.data.experimental.service.DispatchServer( 167 tf.data.experimental.service.DispatcherConfig(port=5050)) 168 dispatcher.join() 169 ``` 170 171 Raises: 172 tf.errors.OpError: Or one of its subclasses if an error occurs while 173 joining the server. 174 """ 175 self._server.join() 176 177 @property 178 def target(self): 179 """Returns a target that can be used to connect to the server. 180 181 >>> dispatcher = tf.data.experimental.service.DispatchServer() 182 >>> dataset = tf.data.Dataset.range(10) 183 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 184 ... processing_mode="parallel_epochs", service=dispatcher.target)) 185 186 The returned string will be in the form protocol://address, e.g. 187 "grpc://localhost:5050". 188 """ 189 return "{0}://localhost:{1}".format(self._config.protocol, 190 self._server.bound_port()) 191 192 def _stop(self): 193 """Stops the server. 194 195 Raises: 196 tf.errors.OpError: Or one of its subclasses if an error occurs while 197 stopping the server. 198 """ 199 self._server.stop() 200 201 def __del__(self): 202 self._stop() 203 204 @property 205 def _address(self): 206 """Returns the address of the server. 207 208 The returned string will be in the form address:port, e.g. "localhost:1000". 209 """ 210 return "localhost:{0}".format(self._server.bound_port()) 211 212 def _num_workers(self): 213 """Returns the number of workers registered with the dispatcher.""" 214 return self._server.num_workers() 215 216 217@tf_export("data.experimental.service.WorkerConfig") 218class WorkerConfig( 219 collections.namedtuple("WorkerConfig", [ 220 "dispatcher_address", "worker_address", "port", "protocol", 221 "heartbeat_interval_ms", "dispatcher_timeout_ms" 222 ])): 223 """Configuration class for tf.data service dispatchers. 224 225 Fields: 226 dispatcher_address: Specifies the address of the dispatcher. 227 worker_address: Specifies the address of the worker server. This address is 228 passed to the dispatcher so that the dispatcher can tell clients how to 229 connect to this worker. 230 port: Specifies the port to bind to. A value of 0 indicates that the worker 231 can bind to any available port. 232 protocol: (Optional.) Specifies the protocol to be used by the server. 233 Defaults to `"grpc"`. 234 heartbeat_interval_ms: How often the worker should heartbeat to the 235 dispatcher, in milliseconds. If not set, the runtime will select a 236 reasonable default. A higher value will reduce the load on the dispatcher, 237 while a lower value will reduce the time it takes to reclaim resources 238 from finished jobs. 239 dispatcher_timeout_ms: How long, in milliseconds, to retry requests to the 240 dispatcher before giving up and reporting an error. Defaults to 1 hour. 241 """ 242 243 def __new__(cls, 244 dispatcher_address, 245 worker_address=None, 246 port=0, 247 protocol="grpc", 248 heartbeat_interval_ms=None, 249 dispatcher_timeout_ms=None): 250 if worker_address is None: 251 worker_address = "localhost:%port%" 252 if heartbeat_interval_ms is None: 253 heartbeat_interval_ms = 30 * 1000 # 30 seconds 254 if dispatcher_timeout_ms is None: 255 dispatcher_timeout_ms = 60 * 60 * 1000 # 1 hour 256 257 return super(WorkerConfig, 258 cls).__new__(cls, dispatcher_address, worker_address, port, 259 protocol, heartbeat_interval_ms, 260 dispatcher_timeout_ms) 261 262 263@tf_export("data.experimental.service.WorkerServer", v1=[]) 264class WorkerServer(object): 265 """An in-process tf.data service worker server. 266 267 A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset` 268 processing for user-defined datasets, and provides the resulting elements over 269 RPC. A worker is associated with a single 270 `tf.data.experimental.service.DispatchServer`. 271 272 >>> dispatcher = tf.data.experimental.service.DispatchServer() 273 >>> dispatcher_address = dispatcher.target.split("://")[1] 274 >>> worker = tf.data.experimental.service.WorkerServer( 275 ... tf.data.experimental.service.WorkerConfig( 276 ... dispatcher_address=dispatcher_address)) 277 >>> dataset = tf.data.Dataset.range(10) 278 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 279 ... processing_mode="parallel_epochs", service=dispatcher.target)) 280 >>> print(list(dataset.as_numpy_iterator())) 281 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 282 283 When starting a dedicated tf.data worker process, use join() to block 284 indefinitely after starting up the server. 285 286 ``` 287 worker = tf.data.experimental.service.WorkerServer( 288 port=5051, dispatcher_address="grpc://localhost:5050") 289 worker.join() 290 ``` 291 """ 292 293 def __init__(self, config, start=True): 294 """Creates a new worker server. 295 296 Args: 297 config: A `tf.data.experimental.service.WorkerConfig` configration. 298 start: (Optional.) Boolean, indicating whether to start the server after 299 creating it. Defaults to True. 300 """ 301 if config.dispatcher_address is None: 302 raise ValueError("must specify a dispatcher_address") 303 self._config = config 304 config_proto = service_config_pb2.WorkerConfig( 305 dispatcher_address=config.dispatcher_address, 306 worker_address=config.worker_address, 307 port=config.port, 308 protocol=config.protocol, 309 heartbeat_interval_ms=config.heartbeat_interval_ms, 310 dispatcher_timeout_ms=config.dispatcher_timeout_ms, 311 data_transfer_protocol=None) 312 self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer( 313 config_proto.SerializeToString()) 314 if start: 315 self._server.start() 316 317 def start(self): 318 """Starts this server. 319 320 Raises: 321 tf.errors.OpError: Or one of its subclasses if an error occurs while 322 starting the server. 323 """ 324 self._server.start() 325 326 def join(self): 327 """Blocks until the server has shut down. 328 329 This is useful when starting a dedicated worker process. 330 331 ``` 332 worker_server = tf.data.experimental.service.WorkerServer( 333 port=5051, dispatcher_address="grpc://localhost:5050") 334 worker_server.join() 335 ``` 336 337 This method currently blocks forever. 338 339 Raises: 340 tf.errors.OpError: Or one of its subclasses if an error occurs while 341 joining the server. 342 """ 343 self._server.join() 344 345 def _stop(self): 346 """Stops the server. 347 348 Raises: 349 tf.errors.OpError: Or one of its subclasses if an error occurs while 350 stopping the server. 351 """ 352 self._server.stop() 353 354 def __del__(self): 355 self._stop() 356 357 @property 358 def _address(self): 359 """Returns the address of the server. 360 361 The returned string will be in the form address:port, e.g. "localhost:1000". 362 """ 363 return "localhost:{0}".format(self._server.bound_port()) 364 365 def _num_tasks(self): 366 """Returns the number of tasks currently being executed on the worker.""" 367 return self._server.num_tasks() 368