1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Class to represent a device.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import threading 22 23from tensorflow.python import tf2 24from tensorflow.python.framework import device_spec 25 26if tf2.enabled(): 27 DeviceSpec = device_spec.DeviceSpecV2 28else: 29 DeviceSpec = device_spec.DeviceSpecV1 30 31 32def check_valid(spec): 33 """Check that a device spec is valid. 34 35 Args: 36 spec: a string. 37 38 Raises: 39 An exception if the spec is invalid. 40 """ 41 # Construct a DeviceSpec. It will assert a failure if spec is invalid. 42 DeviceSpec.from_string(spec) 43 44 45def is_device_spec(obj): 46 """Abstract away the fact that DeviceSpecV2 is the base class.""" 47 return isinstance(obj, device_spec.DeviceSpecV2) 48 49 50def canonical_name(device): 51 """Returns a canonical name for the given `DeviceSpec` or device name.""" 52 if device is None: 53 return "" 54 if is_device_spec(device): 55 return device.to_string() 56 else: 57 device = DeviceSpec.from_string(device) 58 return device.to_string() 59 60 61# Performance caches 62_cached_mergers = {} 63_cache_lock = threading.RLock() 64_string_merge_cache = {} 65 66 67def merge_device(spec): 68 """Returns a device function that merges devices specifications. 69 70 This can be used to merge partial specifications of devices. The 71 innermost setting for a device field takes precedence. For example: 72 73 with tf.device(merge_device("/device:GPU:0")) 74 # Nodes created here have device "/device:GPU:0" 75 with tf.device(merge_device("/job:worker")): 76 # Nodes created here have device "/job:worker/device:GPU:0" 77 with tf.device(merge_device("/device:CPU:0")): 78 # Nodes created here have device "/job:worker/device:CPU:0" 79 with tf.device(merge_device("/job:ps")): 80 # Nodes created here have device "/job:ps/device:CPU:0" 81 82 Args: 83 spec: A `DeviceSpec` or a device spec string (partially) describing the 84 device that should be used for all nodes created in the scope of 85 the returned device function's with block. 86 87 Returns: 88 A MergeDevice object with the above-described behavior. 89 90 Raises: 91 ValueError: if the spec was not valid. 92 """ 93 94 if isinstance(spec, MergeDevice): 95 return spec 96 97 with _cache_lock: 98 merger = _cached_mergers.get(spec) 99 if merger: 100 return merger 101 102 merger = MergeDevice(spec) 103 _cached_mergers[spec] = merger 104 return merger 105 106 107class MergeDevice(object): 108 """Wraps a device specification (DeviceSpec or str) with merge functionality. 109 110 When called, this class will merge a node_def with its own spec. It also 111 exposes a `shortcut_string_merge` method which can significantly improve 112 performance of device placement. 113 """ 114 115 __slots__ = ["_spec"] 116 117 def __init__(self, spec): 118 if isinstance(spec, device_spec.DeviceSpecV2): 119 self._spec = spec 120 elif isinstance(spec, device_spec.DeviceSpecV1): 121 # Capture a snapshot of spec. 122 self._spec = spec.__class__.from_string(spec.to_string()) 123 else: 124 self._spec = DeviceSpec.from_string(spec) 125 126 def __call__(self, node_def): 127 # In general a user may create a device function which takes into account 128 # arbitrary properties of an op. (For instance dynamically placing ops based 129 # on type.) So even though the standard DeviceSpec route only uses the 130 # device attribute, we take an entire node_def to maintain a consistent 131 # signature with general device functions. 132 current_device = DeviceSpec.from_string(node_def.device or "") 133 return self._spec.make_merged_spec(current_device) 134 135 def shortcut_string_merge(self, node_def): 136 """Merge a node def without materializing a full DeviceSpec object. 137 138 Often a device merge is invoked in order to generate a string which can be 139 passed into the c api. In such a case, we can cache the 140 node_def.device -> merge_result_string 141 142 map, and in most cases avoid: 143 - Materializing a copy of self._spec (In the case of DeviceSpecV1) 144 - Materializing a DeviceSpec for node_def.device 145 - A DeviceSpec.merge_from invocation 146 147 In practice the cache hit rate for this function is very high, because the 148 number of invocations when iterating through the device stack is much 149 larger than the number of devices. 150 151 Args: 152 node_def: An Operation (or Operation-like) to merge device constraints 153 with self._spec 154 155 Returns: 156 A string containing the merged device specification. 157 """ 158 device = node_def.device or "" 159 160 merge_key = (self._spec, device) 161 result = _string_merge_cache.get(merge_key) 162 if result is None: 163 # This update is not atomic, however because the merge is stateless 164 # we don't need to lock when updating the cache. 165 result = self.__call__(node_def).to_string() 166 _string_merge_cache[merge_key] = result 167 168 return result 169 170 def __repr__(self): 171 return "{} (spec: {})".format( 172 super(MergeDevice, self).__repr__(), self._spec.to_string()) 173 174 @property 175 def is_null_merge(self): 176 """Indicate whether the wrapped spec is empty. 177 178 In the degenerate case where self._spec is an empty specification, a caller 179 may wish to skip a merge step entirely. (However this class does not have 180 enough information to make that determination.) 181 182 Returns: 183 A boolean indicating whether a device merge will be trivial. 184 """ 185 return not bool(self._spec.to_string()) 186