1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <iterator>
21 #include <string>
22 #include <type_traits>
23 #include <utility>
24
25 #include "absl/strings/string_view.h"
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/iterator_range.h"
30 #include "llvm/Support/FormatVariadic.h"
31 #include "mlir/IR/Attributes.h" // from @llvm-project
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
33 #include "tensorflow/compiler/mlir/utils/string_container_utils.h"
34 #include "tensorflow/compiler/xla/array4d.h"
35 #include "tensorflow/compiler/xla/service/computation_placer.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/framework/types.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
40 #include "tensorflow/core/util/device_name_utils.h"
41 #include "tensorflow/stream_executor/lib/statusor.h"
42
43 namespace tensorflow {
44
45 // Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4
46 // topology.
47 constexpr int kTPUTopologyRank = 4;
48
49 constexpr char kDeviceTPUSystem[] = "TPU_SYSTEM";
50 constexpr char kDeviceTPU[] = "TPU";
51 constexpr char kTPUReplicatedCore[] = "TPU_REPLICATED_CORE";
52 constexpr char kBadIntArrayElementMsg[] =
53 "bad '{0}' attribute at index {1}, not an int";
54
55 using Device = DeviceNameUtils::ParsedName;
56 using Devices = llvm::ArrayRef<DeviceNameUtils::ParsedName>;
57
58 namespace {
59 // Finds matching devices in `devices` based on pattern `spec`.
FindMatchingDevices(Devices devices,const Device & spec,llvm::SmallVectorImpl<Device> * matched_devices)60 void FindMatchingDevices(Devices devices, const Device& spec,
61 llvm::SmallVectorImpl<Device>* matched_devices) {
62 for (const auto& device : devices)
63 if (DeviceNameUtils::IsCompleteSpecification(spec, device))
64 matched_devices->push_back(device);
65 }
66
67 // Creates error message for a conflicting attribute of a device.
68 template <typename T>
MismatchedTPUSystemAttributeErr(absl::string_view attribute,T a,T b)69 Status MismatchedTPUSystemAttributeErr(absl::string_view attribute, T a, T b) {
70 return errors::InvalidArgument("found ", kDeviceTPUSystem,
71 " devices with conflicting ", attribute, "s '",
72 a, "' and '", b, "'");
73 }
74
75 // Finds TPU_SYSTEM:0 devices in `devices`. If multiple TPU_SYSTEM devices are
76 // found, the first one lexicographically is returned. If no TPU_SYSTEM device
77 // is found or if there are multiple TPU_SYSTEM devices with different jobs or
78 // replicas, a failure will be returned.
GetTPUSystemDevices(Devices devices,llvm::SmallVectorImpl<Device> * matched_devices)79 Status GetTPUSystemDevices(Devices devices,
80 llvm::SmallVectorImpl<Device>* matched_devices) {
81 Device spec;
82 spec.type = kDeviceTPUSystem;
83 spec.has_type = true;
84 spec.id = 0;
85 spec.has_id = true;
86
87 llvm::SmallVector<Device, 8> system_devices;
88 FindMatchingDevices(devices, spec, &system_devices);
89 if (system_devices.empty())
90 return errors::InvalidArgument("no ", kDeviceTPUSystem, " devices found");
91
92 // Check that all system devices are part of the same job.
93 const auto& job = system_devices[0].job;
94 auto replica = system_devices[0].replica;
95 for (const auto& device : llvm::make_range(std::next(system_devices.begin()),
96 system_devices.end())) {
97 if (device.job != job)
98 return MismatchedTPUSystemAttributeErr("job", job, device.job);
99
100 if (device.replica != replica)
101 return MismatchedTPUSystemAttributeErr("replica", replica,
102 device.replica);
103 }
104
105 // Sort by task to be deterministic.
106 std::sort(system_devices.begin(), system_devices.end(),
107 [](const Device& a, const Device& b) { return a.task < b.task; });
108
109 matched_devices->swap(system_devices);
110
111 return OkStatus();
112 }
113
114 // Finds TPU devices associated to system device based on spec (e.g. from
115 // GetTPUSystemDevices). If the number of TPU devices per host do not match for
116 // every host, a failure will be returned.
GetTPUDevices(Devices devices,llvm::ArrayRef<Device> system_devices,llvm::SmallVectorImpl<llvm::SmallVector<Device,8>> * tpu_devices)117 Status GetTPUDevices(
118 Devices devices, llvm::ArrayRef<Device> system_devices,
119 llvm::SmallVectorImpl<llvm::SmallVector<Device, 8>>* tpu_devices) {
120 tpu_devices->reserve(system_devices.size());
121
122 auto lookup = [&devices](Device device_spec) {
123 device_spec.has_type = true;
124 device_spec.type = kDeviceTPU;
125 // Enumerate all the available TPUs.
126 device_spec.has_id = false;
127
128 llvm::SmallVector<Device, 8> host_tpu_devices;
129 FindMatchingDevices(devices, device_spec, &host_tpu_devices);
130
131 // Sort devices by id.
132 std::sort(host_tpu_devices.begin(), host_tpu_devices.end(),
133 [](const Device& i, const Device& j) { return i.id < j.id; });
134 return host_tpu_devices;
135 };
136
137 int num_tpus_per_host = 0;
138 {
139 const auto& device = system_devices[0];
140 auto host_tpu_devices = lookup(device);
141 num_tpus_per_host = host_tpu_devices.size();
142 tpu_devices->push_back(std::move(host_tpu_devices));
143 }
144
145 for (const auto& device_spec : llvm::make_range(
146 std::next(system_devices.begin()), system_devices.end())) {
147 auto host_tpu_devices = lookup(device_spec);
148 // Check number of TPU devices per host all match.
149 const int64_t host_tpu_devices_size = host_tpu_devices.size();
150 if (num_tpus_per_host != host_tpu_devices_size)
151 return errors::InvalidArgument(
152 "expected the number of TPU devices per host to be ",
153 num_tpus_per_host, ", got ", host_tpu_devices.size());
154
155 tpu_devices->push_back(std::move(host_tpu_devices));
156 }
157
158 return OkStatus();
159 }
160
161 // Finds the compilation device from system device.
GetTPUCompilationDevice(Device system_device)162 std::string GetTPUCompilationDevice(Device system_device) {
163 // TODO(b/110910013) GetTPUSystemDevices parses the spec and returns the
164 // TPU_SYSTEM device, which we replace with the CPU device. We do this
165 // replacement because we want to place the `tf._TPUCompileMlir` explicitly on
166 // CPU devices of the same job as the TPU_SYSTEM device.
167 system_device.type = tensorflow::DEVICE_CPU;
168 return DeviceNameUtils::ParsedNameToString(system_device);
169 }
170
171 // Finds the host CPU device for a given TPU device.
GetCPUHostDeviceForTPUDevice(Device tpu_device)172 std::string GetCPUHostDeviceForTPUDevice(Device tpu_device) {
173 tpu_device.type = DEVICE_CPU;
174 tpu_device.id = 0;
175 return DeviceNameUtils::ParsedNameToString(tpu_device);
176 }
177
178 // Determines execution devices when topology and device assignment are not
179 // defined. This is a special case where a single core computation is replicated
180 // to every core in the mesh. TPU devices are simply added to
181 // `execution_devices` of one replica. `num_replicas` must be 1 or the total
182 // number of TPU devices available, and `num_cores_per_replica` must be 1.
GetFullMeshTPUExecutionDeviceAssignment(int num_replicas,int num_cores_per_replica,llvm::ArrayRef<llvm::SmallVector<Device,8>> tpu_devices)183 StatusOr<TPUDevicesAndHosts> GetFullMeshTPUExecutionDeviceAssignment(
184 int num_replicas, int num_cores_per_replica,
185 llvm::ArrayRef<llvm::SmallVector<Device, 8>> tpu_devices) {
186 const int num_tasks = tpu_devices.size();
187 const int num_tpus_per_task = tpu_devices[0].size();
188 const int num_tpu_devices = num_tasks * num_tpus_per_task;
189
190 if (num_replicas != 1 && num_replicas != num_tpu_devices)
191 return errors::InvalidArgument("'num_replicas' must be equal to 1 or ",
192 num_tpu_devices, ", got ", num_replicas);
193
194 if (num_cores_per_replica != 1)
195 return errors::InvalidArgument(
196 "'num_cores_per_replica' must be equal to 1, got ",
197 num_cores_per_replica);
198
199 TPUDevicesAndHosts devices_and_hosts;
200 devices_and_hosts.reserve(num_replicas);
201 for (int i = 0; i < num_replicas; ++i) {
202 const int task = i / num_tpus_per_task;
203 const int device = i % num_tpus_per_task;
204 const auto& tpu_device = tpu_devices[task][device];
205 devices_and_hosts.push_back({TPUDeviceAndHost(
206 /*device=*/tensorflow::DeviceNameUtils::ParsedNameToString(tpu_device),
207 /*host=*/GetCPUHostDeviceForTPUDevice(tpu_device))});
208 }
209
210 return devices_and_hosts;
211 }
212
213 // Helper struct for keeping track of task and device for an associated TPU
214 // device coordinate.
215 struct TaskAndDevice {
TaskAndDevicetensorflow::__anon71db7cc30111::TaskAndDevice216 TaskAndDevice() {}
TaskAndDevicetensorflow::__anon71db7cc30111::TaskAndDevice217 TaskAndDevice(int task, int device) : task(task), device(device) {}
218
219 int task = -1;
220 int device = -1;
221 };
222
223 // Checks if device coordinate is outside of topology mesh shape bounds.
DeviceCoordinateOutOfBound(int x,int y,int z,int core,int bound_x,int bound_y,int bound_z,int bound_core)224 bool DeviceCoordinateOutOfBound(int x, int y, int z, int core, int bound_x,
225 int bound_y, int bound_z, int bound_core) {
226 return x < 0 || x >= bound_x || y < 0 || y >= bound_y || z < 0 ||
227 z >= bound_z || core < 0 || core >= bound_core;
228 }
229
230 // Creates error message for an out of bound device coordinate.
DeviceCoordinateErrorMsg(absl::string_view attribute,int x,int y,int z,int core,int bound_x,int bound_y,int bound_z,int bound_core)231 Status DeviceCoordinateErrorMsg(absl::string_view attribute, int x, int y,
232 int z, int core, int bound_x, int bound_y,
233 int bound_z, int bound_core) {
234 return errors::InvalidArgument("device coordinate (", x, ", ", y, ", ", z,
235 ", ", core, ") in '", attribute,
236 "' is outside of mesh shape (", bound_x, ", ",
237 bound_y, ", ", bound_z, ", ", bound_core, ")");
238 }
239
240 // Creates error message for a duplicate device coordinate.
DuplicateCoordinateErrorMsg(absl::string_view attribute,int x,int y,int z,int core)241 Status DuplicateCoordinateErrorMsg(absl::string_view attribute, int x, int y,
242 int z, int core) {
243 return errors::InvalidArgument("'", attribute,
244 "' has duplicate device coordinate (", x, ", ",
245 y, ", ", z, ", ", core, ")");
246 }
247
248 // Parses and validates topology (serialized string of TopologyProto), and maps
249 // device coordinate (x, y, z, core) to task and device (of available TPUs).
250 // Topology attribute device coordinates are ordered by task then device (major
251 // to minor).
252 //
253 // A valid TopologyProto must have:
254 // - a valid mesh shape (rank 4 with positive dimensions)
255 // - `num_tasks` and `num_tpu_devices_per_task` must match the number of
256 // available TPU hosts and devices per host
257 // - device coordinates within the mesh shape
258 // - no duplicate device coordinates
259 // - number of device coordinates (in tuple 3) match number of availabe TPUs
ParseTopologyAttr(llvm::StringRef topology_attr,int num_tasks,int num_tpus_per_task)260 StatusOr<xla::Array4D<TaskAndDevice>> ParseTopologyAttr(
261 llvm::StringRef topology_attr, int num_tasks, int num_tpus_per_task) {
262 tpu::TopologyProto topology_proto;
263 if (!topology_proto.ParseFromString(topology_attr.str()))
264 return errors::InvalidArgument("failed to parse '", kTopologyAttr,
265 "' attribute to TopologyProto");
266
267 if (topology_proto.mesh_shape_size() != kTPUTopologyRank)
268 return errors::InvalidArgument(
269 "'", kTopologyAttr, "' 'mesh_shape' must be rank ", kTPUTopologyRank,
270 ", got rank ", topology_proto.mesh_shape_size());
271
272 for (auto mesh_shape_dim : llvm::enumerate(topology_proto.mesh_shape()))
273 if (mesh_shape_dim.value() <= 0)
274 return errors::InvalidArgument(
275 "'", kTopologyAttr, "' 'mesh_shape' dimension ",
276 mesh_shape_dim.index(), " must be positive, got ",
277 mesh_shape_dim.value());
278
279 if (topology_proto.num_tasks() != num_tasks)
280 return errors::InvalidArgument(
281 "number of tasks from available TPU devices must be 'num_tasks' in '",
282 kTopologyAttr, "' (", topology_proto.num_tasks(), "), got ", num_tasks);
283
284 if (topology_proto.num_tpu_devices_per_task() != num_tpus_per_task)
285 return errors::InvalidArgument(
286 "number of TPU devices available per task must be "
287 "'num_tpu_devices_per_task' in '",
288 kTopologyAttr, "' (", topology_proto.num_tpu_devices_per_task(),
289 "), got ", num_tpus_per_task);
290
291 const int expected_device_coordinates_size =
292 num_tasks * num_tpus_per_task * kTPUTopologyRank;
293 if (topology_proto.device_coordinates_size() !=
294 expected_device_coordinates_size)
295 return errors::InvalidArgument(
296 "length of 'device_coordinates' in '", kTopologyAttr,
297 "' must be 'num_tasks' * 'num_tpus_per_task' * ", kTPUTopologyRank,
298 " (", num_tasks, " * ", num_tpus_per_task, " * ", kTPUTopologyRank,
299 "), got ", topology_proto.device_coordinates_size());
300
301 const int bound_x = topology_proto.mesh_shape(0);
302 const int bound_y = topology_proto.mesh_shape(1);
303 const int bound_z = topology_proto.mesh_shape(2);
304 const int bound_core = topology_proto.mesh_shape(3);
305
306 xla::Array4D<TaskAndDevice> topology(bound_x, bound_y, bound_z, bound_core);
307 int pos = 0;
308 for (int task = 0; task < num_tasks; ++task) {
309 for (int device = 0; device < num_tpus_per_task; ++device) {
310 int x = topology_proto.device_coordinates(pos++);
311 int y = topology_proto.device_coordinates(pos++);
312 int z = topology_proto.device_coordinates(pos++);
313 int core = topology_proto.device_coordinates(pos++);
314 if (DeviceCoordinateOutOfBound(x, y, z, core, bound_x, bound_y, bound_z,
315 bound_core))
316 return DeviceCoordinateErrorMsg(kTopologyAttr, x, y, z, core, bound_x,
317 bound_y, bound_z, bound_core);
318
319 auto& task_and_device = topology(x, y, z, core);
320 if (task_and_device.task != -1)
321 return DuplicateCoordinateErrorMsg(kTopologyAttr, x, y, z, core);
322
323 task_and_device = {task, device};
324 }
325 }
326
327 return topology;
328 }
329
330 // Determines execution devices when topology and device assignment are defined.
331 // With a topology device coordinate to task and device mapping, device
332 // assignment device coordinates can then be mapped to task and device for TPU
333 // devices. The device assignment array is also validated.
334 //
335 // A valid device assignment array must have:
336 // - device coordinates within the topology mesh shape
337 // - no duplicate device coordinates
338 // - number of device coordinates (in tuple 3) match number 'num_replicas' *
339 // 'num_cores_per_replica'
340 // - a TPU device associated with each device coordinate
341 StatusOr<std::pair<TPUDevicesAndHosts, xla::DeviceAssignmentProto>>
GetGeneralTPUExecutionDeviceAssignment(int num_replicas,int num_cores_per_replica,llvm::ArrayRef<llvm::SmallVector<Device,8>> tpu_devices,llvm::StringRef topology_attr,llvm::ArrayRef<int64_t> device_assignment_attr)342 GetGeneralTPUExecutionDeviceAssignment(
343 int num_replicas, int num_cores_per_replica,
344 llvm::ArrayRef<llvm::SmallVector<Device, 8>> tpu_devices,
345 llvm::StringRef topology_attr,
346 llvm::ArrayRef<int64_t> device_assignment_attr) {
347 const int num_tasks = tpu_devices.size();
348 const int num_tpus_per_task = tpu_devices[0].size();
349
350 TF_ASSIGN_OR_RETURN(auto topology, ParseTopologyAttr(topology_attr, num_tasks,
351 num_tpus_per_task));
352
353 const int expected_device_assignment_size =
354 num_replicas * num_cores_per_replica * kTPUTopologyRank;
355 const int device_assignment_attr_size = device_assignment_attr.size();
356 if (device_assignment_attr_size != expected_device_assignment_size)
357 return errors::InvalidArgument(
358 "length of '", kDeviceAssignmentAttr,
359 "' must be 'num_replicas' * 'num_cores_per_replica' * ",
360 kTPUTopologyRank, " (", num_replicas, " * ", num_cores_per_replica,
361 " * ", kTPUTopologyRank, "), got ", device_assignment_attr.size());
362
363 const int bound_x = topology.n1();
364 const int bound_y = topology.n2();
365 const int bound_z = topology.n3();
366 const int bound_core = topology.n4();
367
368 // TPU XLA device ID is determined by its device coordinate, from major to
369 // minor coordinates (z, y, x, core).
370 auto location_to_id = [&](int x, int y, int z, int core) {
371 return (x + bound_x * (y + bound_y * z)) * bound_core + core;
372 };
373
374 std::vector<bool> used_device_ids(bound_x * bound_y * bound_z * bound_core,
375 false);
376 TPUDevicesAndHosts devices_and_hosts(
377 num_replicas, llvm::SmallVector<TPUDeviceAndHost, 8>(
378 num_cores_per_replica, TPUDeviceAndHost()));
379 xla::DeviceAssignment device_assignment(num_replicas, num_cores_per_replica);
380 int pos = 0;
381 for (int replica = 0; replica < num_replicas; ++replica) {
382 for (int logical_core = 0; logical_core < num_cores_per_replica;
383 ++logical_core) {
384 int x = device_assignment_attr[pos++];
385 int y = device_assignment_attr[pos++];
386 int z = device_assignment_attr[pos++];
387 int core = device_assignment_attr[pos++];
388 if (DeviceCoordinateOutOfBound(x, y, z, core, bound_x, bound_y, bound_z,
389 bound_core))
390 return DeviceCoordinateErrorMsg(kDeviceAssignmentAttr, x, y, z, core,
391 bound_x, bound_y, bound_z, bound_core);
392
393 TaskAndDevice task_and_device = topology(x, y, z, core);
394 const int task = task_and_device.task;
395 const int device = task_and_device.device;
396 if (task == -1 || device == -1)
397 return errors::InvalidArgument(
398 "no TPU device found for '", kDeviceAssignmentAttr,
399 "' device coordinate (", x, ", ", y, ", ", z, ", ", core, ")");
400
401 const int device_id = location_to_id(x, y, z, core);
402 if (used_device_ids[device_id])
403 return DuplicateCoordinateErrorMsg(kDeviceAssignmentAttr, x, y, z,
404 core);
405
406 used_device_ids[device_id] = true;
407 device_assignment(replica, logical_core) = device_id;
408 auto& device_and_host = devices_and_hosts[replica][logical_core];
409 const auto& tpu_device = tpu_devices[task][device];
410 device_and_host.device = DeviceNameUtils::ParsedNameToString(tpu_device);
411 device_and_host.host = GetCPUHostDeviceForTPUDevice(tpu_device);
412 }
413 }
414
415 xla::DeviceAssignmentProto device_assignment_proto;
416 TF_RETURN_IF_ERROR(device_assignment.Serialize(&device_assignment_proto));
417
418 return std::pair<TPUDevicesAndHosts, xla::DeviceAssignmentProto>(
419 std::move(devices_and_hosts), std::move(device_assignment_proto));
420 }
421
422 } // anonymous namespace
423
GetDeviceCoordinates(mlir::ArrayAttr device_assignment_attr)424 StatusOr<llvm::SmallVector<int64_t, 8>> GetDeviceCoordinates(
425 mlir::ArrayAttr device_assignment_attr) {
426 llvm::SmallVector<int64_t, 8> device_coordinates;
427 device_coordinates.reserve(device_assignment_attr.size());
428
429 for (auto device_coordinate_and_idx :
430 llvm::enumerate(device_assignment_attr)) {
431 auto device_coordinate =
432 device_coordinate_and_idx.value().dyn_cast<mlir::IntegerAttr>();
433 if (!device_coordinate)
434 return errors::InvalidArgument(
435 llvm::formatv(kBadIntArrayElementMsg, kDeviceAssignmentAttr,
436 device_coordinate_and_idx.index())
437 .str());
438
439 device_coordinates.push_back(device_coordinate.getInt());
440 }
441
442 return device_coordinates;
443 }
444
GetTPUCompilationAndExecutionDevices(Devices devices,int num_replicas,int num_cores_per_replica,llvm::StringRef topology_attr,llvm::ArrayRef<int64_t> device_assignment_attr)445 StatusOr<TPUDeviceAssignment> GetTPUCompilationAndExecutionDevices(
446 Devices devices, int num_replicas, int num_cores_per_replica,
447 llvm::StringRef topology_attr,
448 llvm::ArrayRef<int64_t> device_assignment_attr) {
449 // Collect TPU_SYSTEM devices.
450 llvm::SmallVector<Device, 8> system_devices;
451 TF_RETURN_IF_ERROR(GetTPUSystemDevices(devices, &system_devices));
452
453 // Collect TPU devices based on TPU_SYSTEM devices collected earlier.
454 llvm::SmallVector<llvm::SmallVector<Device, 8>, 8> tpu_devices;
455 TF_RETURN_IF_ERROR(GetTPUDevices(devices, system_devices, &tpu_devices));
456
457 std::string compilation_device = GetTPUCompilationDevice(system_devices[0]);
458
459 if (topology_attr.empty()) {
460 if (!device_assignment_attr.empty())
461 return errors::InvalidArgument("'", kDeviceAssignmentAttr,
462 "' must not be set when '", kTopologyAttr,
463 "' is not set");
464
465 TF_ASSIGN_OR_RETURN(auto execution_devices,
466 GetFullMeshTPUExecutionDeviceAssignment(
467 num_replicas, num_cores_per_replica, tpu_devices));
468 return TPUDeviceAssignment(compilation_device,
469 std::move(execution_devices));
470 }
471
472 TF_ASSIGN_OR_RETURN(auto devices_and_ids,
473 GetGeneralTPUExecutionDeviceAssignment(
474 num_replicas, num_cores_per_replica, tpu_devices,
475 topology_attr, device_assignment_attr));
476 return TPUDeviceAssignment(compilation_device,
477 std::move(devices_and_ids.first),
478 std::move(devices_and_ids.second));
479 }
480
GetDeviceAliasForLogicalCore(int core_index)481 std::string GetDeviceAliasForLogicalCore(int core_index) {
482 return llvm::formatv("{0}_{1}", kTPUReplicatedCore, core_index).str();
483 }
484
HasModelParallelism(mlir::tf_device::ClusterOp cluster)485 bool HasModelParallelism(mlir::tf_device::ClusterOp cluster) {
486 mlir::IntegerAttr num_cores_per_replica_attr =
487 cluster->getAttrOfType<mlir::IntegerAttr>(
488 tensorflow::kNumCoresPerReplicaAttr);
489 if (!num_cores_per_replica_attr) return false;
490 return num_cores_per_replica_attr.getInt() != 1;
491 }
492
GetHostDeviceOutsideComputation(mlir::TF::RuntimeDevices devices,mlir::tf_device::ClusterOp cluster,std::string * host_device)493 mlir::LogicalResult GetHostDeviceOutsideComputation(
494 mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster,
495 std::string* host_device) {
496 auto replicate = cluster->getParentOfType<mlir::tf_device::ReplicateOp>();
497 if (replicate) {
498 *host_device = tensorflow::kTPUReplicatedHost;
499 return mlir::success();
500 }
501
502 auto topology_attr =
503 cluster->getAttrOfType<mlir::StringAttr>(tensorflow::kTopologyAttr);
504 if (!topology_attr)
505 return cluster.emitOpError("cluster op missing `topology` attribute");
506
507 auto num_cores_per_replica_attr = cluster->getAttrOfType<mlir::IntegerAttr>(
508 tensorflow::kNumCoresPerReplicaAttr);
509 if (!num_cores_per_replica_attr)
510 return cluster.emitOpError(
511 llvm::formatv("requires attribute '{0}'",
512 tensorflow::kNumCoresPerReplicaAttr)
513 .str());
514
515 auto device_assignment_attr = cluster->getAttrOfType<mlir::ArrayAttr>(
516 tensorflow::kDeviceAssignmentAttr);
517 if (!device_assignment_attr)
518 return cluster.emitOpError(llvm::formatv("requires attribute '{0}'",
519 tensorflow::kDeviceAssignmentAttr)
520 .str());
521
522 auto status_or_device_coodinates =
523 tensorflow::GetDeviceCoordinates(device_assignment_attr);
524
525 if (!status_or_device_coodinates.ok())
526 return cluster.emitError()
527 << "error in fetching tpu device coordinates: "
528 << status_or_device_coodinates.status().error_message();
529
530 // Determine compilation and execution devices.
531 auto status_or_tpu_device_assignment =
532 tensorflow::GetTPUCompilationAndExecutionDevices(
533 devices.device_names(), /*num_replicas=*/1,
534 num_cores_per_replica_attr.getInt(), topology_attr.getValue(),
535 std::move(status_or_device_coodinates).value());
536 if (!status_or_tpu_device_assignment.ok())
537 return cluster.emitError()
538 << "error in fetching TPU compilation/execution devices: "
539 << status_or_tpu_device_assignment.status().error_message();
540 auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
541
542 *host_device = tpu_device_assignment.tpu_devices[0][0].host;
543 return mlir::success();
544 }
545
IsTPUDevice(llvm::StringRef device)546 bool IsTPUDevice(llvm::StringRef device) {
547 Device parsed_device;
548 if (!DeviceNameUtils::ParseFullName(mlir::StringRefToView(device),
549 &parsed_device))
550 return false;
551 return parsed_device.has_type && parsed_device.type == kDeviceTPU;
552 }
553
IsTPUReplicatedCore(llvm::StringRef device)554 bool IsTPUReplicatedCore(llvm::StringRef device) {
555 Device parsed_device;
556 if (!DeviceNameUtils::ParseFullName(mlir::StringRefToView(device),
557 &parsed_device))
558 return false;
559 return parsed_device.has_type && parsed_device.type == kTPUReplicatedCore;
560 }
561 } // namespace tensorflow
562