• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_
18 
19 #include <string>
20 
21 #include "absl/strings/string_view.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "mlir/IR/Attributes.h"  // from @llvm-project
27 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/util/device_name_utils.h"
33 #include "tensorflow/stream_executor/lib/statusor.h"
34 
35 namespace tensorflow {
36 using stream_executor::port::StatusOr;
37 
38 inline constexpr absl::string_view kTPUReplicatedHost = "TPU_REPLICATED_HOST";
39 inline constexpr absl::string_view kNumCoresPerReplicaAttr =
40     "num_cores_per_replica";
41 inline constexpr absl::string_view kTopologyAttr = "topology";
42 inline constexpr absl::string_view kDeviceAssignmentAttr = "device_assignment";
43 
44 // A TPU device for execution alongside its associated host CPU device.
45 struct TPUDeviceAndHost {
TPUDeviceAndHostTPUDeviceAndHost46   TPUDeviceAndHost() {}
TPUDeviceAndHostTPUDeviceAndHost47   TPUDeviceAndHost(llvm::StringRef device, llvm::StringRef host)
48       : device(device), host(host) {}
49 
50   std::string device;
51   std::string host;
52 };
53 
54 // TPU devices to be used for execution (e.g. devices for TPUExecute ops) and
55 // their associated host CPU devices (for outside compilation). They are ordered
56 // by `num_replicas` followed by `num_cores_per_replica`.
57 using TPUDevicesAndHosts =
58     llvm::SmallVector<llvm::SmallVector<TPUDeviceAndHost, 8>, 8>;
59 
60 // TPU compilation device, execution and associated host devices, and optionally
61 // execution device IDs. Execution device IDs are populated if `topology` and
62 // `device_assignment` are provided.
63 struct TPUDeviceAssignment {
TPUDeviceAssignmentTPUDeviceAssignment64   TPUDeviceAssignment(llvm::StringRef compilation_device,
65                       TPUDevicesAndHosts&& tpu_devices)
66       : compilation_device(compilation_device),
67         tpu_devices(std::move(tpu_devices)) {}
68 
TPUDeviceAssignmentTPUDeviceAssignment69   TPUDeviceAssignment(llvm::StringRef compilation_device,
70                       TPUDevicesAndHosts&& tpu_devices,
71                       xla::DeviceAssignmentProto&& xla_device_assignment)
72       : compilation_device(compilation_device),
73         tpu_devices(std::move(tpu_devices)),
74         xla_device_assignment(std::move(xla_device_assignment)) {}
75 
76   std::string compilation_device;
77   TPUDevicesAndHosts tpu_devices;
78   llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment;
79 };
80 
81 // Extracts device coordinates from a device assignment attribute on an op.
82 StatusOr<llvm::SmallVector<int64_t, 8>> GetDeviceCoordinates(
83     mlir::ArrayAttr device_assignment_attr);
84 
85 // Finds the TPU compilation device and execution devices from `devices` for a
86 // TPU computation subgraph. Compilation device is determined from looking up
87 // all TPU_SYSTEM:0 devices and choosing the CPU device associated to the first
88 // TPU_SYSTEM device sorted lexicographically by replica and task. Execution
89 // devices are determined by looking up all TPU devices associated with each
90 // TPU_SYSTEM:0 device found, alongside associated `topology_attr` and
91 // `device_assignment_attr`. If `topology_attr` not an empty string (parsable to
92 // TopologyProto), `device_assignment_attr` must not be empty also. When
93 // `topology_attr` and `device_assignment_attr` are not empty, a general device
94 // assignment based on those two attributes are used. Otherwise when
95 // `topology_attr` and `device_assignment_attr` are empty, a full mesh device
96 // assignment is used instead. A failure will be returned if it is not possible
97 // (e.g. invalid devices or invalid parameters).
98 //
99 //
100 // For example, for `devices`:
101 //   {
102 //     /job:localhost/replica:0/task:0/device:CPU:0,
103 //     /job:worker/replica:0/task:0/device:CPU:0,
104 //     /job:worker/replica:0/task:0/device:TPU_SYSTEM:0,
105 //     /job:worker/replica:0/task:0/device:TPU:0,
106 //     /job:worker/replica:0/task:0/device:TPU:1,
107 //     /job:worker/replica:0/task:0/device:TPU:2,
108 //     /job:worker/replica:0/task:0/device:TPU:3,
109 //     /job:worker/replica:0/task:1/device:CPU:0,
110 //     /job:worker/replica:0/task:1/device:TPU_SYSTEM:0,
111 //     /job:worker/replica:0/task:1/device:TPU:0,
112 //     /job:worker/replica:0/task:1/device:TPU:1,
113 //     /job:worker/replica:0/task:1/device:TPU:2,
114 //     /job:worker/replica:0/task:1/device:TPU:3
115 //   }
116 //
117 //
118 // With the following parameters (full mesh device assignment):
119 //   `num_replicas` = 8
120 //   `num_cores_per_replica` = 1
121 //   `topology_attr` = ""
122 //   `device_assignment_attr` = {}
123 //
124 // The `compilation_device` will be:
125 //   /job:worker/replica:0/task:0/device:CPU:0
126 //
127 // `execution_devices` will be:
128 //   {
129 //     {
130 //       /job:worker/replica:0/task:0/device:TPU:0
131 //     },
132 //     {
133 //       /job:worker/replica:0/task:0/device:TPU:1
134 //     },
135 //     {
136 //       /job:worker/replica:0/task:0/device:TPU:2
137 //     },
138 //     {
139 //       /job:worker/replica:0/task:0/device:TPU:3
140 //     },
141 //     {
142 //       /job:worker/replica:0/task:1/device:TPU:0
143 //     },
144 //     {
145 //       /job:worker/replica:0/task:1/device:TPU:1
146 //     },
147 //     {
148 //       /job:worker/replica:0/task:1/device:TPU:2
149 //     },
150 //     {
151 //       /job:worker/replica:0/task:1/device:TPU:3
152 //     }
153 //   }
154 //
155 // and `xla_device_assignment` will not be set.
156 //
157 //
158 // With the following parameters (general device assignment):
159 //   `num_replicas` = 4
160 //   `num_cores_per_replica` = 2
161 //   `topology_attr` (in proto debug string format) =
162 //     {
163 //       mesh_shape: 2
164 //       mesh_shape: 2
165 //       mesh_shape: 2
166 //       num_tasks: 2
167 //       num_tpu_devices_per_task: 4
168 //       device_coordinates: 0
169 //       device_coordinates: 0
170 //       device_coordinates: 0
171 //       device_coordinates: 0
172 //       device_coordinates: 1
173 //       device_coordinates: 0
174 //       device_coordinates: 1
175 //       device_coordinates: 1
176 //       device_coordinates: 0
177 //       device_coordinates: 1
178 //       device_coordinates: 0
179 //       device_coordinates: 0
180 //       device_coordinates: 1
181 //       device_coordinates: 0
182 //       device_coordinates: 1
183 //       device_coordinates: 1
184 //       device_coordinates: 1
185 //       device_coordinates: 1
186 //       device_coordinates: 0
187 //       device_coordinates: 1
188 //       device_coordinates: 1
189 //       device_coordinates: 0
190 //       device_coordinates: 0
191 //       device_coordinates: 1
192 //     }
193 //   `device_assignment` =
194 //     {0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1}
195 //
196 // The `compilation_device` will be:
197 //   /job:worker/replica:0/task:0/device:CPU:0
198 //
199 // `execution_devices` will be:
200 //   {
201 //     {
202 //       "/job:worker/replica:0/task:0/device:TPU:0",
203 //       "/job:worker/replica:0/task:1/device:TPU:3"
204 //     },
205 //     {
206 //       "/job:worker/replica:0/task:0/device:TPU:1",
207 //       "/job:worker/replica:0/task:1/device:TPU:2"
208 //     },
209 //     {
210 //       "/job:worker/replica:0/task:0/device:TPU:3",
211 //       "/job:worker/replica:0/task:1/device:TPU:0"
212 //     },
213 //     {
214 //       "/job:worker/replica:0/task:0/device:TPU:2",
215 //       "/job:worker/replica:0/task:1/device:TPU:1"
216 //     }
217 //   }
218 //
219 // and `xla_device_assignment` will be:
220 //   {
221 //     replica_count: 4
222 //     computation_count: 2
223 //     computation_devices {
224 //       replica_device_ids: 0
225 //       replica_device_ids: 4
226 //       replica_device_ids: 2
227 //       replica_device_ids: 6
228 //     }
229 //     computation_devices {
230 //       replica_device_ids: 1
231 //       replica_device_ids: 5
232 //       replica_device_ids: 3
233 //       replica_device_ids: 7
234 //     }
235 //   }
236 StatusOr<TPUDeviceAssignment> GetTPUCompilationAndExecutionDevices(
237     llvm::ArrayRef<DeviceNameUtils::ParsedName> devices, int num_replicas,
238     int num_cores_per_replica, llvm::StringRef topology_attr,
239     llvm::ArrayRef<int64_t> device_assignment_attr);
240 
241 // Virtual device is used for evice assignment for executing ops on a specified
242 // logical core.
243 std::string GetDeviceAliasForLogicalCore(int core_index);
244 
245 // Returns true if cluster contains model parallelism based on
246 // `num_cores_per_replica_attribute`. Otherwise returns false.
247 bool HasModelParallelism(mlir::tf_device::ClusterOp cluster);
248 
249 // Parses TPU compilation and execution devices from a TPU cluster and returns
250 // the host device for the head and tail computations. If the TPU computation is
251 // replicated, kTPUReplicatedHost is returned instead.
252 mlir::LogicalResult GetHostDeviceOutsideComputation(
253     mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster,
254     std::string* host_device);
255 
256 // Checks if a device string is a TPU device.
257 bool IsTPUDevice(llvm::StringRef device);
258 
259 // Checks if a device string is a TPU replicated core device.
260 bool IsTPUReplicatedCore(llvm::StringRef device);
261 
262 }  // namespace tensorflow
263 
264 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_
265