1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Core DTensor Python API.""" 16 17import contextlib 18import os 19import threading 20from typing import Any, Callable, List, Optional, Sequence, Union 21 22from tensorflow.dtensor.python import dtensor_device 23from tensorflow.dtensor.python import gen_dtensor_ops 24from tensorflow.dtensor.python import layout as layout_lib 25from tensorflow.python.eager import context 26from tensorflow.python.framework import config as tf_config 27from tensorflow.python.framework import device as tf_device 28from tensorflow.python.framework import ops 29from tensorflow.python.util.tf_export import tf_export 30 31_DT_CLIENT_ID = "DTENSOR_CLIENT_ID" 32_DT_NUM_CLIENTS = "DTENSOR_NUM_CLIENTS" 33_DT_JOB_NAME = "DTENSOR_JOB_NAME" 34_DT_JOBS = "DTENSOR_JOBS" 35_DT_HEARTBEAT_ENABLED = "DTENSOR_ENABLE_HEARTBEAT" 36 37_dtensor_singleton = None 38_dtensor_singleton_lock = threading.Lock() 39 40# ----------------------------------------------------------------------------- 41# Main methods to launch DTensor computations. 42 43 44@tf_export("experimental.dtensor.call_with_layout", v1=[]) 45def call_with_layout(fn: Callable[..., 46 Any], layout: Optional[layout_lib.Layout], 47 *args, **kwargs) -> Any: 48 """Calls a function in the DTensor device scope if `layout` is not None. 49 50 If `layout` is not None, `fn` consumes DTensor(s) as input and produces a 51 DTensor as output; a DTensor is a tf.Tensor with layout-related attributes. 52 53 If `layout` is None, `fn` consumes and produces regular tf.Tensors. 54 55 Args: 56 fn: A supported TF API function such as tf.zeros. 57 layout: Optional, the layout of the output DTensor. 58 *args: Arguments given to `fn`. 59 **kwargs: Keyword arguments given to `fn`. 60 61 Returns: 62 The return value of `fn` transformed to a DTensor if requested. 63 """ 64 if layout is not None: 65 if not context.executing_eagerly(): 66 # This is a workaround for b/199324097, where functions such as tf.ones 67 # could attach an incorrect layout to the tf.const generated under the 68 # hood. The op runs successfully in eager mode, but in graph mode, MLIR 69 # passes sometimes attach the default layout to a scalar constant. 70 # %cst = tf.Const([1]) -- With the given layout 71 # %0 = "tf.DTensorLayout"(%cst). -- Fails in MLIR pass since shape for 72 # -- layout could be different than 73 # -- shape[0] for %cst. 74 # %1 = tf.Fill(%0, 1) 75 result = fn(*args, **kwargs) 76 return relayout(result, layout) 77 else: 78 with run_on(layout.mesh): 79 with _dtensor_device()._default_layout(layout): # pylint: disable=protected-access 80 return fn(*args, **kwargs) 81 return fn(*args, **kwargs) 82 83 84@tf_export("experimental.dtensor.run_on", v1=[]) 85@contextlib.contextmanager 86def run_on(mesh: layout_lib.Mesh): 87 """Runs enclosed functions in the DTensor device scope. 88 89 This function returns a scope. All the ops and tf.functions in this scope will 90 run on the DTensor device using the mesh provided. 91 This is useful for wrapping any tf.function that doesn't take a DTensor as 92 input but would like to produce DTensor as result. The scope will also make 93 sure all small constants be replicated as DTensor. 94 95 Args: 96 mesh: A Mesh instance to extract a default mesh from. 97 98 Yields: 99 A context in which all ops and tf.functions will run on the DTensor device. 100 """ 101 if not isinstance(mesh, layout_lib.Mesh): 102 raise ValueError(f"Expect `mesh` to be `Mesh`, got {type(mesh)}") 103 104 with _dtensor_device()._experimental_default_mesh(mesh): # pylint: disable=protected-access 105 with ops.device(device_name()): 106 yield 107 108 109@tf_export("experimental.dtensor.device_name", v1=[]) 110def device_name() -> str: 111 """Returns the singleton DTensor device's name. 112 113 This function can be used in the following way: 114 115 ```python 116 import tensorflow as tf 117 118 with tf.device(dtensor.device_name()): 119 # ... 120 ``` 121 """ 122 return _dtensor_device().name 123 124 125# ----------------------------------------------------------------------------- 126# Data transfer methods. 127 128 129@tf_export("experimental.dtensor.copy_to_mesh", v1=[]) 130def copy_to_mesh( 131 tensor: Any, 132 layout: layout_lib.Layout, 133 source_layout: Optional[layout_lib.Layout] = None) -> ops.Tensor: 134 """Copies a tf.Tensor onto the DTensor device with the given layout. 135 136 Copies a regular tf.Tensor onto the DTensor device. Use the mesh attached to 137 `layout` as target mesh. This method currently only supports replicated 138 layouts. To get a DTensor with a sharded layout, use the `pack` method. 139 140 Args: 141 tensor: A regular tf.Tensor to be copied as a DTensor. 142 layout: Target layout (and mesh) for the result DTensor. 143 source_layout: Source layout of the tensor before copy, used for backward 144 passes. 145 146 Returns: 147 A DTensor on the DTensor device with the given layout. 148 """ 149 return _dtensor_device().copy_to_mesh(tensor, layout, source_layout) 150 151 152@tf_export("experimental.dtensor.pack", v1=[]) 153def pack(tensors: Sequence[Any], layout: layout_lib.Layout) -> Any: 154 """Packs `tf.Tensor` components into a DTensor. 155 156 Packing and unpacking are inverse operations: 157 158 ``` 159 * unpack(pack(tensors)) == tensors 160 * pack(unpack(dtensor)) == dtensor 161 ``` 162 163 1. For any DTensor on the mesh, `unpack` returns the raw components placed on 164 each underlying device. 165 2. Packing these raw components in the same order using `pack` returns a 166 DTensor which should be identical to the original DTensor--both the content 167 value and the layout. 168 169 **Shape, Rank, and Scalars**: The rank of the DTensor is the same as the 170 rank of its raw components, i.e., rank is preserved. This leads to a 171 consistent interpretation for packing scalar values into a DTensor. The only 172 valid layout for a scalar value is fully replicated, and the individual 173 components must be identical scalars. 174 175 Each input `tensors[i]` will be copied to `layout.mesh.local_device[i]` 176 if not already on the local device. Non-local components should not be passed 177 to `pack`; use `copy_to_mesh` and `relayout` to place tensors on all global 178 devices on a mesh. 179 180 It is the caller's responsibility to ensure that the underlying values 181 for `pack` adhere to the specified layout, and that only as many values are 182 specified as there are local devices. Pack does not move data between clients. 183 See examples below for more detail about layouts. 184 185 For example, assume we have a mesh `[X(2), Y(3)]`, which has in total 6 186 underlying devices. Futuremore, assume that the device location mapping is 187 the following: 188 189 ``` 190 device_ID | location X, Y 191 0 0, 0 192 1 0, 1 193 2 0, 2 194 3 1, 0 195 4 1, 1 196 5 1, 2 197 ``` 198 199 1. For 1-D vector DTensor with shape `[128]` with layout `[mesh.X]` and value 200 as `range(128)`, the raw components will have shape `[64]` each, and the 201 raw components will be: 202 203 ``` 204 device_ID | raw component 205 0 range(0, 64) 206 1 range(0, 64) 207 2 range(0, 64) 208 3 range(64, 128) 209 4 range(64, 128) 210 5 range(64, 128) 211 ``` 212 213 This also means for a 1-D DTensor with shape `[2]` and layout `[mesh.X]`, 214 the raw components have shape `[1]` rather than the shape for scalar values 215 `[]`. 216 217 2. For 2-D vector DTensor with shape `[2, 3]` with layout `[mesh.X, mesh.Y]` 218 and value as `range(6)`, this is basically a fully-sharded DTensor. 219 220 From global view, the content looks like 221 ``` 222 [ 223 [0.0, 1.0, 2.0], 224 [3.0, 4.0, 5.0], 225 ] 226 ``` 227 228 The raw components will have shape `[1, 1]` each, and have the following 229 content: 230 231 ``` 232 device_ID | raw component 233 0 [[0.0]] 234 1 [[1.0]] 235 2 [[2.0]] 236 3 [[3.0]] 237 4 [[4.0]] 238 5 [[5.0]] 239 ``` 240 241 3. For a scalar value `123.0` DTensor, it can only have one legitimate layout 242 `[]` (no dimension, but fully replicated). 243 244 The raw components will have shape `[]` each, and have the following 245 content: 246 247 ``` 248 device_ID | raw component 249 0 123.0 250 1 123.0 251 2 123.0 252 3 123.0 253 4 123.0 254 5 123.0 255 ``` 256 257 Again, caller of `pack` is expected to provide 6 identical value raw 258 components with scalar shapes. 259 260 4. For 3-D vector DTensor with shape `[2, 2, 3]` with layout 261 `[X, unsharded, unsharded]` and value as `range(12)`, 262 263 From global view, the content looks like: 264 ``` 265 [ 266 [ 267 [0.0, 1.0, 2.0], 268 [3.0, 4.0, 5.0], 269 ], 270 [ 271 [6.0, 7.0, 8.0], 272 [9.0, 10., 11.], 273 ], 274 ] 275 ``` 276 277 The raw components will have shape `[1, 2, 3]` each, and have the following 278 content: 279 280 ``` 281 device_ID | raw component 282 0 range(6).reshape([1, 2, 3]) 283 1 range(6).reshape([1, 2, 3]) 284 2 range(6).reshape([1, 2, 3]) 285 3 range(6, 12).reshape([1, 2, 3]) 286 4 range(6, 12).reshape([1, 2, 3]) 287 5 range(6, 12).reshape([1, 2, 3]) 288 ``` 289 290 Args: 291 tensors: The list of local tensor components to pack into a DTensor. 292 layout: The layout of the DTensor to be created. 293 294 Returns: 295 A DTensor created from the individual component tensors. 296 297 Raises: 298 RuntimeError: When `pack` is not called eagerly. 299 """ 300 return _dtensor_device().pack(tensors, layout) 301 302 303@tf_export("experimental.dtensor.unpack", v1=[]) 304def unpack(tensor: Any) -> Sequence[Any]: 305 """Unpacks a DTensor into `tf.Tensor` components. 306 307 Packing and unpacking are inverse operations: 308 309 ``` 310 * unpack(pack(tensors)) == tensors 311 * pack(unpack(dtensor)) == dtensor 312 ``` 313 314 1. For any DTensor on the mesh, `unpack` returns the raw components placed on 315 each underlying device. 316 2. Packing these raw components in the same order using `pack` returns a 317 DTensor which should be identical to the original DTensor--both the content 318 value and the layout. 319 320 See the documentation for `pack` for more information about how packing and 321 unpacking works. 322 323 Args: 324 tensor: The DTensor to unpack. 325 326 Returns: 327 The individual component tensors of the DTensor. This will include only the 328 client-local components, i.e. the components placed on the local devices. 329 330 Raises: 331 RuntimeError: When `unpack` is not called eagerly. 332 """ 333 return _dtensor_device().unpack(tensor) 334 335 336# ----------------------------------------------------------------------------- 337# Layout-related methods. 338 339 340@tf_export("experimental.dtensor.fetch_layout", v1=[]) 341def fetch_layout(tensor: ops.Tensor) -> layout_lib.Layout: 342 """Fetches the layout of a DTensor. 343 344 Args: 345 tensor: The DTensor whose layout is to be fetched. 346 347 Returns: 348 The `Layout` of this DTensor. 349 350 Raises: 351 RuntimeError: When not called eagerly. 352 """ 353 return _dtensor_device().fetch_layout(tensor) 354 355 356@tf_export("experimental.dtensor.check_layout", v1=[]) 357def check_layout(tensor: ops.Tensor, layout: layout_lib.Layout) -> None: 358 """Asserts that the layout of the DTensor is `layout`. 359 360 Args: 361 tensor: A DTensor whose layout is to be checked. 362 layout: The `Layout` to compare against. 363 364 Raises: 365 ValueError: If the layout of `tensor` does not match the supplied `layout`. 366 """ 367 if fetch_layout(tensor) != layout: 368 raise ValueError("Layout of tensor: " + str(fetch_layout(tensor)) + 369 ", did not match expected layout: " + str(layout)) 370 371 372@tf_export("experimental.dtensor.relayout", v1=[]) 373def relayout(tensor: ops.Tensor, layout: layout_lib.Layout) -> ops.Tensor: 374 """Changes the layout of `tensor`. 375 376 Changes the layout of `tensor` to `layout`. This is used to fine-tune the 377 behavior of ops following/connected to `tensor`, such as choosing one SPMD 378 expansion pattern over another. This works by forward propagating `layout` 379 to connected TensorFlow computation graphs during layout propagation. 380 381 Currently, only converting layouts from replicated to sharded or sharded to 382 replicated per mesh dimension is supported. That is, "x, y" -> "unsharded, y" 383 is supported, while "x, y" -> "z, y" is not supported. 384 385 We also support a special "match" sharding spec, which instructs the relayout 386 to act as an identity operation with respect to any sharding on these 387 mesh dimensions. 388 389 Relayout is internally lowered to a set of Split and/or AllToAll ops. When 390 tensor layouts are converted from replicated to sharded, the cost is 391 comparatively low because we only insert Split ops and no cross-device 392 communication is needed. However, when tensor layouts are converted from 393 sharded to replicated, cross-device communication may occur, causing potential 394 performance impact. 395 396 Args: 397 tensor: A DTensor to specify a new layout for. 398 layout: A Layout object specifying a new sharding spec. 399 400 Returns: 401 A DTensor output from the Relayout op. 402 """ 403 layout_str = layout.to_string() 404 return gen_dtensor_ops.relayout(tensor, layout_str) 405 406 407# ----------------------------------------------------------------------------- 408# Distributed training-related methods. 409# 410# Most users should use DTensor utility methods to create a mesh. The methods 411# here are only for advanced users who want to fully customize their meshes. 412# Note that local_devices and num_local_devices return the actual number of 413# locally attached devices. The others are set through environment variables. 414 415 416@tf_export("experimental.dtensor.client_id", v1=[]) 417def client_id() -> int: 418 """Returns this client's ID.""" 419 # If missing, assume running with a single client with client_id of 0. 420 client_id_value = int(os.environ.get(_DT_CLIENT_ID, "0")) 421 if client_id_value < 0: 422 raise ValueError(f"Environment variable {_DT_CLIENT_ID} " 423 f"must be >= 0, got {client_id_value}. ") 424 if client_id_value >= num_clients(): 425 raise ValueError(f"Environment variable {_DT_CLIENT_ID} " 426 f"must be < {num_clients()}, got {client_id_value}") 427 return client_id_value 428 429 430@tf_export("experimental.dtensor.num_clients", v1=[]) 431def num_clients() -> int: 432 """Returns the number of clients in this DTensor cluster.""" 433 # If missing, assume running with a single client with num_clients of 1. 434 num_clients_value = int(os.environ.get(_DT_NUM_CLIENTS, "1")) 435 if num_clients_value <= 0: 436 raise ValueError(f"Environment variable {_DT_NUM_CLIENTS} " 437 f"must be > 0, got {num_clients_value}.") 438 439 return num_clients_value 440 441 442@tf_export("experimental.dtensor.local_devices", v1=[]) 443def local_devices( 444 device_type: str, 445 for_client_id: Optional[int] = None) -> List[tf_device.DeviceSpec]: 446 """Returns a list of device specs of device_type attached to this client.""" 447 if device_type.upper() not in ["CPU", "GPU", "TPU"]: 448 raise ValueError(f"Device type {device_type} is not CPU, GPU, or TPU.") 449 if for_client_id is None: 450 for_client_id = client_id() 451 452 logical_devices = [ 453 tf_device.DeviceSpec.from_string(d.name) 454 for d in tf_config.list_logical_devices(device_type) 455 ] 456 457 # Get the number of local devices. 458 device_count = 0 459 for d in logical_devices: 460 # d might have a partial name, e.g. /device:TPU:0. 461 if (d.job is None or d.job == job_name()) and (d.task is None or 462 d.task == for_client_id): 463 device_count = device_count + 1 464 465 # Return fully qualified device specs, sorted by increasing device index. 466 return [ 467 tf_device.DeviceSpec( # pylint: disable=g-complex-comprehension 468 job=job_name(), 469 replica=0, # replica is deprecated and mostly hard-coded now. 470 task=for_client_id, 471 device_type=device_type, 472 device_index=i) for i in range(device_count) 473 ] 474 475 476@tf_export("experimental.dtensor.num_local_devices", v1=[]) 477def num_local_devices(device_type: str) -> int: 478 """Returns the number of devices of device_type attached to this client.""" 479 return len(local_devices(device_type)) 480 481 482@tf_export("experimental.dtensor.num_global_devices", v1=[]) 483def num_global_devices(device_type: str) -> int: 484 """Returns the number of devices of device_type in this DTensor cluster.""" 485 return num_local_devices(device_type) * num_clients() 486 487 488@tf_export("experimental.dtensor.job_name", v1=[]) 489def job_name() -> str: 490 """Returns the job name used by all clients in this DTensor cluster.""" 491 # If missing, assumes the program runs locally and use localhost as job name 492 # per TensorFlow convention. 493 return os.environ.get(_DT_JOB_NAME, 494 "localhost" if num_clients() == 1 else "worker") 495 496 497@tf_export("experimental.dtensor.full_job_name", v1=[]) 498def full_job_name(task_id: Optional[int] = None) -> str: 499 """Returns the fully qualified TF job name for this or another task.""" 500 # If task_id is None, use this client's ID, which is equal to its task ID. 501 if task_id is None: 502 task_id = client_id() 503 # In local runs and unit tests, there should be exactly one client running 504 # on one TF task. 505 if num_clients() == 1 and task_id != 0: 506 raise ValueError(f"Unexpected task ID {task_id} in local runs") 507 return f"{job_name()}/replica:0/task:{task_id}" 508 509 510def _bns_task_id(job: str) -> Union[int, str]: 511 """Tries to extract an integer task ID from a job name. 512 513 For example, for `job` = '/.../tpu_worker/0:port_name', return 0. 514 515 Args: 516 job: A job name to extract task ID from. 517 518 Returns: 519 The task ID on success, or the original job name on failure. 520 """ 521 maybe_task_id = job.rsplit("/")[-1].rsplit(":")[0] 522 try: 523 return int(maybe_task_id) 524 except ValueError: 525 return job 526 527 528@tf_export("experimental.dtensor.jobs", v1=[]) 529def jobs() -> List[str]: 530 """Returns a list of job names of all clients in this DTensor cluster.""" 531 d_jobs = os.environ.get(_DT_JOBS) 532 if d_jobs is None: 533 return [] 534 d_jobs_list = d_jobs.split(",") 535 536 # Validate ordering for BNS style job names. 537 # For definition of BNS, refer to https://research.google/pubs/pub43438/. 538 if any([name.startswith("/bns/") for name in d_jobs_list]): 539 if d_jobs_list != sorted(d_jobs_list, key=_bns_task_id): 540 raise ValueError( 541 f"Unexpected DTENSOR_JOBS content {d_jobs}. Sort entries " 542 "in DTENSOR_JOBS because cluster construction relies on " 543 "the order.") 544 545 return d_jobs_list 546 547 548@tf_export("experimental.dtensor.heartbeat_enabled", v1=[]) 549def heartbeat_enabled() -> bool: 550 """Returns true if DTensor heartbeat service is enabled.""" 551 return os.environ.get(_DT_HEARTBEAT_ENABLED, "true").lower() in ("true", "1") 552 553 554# ----------------------------------------------------------------------------- 555# Private methods. 556 557 558def _set_dtensor_device(device: dtensor_device.DTensorDevice) -> None: 559 global _dtensor_singleton 560 _dtensor_singleton = device 561 562 563def _dtensor_device() -> dtensor_device.DTensorDevice: 564 with _dtensor_singleton_lock: 565 if _dtensor_singleton is None: 566 _set_dtensor_device(dtensor_device.DTensorDevice(meshes=[])) 567 return _dtensor_singleton 568 569 570def _reset() -> None: 571 global _dtensor_singleton 572 if _dtensor_singleton is not None: 573 _dtensor_singleton.clear_tpu_core_ids() 574 with _dtensor_singleton_lock: 575 _dtensor_singleton = None 576