• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Device-related support functions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.eager import context
22from tensorflow.python.framework import device as tf_device
23from tensorflow.python.framework import ops
24
25
26def canonicalize(d, default=None):
27  """Canonicalize device string.
28
29  If d has missing components, the rest would be deduced from the `default`
30  argument or from '/replica:0/task:0/device:CPU:0'. For example:
31    If d = '/cpu:0', default='/job:worker/task:1', it returns
32      '/job:worker/replica:0/task:1/device:CPU:0'.
33    If d = '/cpu:0', default='/job:worker', it returns
34      '/job:worker/replica:0/task:0/device:CPU:0'.
35    If d = '/gpu:0', default=None, it returns
36      '/replica:0/task:0/device:GPU:0'.
37
38  Note: This uses "job:localhost" as the default if executing eagerly.
39
40  Args:
41    d: a device string or tf.config.LogicalDevice
42    default: a string for default device if d doesn't have all components.
43
44  Returns:
45    a canonicalized device string.
46  """
47  if isinstance(d, context.LogicalDevice):
48    d = tf_device.DeviceSpec.from_string(d.name)
49  else:
50    d = tf_device.DeviceSpec.from_string(d)
51
52  assert d.device_type is None or d.device_type == d.device_type.upper(), (
53      "Device type '%s' must be all-caps." % (d.device_type,))
54  # Fill in missing device fields using defaults.
55  result = tf_device.DeviceSpec(
56      replica=0, task=0, device_type="CPU", device_index=0)
57  if ops.executing_eagerly_outside_functions():
58    # The default job is localhost if eager execution is enabled
59    result = result.replace(job="localhost")
60  if default:
61    # Overrides any defaults with values from the default device if given.
62    result = result.make_merged_spec(
63        tf_device.DeviceSpec.from_string(default))
64
65  # Apply `d` last, so that it's values take precidence over the defaults.
66  result = result.make_merged_spec(d)
67  return result.to_string()
68
69
70def resolve(d):
71  """Canonicalize `d` with current device as default."""
72  return canonicalize(d, default=current())
73
74
75class _FakeNodeDef(object):
76  """A fake NodeDef for _FakeOperation."""
77
78  def __init__(self):
79    self.op = ""
80    self.name = ""
81
82
83class _FakeOperation(object):
84  """A fake Operation object to pass to device functions."""
85
86  def __init__(self):
87    self.device = ""
88    self.type = ""
89    self.name = ""
90    self.node_def = _FakeNodeDef()
91
92  def _set_device(self, device):
93    self.device = ops._device_string(device)  # pylint: disable=protected-access
94
95  def _set_device_from_string(self, device_str):
96    self.device = device_str
97
98
99def current():
100  """Return a string (not canonicalized) for the current device."""
101  # TODO(josh11b): Work out how this function interacts with ops.colocate_with.
102  if ops.executing_eagerly_outside_functions():
103    d = context.context().device_name
104  else:
105    op = _FakeOperation()
106    ops.get_default_graph()._apply_device_functions(op)  # pylint: disable=protected-access
107    d = op.device
108  return d
109
110
111def get_host_for_device(device):
112  """Returns the corresponding host device for the given device."""
113  spec = tf_device.DeviceSpec.from_string(device)
114  return tf_device.DeviceSpec(
115      job=spec.job, replica=spec.replica, task=spec.task,
116      device_type="CPU", device_index=0).to_string()
117
118
119def local_devices_from_num_gpus(num_gpus):
120  """Returns device strings for local GPUs or CPU."""
121  return (tuple("/device:GPU:%d" % i for i in range(num_gpus)) or
122          ("/device:CPU:0",))
123