• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Class to represent a device."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import threading
22
23from tensorflow.python import tf2
24from tensorflow.python.framework import device_spec
25
26if tf2.enabled():
27  DeviceSpec = device_spec.DeviceSpecV2
28else:
29  DeviceSpec = device_spec.DeviceSpecV1
30
31
32def check_valid(spec):
33  """Check that a device spec is valid.
34
35  Args:
36    spec: a string.
37
38  Raises:
39    An exception if the spec is invalid.
40  """
41  # Construct a DeviceSpec.  It will assert a failure if spec is invalid.
42  DeviceSpec.from_string(spec)
43
44
45def is_device_spec(obj):
46  """Abstract away the fact that DeviceSpecV2 is the base class."""
47  return isinstance(obj, device_spec.DeviceSpecV2)
48
49
50def canonical_name(device):
51  """Returns a canonical name for the given `DeviceSpec` or device name."""
52  if device is None:
53    return ""
54  if is_device_spec(device):
55    return device.to_string()
56  else:
57    device = DeviceSpec.from_string(device)
58    return device.to_string()
59
60
61# Performance caches
62_cached_mergers = {}
63_cache_lock = threading.RLock()
64_string_merge_cache = {}
65
66
67def merge_device(spec):
68  """Returns a device function that merges devices specifications.
69
70  This can be used to merge partial specifications of devices. The
71  innermost setting for a device field takes precedence. For example:
72
73    with tf.device(merge_device("/device:GPU:0"))
74      # Nodes created here have device "/device:GPU:0"
75      with tf.device(merge_device("/job:worker")):
76        # Nodes created here have device "/job:worker/device:GPU:0"
77        with tf.device(merge_device("/device:CPU:0")):
78          # Nodes created here have device "/job:worker/device:CPU:0"
79          with tf.device(merge_device("/job:ps")):
80            # Nodes created here have device "/job:ps/device:CPU:0"
81
82  Args:
83    spec: A `DeviceSpec` or a device spec string (partially) describing the
84      device that should be used for all nodes created in the scope of
85      the returned device function's with block.
86
87  Returns:
88    A MergeDevice object with the above-described behavior.
89
90  Raises:
91    ValueError: if the spec was not valid.
92  """
93
94  if isinstance(spec, MergeDevice):
95    return spec
96
97  with _cache_lock:
98    merger = _cached_mergers.get(spec)
99    if merger:
100      return merger
101
102    merger = MergeDevice(spec)
103    _cached_mergers[spec] = merger
104    return merger
105
106
107class MergeDevice(object):
108  """Wraps a device specification (DeviceSpec or str) with merge functionality.
109
110  When called, this class will merge a node_def with its own spec. It also
111  exposes a `shortcut_string_merge` method which can significantly improve
112  performance of device placement.
113  """
114
115  __slots__ = ["_spec"]
116
117  def __init__(self, spec):
118    if isinstance(spec, device_spec.DeviceSpecV2):
119      self._spec = spec
120    elif isinstance(spec, device_spec.DeviceSpecV1):
121      # Capture a snapshot of spec.
122      self._spec = spec.__class__.from_string(spec.to_string())
123    else:
124      self._spec = DeviceSpec.from_string(spec)
125
126  def __call__(self, node_def):
127    # In general a user may create a device function which takes into account
128    # arbitrary properties of an op. (For instance dynamically placing ops based
129    # on type.) So even though the standard DeviceSpec route only uses the
130    # device attribute, we take an entire node_def to maintain a consistent
131    # signature with general device functions.
132    current_device = DeviceSpec.from_string(node_def.device or "")
133    return self._spec.make_merged_spec(current_device)
134
135  def shortcut_string_merge(self, node_def):
136    """Merge a node def without materializing a full DeviceSpec object.
137
138    Often a device merge is invoked in order to generate a string which can be
139    passed into the c api. In such a case, we can cache the
140      node_def.device  ->  merge_result_string
141
142    map, and in most cases avoid:
143      - Materializing a copy of self._spec (In the case of DeviceSpecV1)
144      - Materializing a DeviceSpec for node_def.device
145      - A DeviceSpec.merge_from invocation
146
147    In practice the cache hit rate for this function is very high, because the
148    number of invocations when iterating through the device stack is much
149    larger than the number of devices.
150
151    Args:
152      node_def: An Operation (or Operation-like) to merge device constraints
153        with self._spec
154
155    Returns:
156      A string containing the merged device specification.
157    """
158    device = node_def.device or ""
159
160    merge_key = (self._spec, device)
161    result = _string_merge_cache.get(merge_key)
162    if result is None:
163      # This update is not atomic, however because the merge is stateless
164      # we don't need to lock when updating the cache.
165      result = self.__call__(node_def).to_string()
166      _string_merge_cache[merge_key] = result
167
168    return result
169
170  def __repr__(self):
171    return "{} (spec: {})".format(
172        super(MergeDevice, self).__repr__(), self._spec.to_string())
173
174  @property
175  def is_null_merge(self):
176    """Indicate whether the wrapped spec is empty.
177
178    In the degenerate case where self._spec is an empty specification, a caller
179    may wish to skip a merge step entirely. (However this class does not have
180    enough information to make that determination.)
181
182    Returns:
183      A boolean indicating whether a device merge will be trivial.
184    """
185    return not bool(self._spec.to_string())
186