1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class to represent a device.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.util.tf_export import tf_export 22from tensorflow.python import pywrap_tfe 23 24# EPU represents for TPU embedding for now. Subject to change in future. 25_VALID_DEVICE_TYPES = frozenset({"CPU", "GPU", "TPU", "CUSTOM", "EPU"}) 26 27 28# ============================================================================== 29# == Global Implementation Details ============================================= 30# ============================================================================== 31_STRING_TO_COMPONENTS_CACHE = {} 32_COMPONENTS_TO_STRING_CACHE = {} 33 34 35def _as_str_or_none(inp): 36 return None if inp is None else str(inp) 37 38 39def _as_int_or_none(inp): 40 return None if inp is None else int(inp) 41 42 43def _as_device_str_or_none(device_type): 44 # For backwards compatibility only, we support lowercase variants of 45 # cpu and gpu but turn them into uppercase here. 46 if device_type in ("cpu", "gpu"): 47 return device_type.upper() 48 return _as_str_or_none(device_type) 49 50 51@tf_export("DeviceSpec", v1=[]) 52class DeviceSpecV2(object): 53 """Represents a (possibly partial) specification for a TensorFlow device. 54 55 `DeviceSpec`s are used throughout TensorFlow to describe where state is stored 56 and computations occur. Using `DeviceSpec` allows you to parse device spec 57 strings to verify their validity, merge them or compose them programmatically. 58 59 Example: 60 61 ```python 62 # Place the operations on device "GPU:0" in the "ps" job. 63 device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0) 64 with tf.device(device_spec.to_string()): 65 # Both my_var and squared_var will be placed on /job:ps/device:GPU:0. 66 my_var = tf.Variable(..., name="my_variable") 67 squared_var = tf.square(my_var) 68 ``` 69 70 With eager execution disabled (by default in TensorFlow 1.x and by calling 71 disable_eager_execution() in TensorFlow 2.x), the following syntax 72 can be used: 73 74 ```python 75 tf.compat.v1.disable_eager_execution() 76 77 # Same as previous 78 device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0) 79 # No need of .to_string() method. 80 with tf.device(device_spec): 81 my_var = tf.Variable(..., name="my_variable") 82 squared_var = tf.square(my_var) 83 ``` 84 85 If a `DeviceSpec` is partially specified, it will be merged with other 86 `DeviceSpec`s according to the scope in which it is defined. `DeviceSpec` 87 components defined in inner scopes take precedence over those defined in 88 outer scopes. 89 90 ```python 91 gpu0_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0) 92 with tf.device(DeviceSpec(job="train").to_string()): 93 with tf.device(gpu0_spec.to_string()): 94 # Nodes created here will be assigned to /job:ps/device:GPU:0. 95 with tf.device(DeviceSpec(device_type="GPU", device_index=1).to_string()): 96 # Nodes created here will be assigned to /job:train/device:GPU:1. 97 ``` 98 99 A `DeviceSpec` consists of 5 components -- each of 100 which is optionally specified: 101 102 * Job: The job name. 103 * Replica: The replica index. 104 * Task: The task index. 105 * Device type: The device type string (e.g. "CPU" or "GPU"). 106 * Device index: The device index. 107 """ 108 109 __slots__ = ("_job", "_replica", "_task", "_device_type", "_device_index", 110 "_as_string", "_hash") 111 112 def __init__(self, job=None, replica=None, task=None, device_type=None, 113 device_index=None): 114 """Create a new `DeviceSpec` object. 115 116 Args: 117 job: string. Optional job name. 118 replica: int. Optional replica index. 119 task: int. Optional task index. 120 device_type: Optional device type string (e.g. "CPU" or "GPU") 121 device_index: int. Optional device index. If left 122 unspecified, device represents 'any' device_index. 123 """ 124 self._job = _as_str_or_none(job) 125 self._replica = _as_int_or_none(replica) 126 self._task = _as_int_or_none(task) 127 self._device_type = _as_device_str_or_none(device_type) 128 self._device_index = _as_int_or_none(device_index) 129 self._as_string = self._components_to_string( 130 job=self._job, replica=self._replica, task=self._task, 131 device_type=self._device_type, device_index=self._device_index) 132 self._hash = hash(self.to_string()) 133 134 def to_string(self): 135 """Return a string representation of this `DeviceSpec`. 136 137 Returns: 138 a string of the form 139 /job:<name>/replica:<id>/task:<id>/device:<device_type>:<id>. 140 """ 141 return self._as_string 142 143 @classmethod 144 def from_string(cls, spec): 145 """Construct a `DeviceSpec` from a string. 146 147 Args: 148 spec: a string of the form 149 /job:<name>/replica:<id>/task:<id>/device:CPU:<id> 150 or 151 /job:<name>/replica:<id>/task:<id>/device:GPU:<id> 152 as cpu and gpu are mutually exclusive. 153 All entries are optional. 154 155 Returns: 156 A DeviceSpec. 157 """ 158 return cls(*cls._string_to_components(spec)) 159 160 def parse_from_string(self, spec): 161 """Parse a `DeviceSpec` name into its components. 162 163 **2.x behavior change**: 164 165 In TensorFlow 1.x, this function mutates its own state and returns itself. 166 In 2.x, DeviceSpecs are immutable, and this function will return a 167 DeviceSpec which contains the spec. 168 169 * Recommended: 170 171 ``` 172 # my_spec and my_updated_spec are unrelated. 173 my_spec = tf.DeviceSpec.from_string("/CPU:0") 174 my_updated_spec = tf.DeviceSpec.from_string("/GPU:0") 175 with tf.device(my_updated_spec): 176 ... 177 ``` 178 179 * Will work in 1.x and 2.x (though deprecated in 2.x): 180 181 ``` 182 my_spec = tf.DeviceSpec.from_string("/CPU:0") 183 my_updated_spec = my_spec.parse_from_string("/GPU:0") 184 with tf.device(my_updated_spec): 185 ... 186 ``` 187 188 * Will NOT work in 2.x: 189 190 ``` 191 my_spec = tf.DeviceSpec.from_string("/CPU:0") 192 my_spec.parse_from_string("/GPU:0") # <== Will not update my_spec 193 with tf.device(my_spec): 194 ... 195 ``` 196 197 In general, `DeviceSpec.from_string` should completely replace 198 `DeviceSpec.parse_from_string`, and `DeviceSpec.replace` should 199 completely replace setting attributes directly. 200 201 Args: 202 spec: an optional string of the form 203 /job:<name>/replica:<id>/task:<id>/device:CPU:<id> 204 or 205 /job:<name>/replica:<id>/task:<id>/device:GPU:<id> 206 as cpu and gpu are mutually exclusive. 207 All entries are optional. 208 209 Returns: 210 The `DeviceSpec`. 211 212 Raises: 213 ValueError: if the spec was not valid. 214 """ 215 return self.from_string(spec) 216 217 def make_merged_spec(self, dev): 218 """Returns a new DeviceSpec which incorporates `dev`. 219 220 When combining specs, `dev` will take precedence over the current spec. 221 So for instance: 222 ``` 223 first_spec = tf.DeviceSpec(job=0, device_type="CPU") 224 second_spec = tf.DeviceSpec(device_type="GPU") 225 combined_spec = first_spec.make_merged_spec(second_spec) 226 ``` 227 228 is equivalent to: 229 ``` 230 combined_spec = tf.DeviceSpec(job=0, device_type="GPU") 231 ``` 232 233 Args: 234 dev: a `DeviceSpec` 235 236 Returns: 237 A new `DeviceSpec` which combines `self` and `dev` 238 """ 239 return self.__class__(*self._get_combined_properties(dev)) 240 241 def replace(self, **kwargs): 242 """Convenience method for making a new DeviceSpec by overriding fields. 243 244 For instance: 245 ``` 246 my_spec = DeviceSpec=(job="my_job", device="CPU") 247 my_updated_spec = my_spec.replace(device="GPU") 248 my_other_spec = my_spec.replace(device=None) 249 ``` 250 251 Args: 252 **kwargs: This method takes the same args as the DeviceSpec constructor 253 254 Returns: 255 A DeviceSpec with the fields specified in kwargs overridden. 256 """ 257 init_kwargs = dict( 258 job=self.job, replica=self.replica, task=self.task, 259 device_type=self.device_type, device_index=self.device_index) 260 261 # Explicitly provided kwargs take precedence. 262 init_kwargs.update(kwargs) 263 return self.__class__(**init_kwargs) 264 265 @property 266 def job(self): 267 return self._job 268 269 @property 270 def replica(self): 271 return self._replica 272 273 @property 274 def task(self): 275 return self._task 276 277 @property 278 def device_type(self): 279 return self._device_type 280 281 @property 282 def device_index(self): 283 return self._device_index 284 285 def _get_combined_properties(self, dev): 286 """Combine the current DeviceSpec with another DeviceSpec. 287 288 The combination of DeviceSpecs is will give priority to dev. 289 290 Args: 291 dev: a `DeviceSpec` 292 293 Returns: 294 A tuple of (job, replica, task, device_type, device_index) which 295 represents the combination of self and dev. 296 """ 297 return ( 298 dev.job if dev.job is not None else self.job, 299 dev.replica if dev.replica is not None else self.replica, 300 dev.task if dev.task is not None else self.task, 301 dev.device_type if dev.device_type is not None else self.device_type, 302 dev.device_index if dev.device_index is not None else self.device_index, 303 ) 304 305 @staticmethod 306 def _get_valid_device_types(): 307 valid_device_types = set({}) 308 physical_devices = pywrap_tfe.TF_ListPluggablePhysicalDevices() 309 for device in physical_devices: 310 valid_device_types.add(device.decode().split(":")[1]) 311 valid_device_types = valid_device_types | _VALID_DEVICE_TYPES 312 return valid_device_types 313 314 @staticmethod 315 def _string_to_components(spec=None): 316 """Stateless portion of device spec string parsing. 317 318 Args: 319 spec: An optional string specifying a device specification. 320 321 Returns: 322 The parsed components of `spec`. Note that the result of this function 323 must go through attribute setters of DeviceSpec, and should therefore NOT 324 be used directly. 325 """ 326 cached_result = _STRING_TO_COMPONENTS_CACHE.get(spec) 327 if cached_result is not None: 328 return cached_result 329 330 raw_spec = spec # keep a copy of the original to update the cache 331 job, replica, task, device_type, device_index = None, None, None, None, None 332 333 spec = spec or "" 334 splits = [x.split(":") for x in spec.split("/")] 335 valid_device_types = DeviceSpecV2._get_valid_device_types() 336 for y in splits: 337 ly = len(y) 338 if y: 339 # NOTE(taylorrobie): these will go through setters later. 340 if ly == 2 and y[0] == "job": 341 job = y[1] 342 elif ly == 2 and y[0] == "replica": 343 replica = y[1] 344 elif ly == 2 and y[0] == "task": 345 task = y[1] 346 elif ((ly == 1 or ly == 2) and (y[0].upper() in valid_device_types)): 347 if device_type is not None: 348 raise ValueError("Cannot specify multiple device types: %s" % spec) 349 device_type = y[0].upper() 350 if ly == 2 and y[1] != "*": 351 device_index = int(y[1]) 352 elif ly == 3 and y[0] == "device": 353 if device_type is not None: 354 raise ValueError("Cannot specify multiple device types: %s" % spec) 355 device_type = y[1] 356 if y[2] != "*": 357 device_index = int(y[2]) 358 elif ly and y[0] != "": # pylint: disable=g-explicit-bool-comparison 359 raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec)) 360 361 output = (job, replica, task, device_type, device_index) 362 _STRING_TO_COMPONENTS_CACHE[raw_spec] = output 363 return output 364 365 @staticmethod 366 def _components_to_string(job, replica, task, device_type, device_index): 367 """Stateless portion of `to_string` (separated to allow caching).""" 368 key = (job, replica, task, device_type, device_index) 369 cached_result = _COMPONENTS_TO_STRING_CACHE.get(key) 370 if cached_result is not None: 371 return cached_result 372 373 output = [] 374 if job is not None: 375 output.append("/job:" + job) 376 if replica is not None: 377 output.append("/replica:" + str(replica)) 378 if task is not None: 379 output.append("/task:" + str(task)) 380 if device_type is not None: 381 device_index_string = "*" 382 if device_index is not None: 383 # Unlike the others, device_index is stored as an int. 384 device_index_string = str(device_index) 385 output.append("/device:%s:%s" % (device_type, device_index_string)) 386 387 output = "".join(output) 388 _COMPONENTS_TO_STRING_CACHE[key] = output 389 return output 390 391 def __eq__(self, other): 392 """Checks if the `other` DeviceSpec is same as the current instance, eg have 393 394 same value for all the internal fields. 395 396 Args: 397 other: Another DeviceSpec 398 399 Returns: 400 Return `True` if `other` is also a DeviceSpec instance and has same value 401 as the current instance. 402 Return `False` otherwise. 403 """ 404 return (isinstance(other, self.__class__) and 405 self.to_string() == other.to_string()) 406 407 def __hash__(self): 408 return self._hash 409 410 411@tf_export(v1=["DeviceSpec"]) # pylint: disable=missing-docstring 412class DeviceSpecV1(DeviceSpecV2): 413 __doc__ = DeviceSpecV2.__doc__ 414 __slots__ = DeviceSpecV2.__slots__ 415 416 @DeviceSpecV2.job.setter 417 def job(self, job): 418 self._job = _as_str_or_none(job) 419 self._as_string, self._hash = None, None 420 421 @DeviceSpecV2.replica.setter 422 def replica(self, replica): 423 self._replica = _as_int_or_none(replica) 424 self._as_string, self._hash = None, None 425 426 @DeviceSpecV2.task.setter 427 def task(self, task): 428 self._task = _as_int_or_none(task) 429 self._as_string, self._hash = None, None 430 431 @DeviceSpecV2.device_type.setter 432 def device_type(self, device_type): 433 self._device_type = _as_device_str_or_none(device_type) 434 self._as_string, self._hash = None, None 435 436 @DeviceSpecV2.device_index.setter 437 def device_index(self, device_index): 438 self._device_index = _as_int_or_none(device_index) 439 self._as_string, self._hash = None, None 440 441 def __hash__(self): 442 if self._hash is None: 443 self._hash = hash(self.to_string()) 444 return self._hash 445 446 def to_string(self): 447 if self._as_string is None: 448 self._as_string = self._components_to_string( 449 job=self.job, replica=self.replica, task=self.task, 450 device_type=self.device_type, device_index=self.device_index) 451 return self._as_string 452 453 def parse_from_string(self, spec): 454 (self.job, self.replica, self.task, self.device_type, self.device_index 455 ) = self._string_to_components(spec) 456 457 return self 458 459 def merge_from(self, dev): 460 """Merge the properties of "dev" into this `DeviceSpec`. 461 462 Note: Will be removed in TensorFlow 2.x since DeviceSpecs will become 463 immutable. 464 465 Args: 466 dev: a `DeviceSpec`. 467 """ 468 (self.job, self.replica, self.task, self.device_type, self.device_index 469 ) = self._get_combined_properties(dev) 470 471 # Use parent class docstrings for public methods. 472 to_string.__doc__ = DeviceSpecV2.to_string.__doc__ 473 parse_from_string.__doc__ = DeviceSpecV2.parse_from_string.__doc__ 474