• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/collectives_common.h"
17 
18 #include <string>
19 
20 namespace tensorflow {
21 namespace dtensor {
22 
23 // A map from a unique set of kept mesh dimension values (a partition) to
24 // IDs of devices in that partition.
25 //
26 // Users will typically ignore the key, but use the map values as the group
27 // assignment for collective operations. This is intentionally a
28 // std::map instead of absl::flat_hash_map to guarantee all hosts in
29 // a multi-host cluster will generate the same grouping, and therefore the same
30 // XLA program fingerprint, independently. std::map guarantees the same
31 // iteration order.
32 using AllReducePartitions = std::map<DeviceLocation, std::vector<int32>>;
33 
34 // Computes AllReduce partitions using reduced mesh dimension names.
35 //
36 // Reduction groups are formed across all _non_-reduced dimensions. For example,
37 // in the following scenario:
38 //
39 // output_layout.dims() = [a, b]
40 // output_layout.mesh() = [(x, 8), (y, 4)]
41 // reduced_dims = `x`
42 //
43 // We first reduce over `a` locally on each device, producing 32 local
44 // reductions. We then AllReduce within each of the 4 partitions. Each partition
45 // corresponds to one unique value of `y` and has 8 devices. The end result is
46 // sharded over the y mesh dimension and replicated 8 times.
47 //
48 // The returned map should have four entries with key values from [0] to [3]
49 // (unique values of `y`). Each key maps to IDs of devices with that `y` value.
GetAllReducePartitionsFromReducedDims(const dtensor::Layout & output_layout,const absl::flat_hash_set<std::string> & reduced_dims)50 StatusOr<AllReducePartitions> GetAllReducePartitionsFromReducedDims(
51     const dtensor::Layout& output_layout,
52     const absl::flat_hash_set<std::string>& reduced_dims) {
53   AllReducePartitions partitions;
54   for (int64 device = 0; device < output_layout.num_devices(); ++device) {
55     TF_ASSIGN_OR_RETURN(const DeviceLocation device_loc,
56                         output_layout.device_location(device));
57     DeviceLocation kept_dims;
58     for (int64 dim_idx = 0; dim_idx < device_loc.size(); ++dim_idx) {
59       if (!reduced_dims.contains(output_layout.mesh().dim_name(dim_idx))) {
60         kept_dims.push_back(device_loc[dim_idx]);
61       }
62     }
63     partitions[kept_dims].push_back(device);
64   }
65   return partitions;
66 }
67 
68 // Use the first device in the mesh to extract the device name. For example:
69 //
70 // device_path = "/job:localhost/replica:0/task:0/device:TPU:0"
71 // device_type = "/job:localhost/replica:0/task:0/device:TPU"
72 // device_id = 0
73 //
74 // The device ID can be obtained through DeviceId as a runtime input. We may
75 // need it in the future to enable device ID-based branch divergence.
DeviceTypeFromMesh(const Mesh & mesh)76 StatusOr<std::string> DeviceTypeFromMesh(const Mesh& mesh) {
77   std::string device_path =
78       mesh.is_remote() ? mesh.global_devices()[0] : mesh.local_devices()[0];
79   size_t device_path_pos = device_path.find_last_of(':');
80   if (device_path_pos == std::string::npos) {
81     return errors::InvalidArgument("Unexpected device path: ", device_path);
82   }
83   return device_path.substr(0, device_path_pos);
84 }
85 
86 }  // namespace dtensor
87 }  // namespace tensorflow
88