• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Class to represent a device."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import threading
22
23from tensorflow.python import tf2
24from tensorflow.python.framework import device_spec
25
26if tf2.enabled():
27  DeviceSpec = device_spec.DeviceSpecV2
28else:
29  DeviceSpec = device_spec.DeviceSpecV1
30
31
32def check_valid(spec):
33  """Check that a device spec is valid.
34
35  Args:
36    spec: a string.
37
38  Raises:
39    An exception if the spec is invalid.
40  """
41  # Construct a DeviceSpec.  It will assert a failure if spec is invalid.
42  DeviceSpec.from_string(spec)
43
44
45def is_device_spec(obj):
46  """Abstract away the fact that DeviceSpecV2 is the base class."""
47  return isinstance(obj, device_spec.DeviceSpecV2)
48
49
50def canonical_name(device):
51  """Returns a canonical name for the given `DeviceSpec` or device name."""
52  if device is None:
53    return ""
54  if is_device_spec(device):
55    return device.to_string()
56  else:
57    device = DeviceSpec.from_string(device)
58    return device.to_string()
59
60
61# Performance caches
62_cached_mergers = {}
63_cache_lock = threading.RLock()
64_string_merge_cache = {}
65
66
67def merge_device(spec):
68  """Returns a device function that merges devices specifications.
69
70  This can be used to merge partial specifications of devices. The
71  innermost setting for a device field takes precedence. For example:
72
73    with tf.device(merge_device("/device:GPU:0"))
74      # Nodes created here have device "/device:GPU:0"
75      with tf.device(merge_device("/job:worker")):
76        # Nodes created here have device "/job:worker/device:GPU:0"
77        with tf.device(merge_device("/device:CPU:0")):
78          # Nodes created here have device "/job:worker/device:CPU:0"
79          with tf.device(merge_device("/job:ps")):
80            # Nodes created here have device "/job:ps/device:CPU:0"
81
82  Args:
83    spec: A `DeviceSpec` or a device spec string (partially) describing the
84      device that should be used for all nodes created in the scope of
85      the returned device function's with block.
86
87  Returns:
88    A MergeDevice object with the above-described behavior.
89
90  Raises:
91    ValueError: if the spec was not valid.
92  """
93
94  if isinstance(spec, MergeDevice):
95    return spec
96
97  with _cache_lock:
98    merger = _cached_mergers.get(spec)
99    if merger:
100      return merger
101
102    merger = MergeDevice(spec)
103    _cached_mergers[spec] = merger
104    return merger
105
106
107class MergeDevice(object):
108  """Wraps a device specification (DeviceSpec or str) with merge functionality.
109
110  When called, this class will merge a node_def with its own spec. It also
111  exposes a `shortcut_string_merge` method which can significantly improve
112  performance of device placement.
113  """
114
115  def __init__(self, spec):
116    if isinstance(spec, device_spec.DeviceSpecV2):
117      self._spec = spec
118    elif isinstance(spec, device_spec.DeviceSpecV1):
119      # Capture a snapshot of spec.
120      self._spec = spec.__class__.from_string(spec.to_string())
121    else:
122      self._spec = DeviceSpec.from_string(spec)
123
124  def __call__(self, node_def):
125    # In general a user may create a device function which takes into account
126    # arbitrary properties of an op. (For instance dynamically placing ops based
127    # on type.) So even though the standard DeviceSpec route only uses the
128    # device attribute, we take an entire node_def to maintain a consistent
129    # signature with general device functions.
130    current_device = DeviceSpec.from_string(node_def.device or "")
131    return self._spec.make_merged_spec(current_device)
132
133  def shortcut_string_merge(self, node_def):
134    """Merge a node def without materializing a full DeviceSpec object.
135
136    Often a device merge is invoked in order to generate a string which can be
137    passed into the c api. In such a case, we can cache the
138      node_def.device  ->  merge_result_string
139
140    map, and in most cases avoid:
141      - Materializing a copy of self._spec (In the case of DeviceSpecV1)
142      - Materializing a DeviceSpec for node_def.device
143      - A DeviceSpec.merge_from invocation
144
145    In practice the cache hit rate for this function is very high, because the
146    number of invocations when iterating through the device stack is much
147    larger than the number of devices.
148
149    Args:
150      node_def: An Operation (or Operation-like) to merge device constraints
151        with self._spec
152
153    Returns:
154      A string containing the merged device specification.
155    """
156    device = node_def.device or ""
157
158    merge_key = (self._spec, device)
159    result = _string_merge_cache.get(merge_key)
160    if result is None:
161      # This update is not atomic, however because the merge is stateless
162      # we don't need to lock when updating the cache.
163      result = self.__call__(node_def).to_string()
164      _string_merge_cache[merge_key] = result
165
166    return result
167
168  def __repr__(self):
169    return "{} (spec: {})".format(
170        super(MergeDevice, self).__repr__(), self._spec.to_string())
171
172  @property
173  def is_null_merge(self):
174    """Indicate whether the wrapped spec is empty.
175
176    In the degenerate case where self._spec is an empty specification, a caller
177    may wish to skip a merge step entirely. (However this class does not have
178    enough information to make that determination.)
179
180    Returns:
181      A boolean indicating whether a device merge will be trivial.
182    """
183    return not bool(self._spec.to_string())
184