• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Core DTensor Python API."""
16
17import contextlib
18import os
19import threading
20from typing import Any, Callable, List, Optional, Sequence, Union
21
22from tensorflow.dtensor.python import dtensor_device
23from tensorflow.dtensor.python import gen_dtensor_ops
24from tensorflow.dtensor.python import layout as layout_lib
25from tensorflow.python.eager import context
26from tensorflow.python.framework import config as tf_config
27from tensorflow.python.framework import device as tf_device
28from tensorflow.python.framework import ops
29from tensorflow.python.util.tf_export import tf_export
30
31_DT_CLIENT_ID = "DTENSOR_CLIENT_ID"
32_DT_NUM_CLIENTS = "DTENSOR_NUM_CLIENTS"
33_DT_JOB_NAME = "DTENSOR_JOB_NAME"
34_DT_JOBS = "DTENSOR_JOBS"
35_DT_HEARTBEAT_ENABLED = "DTENSOR_ENABLE_HEARTBEAT"
36
37_dtensor_singleton = None
38_dtensor_singleton_lock = threading.Lock()
39
40# -----------------------------------------------------------------------------
41# Main methods to launch DTensor computations.
42
43
44@tf_export("experimental.dtensor.call_with_layout", v1=[])
45def call_with_layout(fn: Callable[...,
46                                  Any], layout: Optional[layout_lib.Layout],
47                     *args, **kwargs) -> Any:
48  """Calls a function in the DTensor device scope if `layout` is not None.
49
50  If `layout` is not None, `fn` consumes DTensor(s) as input and produces a
51  DTensor as output; a DTensor is a tf.Tensor with layout-related attributes.
52
53  If `layout` is None, `fn` consumes and produces regular tf.Tensors.
54
55  Args:
56    fn: A supported TF API function such as tf.zeros.
57    layout: Optional, the layout of the output DTensor.
58    *args:  Arguments given to `fn`.
59    **kwargs: Keyword arguments given to `fn`.
60
61  Returns:
62    The return value of `fn` transformed to a DTensor if requested.
63  """
64  if layout is not None:
65    if not context.executing_eagerly():
66      # This is a workaround for b/199324097, where functions such as tf.ones
67      # could attach an incorrect layout to the tf.const generated under the
68      # hood. The op runs successfully in eager mode, but in graph mode, MLIR
69      # passes sometimes attach the default layout to a scalar constant.
70      # %cst = tf.Const([1])  -- With the given layout
71      # %0 = "tf.DTensorLayout"(%cst). -- Fails in MLIR pass since shape for
72      #                                -- layout could be different than
73      #                                -- shape[0] for %cst.
74      # %1 = tf.Fill(%0, 1)
75      result = fn(*args, **kwargs)
76      return relayout(result, layout)
77    else:
78      with run_on(layout.mesh):
79        with _dtensor_device()._default_layout(layout):  # pylint: disable=protected-access
80          return fn(*args, **kwargs)
81  return fn(*args, **kwargs)
82
83
84@tf_export("experimental.dtensor.run_on", v1=[])
85@contextlib.contextmanager
86def run_on(mesh: layout_lib.Mesh):
87  """Runs enclosed functions in the DTensor device scope.
88
89  This function returns a scope. All the ops and tf.functions in this scope will
90  run on the DTensor device using the mesh provided.
91  This is useful for wrapping any tf.function that doesn't take a DTensor as
92  input but would like to produce DTensor as result. The scope will also make
93  sure all small constants be replicated as DTensor.
94
95  Args:
96    mesh: A Mesh instance to extract a default mesh from.
97
98  Yields:
99    A context in which all ops and tf.functions will run on the DTensor device.
100  """
101  if not isinstance(mesh, layout_lib.Mesh):
102    raise ValueError(f"Expect `mesh` to be `Mesh`, got {type(mesh)}")
103
104  with _dtensor_device()._experimental_default_mesh(mesh):  # pylint: disable=protected-access
105    with ops.device(device_name()):
106      yield
107
108
109@tf_export("experimental.dtensor.device_name", v1=[])
110def device_name() -> str:
111  """Returns the singleton DTensor device's name.
112
113  This function can be used in the following way:
114
115  ```python
116  import tensorflow as tf
117
118  with tf.device(dtensor.device_name()):
119    # ...
120  ```
121  """
122  return _dtensor_device().name
123
124
125# -----------------------------------------------------------------------------
126# Data transfer methods.
127
128
129@tf_export("experimental.dtensor.copy_to_mesh", v1=[])
130def copy_to_mesh(
131    tensor: Any,
132    layout: layout_lib.Layout,
133    source_layout: Optional[layout_lib.Layout] = None) -> ops.Tensor:
134  """Copies a tf.Tensor onto the DTensor device with the given layout.
135
136  Copies a regular tf.Tensor onto the DTensor device. Use the mesh attached to
137  `layout` as target mesh. This method currently only supports replicated
138  layouts. To get a DTensor with a sharded layout, use the `pack` method.
139
140  Args:
141    tensor: A regular tf.Tensor to be copied as a DTensor.
142    layout: Target layout (and mesh) for the result DTensor.
143    source_layout: Source layout of the tensor before copy, used for backward
144      passes.
145
146  Returns:
147    A DTensor on the DTensor device with the given layout.
148  """
149  return _dtensor_device().copy_to_mesh(tensor, layout, source_layout)
150
151
152@tf_export("experimental.dtensor.pack", v1=[])
153def pack(tensors: Sequence[Any], layout: layout_lib.Layout) -> Any:
154  """Packs `tf.Tensor` components into a DTensor.
155
156  Packing and unpacking are inverse operations:
157
158  ```
159  * unpack(pack(tensors)) == tensors
160  * pack(unpack(dtensor)) == dtensor
161  ```
162
163  1. For any DTensor on the mesh, `unpack` returns the raw components placed on
164     each underlying device.
165  2. Packing these raw components in the same order using `pack` returns a
166     DTensor which should be identical to the original DTensor--both the content
167     value and the layout.
168
169  **Shape, Rank, and Scalars**: The rank of the DTensor is the same as the
170  rank of its raw components, i.e., rank is preserved.  This leads to a
171  consistent interpretation for packing scalar values into a DTensor. The only
172  valid layout for a scalar value is fully replicated, and the individual
173  components must be identical scalars.
174
175  Each input `tensors[i]` will be copied to `layout.mesh.local_device[i]`
176  if not already on the local device. Non-local components should not be passed
177  to `pack`; use `copy_to_mesh` and `relayout` to place tensors on all global
178  devices on a mesh.
179
180  It is the caller's responsibility to ensure that the underlying values
181  for `pack` adhere to the specified layout, and that only as many values are
182  specified as there are local devices. Pack does not move data between clients.
183  See examples below for more detail about layouts.
184
185  For example, assume we have a mesh `[X(2), Y(3)]`, which has in total 6
186  underlying devices. Futuremore, assume that the device location mapping is
187  the following:
188
189  ```
190  device_ID  |  location X, Y
191          0     0, 0
192          1     0, 1
193          2     0, 2
194          3     1, 0
195          4     1, 1
196          5     1, 2
197  ```
198
199  1. For 1-D vector DTensor with shape `[128]` with layout `[mesh.X]` and value
200     as `range(128)`, the raw components will have shape `[64]` each, and the
201     raw components will be:
202
203     ```
204     device_ID  |  raw component
205             0     range(0, 64)
206             1     range(0, 64)
207             2     range(0, 64)
208             3     range(64, 128)
209             4     range(64, 128)
210             5     range(64, 128)
211     ```
212
213     This also means for a 1-D DTensor with shape `[2]` and layout `[mesh.X]`,
214     the raw components have shape `[1]` rather than the shape for scalar values
215     `[]`.
216
217  2. For 2-D vector DTensor with shape `[2, 3]` with layout `[mesh.X, mesh.Y]`
218     and value as `range(6)`, this is basically a fully-sharded DTensor.
219
220     From global view, the content looks like
221     ```
222     [
223       [0.0, 1.0, 2.0],
224       [3.0, 4.0, 5.0],
225     ]
226     ```
227
228     The raw components will have shape `[1, 1]` each, and have the following
229     content:
230
231     ```
232     device_ID  |  raw component
233             0     [[0.0]]
234             1     [[1.0]]
235             2     [[2.0]]
236             3     [[3.0]]
237             4     [[4.0]]
238             5     [[5.0]]
239     ```
240
241  3. For a scalar value `123.0` DTensor, it can only have one legitimate layout
242     `[]` (no dimension, but fully replicated).
243
244     The raw components will have shape `[]` each, and have the following
245     content:
246
247     ```
248     device_ID  |  raw component
249             0     123.0
250             1     123.0
251             2     123.0
252             3     123.0
253             4     123.0
254             5     123.0
255     ```
256
257     Again, caller of `pack` is expected to provide 6 identical value raw
258     components with scalar shapes.
259
260  4. For 3-D vector DTensor with shape `[2, 2, 3]` with layout
261     `[X, unsharded, unsharded]` and value as `range(12)`,
262
263     From global view, the content looks like:
264     ```
265     [
266       [
267         [0.0, 1.0, 2.0],
268         [3.0, 4.0, 5.0],
269       ],
270       [
271         [6.0, 7.0, 8.0],
272         [9.0, 10., 11.],
273       ],
274     ]
275     ```
276
277     The raw components will have shape `[1, 2, 3]` each, and have the following
278     content:
279
280     ```
281     device_ID  |  raw component
282             0     range(6).reshape([1, 2, 3])
283             1     range(6).reshape([1, 2, 3])
284             2     range(6).reshape([1, 2, 3])
285             3     range(6, 12).reshape([1, 2, 3])
286             4     range(6, 12).reshape([1, 2, 3])
287             5     range(6, 12).reshape([1, 2, 3])
288     ```
289
290  Args:
291    tensors: The list of local tensor components to pack into a DTensor.
292    layout: The layout of the DTensor to be created.
293
294  Returns:
295    A DTensor created from the individual component tensors.
296
297  Raises:
298    RuntimeError: When `pack` is not called eagerly.
299  """
300  return _dtensor_device().pack(tensors, layout)
301
302
303@tf_export("experimental.dtensor.unpack", v1=[])
304def unpack(tensor: Any) -> Sequence[Any]:
305  """Unpacks a DTensor into `tf.Tensor` components.
306
307  Packing and unpacking are inverse operations:
308
309  ```
310  * unpack(pack(tensors)) == tensors
311  * pack(unpack(dtensor)) == dtensor
312  ```
313
314  1. For any DTensor on the mesh, `unpack` returns the raw components placed on
315     each underlying device.
316  2. Packing these raw components in the same order using `pack` returns a
317     DTensor which should be identical to the original DTensor--both the content
318     value and the layout.
319
320  See the documentation for `pack` for more information about how packing and
321  unpacking works.
322
323  Args:
324    tensor: The DTensor to unpack.
325
326  Returns:
327    The individual component tensors of the DTensor. This will include only the
328    client-local components, i.e. the components placed on the local devices.
329
330  Raises:
331    RuntimeError: When `unpack` is not called eagerly.
332  """
333  return _dtensor_device().unpack(tensor)
334
335
336# -----------------------------------------------------------------------------
337# Layout-related methods.
338
339
340@tf_export("experimental.dtensor.fetch_layout", v1=[])
341def fetch_layout(tensor: ops.Tensor) -> layout_lib.Layout:
342  """Fetches the layout of a DTensor.
343
344  Args:
345    tensor: The DTensor whose layout is to be fetched.
346
347  Returns:
348    The `Layout` of this DTensor.
349
350  Raises:
351    RuntimeError: When not called eagerly.
352  """
353  return _dtensor_device().fetch_layout(tensor)
354
355
356@tf_export("experimental.dtensor.check_layout", v1=[])
357def check_layout(tensor: ops.Tensor, layout: layout_lib.Layout) -> None:
358  """Asserts that the layout of the DTensor is `layout`.
359
360  Args:
361    tensor: A DTensor whose layout is to be checked.
362    layout: The `Layout` to compare against.
363
364  Raises:
365    ValueError: If the layout of `tensor` does not match the supplied `layout`.
366  """
367  if fetch_layout(tensor) != layout:
368    raise ValueError("Layout of tensor: " + str(fetch_layout(tensor)) +
369                     ", did not match expected layout: " + str(layout))
370
371
372@tf_export("experimental.dtensor.relayout", v1=[])
373def relayout(tensor: ops.Tensor, layout: layout_lib.Layout) -> ops.Tensor:
374  """Changes the layout of `tensor`.
375
376  Changes the layout of `tensor` to `layout`. This is used to fine-tune the
377  behavior of ops following/connected to `tensor`, such as choosing one SPMD
378  expansion pattern over another. This works by forward propagating `layout`
379  to connected TensorFlow computation graphs during layout propagation.
380
381  Currently, only converting layouts from replicated to sharded or sharded to
382  replicated per mesh dimension is supported. That is, "x, y" -> "unsharded, y"
383  is supported, while "x, y" -> "z, y" is not supported.
384
385  We also support a special "match" sharding spec, which instructs the relayout
386  to act as an identity operation with respect to any sharding on these
387  mesh dimensions.
388
389  Relayout is internally lowered to a set of Split and/or AllToAll ops. When
390  tensor layouts are converted from replicated to sharded, the cost is
391  comparatively low because we only insert Split ops and no cross-device
392  communication is needed. However, when tensor layouts are converted from
393  sharded to replicated, cross-device communication may occur, causing potential
394  performance impact.
395
396  Args:
397    tensor: A DTensor to specify a new layout for.
398    layout: A Layout object specifying a new sharding spec.
399
400  Returns:
401    A DTensor output from the Relayout op.
402  """
403  layout_str = layout.to_string()
404  return gen_dtensor_ops.relayout(tensor, layout_str)
405
406
407# -----------------------------------------------------------------------------
408# Distributed training-related methods.
409#
410# Most users should use DTensor utility methods to create a mesh. The methods
411# here are only for advanced users who want to fully customize their meshes.
412# Note that local_devices and num_local_devices return the actual number of
413# locally attached devices. The others are set through environment variables.
414
415
416@tf_export("experimental.dtensor.client_id", v1=[])
417def client_id() -> int:
418  """Returns this client's ID."""
419  # If missing, assume running with a single client with client_id of 0.
420  client_id_value = int(os.environ.get(_DT_CLIENT_ID, "0"))
421  if client_id_value < 0:
422    raise ValueError(f"Environment variable {_DT_CLIENT_ID} "
423                     f"must be >= 0, got {client_id_value}. ")
424  if client_id_value >= num_clients():
425    raise ValueError(f"Environment variable {_DT_CLIENT_ID} "
426                     f"must be < {num_clients()}, got {client_id_value}")
427  return client_id_value
428
429
430@tf_export("experimental.dtensor.num_clients", v1=[])
431def num_clients() -> int:
432  """Returns the number of clients in this DTensor cluster."""
433  # If missing, assume running with a single client with num_clients of 1.
434  num_clients_value = int(os.environ.get(_DT_NUM_CLIENTS, "1"))
435  if num_clients_value <= 0:
436    raise ValueError(f"Environment variable {_DT_NUM_CLIENTS} "
437                     f"must be > 0, got {num_clients_value}.")
438
439  return num_clients_value
440
441
442@tf_export("experimental.dtensor.local_devices", v1=[])
443def local_devices(
444    device_type: str,
445    for_client_id: Optional[int] = None) -> List[tf_device.DeviceSpec]:
446  """Returns a list of device specs of device_type attached to this client."""
447  if device_type.upper() not in ["CPU", "GPU", "TPU"]:
448    raise ValueError(f"Device type {device_type} is not CPU, GPU, or TPU.")
449  if for_client_id is None:
450    for_client_id = client_id()
451
452  logical_devices = [
453      tf_device.DeviceSpec.from_string(d.name)
454      for d in tf_config.list_logical_devices(device_type)
455  ]
456
457  # Get the number of local devices.
458  device_count = 0
459  for d in logical_devices:
460    # d might have a partial name, e.g. /device:TPU:0.
461    if (d.job is None or d.job == job_name()) and (d.task is None or
462                                                   d.task == for_client_id):
463      device_count = device_count + 1
464
465  # Return fully qualified device specs, sorted by increasing device index.
466  return [
467      tf_device.DeviceSpec(  # pylint: disable=g-complex-comprehension
468          job=job_name(),
469          replica=0,  # replica is deprecated and mostly hard-coded now.
470          task=for_client_id,
471          device_type=device_type,
472          device_index=i) for i in range(device_count)
473  ]
474
475
476@tf_export("experimental.dtensor.num_local_devices", v1=[])
477def num_local_devices(device_type: str) -> int:
478  """Returns the number of devices of device_type attached to this client."""
479  return len(local_devices(device_type))
480
481
482@tf_export("experimental.dtensor.num_global_devices", v1=[])
483def num_global_devices(device_type: str) -> int:
484  """Returns the number of devices of device_type in this DTensor cluster."""
485  return num_local_devices(device_type) * num_clients()
486
487
488@tf_export("experimental.dtensor.job_name", v1=[])
489def job_name() -> str:
490  """Returns the job name used by all clients in this DTensor cluster."""
491  # If missing, assumes the program runs locally and use localhost as job name
492  # per TensorFlow convention.
493  return os.environ.get(_DT_JOB_NAME,
494                        "localhost" if num_clients() == 1 else "worker")
495
496
497@tf_export("experimental.dtensor.full_job_name", v1=[])
498def full_job_name(task_id: Optional[int] = None) -> str:
499  """Returns the fully qualified TF job name for this or another task."""
500  # If task_id is None, use this client's ID, which is equal to its task ID.
501  if task_id is None:
502    task_id = client_id()
503  # In local runs and unit tests, there should be exactly one client running
504  # on one TF task.
505  if num_clients() == 1 and task_id != 0:
506    raise ValueError(f"Unexpected task ID {task_id} in local runs")
507  return f"{job_name()}/replica:0/task:{task_id}"
508
509
510def _bns_task_id(job: str) -> Union[int, str]:
511  """Tries to extract an integer task ID from a job name.
512
513  For example, for `job` = '/.../tpu_worker/0:port_name', return 0.
514
515  Args:
516    job: A job name to extract task ID from.
517
518  Returns:
519    The task ID on success, or the original job name on failure.
520  """
521  maybe_task_id = job.rsplit("/")[-1].rsplit(":")[0]
522  try:
523    return int(maybe_task_id)
524  except ValueError:
525    return job
526
527
528@tf_export("experimental.dtensor.jobs", v1=[])
529def jobs() -> List[str]:
530  """Returns a list of job names of all clients in this DTensor cluster."""
531  d_jobs = os.environ.get(_DT_JOBS)
532  if d_jobs is None:
533    return []
534  d_jobs_list = d_jobs.split(",")
535
536  # Validate ordering for BNS style job names.
537  # For definition of BNS, refer to https://research.google/pubs/pub43438/.
538  if any([name.startswith("/bns/") for name in d_jobs_list]):
539    if d_jobs_list != sorted(d_jobs_list, key=_bns_task_id):
540      raise ValueError(
541          f"Unexpected DTENSOR_JOBS content {d_jobs}. Sort entries "
542          "in DTENSOR_JOBS because cluster construction relies on "
543          "the order.")
544
545  return d_jobs_list
546
547
548@tf_export("experimental.dtensor.heartbeat_enabled", v1=[])
549def heartbeat_enabled() -> bool:
550  """Returns true if DTensor heartbeat service is enabled."""
551  return os.environ.get(_DT_HEARTBEAT_ENABLED, "true").lower() in ("true", "1")
552
553
554# -----------------------------------------------------------------------------
555# Private methods.
556
557
558def _set_dtensor_device(device: dtensor_device.DTensorDevice) -> None:
559  global _dtensor_singleton
560  _dtensor_singleton = device
561
562
563def _dtensor_device() -> dtensor_device.DTensorDevice:
564  with _dtensor_singleton_lock:
565    if _dtensor_singleton is None:
566      _set_dtensor_device(dtensor_device.DTensorDevice(meshes=[]))
567  return _dtensor_singleton
568
569
570def _reset() -> None:
571  global _dtensor_singleton
572  if _dtensor_singleton is not None:
573    _dtensor_singleton.clear_tpu_core_ids()
574  with _dtensor_singleton_lock:
575    _dtensor_singleton = None
576