• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class to represent a device."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.util.tf_export import tf_export
22
23
24_VALID_DEVICE_TYPES = {"CPU", "GPU", "TPU"}
25
26
27# ==============================================================================
28# == Global Implementation Details =============================================
29# ==============================================================================
30_STRING_TO_COMPONENTS_CACHE = {}
31_COMPONENTS_TO_STRING_CACHE = {}
32
33
34def _as_str_or_none(inp):
35  return None if inp is None else str(inp)
36
37
38def _as_int_or_none(inp):
39  return None if inp is None else int(inp)
40
41
42def _as_device_str_or_none(device_type):
43  # For backwards compatibility only, we support lowercase variants of
44  # cpu and gpu but turn them into uppercase here.
45  if device_type in ("cpu", "gpu"):
46    return device_type.upper()
47  return _as_str_or_none(device_type)
48
49
50@tf_export("DeviceSpec", v1=[])
51class DeviceSpecV2(object):
52  """Represents a (possibly partial) specification for a TensorFlow device.
53
54  `DeviceSpec`s are used throughout TensorFlow to describe where state is stored
55  and computations occur. Using `DeviceSpec` allows you to parse device spec
56  strings to verify their validity, merge them or compose them programmatically.
57
58  Example:
59
60  ```python
61  # Place the operations on device "GPU:0" in the "ps" job.
62  device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
63  with tf.device(device_spec.to_string()):
64    # Both my_var and squared_var will be placed on /job:ps/device:GPU:0.
65    my_var = tf.Variable(..., name="my_variable")
66    squared_var = tf.square(my_var)
67  ```
68
69  With eager execution disabled (by default in TensorFlow 1.x and by calling
70  disable_eager_execution() in TensorFlow 2.x), the following syntax
71  can be used:
72
73  ```python
74  tf.compat.v1.disable_eager_execution()
75
76  # Same as previous
77  device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
78  # No need of .to_string() method.
79  with tf.device(device_spec):
80    my_var = tf.Variable(..., name="my_variable")
81    squared_var = tf.square(my_var)
82   ```
83
84  If a `DeviceSpec` is partially specified, it will be merged with other
85  `DeviceSpec`s according to the scope in which it is defined. `DeviceSpec`
86  components defined in inner scopes take precedence over those defined in
87  outer scopes.
88
89  ```python
90  gpu0_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
91  with tf.device(DeviceSpec(job="train").to_string()):
92    with tf.device(gpu0_spec.to_string()):
93      # Nodes created here will be assigned to /job:ps/device:GPU:0.
94    with tf.device(DeviceSpec(device_type="GPU", device_index=1).to_string()):
95      # Nodes created here will be assigned to /job:train/device:GPU:1.
96  ```
97
98  A `DeviceSpec` consists of 5 components -- each of
99  which is optionally specified:
100
101  * Job: The job name.
102  * Replica: The replica index.
103  * Task: The task index.
104  * Device type: The device type string (e.g. "CPU" or "GPU").
105  * Device index: The device index.
106  """
107
108  __slots__ = ("_job", "_replica", "_task", "_device_type", "_device_index",
109               "_as_string", "_hash")
110
111  def __init__(self, job=None, replica=None, task=None, device_type=None,
112               device_index=None):
113    """Create a new `DeviceSpec` object.
114
115    Args:
116      job: string.  Optional job name.
117      replica: int.  Optional replica index.
118      task: int.  Optional task index.
119      device_type: Optional device type string (e.g. "CPU" or "GPU")
120      device_index: int.  Optional device index.  If left
121        unspecified, device represents 'any' device_index.
122    """
123    self._job = _as_str_or_none(job)
124    self._replica = _as_int_or_none(replica)
125    self._task = _as_int_or_none(task)
126    self._device_type = _as_device_str_or_none(device_type)
127    self._device_index = _as_int_or_none(device_index)
128    self._as_string = self._components_to_string(
129        job=self._job, replica=self._replica, task=self._task,
130        device_type=self._device_type, device_index=self._device_index)
131    self._hash = hash(self.to_string())
132
133  def to_string(self):
134    """Return a string representation of this `DeviceSpec`.
135
136    Returns:
137      a string of the form
138      /job:<name>/replica:<id>/task:<id>/device:<device_type>:<id>.
139    """
140    return self._as_string
141
142  @classmethod
143  def from_string(cls, spec):
144    """Construct a `DeviceSpec` from a string.
145
146    Args:
147      spec: a string of the form
148       /job:<name>/replica:<id>/task:<id>/device:CPU:<id>
149      or
150       /job:<name>/replica:<id>/task:<id>/device:GPU:<id>
151      as cpu and gpu are mutually exclusive.
152      All entries are optional.
153
154    Returns:
155      A DeviceSpec.
156    """
157    return cls(*cls._string_to_components(spec))
158
159  def parse_from_string(self, spec):
160    """Parse a `DeviceSpec` name into its components.
161
162    2.x behavior change:
163      In TensorFlow 1.x, this function mutates its own state and returns itself.
164      In 2.x, DeviceSpecs are immutable, and this function will return a
165        DeviceSpec which contains the spec.
166
167      Recommended:
168        ```
169        # my_spec and my_updated_spec are unrelated.
170        my_spec = tf.DeviceSpec.from_string("/CPU:0")
171        my_updated_spec = tf.DeviceSpec.from_string("/GPU:0")
172        with tf.device(my_updated_spec):
173          ...
174        ```
175
176      Will work in 1.x and 2.x (though deprecated in 2.x):
177        ```
178        my_spec = tf.DeviceSpec.from_string("/CPU:0")
179        my_updated_spec = my_spec.parse_from_string("/GPU:0")
180        with tf.device(my_updated_spec):
181          ...
182        ```
183
184      Will NOT work in 2.x:
185        ```
186        my_spec = tf.DeviceSpec.from_string("/CPU:0")
187        my_spec.parse_from_string("/GPU:0")  # <== Will not update my_spec
188        with tf.device(my_spec):
189          ...
190        ```
191
192      In general, `DeviceSpec.from_string` should completely replace
193      `DeviceSpec.parse_from_string`, and `DeviceSpec.replace` should
194      completely replace setting attributes directly.
195
196    Args:
197      spec: an optional string of the form
198       /job:<name>/replica:<id>/task:<id>/device:CPU:<id>
199      or
200       /job:<name>/replica:<id>/task:<id>/device:GPU:<id>
201      as cpu and gpu are mutually exclusive.
202      All entries are optional.
203
204    Returns:
205      The `DeviceSpec`.
206
207    Raises:
208      ValueError: if the spec was not valid.
209    """
210    return self.from_string(spec)
211
212  def make_merged_spec(self, dev):
213    """Returns a new DeviceSpec which incorporates `dev`.
214
215    When combining specs, `dev` will take precidence over the current spec.
216    So for instance:
217    ```
218    first_spec = tf.DeviceSpec(job=0, device_type="CPU")
219    second_spec = tf.DeviceSpec(device_type="GPU")
220    combined_spec = first_spec.make_merged_spec(second_spec)
221    ```
222
223    is equivalent to:
224    ```
225    combined_spec = tf.DeviceSpec(job=0, device_type="GPU")
226    ```
227
228    Args:
229      dev: a `DeviceSpec`
230
231    Returns:
232      A new `DeviceSpec` which combines `self` and `dev`
233    """
234    return self.__class__(*self._get_combined_properties(dev))
235
236  def replace(self, **kwargs):
237    """Convenience method for making a new DeviceSpec by overriding fields.
238
239    For instance:
240    ```
241    my_spec = DeviceSpec=(job="my_job", device="CPU")
242    my_updated_spec = my_spec.replace(device="GPU")
243    my_other_spec = my_spec.replace(device=None)
244    ```
245
246    Args:
247      **kwargs: This method takes the same args as the DeviceSpec constructor
248
249    Returns:
250      A DeviceSpec with the fields specified in kwargs overridden.
251    """
252    init_kwargs = dict(
253        job=self.job, replica=self.replica, task=self.task,
254        device_type=self.device_type, device_index=self.device_index)
255
256    # Explicitly provided kwargs take precidence.
257    init_kwargs.update(kwargs)
258    return self.__class__(**init_kwargs)
259
260  @property
261  def job(self):
262    return self._job
263
264  @property
265  def replica(self):
266    return self._replica
267
268  @property
269  def task(self):
270    return self._task
271
272  @property
273  def device_type(self):
274    return self._device_type
275
276  @property
277  def device_index(self):
278    return self._device_index
279
280  def _get_combined_properties(self, dev):
281    """Combine the current DeviceSpec with another DeviceSpec.
282
283    The combination of DeviceSpecs is will give priority to dev.
284
285    Args:
286      dev: a `DeviceSpec`
287
288    Returns:
289      A tuple of (job, replica, task, device_type, device_index) which
290      represents the combination of self and dev.
291    """
292    return (
293        dev.job if dev.job is not None else self.job,
294        dev.replica if dev.replica is not None else self.replica,
295        dev.task if dev.task is not None else self.task,
296        dev.device_type if dev.device_type is not None else self.device_type,
297        dev.device_index if dev.device_index is not None else self.device_index,
298    )
299
300  @staticmethod
301  def _string_to_components(spec=None):
302    """Stateless portion of device spec string parsing.
303
304    Args:
305      spec: An optional string specifying a device specification.
306
307    Returns:
308      The parsed components of `spec`. Note that the result of this function
309      must go through attribute setters of DeviceSpec, and should therefore NOT
310      be used directly.
311    """
312    cached_result = _STRING_TO_COMPONENTS_CACHE.get(spec)
313    if cached_result is not None:
314      return cached_result
315
316    raw_spec = spec  # keep a copy of the original to update the cache
317    job, replica, task, device_type, device_index = None, None, None, None, None
318
319    spec = spec or ""
320    splits = [x.split(":") for x in spec.split("/")]
321    for y in splits:
322      ly = len(y)
323      if y:
324        # NOTE(taylorrobie): these will go through setters later.
325        if ly == 2 and y[0] == "job":
326          job = y[1]
327        elif ly == 2 and y[0] == "replica":
328          replica = y[1]
329        elif ly == 2 and y[0] == "task":
330          task = y[1]
331        elif ((ly == 1 or ly == 2) and (y[0].upper() in _VALID_DEVICE_TYPES)):
332          if device_type is not None:
333            raise ValueError("Cannot specify multiple device types: %s" % spec)
334          device_type = y[0].upper()
335          if ly == 2 and y[1] != "*":
336            device_index = int(y[1])
337        elif ly == 3 and y[0] == "device":
338          if device_type is not None:
339            raise ValueError("Cannot specify multiple device types: %s" % spec)
340          device_type = y[1]
341          if y[2] != "*":
342            device_index = int(y[2])
343        elif ly and y[0] != "":  # pylint: disable=g-explicit-bool-comparison
344          raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec))
345
346    output = (job, replica, task, device_type, device_index)
347    _STRING_TO_COMPONENTS_CACHE[raw_spec] = output
348    return output
349
350  @staticmethod
351  def _components_to_string(job, replica, task, device_type, device_index):
352    """Stateless portion of `to_string` (separated to allow caching)."""
353    key = (job, replica, task, device_type, device_index)
354    cached_result = _COMPONENTS_TO_STRING_CACHE.get(key)
355    if cached_result is not None:
356      return cached_result
357
358    output = []
359    if job is not None:
360      output.append("/job:" + job)
361    if replica is not None:
362      output.append("/replica:" + str(replica))
363    if task is not None:
364      output.append("/task:" + str(task))
365    if device_type is not None:
366      device_index_string = "*"
367      if device_index is not None:
368        # Unlike the others, device_index is stored as an int.
369        device_index_string = str(device_index)
370      output.append("/device:%s:%s" % (device_type, device_index_string))
371
372    output = "".join(output)
373    _COMPONENTS_TO_STRING_CACHE[key] = output
374    return output
375
376  def __eq__(self, other):
377    """Checks if the `other` DeviceSpec is same as the current instance, eg have
378
379       same value for all the internal fields.
380
381    Args:
382      other: Another DeviceSpec
383
384    Returns:
385      Return `True` if `other` is also a DeviceSpec instance and has same value
386      as the current instance.
387      Return `False` otherwise.
388    """
389    return (isinstance(other, self.__class__) and
390            self.to_string() == other.to_string())
391
392  def __hash__(self):
393    return self._hash
394
395
396@tf_export(v1=["DeviceSpec"])  # pylint: disable=missing-docstring
397class DeviceSpecV1(DeviceSpecV2):
398  __doc__ = DeviceSpecV2.__doc__
399  __slots__ = DeviceSpecV2.__slots__
400
401  @DeviceSpecV2.job.setter
402  def job(self, job):
403    self._job = _as_str_or_none(job)
404    self._as_string, self._hash = None, None
405
406  @DeviceSpecV2.replica.setter
407  def replica(self, replica):
408    self._replica = _as_int_or_none(replica)
409    self._as_string, self._hash = None, None
410
411  @DeviceSpecV2.task.setter
412  def task(self, task):
413    self._task = _as_int_or_none(task)
414    self._as_string, self._hash = None, None
415
416  @DeviceSpecV2.device_type.setter
417  def device_type(self, device_type):
418    self._device_type = _as_device_str_or_none(device_type)
419    self._as_string, self._hash = None, None
420
421  @DeviceSpecV2.device_index.setter
422  def device_index(self, device_index):
423    self._device_index = _as_int_or_none(device_index)
424    self._as_string, self._hash = None, None
425
426  def __hash__(self):
427    if self._hash is None:
428      self._hash = hash(self.to_string())
429    return self._hash
430
431  def to_string(self):
432    if self._as_string is None:
433      self._as_string = self._components_to_string(
434          job=self.job, replica=self.replica, task=self.task,
435          device_type=self.device_type, device_index=self.device_index)
436    return self._as_string
437
438  def parse_from_string(self, spec):
439    (self.job, self.replica, self.task, self.device_type, self.device_index
440    ) = self._string_to_components(spec)
441
442    return self
443
444  def merge_from(self, dev):
445    """Merge the properties of "dev" into this `DeviceSpec`.
446
447    Note: Will be removed in TensorFlow 2.x since DeviceSpecs will become
448          immutable.
449
450    Args:
451      dev: a `DeviceSpec`.
452    """
453    (self.job, self.replica, self.task, self.device_type, self.device_index
454    ) = self._get_combined_properties(dev)
455
456  # Use parent class docstrings for public methods.
457  to_string.__doc__ = DeviceSpecV2.to_string.__doc__
458  parse_from_string.__doc__ = DeviceSpecV2.parse_from_string.__doc__
459