• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <iterator>
21 #include <string>
22 #include <type_traits>
23 #include <utility>
24 
25 #include "absl/strings/string_view.h"
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/iterator_range.h"
30 #include "llvm/Support/FormatVariadic.h"
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
33 #include "tensorflow/compiler/mlir/utils/string_container_utils.h"
34 #include "tensorflow/compiler/xla/array4d.h"
35 #include "tensorflow/compiler/xla/service/computation_placer.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/framework/types.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
40 #include "tensorflow/core/util/device_name_utils.h"
41 #include "tensorflow/stream_executor/lib/statusor.h"
42 
43 namespace tensorflow {
44 
45 // Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4
46 // topology.
47 constexpr int kTPUTopologyRank = 4;
48 
49 constexpr char kDeviceTPUSystem[] = "TPU_SYSTEM";
50 constexpr char kDeviceTPU[] = "TPU";
51 constexpr char kTPUReplicatedCore[] = "TPU_REPLICATED_CORE";
52 constexpr char kBadIntArrayElementMsg[] =
53     "bad '{0}' attribute at index {1}, not an int";
54 
55 using Device = DeviceNameUtils::ParsedName;
56 using Devices = llvm::ArrayRef<DeviceNameUtils::ParsedName>;
57 
58 namespace {
59 // Finds matching devices in `devices` based on pattern `spec`.
FindMatchingDevices(Devices devices,const Device & spec,llvm::SmallVectorImpl<Device> * matched_devices)60 void FindMatchingDevices(Devices devices, const Device& spec,
61                          llvm::SmallVectorImpl<Device>* matched_devices) {
62   for (const auto& device : devices)
63     if (DeviceNameUtils::IsCompleteSpecification(spec, device))
64       matched_devices->push_back(device);
65 }
66 
67 // Creates error message for a conflicting attribute of a device.
68 template <typename T>
MismatchedTPUSystemAttributeErr(absl::string_view attribute,T a,T b)69 Status MismatchedTPUSystemAttributeErr(absl::string_view attribute, T a, T b) {
70   return errors::InvalidArgument("found ", kDeviceTPUSystem,
71                                  " devices with conflicting ", attribute, "s '",
72                                  a, "' and '", b, "'");
73 }
74 
75 // Finds TPU_SYSTEM:0 devices in `devices`. If multiple TPU_SYSTEM devices are
76 // found, the first one lexicographically is returned. If no TPU_SYSTEM device
77 // is found or if there are multiple TPU_SYSTEM devices with different jobs or
78 // replicas, a failure will be returned.
GetTPUSystemDevices(Devices devices,llvm::SmallVectorImpl<Device> * matched_devices)79 Status GetTPUSystemDevices(Devices devices,
80                            llvm::SmallVectorImpl<Device>* matched_devices) {
81   Device spec;
82   spec.type = kDeviceTPUSystem;
83   spec.has_type = true;
84   spec.id = 0;
85   spec.has_id = true;
86 
87   llvm::SmallVector<Device, 8> system_devices;
88   FindMatchingDevices(devices, spec, &system_devices);
89   if (system_devices.empty())
90     return errors::InvalidArgument("no ", kDeviceTPUSystem, " devices found");
91 
92   // Check that all system devices are part of the same job.
93   const auto& job = system_devices[0].job;
94   auto replica = system_devices[0].replica;
95   for (const auto& device : llvm::make_range(std::next(system_devices.begin()),
96                                              system_devices.end())) {
97     if (device.job != job)
98       return MismatchedTPUSystemAttributeErr("job", job, device.job);
99 
100     if (device.replica != replica)
101       return MismatchedTPUSystemAttributeErr("replica", replica,
102                                              device.replica);
103   }
104 
105   // Sort by task to be deterministic.
106   std::sort(system_devices.begin(), system_devices.end(),
107             [](const Device& a, const Device& b) { return a.task < b.task; });
108 
109   matched_devices->swap(system_devices);
110 
111   return OkStatus();
112 }
113 
114 // Finds TPU devices associated to system device based on spec (e.g. from
115 // GetTPUSystemDevices). If the number of TPU devices per host do not match for
116 // every host, a failure will be returned.
GetTPUDevices(Devices devices,llvm::ArrayRef<Device> system_devices,llvm::SmallVectorImpl<llvm::SmallVector<Device,8>> * tpu_devices)117 Status GetTPUDevices(
118     Devices devices, llvm::ArrayRef<Device> system_devices,
119     llvm::SmallVectorImpl<llvm::SmallVector<Device, 8>>* tpu_devices) {
120   tpu_devices->reserve(system_devices.size());
121 
122   auto lookup = [&devices](Device device_spec) {
123     device_spec.has_type = true;
124     device_spec.type = kDeviceTPU;
125     // Enumerate all the available TPUs.
126     device_spec.has_id = false;
127 
128     llvm::SmallVector<Device, 8> host_tpu_devices;
129     FindMatchingDevices(devices, device_spec, &host_tpu_devices);
130 
131     // Sort devices by id.
132     std::sort(host_tpu_devices.begin(), host_tpu_devices.end(),
133               [](const Device& i, const Device& j) { return i.id < j.id; });
134     return host_tpu_devices;
135   };
136 
137   int num_tpus_per_host = 0;
138   {
139     const auto& device = system_devices[0];
140     auto host_tpu_devices = lookup(device);
141     num_tpus_per_host = host_tpu_devices.size();
142     tpu_devices->push_back(std::move(host_tpu_devices));
143   }
144 
145   for (const auto& device_spec : llvm::make_range(
146            std::next(system_devices.begin()), system_devices.end())) {
147     auto host_tpu_devices = lookup(device_spec);
148     // Check number of TPU devices per host all match.
149     const int64_t host_tpu_devices_size = host_tpu_devices.size();
150     if (num_tpus_per_host != host_tpu_devices_size)
151       return errors::InvalidArgument(
152           "expected the number of TPU devices per host to be ",
153           num_tpus_per_host, ", got ", host_tpu_devices.size());
154 
155     tpu_devices->push_back(std::move(host_tpu_devices));
156   }
157 
158   return OkStatus();
159 }
160 
161 // Finds the compilation device from system device.
GetTPUCompilationDevice(Device system_device)162 std::string GetTPUCompilationDevice(Device system_device) {
163   // TODO(b/110910013) GetTPUSystemDevices parses the spec and returns the
164   // TPU_SYSTEM device, which we replace with the CPU device. We do this
165   // replacement because we want to place the `tf._TPUCompileMlir` explicitly on
166   // CPU devices of the same job as the TPU_SYSTEM device.
167   system_device.type = tensorflow::DEVICE_CPU;
168   return DeviceNameUtils::ParsedNameToString(system_device);
169 }
170 
171 // Finds the host CPU device for a given TPU device.
GetCPUHostDeviceForTPUDevice(Device tpu_device)172 std::string GetCPUHostDeviceForTPUDevice(Device tpu_device) {
173   tpu_device.type = DEVICE_CPU;
174   tpu_device.id = 0;
175   return DeviceNameUtils::ParsedNameToString(tpu_device);
176 }
177 
178 // Determines execution devices when topology and device assignment are not
179 // defined. This is a special case where a single core computation is replicated
180 // to every core in the mesh. TPU devices are simply added to
181 // `execution_devices` of one replica. `num_replicas` must be 1 or the total
182 // number of TPU devices available, and `num_cores_per_replica` must be 1.
GetFullMeshTPUExecutionDeviceAssignment(int num_replicas,int num_cores_per_replica,llvm::ArrayRef<llvm::SmallVector<Device,8>> tpu_devices)183 StatusOr<TPUDevicesAndHosts> GetFullMeshTPUExecutionDeviceAssignment(
184     int num_replicas, int num_cores_per_replica,
185     llvm::ArrayRef<llvm::SmallVector<Device, 8>> tpu_devices) {
186   const int num_tasks = tpu_devices.size();
187   const int num_tpus_per_task = tpu_devices[0].size();
188   const int num_tpu_devices = num_tasks * num_tpus_per_task;
189 
190   if (num_replicas != 1 && num_replicas != num_tpu_devices)
191     return errors::InvalidArgument("'num_replicas' must be equal to 1 or ",
192                                    num_tpu_devices, ", got ", num_replicas);
193 
194   if (num_cores_per_replica != 1)
195     return errors::InvalidArgument(
196         "'num_cores_per_replica' must be equal to 1, got ",
197         num_cores_per_replica);
198 
199   TPUDevicesAndHosts devices_and_hosts;
200   devices_and_hosts.reserve(num_replicas);
201   for (int i = 0; i < num_replicas; ++i) {
202     const int task = i / num_tpus_per_task;
203     const int device = i % num_tpus_per_task;
204     const auto& tpu_device = tpu_devices[task][device];
205     devices_and_hosts.push_back({TPUDeviceAndHost(
206         /*device=*/tensorflow::DeviceNameUtils::ParsedNameToString(tpu_device),
207         /*host=*/GetCPUHostDeviceForTPUDevice(tpu_device))});
208   }
209 
210   return devices_and_hosts;
211 }
212 
213 // Helper struct for keeping track of task and device for an associated TPU
214 // device coordinate.
215 struct TaskAndDevice {
TaskAndDevicetensorflow::__anon71db7cc30111::TaskAndDevice216   TaskAndDevice() {}
TaskAndDevicetensorflow::__anon71db7cc30111::TaskAndDevice217   TaskAndDevice(int task, int device) : task(task), device(device) {}
218 
219   int task = -1;
220   int device = -1;
221 };
222 
223 // Checks if device coordinate is outside of topology mesh shape bounds.
DeviceCoordinateOutOfBound(int x,int y,int z,int core,int bound_x,int bound_y,int bound_z,int bound_core)224 bool DeviceCoordinateOutOfBound(int x, int y, int z, int core, int bound_x,
225                                 int bound_y, int bound_z, int bound_core) {
226   return x < 0 || x >= bound_x || y < 0 || y >= bound_y || z < 0 ||
227          z >= bound_z || core < 0 || core >= bound_core;
228 }
229 
230 // Creates error message for an out of bound device coordinate.
DeviceCoordinateErrorMsg(absl::string_view attribute,int x,int y,int z,int core,int bound_x,int bound_y,int bound_z,int bound_core)231 Status DeviceCoordinateErrorMsg(absl::string_view attribute, int x, int y,
232                                 int z, int core, int bound_x, int bound_y,
233                                 int bound_z, int bound_core) {
234   return errors::InvalidArgument("device coordinate (", x, ", ", y, ", ", z,
235                                  ", ", core, ") in '", attribute,
236                                  "' is outside of mesh shape (", bound_x, ", ",
237                                  bound_y, ", ", bound_z, ", ", bound_core, ")");
238 }
239 
240 // Creates error message for a duplicate device coordinate.
DuplicateCoordinateErrorMsg(absl::string_view attribute,int x,int y,int z,int core)241 Status DuplicateCoordinateErrorMsg(absl::string_view attribute, int x, int y,
242                                    int z, int core) {
243   return errors::InvalidArgument("'", attribute,
244                                  "' has duplicate device coordinate (", x, ", ",
245                                  y, ", ", z, ", ", core, ")");
246 }
247 
248 // Parses and validates topology (serialized string of TopologyProto), and maps
249 // device coordinate (x, y, z, core) to task and device (of available TPUs).
250 // Topology attribute device coordinates are ordered by task then device (major
251 // to minor).
252 //
253 // A valid TopologyProto must have:
254 //  - a valid mesh shape (rank 4 with positive dimensions)
255 //  - `num_tasks` and `num_tpu_devices_per_task` must match the number of
256 //    available TPU hosts and devices per host
257 //  - device coordinates within the mesh shape
258 //  - no duplicate device coordinates
259 //  - number of device coordinates (in tuple 3) match number of availabe TPUs
ParseTopologyAttr(llvm::StringRef topology_attr,int num_tasks,int num_tpus_per_task)260 StatusOr<xla::Array4D<TaskAndDevice>> ParseTopologyAttr(
261     llvm::StringRef topology_attr, int num_tasks, int num_tpus_per_task) {
262   tpu::TopologyProto topology_proto;
263   if (!topology_proto.ParseFromString(topology_attr.str()))
264     return errors::InvalidArgument("failed to parse '", kTopologyAttr,
265                                    "' attribute to TopologyProto");
266 
267   if (topology_proto.mesh_shape_size() != kTPUTopologyRank)
268     return errors::InvalidArgument(
269         "'", kTopologyAttr, "' 'mesh_shape' must be rank ", kTPUTopologyRank,
270         ", got rank ", topology_proto.mesh_shape_size());
271 
272   for (auto mesh_shape_dim : llvm::enumerate(topology_proto.mesh_shape()))
273     if (mesh_shape_dim.value() <= 0)
274       return errors::InvalidArgument(
275           "'", kTopologyAttr, "' 'mesh_shape' dimension ",
276           mesh_shape_dim.index(), " must be positive, got ",
277           mesh_shape_dim.value());
278 
279   if (topology_proto.num_tasks() != num_tasks)
280     return errors::InvalidArgument(
281         "number of tasks from available TPU devices must be 'num_tasks' in '",
282         kTopologyAttr, "' (", topology_proto.num_tasks(), "), got ", num_tasks);
283 
284   if (topology_proto.num_tpu_devices_per_task() != num_tpus_per_task)
285     return errors::InvalidArgument(
286         "number of TPU devices available per task must be "
287         "'num_tpu_devices_per_task' in '",
288         kTopologyAttr, "' (", topology_proto.num_tpu_devices_per_task(),
289         "), got ", num_tpus_per_task);
290 
291   const int expected_device_coordinates_size =
292       num_tasks * num_tpus_per_task * kTPUTopologyRank;
293   if (topology_proto.device_coordinates_size() !=
294       expected_device_coordinates_size)
295     return errors::InvalidArgument(
296         "length of 'device_coordinates' in '", kTopologyAttr,
297         "' must be 'num_tasks' * 'num_tpus_per_task' * ", kTPUTopologyRank,
298         " (", num_tasks, " * ", num_tpus_per_task, " * ", kTPUTopologyRank,
299         "), got ", topology_proto.device_coordinates_size());
300 
301   const int bound_x = topology_proto.mesh_shape(0);
302   const int bound_y = topology_proto.mesh_shape(1);
303   const int bound_z = topology_proto.mesh_shape(2);
304   const int bound_core = topology_proto.mesh_shape(3);
305 
306   xla::Array4D<TaskAndDevice> topology(bound_x, bound_y, bound_z, bound_core);
307   int pos = 0;
308   for (int task = 0; task < num_tasks; ++task) {
309     for (int device = 0; device < num_tpus_per_task; ++device) {
310       int x = topology_proto.device_coordinates(pos++);
311       int y = topology_proto.device_coordinates(pos++);
312       int z = topology_proto.device_coordinates(pos++);
313       int core = topology_proto.device_coordinates(pos++);
314       if (DeviceCoordinateOutOfBound(x, y, z, core, bound_x, bound_y, bound_z,
315                                      bound_core))
316         return DeviceCoordinateErrorMsg(kTopologyAttr, x, y, z, core, bound_x,
317                                         bound_y, bound_z, bound_core);
318 
319       auto& task_and_device = topology(x, y, z, core);
320       if (task_and_device.task != -1)
321         return DuplicateCoordinateErrorMsg(kTopologyAttr, x, y, z, core);
322 
323       task_and_device = {task, device};
324     }
325   }
326 
327   return topology;
328 }
329 
330 // Determines execution devices when topology and device assignment are defined.
331 // With a topology device coordinate to task and device mapping, device
332 // assignment device coordinates can then be mapped to task and device for TPU
333 // devices. The device assignment array is also validated.
334 //
335 // A valid device assignment array must have:
336 //  - device coordinates within the topology mesh shape
337 //  - no duplicate device coordinates
338 //  - number of device coordinates (in tuple 3) match number 'num_replicas' *
339 //    'num_cores_per_replica'
340 //  - a TPU device associated with each device coordinate
341 StatusOr<std::pair<TPUDevicesAndHosts, xla::DeviceAssignmentProto>>
GetGeneralTPUExecutionDeviceAssignment(int num_replicas,int num_cores_per_replica,llvm::ArrayRef<llvm::SmallVector<Device,8>> tpu_devices,llvm::StringRef topology_attr,llvm::ArrayRef<int64_t> device_assignment_attr)342 GetGeneralTPUExecutionDeviceAssignment(
343     int num_replicas, int num_cores_per_replica,
344     llvm::ArrayRef<llvm::SmallVector<Device, 8>> tpu_devices,
345     llvm::StringRef topology_attr,
346     llvm::ArrayRef<int64_t> device_assignment_attr) {
347   const int num_tasks = tpu_devices.size();
348   const int num_tpus_per_task = tpu_devices[0].size();
349 
350   TF_ASSIGN_OR_RETURN(auto topology, ParseTopologyAttr(topology_attr, num_tasks,
351                                                        num_tpus_per_task));
352 
353   const int expected_device_assignment_size =
354       num_replicas * num_cores_per_replica * kTPUTopologyRank;
355   const int device_assignment_attr_size = device_assignment_attr.size();
356   if (device_assignment_attr_size != expected_device_assignment_size)
357     return errors::InvalidArgument(
358         "length of '", kDeviceAssignmentAttr,
359         "' must be 'num_replicas' * 'num_cores_per_replica' * ",
360         kTPUTopologyRank, " (", num_replicas, " * ", num_cores_per_replica,
361         " * ", kTPUTopologyRank, "), got ", device_assignment_attr.size());
362 
363   const int bound_x = topology.n1();
364   const int bound_y = topology.n2();
365   const int bound_z = topology.n3();
366   const int bound_core = topology.n4();
367 
368   // TPU XLA device ID is determined by its device coordinate, from major to
369   // minor coordinates (z, y, x, core).
370   auto location_to_id = [&](int x, int y, int z, int core) {
371     return (x + bound_x * (y + bound_y * z)) * bound_core + core;
372   };
373 
374   std::vector<bool> used_device_ids(bound_x * bound_y * bound_z * bound_core,
375                                     false);
376   TPUDevicesAndHosts devices_and_hosts(
377       num_replicas, llvm::SmallVector<TPUDeviceAndHost, 8>(
378                         num_cores_per_replica, TPUDeviceAndHost()));
379   xla::DeviceAssignment device_assignment(num_replicas, num_cores_per_replica);
380   int pos = 0;
381   for (int replica = 0; replica < num_replicas; ++replica) {
382     for (int logical_core = 0; logical_core < num_cores_per_replica;
383          ++logical_core) {
384       int x = device_assignment_attr[pos++];
385       int y = device_assignment_attr[pos++];
386       int z = device_assignment_attr[pos++];
387       int core = device_assignment_attr[pos++];
388       if (DeviceCoordinateOutOfBound(x, y, z, core, bound_x, bound_y, bound_z,
389                                      bound_core))
390         return DeviceCoordinateErrorMsg(kDeviceAssignmentAttr, x, y, z, core,
391                                         bound_x, bound_y, bound_z, bound_core);
392 
393       TaskAndDevice task_and_device = topology(x, y, z, core);
394       const int task = task_and_device.task;
395       const int device = task_and_device.device;
396       if (task == -1 || device == -1)
397         return errors::InvalidArgument(
398             "no TPU device found for '", kDeviceAssignmentAttr,
399             "' device coordinate (", x, ", ", y, ", ", z, ", ", core, ")");
400 
401       const int device_id = location_to_id(x, y, z, core);
402       if (used_device_ids[device_id])
403         return DuplicateCoordinateErrorMsg(kDeviceAssignmentAttr, x, y, z,
404                                            core);
405 
406       used_device_ids[device_id] = true;
407       device_assignment(replica, logical_core) = device_id;
408       auto& device_and_host = devices_and_hosts[replica][logical_core];
409       const auto& tpu_device = tpu_devices[task][device];
410       device_and_host.device = DeviceNameUtils::ParsedNameToString(tpu_device);
411       device_and_host.host = GetCPUHostDeviceForTPUDevice(tpu_device);
412     }
413   }
414 
415   xla::DeviceAssignmentProto device_assignment_proto;
416   TF_RETURN_IF_ERROR(device_assignment.Serialize(&device_assignment_proto));
417 
418   return std::pair<TPUDevicesAndHosts, xla::DeviceAssignmentProto>(
419       std::move(devices_and_hosts), std::move(device_assignment_proto));
420 }
421 
422 }  // anonymous namespace
423 
GetDeviceCoordinates(mlir::ArrayAttr device_assignment_attr)424 StatusOr<llvm::SmallVector<int64_t, 8>> GetDeviceCoordinates(
425     mlir::ArrayAttr device_assignment_attr) {
426   llvm::SmallVector<int64_t, 8> device_coordinates;
427   device_coordinates.reserve(device_assignment_attr.size());
428 
429   for (auto device_coordinate_and_idx :
430        llvm::enumerate(device_assignment_attr)) {
431     auto device_coordinate =
432         device_coordinate_and_idx.value().dyn_cast<mlir::IntegerAttr>();
433     if (!device_coordinate)
434       return errors::InvalidArgument(
435           llvm::formatv(kBadIntArrayElementMsg, kDeviceAssignmentAttr,
436                         device_coordinate_and_idx.index())
437               .str());
438 
439     device_coordinates.push_back(device_coordinate.getInt());
440   }
441 
442   return device_coordinates;
443 }
444 
GetTPUCompilationAndExecutionDevices(Devices devices,int num_replicas,int num_cores_per_replica,llvm::StringRef topology_attr,llvm::ArrayRef<int64_t> device_assignment_attr)445 StatusOr<TPUDeviceAssignment> GetTPUCompilationAndExecutionDevices(
446     Devices devices, int num_replicas, int num_cores_per_replica,
447     llvm::StringRef topology_attr,
448     llvm::ArrayRef<int64_t> device_assignment_attr) {
449   // Collect TPU_SYSTEM devices.
450   llvm::SmallVector<Device, 8> system_devices;
451   TF_RETURN_IF_ERROR(GetTPUSystemDevices(devices, &system_devices));
452 
453   // Collect TPU devices based on TPU_SYSTEM devices collected earlier.
454   llvm::SmallVector<llvm::SmallVector<Device, 8>, 8> tpu_devices;
455   TF_RETURN_IF_ERROR(GetTPUDevices(devices, system_devices, &tpu_devices));
456 
457   std::string compilation_device = GetTPUCompilationDevice(system_devices[0]);
458 
459   if (topology_attr.empty()) {
460     if (!device_assignment_attr.empty())
461       return errors::InvalidArgument("'", kDeviceAssignmentAttr,
462                                      "' must not be set when '", kTopologyAttr,
463                                      "' is not set");
464 
465     TF_ASSIGN_OR_RETURN(auto execution_devices,
466                         GetFullMeshTPUExecutionDeviceAssignment(
467                             num_replicas, num_cores_per_replica, tpu_devices));
468     return TPUDeviceAssignment(compilation_device,
469                                std::move(execution_devices));
470   }
471 
472   TF_ASSIGN_OR_RETURN(auto devices_and_ids,
473                       GetGeneralTPUExecutionDeviceAssignment(
474                           num_replicas, num_cores_per_replica, tpu_devices,
475                           topology_attr, device_assignment_attr));
476   return TPUDeviceAssignment(compilation_device,
477                              std::move(devices_and_ids.first),
478                              std::move(devices_and_ids.second));
479 }
480 
GetDeviceAliasForLogicalCore(int core_index)481 std::string GetDeviceAliasForLogicalCore(int core_index) {
482   return llvm::formatv("{0}_{1}", kTPUReplicatedCore, core_index).str();
483 }
484 
HasModelParallelism(mlir::tf_device::ClusterOp cluster)485 bool HasModelParallelism(mlir::tf_device::ClusterOp cluster) {
486   mlir::IntegerAttr num_cores_per_replica_attr =
487       cluster->getAttrOfType<mlir::IntegerAttr>(
488           tensorflow::kNumCoresPerReplicaAttr);
489   if (!num_cores_per_replica_attr) return false;
490   return num_cores_per_replica_attr.getInt() != 1;
491 }
492 
GetHostDeviceOutsideComputation(mlir::TF::RuntimeDevices devices,mlir::tf_device::ClusterOp cluster,std::string * host_device)493 mlir::LogicalResult GetHostDeviceOutsideComputation(
494     mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster,
495     std::string* host_device) {
496   auto replicate = cluster->getParentOfType<mlir::tf_device::ReplicateOp>();
497   if (replicate) {
498     *host_device = tensorflow::kTPUReplicatedHost;
499     return mlir::success();
500   }
501 
502   auto topology_attr =
503       cluster->getAttrOfType<mlir::StringAttr>(tensorflow::kTopologyAttr);
504   if (!topology_attr)
505     return cluster.emitOpError("cluster op missing `topology` attribute");
506 
507   auto num_cores_per_replica_attr = cluster->getAttrOfType<mlir::IntegerAttr>(
508       tensorflow::kNumCoresPerReplicaAttr);
509   if (!num_cores_per_replica_attr)
510     return cluster.emitOpError(
511         llvm::formatv("requires attribute '{0}'",
512                       tensorflow::kNumCoresPerReplicaAttr)
513             .str());
514 
515   auto device_assignment_attr = cluster->getAttrOfType<mlir::ArrayAttr>(
516       tensorflow::kDeviceAssignmentAttr);
517   if (!device_assignment_attr)
518     return cluster.emitOpError(llvm::formatv("requires attribute '{0}'",
519                                              tensorflow::kDeviceAssignmentAttr)
520                                    .str());
521 
522   auto status_or_device_coodinates =
523       tensorflow::GetDeviceCoordinates(device_assignment_attr);
524 
525   if (!status_or_device_coodinates.ok())
526     return cluster.emitError()
527            << "error in fetching tpu device coordinates: "
528            << status_or_device_coodinates.status().error_message();
529 
530   // Determine compilation and execution devices.
531   auto status_or_tpu_device_assignment =
532       tensorflow::GetTPUCompilationAndExecutionDevices(
533           devices.device_names(), /*num_replicas=*/1,
534           num_cores_per_replica_attr.getInt(), topology_attr.getValue(),
535           std::move(status_or_device_coodinates).value());
536   if (!status_or_tpu_device_assignment.ok())
537     return cluster.emitError()
538            << "error in fetching TPU compilation/execution devices: "
539            << status_or_tpu_device_assignment.status().error_message();
540   auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
541 
542   *host_device = tpu_device_assignment.tpu_devices[0][0].host;
543   return mlir::success();
544 }
545 
IsTPUDevice(llvm::StringRef device)546 bool IsTPUDevice(llvm::StringRef device) {
547   Device parsed_device;
548   if (!DeviceNameUtils::ParseFullName(mlir::StringRefToView(device),
549                                       &parsed_device))
550     return false;
551   return parsed_device.has_type && parsed_device.type == kDeviceTPU;
552 }
553 
IsTPUReplicatedCore(llvm::StringRef device)554 bool IsTPUReplicatedCore(llvm::StringRef device) {
555   Device parsed_device;
556   if (!DeviceNameUtils::ParseFullName(mlir::StringRefToView(device),
557                                       &parsed_device))
558     return false;
559   return parsed_device.has_type && parsed_device.type == kTPUReplicatedCore;
560 }
561 }  // namespace tensorflow
562