1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Device-related support functions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.eager import context 22from tensorflow.python.framework import config 23from tensorflow.python.framework import device as tf_device 24from tensorflow.python.framework import ops 25 26 27def canonicalize(d, default=None): 28 """Canonicalize device string. 29 30 If d has missing components, the rest would be deduced from the `default` 31 argument or from '/replica:0/task:0/device:CPU:0'. For example: 32 If d = '/cpu:0', default='/job:worker/task:1', it returns 33 '/job:worker/replica:0/task:1/device:CPU:0'. 34 If d = '/cpu:0', default='/job:worker', it returns 35 '/job:worker/replica:0/task:0/device:CPU:0'. 36 If d = '/gpu:0', default=None, it returns 37 '/replica:0/task:0/device:GPU:0'. 38 39 Note: This uses "job:localhost" as the default if executing eagerly. 40 41 Args: 42 d: a device string or tf.config.LogicalDevice 43 default: a string for default device if d doesn't have all components. 44 45 Returns: 46 a canonicalized device string. 47 """ 48 if isinstance(d, context.LogicalDevice): 49 d = tf_device.DeviceSpec.from_string(d.name) 50 else: 51 d = tf_device.DeviceSpec.from_string(d) 52 53 assert d.device_type is None or d.device_type == d.device_type.upper(), ( 54 "Device type '%s' must be all-caps." % (d.device_type,)) 55 # Fill in missing device fields using defaults. 56 result = tf_device.DeviceSpec( 57 replica=0, task=0, device_type="CPU", device_index=0) 58 if ops.executing_eagerly_outside_functions(): 59 # Try to deduce job, replica and task in case it's in a multi worker setup. 60 # TODO(b/151452748): Using list_logical_devices is not always safe since it 61 # may return remote devices as well, but we're already doing this elsewhere. 62 host_cpu = tf_device.DeviceSpec.from_string( 63 config.list_logical_devices("CPU")[0].name) 64 if host_cpu.job: 65 result = result.make_merged_spec(host_cpu) 66 else: 67 # The default job is localhost if eager execution is enabled 68 result = result.replace(job="localhost") 69 if default: 70 # Overrides any defaults with values from the default device if given. 71 result = result.make_merged_spec( 72 tf_device.DeviceSpec.from_string(default)) 73 74 # Apply `d` last, so that it's values take precedence over the defaults. 75 result = result.make_merged_spec(d) 76 return result.to_string() 77 78 79def resolve(d): 80 """Canonicalize `d` with current device as default.""" 81 return canonicalize(d, default=current()) 82 83 84class _FakeNodeDef(object): 85 """A fake NodeDef for _FakeOperation.""" 86 87 __slots__ = ["op", "name"] 88 89 def __init__(self): 90 self.op = "" 91 self.name = "" 92 93 94class _FakeOperation(object): 95 """A fake Operation object to pass to device functions.""" 96 97 def __init__(self): 98 self.device = "" 99 self.type = "" 100 self.name = "" 101 self.node_def = _FakeNodeDef() 102 103 def _set_device(self, device): 104 self.device = ops._device_string(device) # pylint: disable=protected-access 105 106 def _set_device_from_string(self, device_str): 107 self.device = device_str 108 109 110def current(): 111 """Return a string (not canonicalized) for the current device.""" 112 # TODO(josh11b): Work out how this function interacts with ops.colocate_with. 113 if ops.executing_eagerly_outside_functions(): 114 d = context.context().device_name 115 else: 116 op = _FakeOperation() 117 ops.get_default_graph()._apply_device_functions(op) # pylint: disable=protected-access 118 d = op.device 119 return d 120 121 122def get_host_for_device(device): 123 """Returns the corresponding host device for the given device.""" 124 spec = tf_device.DeviceSpec.from_string(device) 125 return tf_device.DeviceSpec( 126 job=spec.job, replica=spec.replica, task=spec.task, 127 device_type="CPU", device_index=0).to_string() 128 129 130def local_devices_from_num_gpus(num_gpus): 131 """Returns device strings for local GPUs or CPU.""" 132 return (tuple("/device:GPU:%d" % i for i in range(num_gpus)) or 133 ("/device:CPU:0",)) 134