• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A Python interface for creating dataset servers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23# pylint: disable=invalid-import-order,g-bad-import-order, unused-import
24from tensorflow.core.protobuf import service_config_pb2
25from tensorflow.python import pywrap_tensorflow
26from tensorflow.python.data.experimental.service import _pywrap_server_lib
27from tensorflow.python.data.experimental.service import _pywrap_utils
28from tensorflow.python.util.tf_export import tf_export
29
30
31@tf_export("data.experimental.service.DispatcherConfig")
32class DispatcherConfig(
33    collections.namedtuple("DispatcherConfig", [
34        "port", "protocol", "work_dir", "fault_tolerant_mode",
35        "worker_addresses", "job_gc_check_interval_ms", "job_gc_timeout_ms"
36    ])):
37  """Configuration class for tf.data service dispatchers.
38
39  Fields:
40    port: Specifies the port to bind to. A value of 0 indicates that the server
41      may bind to any available port.
42    protocol: The protocol to use for communicating with the tf.data service,
43      e.g. "grpc".
44    work_dir: A directory to store dispatcher state in. This
45      argument is required for the dispatcher to be able to recover from
46      restarts.
47    fault_tolerant_mode: Whether the dispatcher should write its state to a
48      journal so that it can recover from restarts. Dispatcher state, including
49      registered datasets and created jobs, is synchronously written to the
50      journal before responding to RPCs. If `True`, `work_dir` must also be
51      specified.
52    worker_addresses: If the job uses auto-sharding, it needs to specify a fixed
53      list of worker addresses that will register with the dispatcher. The
54      worker addresses should be in the format `"host"` or `"host:port"`, where
55      `"port"` is an integer, named port, or `%port%` to match any port.
56    job_gc_check_interval_ms: How often the dispatcher should scan through to
57      delete old and unused jobs, in milliseconds. If not set, the runtime will
58      select a reasonable default. A higher value will reduce load on the
59      dispatcher, while a lower value will reduce the time it takes for the
60      dispatcher to garbage collect expired jobs.
61    job_gc_timeout_ms: How long a job needs to be unused before it becomes a
62      candidate for garbage collection, in milliseconds. A value of -1 indicates
63      that jobs should never be garbage collected. If not set, the runtime will
64      select a reasonable default. A higher value will cause jobs to stay around
65      longer with no consumers. This is useful if there is a large gap in
66      time between when consumers read from the job. A lower value will reduce
67      the time it takes to reclaim the resources from expired jobs.
68  """
69
70  def __new__(cls,
71              port=0,
72              protocol=None,
73              work_dir=None,
74              fault_tolerant_mode=False,
75              worker_addresses=None,
76              job_gc_check_interval_ms=None,
77              job_gc_timeout_ms=None):
78    if protocol is None:
79      protocol = _pywrap_utils.TF_DATA_DefaultProtocol()
80    if job_gc_check_interval_ms is None:
81      job_gc_check_interval_ms = 10 * 60 * 1000  # 10 minutes.
82    if job_gc_timeout_ms is None:
83      job_gc_timeout_ms = 5 * 60 * 1000  # 5 minutes.
84    return super(DispatcherConfig,
85                 cls).__new__(cls, port, protocol, work_dir,
86                              fault_tolerant_mode, worker_addresses,
87                              job_gc_check_interval_ms, job_gc_timeout_ms)
88
89
90@tf_export("data.experimental.service.DispatchServer", v1=[])
91class DispatchServer(object):
92  """An in-process tf.data service dispatch server.
93
94  A `tf.data.experimental.service.DispatchServer` coordinates a cluster of
95  `tf.data.experimental.service.WorkerServer`s. When the workers start, they
96  register themselves with the dispatcher.
97
98  >>> dispatcher = tf.data.experimental.service.DispatchServer()
99  >>> dispatcher_address = dispatcher.target.split("://")[1]
100  >>> worker = tf.data.experimental.service.WorkerServer(
101  ...     tf.data.experimental.service.WorkerConfig(
102  ...     dispatcher_address=dispatcher_address))
103  >>> dataset = tf.data.Dataset.range(10)
104  >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
105  ...     processing_mode="parallel_epochs", service=dispatcher.target))
106  >>> print(list(dataset.as_numpy_iterator()))
107  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
108
109  When starting a dedicated tf.data dispatch process, use join() to block
110  indefinitely after starting up the server.
111
112  ```
113  dispatcher = tf.data.experimental.service.DispatchServer(
114      tf.data.experimental.service.DispatcherConfig(port=5050))
115  dispatcher.join()
116  ```
117
118  To start a `DispatchServer` in fault-tolerant mode, set `work_dir` and
119  `fault_tolerant_mode` like below:
120
121  ```
122  dispatcher = tf.data.experimental.service.DispatchServer(
123      tf.data.experimental.service.DispatcherConfig(
124          port=5050,
125          work_dir="gs://my-bucket/dispatcher/work_dir",
126          fault_tolerant_mode=True))
127  ```
128  """
129
130  def __init__(self, config=None, start=True):
131    """Creates a new dispatch server.
132
133    Args:
134      config: (Optional.) A `tf.data.experimental.service.DispatcherConfig`
135        configration. If `None`, the dispatcher will use default
136        configuration values.
137      start: (Optional.) Boolean, indicating whether to start the server after
138        creating it. Defaults to True.
139    """
140    config = config or DispatcherConfig()
141    if config.fault_tolerant_mode and not config.work_dir:
142      raise ValueError(
143          "Cannot enable fault tolerant mode without configuring a work_dir")
144    self._config = config
145    config_proto = service_config_pb2.DispatcherConfig(
146        port=config.port,
147        protocol=config.protocol,
148        work_dir=config.work_dir,
149        fault_tolerant_mode=config.fault_tolerant_mode,
150        worker_addresses=config.worker_addresses,
151        job_gc_check_interval_ms=config.job_gc_check_interval_ms,
152        job_gc_timeout_ms=config.job_gc_timeout_ms)
153    self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(
154        config_proto.SerializeToString())
155    if start:
156      self._server.start()
157
158  def start(self):
159    """Starts this server.
160
161    >>> dispatcher = tf.data.experimental.service.DispatchServer(start=False)
162    >>> dispatcher.start()
163
164    Raises:
165      tf.errors.OpError: Or one of its subclasses if an error occurs while
166        starting the server.
167    """
168    self._server.start()
169
170  def join(self):
171    """Blocks until the server has shut down.
172
173    This is useful when starting a dedicated dispatch process.
174
175    ```
176    dispatcher = tf.data.experimental.service.DispatchServer(
177        tf.data.experimental.service.DispatcherConfig(port=5050))
178    dispatcher.join()
179    ```
180
181    Raises:
182      tf.errors.OpError: Or one of its subclasses if an error occurs while
183        joining the server.
184    """
185    self._server.join()
186
187  @property
188  def target(self):
189    """Returns a target that can be used to connect to the server.
190
191    >>> dispatcher = tf.data.experimental.service.DispatchServer()
192    >>> dataset = tf.data.Dataset.range(10)
193    >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
194    ...     processing_mode="parallel_epochs", service=dispatcher.target))
195
196    The returned string will be in the form protocol://address, e.g.
197    "grpc://localhost:5050".
198    """
199    return "{0}://localhost:{1}".format(self._config.protocol,
200                                        self._server.bound_port())
201
202  def _stop(self):
203    """Stops the server.
204
205    Raises:
206      tf.errors.OpError: Or one of its subclasses if an error occurs while
207        stopping the server.
208    """
209    self._server.stop()
210
211  def __del__(self):
212    self._stop()
213
214  @property
215  def _address(self):
216    """Returns the address of the server.
217
218    The returned string will be in the form address:port, e.g. "localhost:1000".
219    """
220    return "localhost:{0}".format(self._server.bound_port())
221
222  def _num_workers(self):
223    """Returns the number of workers registered with the dispatcher."""
224    return self._server.num_workers()
225
226
227@tf_export("data.experimental.service.WorkerConfig")
228class WorkerConfig(
229    collections.namedtuple("WorkerConfig", [
230        "dispatcher_address", "worker_address", "port", "protocol",
231        "heartbeat_interval_ms", "dispatcher_timeout_ms"
232    ])):
233  """Configuration class for tf.data service dispatchers.
234
235  Fields:
236    dispatcher_address: Specifies the address of the dispatcher.
237    worker_address: Specifies the address of the worker server. This address is
238      passed to the dispatcher so that the dispatcher can tell clients how to
239      connect to this worker.
240    port: Specifies the port to bind to. A value of 0 indicates that the worker
241      can bind to any available port.
242    protocol: (Optional.) Specifies the protocol to be used by the server, e.g.
243      "grpc".
244    heartbeat_interval_ms: How often the worker should heartbeat to the
245      dispatcher, in milliseconds. If not set, the runtime will select a
246      reasonable default. A higher value will reduce the load on the dispatcher,
247      while a lower value will reduce the time it takes to reclaim resources
248      from finished jobs.
249    dispatcher_timeout_ms: How long, in milliseconds, to retry requests to the
250      dispatcher before giving up and reporting an error. Defaults to 1 hour.
251  """
252
253  def __new__(cls,
254              dispatcher_address,
255              worker_address=None,
256              port=0,
257              protocol=None,
258              heartbeat_interval_ms=None,
259              dispatcher_timeout_ms=None):
260    if worker_address is None:
261      worker_address = "localhost:%port%"
262    if protocol is None:
263      protocol = _pywrap_utils.TF_DATA_DefaultProtocol()
264    if heartbeat_interval_ms is None:
265      heartbeat_interval_ms = 30 * 1000  # 30 seconds
266    if dispatcher_timeout_ms is None:
267      dispatcher_timeout_ms = 60 * 60 * 1000  # 1 hour
268
269    return super(WorkerConfig,
270                 cls).__new__(cls, dispatcher_address, worker_address, port,
271                              protocol, heartbeat_interval_ms,
272                              dispatcher_timeout_ms)
273
274
275@tf_export("data.experimental.service.WorkerServer", v1=[])
276class WorkerServer(object):
277  """An in-process tf.data service worker server.
278
279  A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset`
280  processing for user-defined datasets, and provides the resulting elements over
281  RPC. A worker is associated with a single
282  `tf.data.experimental.service.DispatchServer`.
283
284  >>> dispatcher = tf.data.experimental.service.DispatchServer()
285  >>> dispatcher_address = dispatcher.target.split("://")[1]
286  >>> worker = tf.data.experimental.service.WorkerServer(
287  ...     tf.data.experimental.service.WorkerConfig(
288  ...         dispatcher_address=dispatcher_address))
289  >>> dataset = tf.data.Dataset.range(10)
290  >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
291  ...     processing_mode="parallel_epochs", service=dispatcher.target))
292  >>> print(list(dataset.as_numpy_iterator()))
293  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
294
295  When starting a dedicated tf.data worker process, use join() to block
296  indefinitely after starting up the server.
297
298  ```
299  worker = tf.data.experimental.service.WorkerServer(
300      port=5051, dispatcher_address="localhost:5050")
301  worker.join()
302  ```
303  """
304
305  def __init__(self, config, start=True):
306    """Creates a new worker server.
307
308    Args:
309      config: A `tf.data.experimental.service.WorkerConfig` configration.
310      start: (Optional.) Boolean, indicating whether to start the server after
311        creating it. Defaults to True.
312    """
313    if config.dispatcher_address is None:
314      raise ValueError("must specify a dispatcher_address")
315    if isinstance(config, service_config_pb2.WorkerConfig):
316      config_proto = config
317    else:
318      config_proto = service_config_pb2.WorkerConfig(
319          dispatcher_address=config.dispatcher_address,
320          worker_address=config.worker_address,
321          port=config.port,
322          protocol=config.protocol,
323          heartbeat_interval_ms=config.heartbeat_interval_ms,
324          dispatcher_timeout_ms=config.dispatcher_timeout_ms,
325          data_transfer_protocol=None)
326    self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
327        config_proto.SerializeToString())
328    if start:
329      self._server.start()
330
331  def start(self):
332    """Starts this server.
333
334    Raises:
335      tf.errors.OpError: Or one of its subclasses if an error occurs while
336        starting the server.
337    """
338    self._server.start()
339
340  def join(self):
341    """Blocks until the server has shut down.
342
343    This is useful when starting a dedicated worker process.
344
345    ```
346    worker_server = tf.data.experimental.service.WorkerServer(
347        port=5051, dispatcher_address="localhost:5050")
348    worker_server.join()
349    ```
350
351    This method currently blocks forever.
352
353    Raises:
354      tf.errors.OpError: Or one of its subclasses if an error occurs while
355        joining the server.
356    """
357    self._server.join()
358
359  def _stop(self):
360    """Stops the server.
361
362    Raises:
363      tf.errors.OpError: Or one of its subclasses if an error occurs while
364        stopping the server.
365    """
366    self._server.stop()
367
368  def __del__(self):
369    self._stop()
370
371  @property
372  def _address(self):
373    """Returns the address of the server.
374
375    The returned string will be in the form address:port, e.g. "localhost:1000".
376    """
377    return "localhost:{0}".format(self._server.bound_port())
378
379  def _num_tasks(self):
380    """Returns the number of tasks currently being executed on the worker."""
381    return self._server.num_tasks()
382