1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <iterator>
21 #include <string>
22 #include <type_traits>
23 #include <utility>
24
25 #include "absl/strings/string_view.h"
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/iterator_range.h"
30 #include "llvm/Support/FormatVariadic.h"
31 #include "mlir/IR/Attributes.h" // from @llvm-project
32 #include "tensorflow/compiler/mlir/utils/string_container_utils.h"
33 #include "tensorflow/compiler/xla/array4d.h"
34 #include "tensorflow/compiler/xla/service/computation_placer.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
39 #include "tensorflow/core/util/device_name_utils.h"
40 #include "tensorflow/stream_executor/lib/statusor.h"
41
42 namespace tensorflow {
43
44 const char* const kTPUReplicatedHost = "TPU_REPLICATED_HOST";
45 const char* const kNumCoresPerReplicaAttr = "num_cores_per_replica";
46 const char* const kTopologyAttr = "topology";
47 const char* const kDeviceAssignmentAttr = "device_assignment";
48
49 // Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4
50 // topology.
51 constexpr int kTPUTopologyRank = 4;
52
53 constexpr char kDeviceTPUSystem[] = "TPU_SYSTEM";
54 constexpr char kDeviceTPU[] = "TPU";
55 constexpr char kTPUReplicatedCore[] = "TPU_REPLICATED_CORE";
56 constexpr char kBadIntArrayElementMsg[] =
57 "bad '{0}' attribute at index {1}, not an int";
58
59 using Device = DeviceNameUtils::ParsedName;
60 using Devices = llvm::ArrayRef<DeviceNameUtils::ParsedName>;
61
62 namespace {
63 // Finds matching devices in `devices` based on pattern `spec`.
FindMatchingDevices(Devices devices,const Device & spec,llvm::SmallVectorImpl<Device> * matched_devices)64 void FindMatchingDevices(Devices devices, const Device& spec,
65 llvm::SmallVectorImpl<Device>* matched_devices) {
66 for (const auto& device : devices)
67 if (DeviceNameUtils::IsCompleteSpecification(spec, device))
68 matched_devices->push_back(device);
69 }
70
71 // Creates error message for a conflicting attribute of a device.
72 template <typename T>
MismatchedTPUSystemAttributeErr(absl::string_view attribute,T a,T b)73 Status MismatchedTPUSystemAttributeErr(absl::string_view attribute, T a, T b) {
74 return errors::InvalidArgument("found ", kDeviceTPUSystem,
75 " devices with conflicting ", attribute, "s '",
76 a, "' and '", b, "'");
77 }
78
79 // Finds TPU_SYSTEM:0 devices in `devices`. If multiple TPU_SYSTEM devices are
80 // found, the first one lexicographically is returned. If no TPU_SYSTEM device
81 // is found or if there are multiple TPU_SYSTEM devices with different jobs or
82 // replicas, a failure will be returned.
GetTPUSystemDevices(Devices devices,llvm::SmallVectorImpl<Device> * matched_devices)83 Status GetTPUSystemDevices(Devices devices,
84 llvm::SmallVectorImpl<Device>* matched_devices) {
85 Device spec;
86 spec.type = kDeviceTPUSystem;
87 spec.has_type = true;
88 spec.id = 0;
89 spec.has_id = true;
90
91 llvm::SmallVector<Device, 8> system_devices;
92 FindMatchingDevices(devices, spec, &system_devices);
93 if (system_devices.empty())
94 return errors::InvalidArgument("no ", kDeviceTPUSystem, " devices found");
95
96 // Check that all system devices are part of the same job.
97 const auto& job = system_devices[0].job;
98 auto replica = system_devices[0].replica;
99 for (const auto& device : llvm::make_range(std::next(system_devices.begin()),
100 system_devices.end())) {
101 if (device.job != job)
102 return MismatchedTPUSystemAttributeErr("job", job, device.job);
103
104 if (device.replica != replica)
105 return MismatchedTPUSystemAttributeErr("replica", replica,
106 device.replica);
107 }
108
109 // Sort by task to be deterministic.
110 std::sort(system_devices.begin(), system_devices.end(),
111 [](const Device& a, const Device& b) { return a.task < b.task; });
112
113 matched_devices->swap(system_devices);
114
115 return Status::OK();
116 }
117
118 // Finds TPU devices associated to system device based on spec (e.g. from
119 // GetTPUSystemDevices). If the number of TPU devices per host do not match for
120 // every host, a failure will be returned.
GetTPUDevices(Devices devices,llvm::ArrayRef<Device> system_devices,llvm::SmallVectorImpl<llvm::SmallVector<Device,8>> * tpu_devices)121 Status GetTPUDevices(
122 Devices devices, llvm::ArrayRef<Device> system_devices,
123 llvm::SmallVectorImpl<llvm::SmallVector<Device, 8>>* tpu_devices) {
124 tpu_devices->reserve(system_devices.size());
125
126 auto lookup = [&devices](Device device_spec) {
127 device_spec.has_type = true;
128 device_spec.type = kDeviceTPU;
129 // Enumerate all the available TPUs.
130 device_spec.has_id = false;
131
132 llvm::SmallVector<Device, 8> host_tpu_devices;
133 FindMatchingDevices(devices, device_spec, &host_tpu_devices);
134
135 // Sort devices by id.
136 std::sort(host_tpu_devices.begin(), host_tpu_devices.end(),
137 [](const Device& i, const Device& j) { return i.id < j.id; });
138 return host_tpu_devices;
139 };
140
141 int num_tpus_per_host = 0;
142 {
143 const auto& device = system_devices[0];
144 auto host_tpu_devices = lookup(device);
145 num_tpus_per_host = host_tpu_devices.size();
146 tpu_devices->push_back(std::move(host_tpu_devices));
147 }
148
149 for (const auto& device_spec : llvm::make_range(
150 std::next(system_devices.begin()), system_devices.end())) {
151 auto host_tpu_devices = lookup(device_spec);
152 // Check number of TPU devices per host all match.
153 const int64 host_tpu_devices_size = host_tpu_devices.size();
154 if (num_tpus_per_host != host_tpu_devices_size)
155 return errors::InvalidArgument(
156 "expected the number of TPU devices per host to be ",
157 num_tpus_per_host, ", got ", host_tpu_devices.size());
158
159 tpu_devices->push_back(std::move(host_tpu_devices));
160 }
161
162 return Status::OK();
163 }
164
165 // Finds the compilation device from system device.
GetTPUCompilationDevice(Device system_device)166 std::string GetTPUCompilationDevice(Device system_device) {
167 // TODO(b/110910013) GetTPUSystemDevices parses the spec and returns the
168 // TPU_SYSTEM device, which we replace with the CPU device. We do this
169 // replacement because we want to place the `tf._TPUCompileMlir` explicitly on
170 // CPU devices of the same job as the TPU_SYSTEM device.
171 system_device.type = tensorflow::DEVICE_CPU;
172 return DeviceNameUtils::ParsedNameToString(system_device);
173 }
174
175 // Finds the host CPU device for a given TPU device.
GetCPUHostDeviceForTPUDevice(Device tpu_device)176 std::string GetCPUHostDeviceForTPUDevice(Device tpu_device) {
177 tpu_device.type = DEVICE_CPU;
178 tpu_device.id = 0;
179 return DeviceNameUtils::ParsedNameToString(tpu_device);
180 }
181
182 // Determines execution devices when topology and device assignment are not
183 // defined. This is a special case where a single core computation is replicated
184 // to every core in the mesh. TPU devices are simply added to
185 // `execution_devices` of one replica. `num_replicas` must be 1 or the total
186 // number of TPU devices available, and `num_cores_per_replica` must be 1.
GetFullMeshTPUExecutionDeviceAssignment(int num_replicas,int num_cores_per_replica,llvm::ArrayRef<llvm::SmallVector<Device,8>> tpu_devices)187 StatusOr<TPUDevicesAndHosts> GetFullMeshTPUExecutionDeviceAssignment(
188 int num_replicas, int num_cores_per_replica,
189 llvm::ArrayRef<llvm::SmallVector<Device, 8>> tpu_devices) {
190 const int num_tasks = tpu_devices.size();
191 const int num_tpus_per_task = tpu_devices[0].size();
192 const int num_tpu_devices = num_tasks * num_tpus_per_task;
193
194 if (num_replicas != 1 && num_replicas != num_tpu_devices)
195 return errors::InvalidArgument("'num_replicas' must be equal to 1 or ",
196 num_tpu_devices, ", got ", num_replicas);
197
198 if (num_cores_per_replica != 1)
199 return errors::InvalidArgument(
200 "'num_cores_per_replica' must be equal to 1, got ",
201 num_cores_per_replica);
202
203 TPUDevicesAndHosts devices_and_hosts;
204 devices_and_hosts.reserve(num_replicas);
205 for (int i = 0; i < num_replicas; ++i) {
206 const int task = i / num_tpus_per_task;
207 const int device = i % num_tpus_per_task;
208 const auto& tpu_device = tpu_devices[task][device];
209 devices_and_hosts.push_back({TPUDeviceAndHost(
210 /*device=*/tensorflow::DeviceNameUtils::ParsedNameToString(tpu_device),
211 /*host=*/GetCPUHostDeviceForTPUDevice(tpu_device))});
212 }
213
214 return devices_and_hosts;
215 }
216
217 // Helper struct for keeping track of task and device for an associated TPU
218 // device coordinate.
219 struct TaskAndDevice {
TaskAndDevicetensorflow::__anon646e35250111::TaskAndDevice220 TaskAndDevice() {}
TaskAndDevicetensorflow::__anon646e35250111::TaskAndDevice221 TaskAndDevice(int task, int device) : task(task), device(device) {}
222
223 int task = -1;
224 int device = -1;
225 };
226
227 // Checks if device coordinate is outside of topology mesh shape bounds.
DeviceCoordinateOutOfBound(int x,int y,int z,int core,int bound_x,int bound_y,int bound_z,int bound_core)228 bool DeviceCoordinateOutOfBound(int x, int y, int z, int core, int bound_x,
229 int bound_y, int bound_z, int bound_core) {
230 return x < 0 || x >= bound_x || y < 0 || y >= bound_y || z < 0 ||
231 z >= bound_z || core < 0 || core >= bound_core;
232 }
233
234 // Creates error message for an out of bound device coordinate.
DeviceCoordinateErrorMsg(absl::string_view attribute,int x,int y,int z,int core,int bound_x,int bound_y,int bound_z,int bound_core)235 Status DeviceCoordinateErrorMsg(absl::string_view attribute, int x, int y,
236 int z, int core, int bound_x, int bound_y,
237 int bound_z, int bound_core) {
238 return errors::InvalidArgument("device coordinate (", x, ", ", y, ", ", z,
239 ", ", core, ") in '", attribute,
240 "' is outside of mesh shape (", bound_x, ", ",
241 bound_y, ", ", bound_z, ", ", bound_core, ")");
242 }
243
244 // Creates error message for a duplicate device coordinate.
DuplicateCoordinateErrorMsg(absl::string_view attribute,int x,int y,int z,int core)245 Status DuplicateCoordinateErrorMsg(absl::string_view attribute, int x, int y,
246 int z, int core) {
247 return errors::InvalidArgument("'", attribute,
248 "' has duplicate device coordinate (", x, ", ",
249 y, ", ", z, ", ", core, ")");
250 }
251
252 // Parses and validates topology (serialized string of TopologyProto), and maps
253 // device coordinate (x, y, z, core) to task and device (of available TPUs).
254 // Topology attribute device coordinates are ordered by task then device (major
255 // to minor).
256 //
257 // A valid TopologyProto must have:
258 // - a valid mesh shape (rank 4 with positive dimensions)
259 // - `num_tasks` and `num_tpu_devices_per_task` must match the number of
260 // available TPU hosts and devices per host
261 // - device coordinates within the mesh shape
262 // - no duplicate device coordinates
263 // - number of device coordinates (in tuple 3) match number of availabe TPUs
ParseTopologyAttr(llvm::StringRef topology_attr,int num_tasks,int num_tpus_per_task)264 StatusOr<xla::Array4D<TaskAndDevice>> ParseTopologyAttr(
265 llvm::StringRef topology_attr, int num_tasks, int num_tpus_per_task) {
266 tpu::TopologyProto topology_proto;
267 if (!topology_proto.ParseFromString(topology_attr.str()))
268 return errors::InvalidArgument("failed to parse '", kTopologyAttr,
269 "' attribute to TopologyProto");
270
271 if (topology_proto.mesh_shape_size() != kTPUTopologyRank)
272 return errors::InvalidArgument(
273 "'", kTopologyAttr, "' 'mesh_shape' must be rank ", kTPUTopologyRank,
274 ", got rank ", topology_proto.mesh_shape_size());
275
276 for (auto mesh_shape_dim : llvm::enumerate(topology_proto.mesh_shape()))
277 if (mesh_shape_dim.value() <= 0)
278 return errors::InvalidArgument(
279 "'", kTopologyAttr, "' 'mesh_shape' dimension ",
280 mesh_shape_dim.index(), " must be positive, got ",
281 mesh_shape_dim.value());
282
283 if (topology_proto.num_tasks() != num_tasks)
284 return errors::InvalidArgument(
285 "number of tasks from available TPU devices must be 'num_tasks' in '",
286 kTopologyAttr, "' (", topology_proto.num_tasks(), "), got ", num_tasks);
287
288 if (topology_proto.num_tpu_devices_per_task() != num_tpus_per_task)
289 return errors::InvalidArgument(
290 "number of TPU devices available per task must be "
291 "'num_tpu_devices_per_task' in '",
292 kTopologyAttr, "' (", topology_proto.num_tpu_devices_per_task(),
293 "), got ", num_tpus_per_task);
294
295 const int expected_device_coordinates_size =
296 num_tasks * num_tpus_per_task * kTPUTopologyRank;
297 if (topology_proto.device_coordinates_size() !=
298 expected_device_coordinates_size)
299 return errors::InvalidArgument(
300 "length of 'device_coordinates' in '", kTopologyAttr,
301 "' must be 'num_tasks' * 'num_tpus_per_task' * ", kTPUTopologyRank,
302 " (", num_tasks, " * ", num_tpus_per_task, " * ", kTPUTopologyRank,
303 "), got ", topology_proto.device_coordinates_size());
304
305 const int bound_x = topology_proto.mesh_shape(0);
306 const int bound_y = topology_proto.mesh_shape(1);
307 const int bound_z = topology_proto.mesh_shape(2);
308 const int bound_core = topology_proto.mesh_shape(3);
309
310 xla::Array4D<TaskAndDevice> topology(bound_x, bound_y, bound_z, bound_core);
311 int pos = 0;
312 for (int task = 0; task < num_tasks; ++task) {
313 for (int device = 0; device < num_tpus_per_task; ++device) {
314 int x = topology_proto.device_coordinates(pos++);
315 int y = topology_proto.device_coordinates(pos++);
316 int z = topology_proto.device_coordinates(pos++);
317 int core = topology_proto.device_coordinates(pos++);
318 if (DeviceCoordinateOutOfBound(x, y, z, core, bound_x, bound_y, bound_z,
319 bound_core))
320 return DeviceCoordinateErrorMsg(kTopologyAttr, x, y, z, core, bound_x,
321 bound_y, bound_z, bound_core);
322
323 auto& task_and_device = topology(x, y, z, core);
324 if (task_and_device.task != -1)
325 return DuplicateCoordinateErrorMsg(kTopologyAttr, x, y, z, core);
326
327 task_and_device = {task, device};
328 }
329 }
330
331 return topology;
332 }
333
334 // Determines execution devices when topology and device assignment are defined.
335 // With a topology device coordinate to task and device mapping, device
336 // assignment device coordinates can then be mapped to task and device for TPU
337 // devices. The device assignment array is also validated.
338 //
339 // A valid device assignment array must have:
340 // - device coordinates within the topology mesh shape
341 // - no duplicate device coordinates
342 // - number of device coordinates (in tuple 3) match number 'num_replicas' *
343 // 'num_cores_per_replica'
344 // - a TPU device associated with each device coordinate
345 StatusOr<std::pair<TPUDevicesAndHosts, xla::DeviceAssignmentProto>>
GetGeneralTPUExecutionDeviceAssignment(int num_replicas,int num_cores_per_replica,llvm::ArrayRef<llvm::SmallVector<Device,8>> tpu_devices,llvm::StringRef topology_attr,llvm::ArrayRef<int64_t> device_assignment_attr)346 GetGeneralTPUExecutionDeviceAssignment(
347 int num_replicas, int num_cores_per_replica,
348 llvm::ArrayRef<llvm::SmallVector<Device, 8>> tpu_devices,
349 llvm::StringRef topology_attr,
350 llvm::ArrayRef<int64_t> device_assignment_attr) {
351 const int num_tasks = tpu_devices.size();
352 const int num_tpus_per_task = tpu_devices[0].size();
353
354 TF_ASSIGN_OR_RETURN(auto topology, ParseTopologyAttr(topology_attr, num_tasks,
355 num_tpus_per_task));
356
357 const int expected_device_assignment_size =
358 num_replicas * num_cores_per_replica * kTPUTopologyRank;
359 const int device_assignment_attr_size = device_assignment_attr.size();
360 if (device_assignment_attr_size != expected_device_assignment_size)
361 return errors::InvalidArgument(
362 "length of '", kDeviceAssignmentAttr,
363 "' must be 'num_replicas' * 'num_cores_per_replica' * ",
364 kTPUTopologyRank, " (", num_replicas, " * ", num_cores_per_replica,
365 " * ", kTPUTopologyRank, "), got ", device_assignment_attr.size());
366
367 const int bound_x = topology.n1();
368 const int bound_y = topology.n2();
369 const int bound_z = topology.n3();
370 const int bound_core = topology.n4();
371
372 // TPU XLA device ID is determined by its device coordinate, from major to
373 // minor coordinates (z, y, x, core).
374 auto location_to_id = [&](int x, int y, int z, int core) {
375 return (x + bound_x * (y + bound_y * z)) * bound_core + core;
376 };
377
378 std::vector<bool> used_device_ids(bound_x * bound_y * bound_z * bound_core,
379 false);
380 TPUDevicesAndHosts devices_and_hosts(
381 num_replicas, llvm::SmallVector<TPUDeviceAndHost, 8>(
382 num_cores_per_replica, TPUDeviceAndHost()));
383 xla::DeviceAssignment device_assignment(num_replicas, num_cores_per_replica);
384 int pos = 0;
385 for (int replica = 0; replica < num_replicas; ++replica) {
386 for (int logical_core = 0; logical_core < num_cores_per_replica;
387 ++logical_core) {
388 int x = device_assignment_attr[pos++];
389 int y = device_assignment_attr[pos++];
390 int z = device_assignment_attr[pos++];
391 int core = device_assignment_attr[pos++];
392 if (DeviceCoordinateOutOfBound(x, y, z, core, bound_x, bound_y, bound_z,
393 bound_core))
394 return DeviceCoordinateErrorMsg(kDeviceAssignmentAttr, x, y, z, core,
395 bound_x, bound_y, bound_z, bound_core);
396
397 TaskAndDevice task_and_device = topology(x, y, z, core);
398 const int task = task_and_device.task;
399 const int device = task_and_device.device;
400 if (task == -1 || device == -1)
401 return errors::InvalidArgument(
402 "no TPU device found for '", kDeviceAssignmentAttr,
403 "' device coordinate (", x, ", ", y, ", ", z, ", ", core, ")");
404
405 const int device_id = location_to_id(x, y, z, core);
406 if (used_device_ids[device_id])
407 return DuplicateCoordinateErrorMsg(kDeviceAssignmentAttr, x, y, z,
408 core);
409
410 used_device_ids[device_id] = true;
411 device_assignment(replica, logical_core) = device_id;
412 auto& device_and_host = devices_and_hosts[replica][logical_core];
413 const auto& tpu_device = tpu_devices[task][device];
414 device_and_host.device = DeviceNameUtils::ParsedNameToString(tpu_device);
415 device_and_host.host = GetCPUHostDeviceForTPUDevice(tpu_device);
416 }
417 }
418
419 xla::DeviceAssignmentProto device_assignment_proto;
420 TF_RETURN_IF_ERROR(device_assignment.Serialize(&device_assignment_proto));
421
422 return std::pair<TPUDevicesAndHosts, xla::DeviceAssignmentProto>(
423 std::move(devices_and_hosts), std::move(device_assignment_proto));
424 }
425
426 } // anonymous namespace
427
GetDeviceCoordinates(mlir::ArrayAttr device_assignment_attr)428 StatusOr<llvm::SmallVector<int64_t, 8>> GetDeviceCoordinates(
429 mlir::ArrayAttr device_assignment_attr) {
430 llvm::SmallVector<int64_t, 8> device_coordinates;
431 device_coordinates.reserve(device_assignment_attr.size());
432
433 for (auto device_coordinate_and_idx :
434 llvm::enumerate(device_assignment_attr)) {
435 auto device_coordinate =
436 device_coordinate_and_idx.value().dyn_cast<mlir::IntegerAttr>();
437 if (!device_coordinate)
438 return errors::InvalidArgument(
439 llvm::formatv(kBadIntArrayElementMsg, kDeviceAssignmentAttr,
440 device_coordinate_and_idx.index())
441 .str());
442
443 device_coordinates.push_back(device_coordinate.getInt());
444 }
445
446 return device_coordinates;
447 }
448
GetTPUCompilationAndExecutionDevices(Devices devices,int num_replicas,int num_cores_per_replica,llvm::StringRef topology_attr,llvm::ArrayRef<int64_t> device_assignment_attr)449 StatusOr<TPUDeviceAssignment> GetTPUCompilationAndExecutionDevices(
450 Devices devices, int num_replicas, int num_cores_per_replica,
451 llvm::StringRef topology_attr,
452 llvm::ArrayRef<int64_t> device_assignment_attr) {
453 // Collect TPU_SYSTEM devices.
454 llvm::SmallVector<Device, 8> system_devices;
455 TF_RETURN_IF_ERROR(GetTPUSystemDevices(devices, &system_devices));
456
457 // Collect TPU devices based on TPU_SYSTEM devices collected earlier.
458 llvm::SmallVector<llvm::SmallVector<Device, 8>, 8> tpu_devices;
459 TF_RETURN_IF_ERROR(GetTPUDevices(devices, system_devices, &tpu_devices));
460
461 std::string compilation_device = GetTPUCompilationDevice(system_devices[0]);
462
463 if (topology_attr.empty()) {
464 if (!device_assignment_attr.empty())
465 return errors::InvalidArgument("'", kDeviceAssignmentAttr,
466 "' must not be set when '", kTopologyAttr,
467 "' is not set");
468
469 TF_ASSIGN_OR_RETURN(auto execution_devices,
470 GetFullMeshTPUExecutionDeviceAssignment(
471 num_replicas, num_cores_per_replica, tpu_devices));
472 return TPUDeviceAssignment(compilation_device,
473 std::move(execution_devices));
474 }
475
476 TF_ASSIGN_OR_RETURN(auto devices_and_ids,
477 GetGeneralTPUExecutionDeviceAssignment(
478 num_replicas, num_cores_per_replica, tpu_devices,
479 topology_attr, device_assignment_attr));
480 return TPUDeviceAssignment(compilation_device,
481 std::move(devices_and_ids.first),
482 std::move(devices_and_ids.second));
483 }
484
GetDeviceAliasForLogicalCore(int core_index)485 std::string GetDeviceAliasForLogicalCore(int core_index) {
486 return llvm::formatv("{0}_{1}", kTPUReplicatedCore, core_index).str();
487 }
488
GetHostDeviceOutsideComputation(mlir::TF::RuntimeDevices devices,mlir::tf_device::ClusterOp cluster,std::string * host_device)489 mlir::LogicalResult GetHostDeviceOutsideComputation(
490 mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster,
491 std::string* host_device) {
492 auto replicate = cluster->getParentOfType<mlir::tf_device::ReplicateOp>();
493 if (replicate) {
494 *host_device = tensorflow::kTPUReplicatedHost;
495 return mlir::success();
496 }
497
498 auto num_cores_per_replica_attr = cluster->getAttrOfType<mlir::IntegerAttr>(
499 tensorflow::kNumCoresPerReplicaAttr);
500 if (!num_cores_per_replica_attr)
501 return cluster.emitOpError(
502 "cluster op missing `num_cores_per_replica` attribute");
503
504 if (num_cores_per_replica_attr.getInt() != 1)
505 return cluster.emitOpError(
506 "outside compilation is not supported with model parallelism.");
507
508 auto topology_attr =
509 cluster->getAttrOfType<mlir::StringAttr>(tensorflow::kTopologyAttr);
510 if (!topology_attr)
511 return cluster.emitOpError("cluster op missing `topology` attribute");
512
513 auto device_assignment_attr = cluster->getAttrOfType<mlir::ArrayAttr>(
514 tensorflow::kDeviceAssignmentAttr);
515 if (!device_assignment_attr)
516 return cluster.emitOpError(llvm::formatv("requires attribute '{0}'",
517 tensorflow::kDeviceAssignmentAttr)
518 .str());
519
520 auto status_or_device_coodinates =
521 tensorflow::GetDeviceCoordinates(device_assignment_attr);
522
523 if (!status_or_device_coodinates.ok())
524 return cluster.emitError()
525 << "error in fetching tpu device coordinates: "
526 << status_or_device_coodinates.status().error_message();
527
528 // Determine compilation and execution devices.
529 auto status_or_tpu_device_assignment =
530 tensorflow::GetTPUCompilationAndExecutionDevices(
531 devices.device_names(), /*num_replicas=*/1,
532 /*num_cores_per_replica=*/1, topology_attr.getValue(),
533 status_or_device_coodinates.ConsumeValueOrDie());
534 if (!status_or_tpu_device_assignment.ok())
535 return cluster.emitError()
536 << "error in fetching TPU compilation/execution devices: "
537 << status_or_tpu_device_assignment.status().error_message();
538 auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
539
540 *host_device = tpu_device_assignment.tpu_devices[0][0].host;
541 return mlir::success();
542 }
543
IsTPUDevice(llvm::StringRef device)544 bool IsTPUDevice(llvm::StringRef device) {
545 Device parsed_device;
546 if (!DeviceNameUtils::ParseFullName(mlir::StringRefToView(device),
547 &parsed_device))
548 return false;
549 return parsed_device.has_type && parsed_device.type == kDeviceTPU;
550 }
551
552 } // namespace tensorflow
553