1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class to represent a device.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.util.tf_export import tf_export 22 23 24_VALID_DEVICE_TYPES = {"CPU", "GPU", "TPU"} 25 26 27# ============================================================================== 28# == Global Implementation Details ============================================= 29# ============================================================================== 30_STRING_TO_COMPONENTS_CACHE = {} 31_COMPONENTS_TO_STRING_CACHE = {} 32 33 34def _as_str_or_none(inp): 35 return None if inp is None else str(inp) 36 37 38def _as_int_or_none(inp): 39 return None if inp is None else int(inp) 40 41 42def _as_device_str_or_none(device_type): 43 # For backwards compatibility only, we support lowercase variants of 44 # cpu and gpu but turn them into uppercase here. 45 if device_type in ("cpu", "gpu"): 46 return device_type.upper() 47 return _as_str_or_none(device_type) 48 49 50@tf_export("DeviceSpec", v1=[]) 51class DeviceSpecV2(object): 52 """Represents a (possibly partial) specification for a TensorFlow device. 53 54 `DeviceSpec`s are used throughout TensorFlow to describe where state is stored 55 and computations occur. Using `DeviceSpec` allows you to parse device spec 56 strings to verify their validity, merge them or compose them programmatically. 57 58 Example: 59 60 ```python 61 # Place the operations on device "GPU:0" in the "ps" job. 62 device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0) 63 with tf.device(device_spec.to_string()): 64 # Both my_var and squared_var will be placed on /job:ps/device:GPU:0. 65 my_var = tf.Variable(..., name="my_variable") 66 squared_var = tf.square(my_var) 67 ``` 68 69 With eager execution disabled (by default in TensorFlow 1.x and by calling 70 disable_eager_execution() in TensorFlow 2.x), the following syntax 71 can be used: 72 73 ```python 74 tf.compat.v1.disable_eager_execution() 75 76 # Same as previous 77 device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0) 78 # No need of .to_string() method. 79 with tf.device(device_spec): 80 my_var = tf.Variable(..., name="my_variable") 81 squared_var = tf.square(my_var) 82 ``` 83 84 If a `DeviceSpec` is partially specified, it will be merged with other 85 `DeviceSpec`s according to the scope in which it is defined. `DeviceSpec` 86 components defined in inner scopes take precedence over those defined in 87 outer scopes. 88 89 ```python 90 gpu0_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0) 91 with tf.device(DeviceSpec(job="train").to_string()): 92 with tf.device(gpu0_spec.to_string()): 93 # Nodes created here will be assigned to /job:ps/device:GPU:0. 94 with tf.device(DeviceSpec(device_type="GPU", device_index=1).to_string()): 95 # Nodes created here will be assigned to /job:train/device:GPU:1. 96 ``` 97 98 A `DeviceSpec` consists of 5 components -- each of 99 which is optionally specified: 100 101 * Job: The job name. 102 * Replica: The replica index. 103 * Task: The task index. 104 * Device type: The device type string (e.g. "CPU" or "GPU"). 105 * Device index: The device index. 106 """ 107 108 __slots__ = ("_job", "_replica", "_task", "_device_type", "_device_index", 109 "_as_string", "_hash") 110 111 def __init__(self, job=None, replica=None, task=None, device_type=None, 112 device_index=None): 113 """Create a new `DeviceSpec` object. 114 115 Args: 116 job: string. Optional job name. 117 replica: int. Optional replica index. 118 task: int. Optional task index. 119 device_type: Optional device type string (e.g. "CPU" or "GPU") 120 device_index: int. Optional device index. If left 121 unspecified, device represents 'any' device_index. 122 """ 123 self._job = _as_str_or_none(job) 124 self._replica = _as_int_or_none(replica) 125 self._task = _as_int_or_none(task) 126 self._device_type = _as_device_str_or_none(device_type) 127 self._device_index = _as_int_or_none(device_index) 128 self._as_string = self._components_to_string( 129 job=self._job, replica=self._replica, task=self._task, 130 device_type=self._device_type, device_index=self._device_index) 131 self._hash = hash(self.to_string()) 132 133 def to_string(self): 134 """Return a string representation of this `DeviceSpec`. 135 136 Returns: 137 a string of the form 138 /job:<name>/replica:<id>/task:<id>/device:<device_type>:<id>. 139 """ 140 return self._as_string 141 142 @classmethod 143 def from_string(cls, spec): 144 """Construct a `DeviceSpec` from a string. 145 146 Args: 147 spec: a string of the form 148 /job:<name>/replica:<id>/task:<id>/device:CPU:<id> 149 or 150 /job:<name>/replica:<id>/task:<id>/device:GPU:<id> 151 as cpu and gpu are mutually exclusive. 152 All entries are optional. 153 154 Returns: 155 A DeviceSpec. 156 """ 157 return cls(*cls._string_to_components(spec)) 158 159 def parse_from_string(self, spec): 160 """Parse a `DeviceSpec` name into its components. 161 162 2.x behavior change: 163 In TensorFlow 1.x, this function mutates its own state and returns itself. 164 In 2.x, DeviceSpecs are immutable, and this function will return a 165 DeviceSpec which contains the spec. 166 167 Recommended: 168 ``` 169 # my_spec and my_updated_spec are unrelated. 170 my_spec = tf.DeviceSpec.from_string("/CPU:0") 171 my_updated_spec = tf.DeviceSpec.from_string("/GPU:0") 172 with tf.device(my_updated_spec): 173 ... 174 ``` 175 176 Will work in 1.x and 2.x (though deprecated in 2.x): 177 ``` 178 my_spec = tf.DeviceSpec.from_string("/CPU:0") 179 my_updated_spec = my_spec.parse_from_string("/GPU:0") 180 with tf.device(my_updated_spec): 181 ... 182 ``` 183 184 Will NOT work in 2.x: 185 ``` 186 my_spec = tf.DeviceSpec.from_string("/CPU:0") 187 my_spec.parse_from_string("/GPU:0") # <== Will not update my_spec 188 with tf.device(my_spec): 189 ... 190 ``` 191 192 In general, `DeviceSpec.from_string` should completely replace 193 `DeviceSpec.parse_from_string`, and `DeviceSpec.replace` should 194 completely replace setting attributes directly. 195 196 Args: 197 spec: an optional string of the form 198 /job:<name>/replica:<id>/task:<id>/device:CPU:<id> 199 or 200 /job:<name>/replica:<id>/task:<id>/device:GPU:<id> 201 as cpu and gpu are mutually exclusive. 202 All entries are optional. 203 204 Returns: 205 The `DeviceSpec`. 206 207 Raises: 208 ValueError: if the spec was not valid. 209 """ 210 return self.from_string(spec) 211 212 def make_merged_spec(self, dev): 213 """Returns a new DeviceSpec which incorporates `dev`. 214 215 When combining specs, `dev` will take precidence over the current spec. 216 So for instance: 217 ``` 218 first_spec = tf.DeviceSpec(job=0, device_type="CPU") 219 second_spec = tf.DeviceSpec(device_type="GPU") 220 combined_spec = first_spec.make_merged_spec(second_spec) 221 ``` 222 223 is equivalent to: 224 ``` 225 combined_spec = tf.DeviceSpec(job=0, device_type="GPU") 226 ``` 227 228 Args: 229 dev: a `DeviceSpec` 230 231 Returns: 232 A new `DeviceSpec` which combines `self` and `dev` 233 """ 234 return self.__class__(*self._get_combined_properties(dev)) 235 236 def replace(self, **kwargs): 237 """Convenience method for making a new DeviceSpec by overriding fields. 238 239 For instance: 240 ``` 241 my_spec = DeviceSpec=(job="my_job", device="CPU") 242 my_updated_spec = my_spec.replace(device="GPU") 243 my_other_spec = my_spec.replace(device=None) 244 ``` 245 246 Args: 247 **kwargs: This method takes the same args as the DeviceSpec constructor 248 249 Returns: 250 A DeviceSpec with the fields specified in kwargs overridden. 251 """ 252 init_kwargs = dict( 253 job=self.job, replica=self.replica, task=self.task, 254 device_type=self.device_type, device_index=self.device_index) 255 256 # Explicitly provided kwargs take precidence. 257 init_kwargs.update(kwargs) 258 return self.__class__(**init_kwargs) 259 260 @property 261 def job(self): 262 return self._job 263 264 @property 265 def replica(self): 266 return self._replica 267 268 @property 269 def task(self): 270 return self._task 271 272 @property 273 def device_type(self): 274 return self._device_type 275 276 @property 277 def device_index(self): 278 return self._device_index 279 280 def _get_combined_properties(self, dev): 281 """Combine the current DeviceSpec with another DeviceSpec. 282 283 The combination of DeviceSpecs is will give priority to dev. 284 285 Args: 286 dev: a `DeviceSpec` 287 288 Returns: 289 A tuple of (job, replica, task, device_type, device_index) which 290 represents the combination of self and dev. 291 """ 292 return ( 293 dev.job if dev.job is not None else self.job, 294 dev.replica if dev.replica is not None else self.replica, 295 dev.task if dev.task is not None else self.task, 296 dev.device_type if dev.device_type is not None else self.device_type, 297 dev.device_index if dev.device_index is not None else self.device_index, 298 ) 299 300 @staticmethod 301 def _string_to_components(spec=None): 302 """Stateless portion of device spec string parsing. 303 304 Args: 305 spec: An optional string specifying a device specification. 306 307 Returns: 308 The parsed components of `spec`. Note that the result of this function 309 must go through attribute setters of DeviceSpec, and should therefore NOT 310 be used directly. 311 """ 312 cached_result = _STRING_TO_COMPONENTS_CACHE.get(spec) 313 if cached_result is not None: 314 return cached_result 315 316 raw_spec = spec # keep a copy of the original to update the cache 317 job, replica, task, device_type, device_index = None, None, None, None, None 318 319 spec = spec or "" 320 splits = [x.split(":") for x in spec.split("/")] 321 for y in splits: 322 ly = len(y) 323 if y: 324 # NOTE(taylorrobie): these will go through setters later. 325 if ly == 2 and y[0] == "job": 326 job = y[1] 327 elif ly == 2 and y[0] == "replica": 328 replica = y[1] 329 elif ly == 2 and y[0] == "task": 330 task = y[1] 331 elif ((ly == 1 or ly == 2) and (y[0].upper() in _VALID_DEVICE_TYPES)): 332 if device_type is not None: 333 raise ValueError("Cannot specify multiple device types: %s" % spec) 334 device_type = y[0].upper() 335 if ly == 2 and y[1] != "*": 336 device_index = int(y[1]) 337 elif ly == 3 and y[0] == "device": 338 if device_type is not None: 339 raise ValueError("Cannot specify multiple device types: %s" % spec) 340 device_type = y[1] 341 if y[2] != "*": 342 device_index = int(y[2]) 343 elif ly and y[0] != "": # pylint: disable=g-explicit-bool-comparison 344 raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec)) 345 346 output = (job, replica, task, device_type, device_index) 347 _STRING_TO_COMPONENTS_CACHE[raw_spec] = output 348 return output 349 350 @staticmethod 351 def _components_to_string(job, replica, task, device_type, device_index): 352 """Stateless portion of `to_string` (separated to allow caching).""" 353 key = (job, replica, task, device_type, device_index) 354 cached_result = _COMPONENTS_TO_STRING_CACHE.get(key) 355 if cached_result is not None: 356 return cached_result 357 358 output = [] 359 if job is not None: 360 output.append("/job:" + job) 361 if replica is not None: 362 output.append("/replica:" + str(replica)) 363 if task is not None: 364 output.append("/task:" + str(task)) 365 if device_type is not None: 366 device_index_string = "*" 367 if device_index is not None: 368 # Unlike the others, device_index is stored as an int. 369 device_index_string = str(device_index) 370 output.append("/device:%s:%s" % (device_type, device_index_string)) 371 372 output = "".join(output) 373 _COMPONENTS_TO_STRING_CACHE[key] = output 374 return output 375 376 def __eq__(self, other): 377 """Checks if the `other` DeviceSpec is same as the current instance, eg have 378 379 same value for all the internal fields. 380 381 Args: 382 other: Another DeviceSpec 383 384 Returns: 385 Return `True` if `other` is also a DeviceSpec instance and has same value 386 as the current instance. 387 Return `False` otherwise. 388 """ 389 return (isinstance(other, self.__class__) and 390 self.to_string() == other.to_string()) 391 392 def __hash__(self): 393 return self._hash 394 395 396@tf_export(v1=["DeviceSpec"]) # pylint: disable=missing-docstring 397class DeviceSpecV1(DeviceSpecV2): 398 __doc__ = DeviceSpecV2.__doc__ 399 __slots__ = DeviceSpecV2.__slots__ 400 401 @DeviceSpecV2.job.setter 402 def job(self, job): 403 self._job = _as_str_or_none(job) 404 self._as_string, self._hash = None, None 405 406 @DeviceSpecV2.replica.setter 407 def replica(self, replica): 408 self._replica = _as_int_or_none(replica) 409 self._as_string, self._hash = None, None 410 411 @DeviceSpecV2.task.setter 412 def task(self, task): 413 self._task = _as_int_or_none(task) 414 self._as_string, self._hash = None, None 415 416 @DeviceSpecV2.device_type.setter 417 def device_type(self, device_type): 418 self._device_type = _as_device_str_or_none(device_type) 419 self._as_string, self._hash = None, None 420 421 @DeviceSpecV2.device_index.setter 422 def device_index(self, device_index): 423 self._device_index = _as_int_or_none(device_index) 424 self._as_string, self._hash = None, None 425 426 def __hash__(self): 427 if self._hash is None: 428 self._hash = hash(self.to_string()) 429 return self._hash 430 431 def to_string(self): 432 if self._as_string is None: 433 self._as_string = self._components_to_string( 434 job=self.job, replica=self.replica, task=self.task, 435 device_type=self.device_type, device_index=self.device_index) 436 return self._as_string 437 438 def parse_from_string(self, spec): 439 (self.job, self.replica, self.task, self.device_type, self.device_index 440 ) = self._string_to_components(spec) 441 442 return self 443 444 def merge_from(self, dev): 445 """Merge the properties of "dev" into this `DeviceSpec`. 446 447 Note: Will be removed in TensorFlow 2.x since DeviceSpecs will become 448 immutable. 449 450 Args: 451 dev: a `DeviceSpec`. 452 """ 453 (self.job, self.replica, self.task, self.device_type, self.device_index 454 ) = self._get_combined_properties(dev) 455 456 # Use parent class docstrings for public methods. 457 to_string.__doc__ = DeviceSpecV2.to_string.__doc__ 458 parse_from_string.__doc__ = DeviceSpecV2.parse_from_string.__doc__ 459