1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Device-related support functions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.eager import context 22from tensorflow.python.framework import device as tf_device 23from tensorflow.python.framework import ops 24 25 26def canonicalize(d, default=None): 27 """Canonicalize device string. 28 29 If d has missing components, the rest would be deduced from the `default` 30 argument or from '/replica:0/task:0/device:CPU:0'. For example: 31 If d = '/cpu:0', default='/job:worker/task:1', it returns 32 '/job:worker/replica:0/task:1/device:CPU:0'. 33 If d = '/cpu:0', default='/job:worker', it returns 34 '/job:worker/replica:0/task:0/device:CPU:0'. 35 If d = '/gpu:0', default=None, it returns 36 '/replica:0/task:0/device:GPU:0'. 37 38 Note: This uses "job:localhost" as the default if executing eagerly. 39 40 Args: 41 d: a device string or tf.config.LogicalDevice 42 default: a string for default device if d doesn't have all components. 43 44 Returns: 45 a canonicalized device string. 46 """ 47 if isinstance(d, context.LogicalDevice): 48 d = tf_device.DeviceSpec.from_string(d.name) 49 else: 50 d = tf_device.DeviceSpec.from_string(d) 51 52 assert d.device_type is None or d.device_type == d.device_type.upper(), ( 53 "Device type '%s' must be all-caps." % (d.device_type,)) 54 # Fill in missing device fields using defaults. 55 result = tf_device.DeviceSpec( 56 replica=0, task=0, device_type="CPU", device_index=0) 57 if ops.executing_eagerly_outside_functions(): 58 # The default job is localhost if eager execution is enabled 59 result = result.replace(job="localhost") 60 if default: 61 # Overrides any defaults with values from the default device if given. 62 result = result.make_merged_spec( 63 tf_device.DeviceSpec.from_string(default)) 64 65 # Apply `d` last, so that it's values take precidence over the defaults. 66 result = result.make_merged_spec(d) 67 return result.to_string() 68 69 70def resolve(d): 71 """Canonicalize `d` with current device as default.""" 72 return canonicalize(d, default=current()) 73 74 75class _FakeNodeDef(object): 76 """A fake NodeDef for _FakeOperation.""" 77 78 def __init__(self): 79 self.op = "" 80 self.name = "" 81 82 83class _FakeOperation(object): 84 """A fake Operation object to pass to device functions.""" 85 86 def __init__(self): 87 self.device = "" 88 self.type = "" 89 self.name = "" 90 self.node_def = _FakeNodeDef() 91 92 def _set_device(self, device): 93 self.device = ops._device_string(device) # pylint: disable=protected-access 94 95 def _set_device_from_string(self, device_str): 96 self.device = device_str 97 98 99def current(): 100 """Return a string (not canonicalized) for the current device.""" 101 # TODO(josh11b): Work out how this function interacts with ops.colocate_with. 102 if ops.executing_eagerly_outside_functions(): 103 d = context.context().device_name 104 else: 105 op = _FakeOperation() 106 ops.get_default_graph()._apply_device_functions(op) # pylint: disable=protected-access 107 d = op.device 108 return d 109 110 111def get_host_for_device(device): 112 """Returns the corresponding host device for the given device.""" 113 spec = tf_device.DeviceSpec.from_string(device) 114 return tf_device.DeviceSpec( 115 job=spec.job, replica=spec.replica, task=spec.task, 116 device_type="CPU", device_index=0).to_string() 117 118 119def local_devices_from_num_gpus(num_gpus): 120 """Returns device strings for local GPUs or CPU.""" 121 return (tuple("/device:GPU:%d" % i for i in range(num_gpus)) or 122 ("/device:CPU:0",)) 123