• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <iterator>
21 #include <string>
22 #include <type_traits>
23 #include <utility>
24 
25 #include "absl/strings/string_view.h"
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/iterator_range.h"
30 #include "llvm/Support/FormatVariadic.h"
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "tensorflow/compiler/mlir/utils/string_container_utils.h"
33 #include "tensorflow/compiler/xla/array4d.h"
34 #include "tensorflow/compiler/xla/service/computation_placer.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
39 #include "tensorflow/core/util/device_name_utils.h"
40 #include "tensorflow/stream_executor/lib/statusor.h"
41 
42 namespace tensorflow {
43 
44 const char* const kTPUReplicatedHost = "TPU_REPLICATED_HOST";
45 const char* const kNumCoresPerReplicaAttr = "num_cores_per_replica";
46 const char* const kTopologyAttr = "topology";
47 const char* const kDeviceAssignmentAttr = "device_assignment";
48 
49 // Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4
50 // topology.
51 constexpr int kTPUTopologyRank = 4;
52 
53 constexpr char kDeviceTPUSystem[] = "TPU_SYSTEM";
54 constexpr char kDeviceTPU[] = "TPU";
55 constexpr char kTPUReplicatedCore[] = "TPU_REPLICATED_CORE";
56 constexpr char kBadIntArrayElementMsg[] =
57     "bad '{0}' attribute at index {1}, not an int";
58 
59 using Device = DeviceNameUtils::ParsedName;
60 using Devices = llvm::ArrayRef<DeviceNameUtils::ParsedName>;
61 
62 namespace {
63 // Finds matching devices in `devices` based on pattern `spec`.
FindMatchingDevices(Devices devices,const Device & spec,llvm::SmallVectorImpl<Device> * matched_devices)64 void FindMatchingDevices(Devices devices, const Device& spec,
65                          llvm::SmallVectorImpl<Device>* matched_devices) {
66   for (const auto& device : devices)
67     if (DeviceNameUtils::IsCompleteSpecification(spec, device))
68       matched_devices->push_back(device);
69 }
70 
71 // Creates error message for a conflicting attribute of a device.
72 template <typename T>
MismatchedTPUSystemAttributeErr(absl::string_view attribute,T a,T b)73 Status MismatchedTPUSystemAttributeErr(absl::string_view attribute, T a, T b) {
74   return errors::InvalidArgument("found ", kDeviceTPUSystem,
75                                  " devices with conflicting ", attribute, "s '",
76                                  a, "' and '", b, "'");
77 }
78 
79 // Finds TPU_SYSTEM:0 devices in `devices`. If multiple TPU_SYSTEM devices are
80 // found, the first one lexicographically is returned. If no TPU_SYSTEM device
81 // is found or if there are multiple TPU_SYSTEM devices with different jobs or
82 // replicas, a failure will be returned.
GetTPUSystemDevices(Devices devices,llvm::SmallVectorImpl<Device> * matched_devices)83 Status GetTPUSystemDevices(Devices devices,
84                            llvm::SmallVectorImpl<Device>* matched_devices) {
85   Device spec;
86   spec.type = kDeviceTPUSystem;
87   spec.has_type = true;
88   spec.id = 0;
89   spec.has_id = true;
90 
91   llvm::SmallVector<Device, 8> system_devices;
92   FindMatchingDevices(devices, spec, &system_devices);
93   if (system_devices.empty())
94     return errors::InvalidArgument("no ", kDeviceTPUSystem, " devices found");
95 
96   // Check that all system devices are part of the same job.
97   const auto& job = system_devices[0].job;
98   auto replica = system_devices[0].replica;
99   for (const auto& device : llvm::make_range(std::next(system_devices.begin()),
100                                              system_devices.end())) {
101     if (device.job != job)
102       return MismatchedTPUSystemAttributeErr("job", job, device.job);
103 
104     if (device.replica != replica)
105       return MismatchedTPUSystemAttributeErr("replica", replica,
106                                              device.replica);
107   }
108 
109   // Sort by task to be deterministic.
110   std::sort(system_devices.begin(), system_devices.end(),
111             [](const Device& a, const Device& b) { return a.task < b.task; });
112 
113   matched_devices->swap(system_devices);
114 
115   return Status::OK();
116 }
117 
118 // Finds TPU devices associated to system device based on spec (e.g. from
119 // GetTPUSystemDevices). If the number of TPU devices per host do not match for
120 // every host, a failure will be returned.
GetTPUDevices(Devices devices,llvm::ArrayRef<Device> system_devices,llvm::SmallVectorImpl<llvm::SmallVector<Device,8>> * tpu_devices)121 Status GetTPUDevices(
122     Devices devices, llvm::ArrayRef<Device> system_devices,
123     llvm::SmallVectorImpl<llvm::SmallVector<Device, 8>>* tpu_devices) {
124   tpu_devices->reserve(system_devices.size());
125 
126   auto lookup = [&devices](Device device_spec) {
127     device_spec.has_type = true;
128     device_spec.type = kDeviceTPU;
129     // Enumerate all the available TPUs.
130     device_spec.has_id = false;
131 
132     llvm::SmallVector<Device, 8> host_tpu_devices;
133     FindMatchingDevices(devices, device_spec, &host_tpu_devices);
134 
135     // Sort devices by id.
136     std::sort(host_tpu_devices.begin(), host_tpu_devices.end(),
137               [](const Device& i, const Device& j) { return i.id < j.id; });
138     return host_tpu_devices;
139   };
140 
141   int num_tpus_per_host = 0;
142   {
143     const auto& device = system_devices[0];
144     auto host_tpu_devices = lookup(device);
145     num_tpus_per_host = host_tpu_devices.size();
146     tpu_devices->push_back(std::move(host_tpu_devices));
147   }
148 
149   for (const auto& device_spec : llvm::make_range(
150            std::next(system_devices.begin()), system_devices.end())) {
151     auto host_tpu_devices = lookup(device_spec);
152     // Check number of TPU devices per host all match.
153     const int64 host_tpu_devices_size = host_tpu_devices.size();
154     if (num_tpus_per_host != host_tpu_devices_size)
155       return errors::InvalidArgument(
156           "expected the number of TPU devices per host to be ",
157           num_tpus_per_host, ", got ", host_tpu_devices.size());
158 
159     tpu_devices->push_back(std::move(host_tpu_devices));
160   }
161 
162   return Status::OK();
163 }
164 
165 // Finds the compilation device from system device.
GetTPUCompilationDevice(Device system_device)166 std::string GetTPUCompilationDevice(Device system_device) {
167   // TODO(b/110910013) GetTPUSystemDevices parses the spec and returns the
168   // TPU_SYSTEM device, which we replace with the CPU device. We do this
169   // replacement because we want to place the `tf._TPUCompileMlir` explicitly on
170   // CPU devices of the same job as the TPU_SYSTEM device.
171   system_device.type = tensorflow::DEVICE_CPU;
172   return DeviceNameUtils::ParsedNameToString(system_device);
173 }
174 
175 // Finds the host CPU device for a given TPU device.
GetCPUHostDeviceForTPUDevice(Device tpu_device)176 std::string GetCPUHostDeviceForTPUDevice(Device tpu_device) {
177   tpu_device.type = DEVICE_CPU;
178   tpu_device.id = 0;
179   return DeviceNameUtils::ParsedNameToString(tpu_device);
180 }
181 
182 // Determines execution devices when topology and device assignment are not
183 // defined. This is a special case where a single core computation is replicated
184 // to every core in the mesh. TPU devices are simply added to
185 // `execution_devices` of one replica. `num_replicas` must be 1 or the total
186 // number of TPU devices available, and `num_cores_per_replica` must be 1.
GetFullMeshTPUExecutionDeviceAssignment(int num_replicas,int num_cores_per_replica,llvm::ArrayRef<llvm::SmallVector<Device,8>> tpu_devices)187 StatusOr<TPUDevicesAndHosts> GetFullMeshTPUExecutionDeviceAssignment(
188     int num_replicas, int num_cores_per_replica,
189     llvm::ArrayRef<llvm::SmallVector<Device, 8>> tpu_devices) {
190   const int num_tasks = tpu_devices.size();
191   const int num_tpus_per_task = tpu_devices[0].size();
192   const int num_tpu_devices = num_tasks * num_tpus_per_task;
193 
194   if (num_replicas != 1 && num_replicas != num_tpu_devices)
195     return errors::InvalidArgument("'num_replicas' must be equal to 1 or ",
196                                    num_tpu_devices, ", got ", num_replicas);
197 
198   if (num_cores_per_replica != 1)
199     return errors::InvalidArgument(
200         "'num_cores_per_replica' must be equal to 1, got ",
201         num_cores_per_replica);
202 
203   TPUDevicesAndHosts devices_and_hosts;
204   devices_and_hosts.reserve(num_replicas);
205   for (int i = 0; i < num_replicas; ++i) {
206     const int task = i / num_tpus_per_task;
207     const int device = i % num_tpus_per_task;
208     const auto& tpu_device = tpu_devices[task][device];
209     devices_and_hosts.push_back({TPUDeviceAndHost(
210         /*device=*/tensorflow::DeviceNameUtils::ParsedNameToString(tpu_device),
211         /*host=*/GetCPUHostDeviceForTPUDevice(tpu_device))});
212   }
213 
214   return devices_and_hosts;
215 }
216 
217 // Helper struct for keeping track of task and device for an associated TPU
218 // device coordinate.
219 struct TaskAndDevice {
TaskAndDevicetensorflow::__anon646e35250111::TaskAndDevice220   TaskAndDevice() {}
TaskAndDevicetensorflow::__anon646e35250111::TaskAndDevice221   TaskAndDevice(int task, int device) : task(task), device(device) {}
222 
223   int task = -1;
224   int device = -1;
225 };
226 
227 // Checks if device coordinate is outside of topology mesh shape bounds.
DeviceCoordinateOutOfBound(int x,int y,int z,int core,int bound_x,int bound_y,int bound_z,int bound_core)228 bool DeviceCoordinateOutOfBound(int x, int y, int z, int core, int bound_x,
229                                 int bound_y, int bound_z, int bound_core) {
230   return x < 0 || x >= bound_x || y < 0 || y >= bound_y || z < 0 ||
231          z >= bound_z || core < 0 || core >= bound_core;
232 }
233 
234 // Creates error message for an out of bound device coordinate.
DeviceCoordinateErrorMsg(absl::string_view attribute,int x,int y,int z,int core,int bound_x,int bound_y,int bound_z,int bound_core)235 Status DeviceCoordinateErrorMsg(absl::string_view attribute, int x, int y,
236                                 int z, int core, int bound_x, int bound_y,
237                                 int bound_z, int bound_core) {
238   return errors::InvalidArgument("device coordinate (", x, ", ", y, ", ", z,
239                                  ", ", core, ") in '", attribute,
240                                  "' is outside of mesh shape (", bound_x, ", ",
241                                  bound_y, ", ", bound_z, ", ", bound_core, ")");
242 }
243 
244 // Creates error message for a duplicate device coordinate.
DuplicateCoordinateErrorMsg(absl::string_view attribute,int x,int y,int z,int core)245 Status DuplicateCoordinateErrorMsg(absl::string_view attribute, int x, int y,
246                                    int z, int core) {
247   return errors::InvalidArgument("'", attribute,
248                                  "' has duplicate device coordinate (", x, ", ",
249                                  y, ", ", z, ", ", core, ")");
250 }
251 
252 // Parses and validates topology (serialized string of TopologyProto), and maps
253 // device coordinate (x, y, z, core) to task and device (of available TPUs).
254 // Topology attribute device coordinates are ordered by task then device (major
255 // to minor).
256 //
257 // A valid TopologyProto must have:
258 //  - a valid mesh shape (rank 4 with positive dimensions)
259 //  - `num_tasks` and `num_tpu_devices_per_task` must match the number of
260 //    available TPU hosts and devices per host
261 //  - device coordinates within the mesh shape
262 //  - no duplicate device coordinates
263 //  - number of device coordinates (in tuple 3) match number of availabe TPUs
ParseTopologyAttr(llvm::StringRef topology_attr,int num_tasks,int num_tpus_per_task)264 StatusOr<xla::Array4D<TaskAndDevice>> ParseTopologyAttr(
265     llvm::StringRef topology_attr, int num_tasks, int num_tpus_per_task) {
266   tpu::TopologyProto topology_proto;
267   if (!topology_proto.ParseFromString(topology_attr.str()))
268     return errors::InvalidArgument("failed to parse '", kTopologyAttr,
269                                    "' attribute to TopologyProto");
270 
271   if (topology_proto.mesh_shape_size() != kTPUTopologyRank)
272     return errors::InvalidArgument(
273         "'", kTopologyAttr, "' 'mesh_shape' must be rank ", kTPUTopologyRank,
274         ", got rank ", topology_proto.mesh_shape_size());
275 
276   for (auto mesh_shape_dim : llvm::enumerate(topology_proto.mesh_shape()))
277     if (mesh_shape_dim.value() <= 0)
278       return errors::InvalidArgument(
279           "'", kTopologyAttr, "' 'mesh_shape' dimension ",
280           mesh_shape_dim.index(), " must be positive, got ",
281           mesh_shape_dim.value());
282 
283   if (topology_proto.num_tasks() != num_tasks)
284     return errors::InvalidArgument(
285         "number of tasks from available TPU devices must be 'num_tasks' in '",
286         kTopologyAttr, "' (", topology_proto.num_tasks(), "), got ", num_tasks);
287 
288   if (topology_proto.num_tpu_devices_per_task() != num_tpus_per_task)
289     return errors::InvalidArgument(
290         "number of TPU devices available per task must be "
291         "'num_tpu_devices_per_task' in '",
292         kTopologyAttr, "' (", topology_proto.num_tpu_devices_per_task(),
293         "), got ", num_tpus_per_task);
294 
295   const int expected_device_coordinates_size =
296       num_tasks * num_tpus_per_task * kTPUTopologyRank;
297   if (topology_proto.device_coordinates_size() !=
298       expected_device_coordinates_size)
299     return errors::InvalidArgument(
300         "length of 'device_coordinates' in '", kTopologyAttr,
301         "' must be 'num_tasks' * 'num_tpus_per_task' * ", kTPUTopologyRank,
302         " (", num_tasks, " * ", num_tpus_per_task, " * ", kTPUTopologyRank,
303         "), got ", topology_proto.device_coordinates_size());
304 
305   const int bound_x = topology_proto.mesh_shape(0);
306   const int bound_y = topology_proto.mesh_shape(1);
307   const int bound_z = topology_proto.mesh_shape(2);
308   const int bound_core = topology_proto.mesh_shape(3);
309 
310   xla::Array4D<TaskAndDevice> topology(bound_x, bound_y, bound_z, bound_core);
311   int pos = 0;
312   for (int task = 0; task < num_tasks; ++task) {
313     for (int device = 0; device < num_tpus_per_task; ++device) {
314       int x = topology_proto.device_coordinates(pos++);
315       int y = topology_proto.device_coordinates(pos++);
316       int z = topology_proto.device_coordinates(pos++);
317       int core = topology_proto.device_coordinates(pos++);
318       if (DeviceCoordinateOutOfBound(x, y, z, core, bound_x, bound_y, bound_z,
319                                      bound_core))
320         return DeviceCoordinateErrorMsg(kTopologyAttr, x, y, z, core, bound_x,
321                                         bound_y, bound_z, bound_core);
322 
323       auto& task_and_device = topology(x, y, z, core);
324       if (task_and_device.task != -1)
325         return DuplicateCoordinateErrorMsg(kTopologyAttr, x, y, z, core);
326 
327       task_and_device = {task, device};
328     }
329   }
330 
331   return topology;
332 }
333 
334 // Determines execution devices when topology and device assignment are defined.
335 // With a topology device coordinate to task and device mapping, device
336 // assignment device coordinates can then be mapped to task and device for TPU
337 // devices. The device assignment array is also validated.
338 //
339 // A valid device assignment array must have:
340 //  - device coordinates within the topology mesh shape
341 //  - no duplicate device coordinates
342 //  - number of device coordinates (in tuple 3) match number 'num_replicas' *
343 //    'num_cores_per_replica'
344 //  - a TPU device associated with each device coordinate
345 StatusOr<std::pair<TPUDevicesAndHosts, xla::DeviceAssignmentProto>>
GetGeneralTPUExecutionDeviceAssignment(int num_replicas,int num_cores_per_replica,llvm::ArrayRef<llvm::SmallVector<Device,8>> tpu_devices,llvm::StringRef topology_attr,llvm::ArrayRef<int64_t> device_assignment_attr)346 GetGeneralTPUExecutionDeviceAssignment(
347     int num_replicas, int num_cores_per_replica,
348     llvm::ArrayRef<llvm::SmallVector<Device, 8>> tpu_devices,
349     llvm::StringRef topology_attr,
350     llvm::ArrayRef<int64_t> device_assignment_attr) {
351   const int num_tasks = tpu_devices.size();
352   const int num_tpus_per_task = tpu_devices[0].size();
353 
354   TF_ASSIGN_OR_RETURN(auto topology, ParseTopologyAttr(topology_attr, num_tasks,
355                                                        num_tpus_per_task));
356 
357   const int expected_device_assignment_size =
358       num_replicas * num_cores_per_replica * kTPUTopologyRank;
359   const int device_assignment_attr_size = device_assignment_attr.size();
360   if (device_assignment_attr_size != expected_device_assignment_size)
361     return errors::InvalidArgument(
362         "length of '", kDeviceAssignmentAttr,
363         "' must be 'num_replicas' * 'num_cores_per_replica' * ",
364         kTPUTopologyRank, " (", num_replicas, " * ", num_cores_per_replica,
365         " * ", kTPUTopologyRank, "), got ", device_assignment_attr.size());
366 
367   const int bound_x = topology.n1();
368   const int bound_y = topology.n2();
369   const int bound_z = topology.n3();
370   const int bound_core = topology.n4();
371 
372   // TPU XLA device ID is determined by its device coordinate, from major to
373   // minor coordinates (z, y, x, core).
374   auto location_to_id = [&](int x, int y, int z, int core) {
375     return (x + bound_x * (y + bound_y * z)) * bound_core + core;
376   };
377 
378   std::vector<bool> used_device_ids(bound_x * bound_y * bound_z * bound_core,
379                                     false);
380   TPUDevicesAndHosts devices_and_hosts(
381       num_replicas, llvm::SmallVector<TPUDeviceAndHost, 8>(
382                         num_cores_per_replica, TPUDeviceAndHost()));
383   xla::DeviceAssignment device_assignment(num_replicas, num_cores_per_replica);
384   int pos = 0;
385   for (int replica = 0; replica < num_replicas; ++replica) {
386     for (int logical_core = 0; logical_core < num_cores_per_replica;
387          ++logical_core) {
388       int x = device_assignment_attr[pos++];
389       int y = device_assignment_attr[pos++];
390       int z = device_assignment_attr[pos++];
391       int core = device_assignment_attr[pos++];
392       if (DeviceCoordinateOutOfBound(x, y, z, core, bound_x, bound_y, bound_z,
393                                      bound_core))
394         return DeviceCoordinateErrorMsg(kDeviceAssignmentAttr, x, y, z, core,
395                                         bound_x, bound_y, bound_z, bound_core);
396 
397       TaskAndDevice task_and_device = topology(x, y, z, core);
398       const int task = task_and_device.task;
399       const int device = task_and_device.device;
400       if (task == -1 || device == -1)
401         return errors::InvalidArgument(
402             "no TPU device found for '", kDeviceAssignmentAttr,
403             "' device coordinate (", x, ", ", y, ", ", z, ", ", core, ")");
404 
405       const int device_id = location_to_id(x, y, z, core);
406       if (used_device_ids[device_id])
407         return DuplicateCoordinateErrorMsg(kDeviceAssignmentAttr, x, y, z,
408                                            core);
409 
410       used_device_ids[device_id] = true;
411       device_assignment(replica, logical_core) = device_id;
412       auto& device_and_host = devices_and_hosts[replica][logical_core];
413       const auto& tpu_device = tpu_devices[task][device];
414       device_and_host.device = DeviceNameUtils::ParsedNameToString(tpu_device);
415       device_and_host.host = GetCPUHostDeviceForTPUDevice(tpu_device);
416     }
417   }
418 
419   xla::DeviceAssignmentProto device_assignment_proto;
420   TF_RETURN_IF_ERROR(device_assignment.Serialize(&device_assignment_proto));
421 
422   return std::pair<TPUDevicesAndHosts, xla::DeviceAssignmentProto>(
423       std::move(devices_and_hosts), std::move(device_assignment_proto));
424 }
425 
426 }  // anonymous namespace
427 
GetDeviceCoordinates(mlir::ArrayAttr device_assignment_attr)428 StatusOr<llvm::SmallVector<int64_t, 8>> GetDeviceCoordinates(
429     mlir::ArrayAttr device_assignment_attr) {
430   llvm::SmallVector<int64_t, 8> device_coordinates;
431   device_coordinates.reserve(device_assignment_attr.size());
432 
433   for (auto device_coordinate_and_idx :
434        llvm::enumerate(device_assignment_attr)) {
435     auto device_coordinate =
436         device_coordinate_and_idx.value().dyn_cast<mlir::IntegerAttr>();
437     if (!device_coordinate)
438       return errors::InvalidArgument(
439           llvm::formatv(kBadIntArrayElementMsg, kDeviceAssignmentAttr,
440                         device_coordinate_and_idx.index())
441               .str());
442 
443     device_coordinates.push_back(device_coordinate.getInt());
444   }
445 
446   return device_coordinates;
447 }
448 
GetTPUCompilationAndExecutionDevices(Devices devices,int num_replicas,int num_cores_per_replica,llvm::StringRef topology_attr,llvm::ArrayRef<int64_t> device_assignment_attr)449 StatusOr<TPUDeviceAssignment> GetTPUCompilationAndExecutionDevices(
450     Devices devices, int num_replicas, int num_cores_per_replica,
451     llvm::StringRef topology_attr,
452     llvm::ArrayRef<int64_t> device_assignment_attr) {
453   // Collect TPU_SYSTEM devices.
454   llvm::SmallVector<Device, 8> system_devices;
455   TF_RETURN_IF_ERROR(GetTPUSystemDevices(devices, &system_devices));
456 
457   // Collect TPU devices based on TPU_SYSTEM devices collected earlier.
458   llvm::SmallVector<llvm::SmallVector<Device, 8>, 8> tpu_devices;
459   TF_RETURN_IF_ERROR(GetTPUDevices(devices, system_devices, &tpu_devices));
460 
461   std::string compilation_device = GetTPUCompilationDevice(system_devices[0]);
462 
463   if (topology_attr.empty()) {
464     if (!device_assignment_attr.empty())
465       return errors::InvalidArgument("'", kDeviceAssignmentAttr,
466                                      "' must not be set when '", kTopologyAttr,
467                                      "' is not set");
468 
469     TF_ASSIGN_OR_RETURN(auto execution_devices,
470                         GetFullMeshTPUExecutionDeviceAssignment(
471                             num_replicas, num_cores_per_replica, tpu_devices));
472     return TPUDeviceAssignment(compilation_device,
473                                std::move(execution_devices));
474   }
475 
476   TF_ASSIGN_OR_RETURN(auto devices_and_ids,
477                       GetGeneralTPUExecutionDeviceAssignment(
478                           num_replicas, num_cores_per_replica, tpu_devices,
479                           topology_attr, device_assignment_attr));
480   return TPUDeviceAssignment(compilation_device,
481                              std::move(devices_and_ids.first),
482                              std::move(devices_and_ids.second));
483 }
484 
GetDeviceAliasForLogicalCore(int core_index)485 std::string GetDeviceAliasForLogicalCore(int core_index) {
486   return llvm::formatv("{0}_{1}", kTPUReplicatedCore, core_index).str();
487 }
488 
GetHostDeviceOutsideComputation(mlir::TF::RuntimeDevices devices,mlir::tf_device::ClusterOp cluster,std::string * host_device)489 mlir::LogicalResult GetHostDeviceOutsideComputation(
490     mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster,
491     std::string* host_device) {
492   auto replicate = cluster->getParentOfType<mlir::tf_device::ReplicateOp>();
493   if (replicate) {
494     *host_device = tensorflow::kTPUReplicatedHost;
495     return mlir::success();
496   }
497 
498   auto num_cores_per_replica_attr = cluster->getAttrOfType<mlir::IntegerAttr>(
499       tensorflow::kNumCoresPerReplicaAttr);
500   if (!num_cores_per_replica_attr)
501     return cluster.emitOpError(
502         "cluster op missing `num_cores_per_replica` attribute");
503 
504   if (num_cores_per_replica_attr.getInt() != 1)
505     return cluster.emitOpError(
506         "outside compilation is not supported with model parallelism.");
507 
508   auto topology_attr =
509       cluster->getAttrOfType<mlir::StringAttr>(tensorflow::kTopologyAttr);
510   if (!topology_attr)
511     return cluster.emitOpError("cluster op missing `topology` attribute");
512 
513   auto device_assignment_attr = cluster->getAttrOfType<mlir::ArrayAttr>(
514       tensorflow::kDeviceAssignmentAttr);
515   if (!device_assignment_attr)
516     return cluster.emitOpError(llvm::formatv("requires attribute '{0}'",
517                                              tensorflow::kDeviceAssignmentAttr)
518                                    .str());
519 
520   auto status_or_device_coodinates =
521       tensorflow::GetDeviceCoordinates(device_assignment_attr);
522 
523   if (!status_or_device_coodinates.ok())
524     return cluster.emitError()
525            << "error in fetching tpu device coordinates: "
526            << status_or_device_coodinates.status().error_message();
527 
528   // Determine compilation and execution devices.
529   auto status_or_tpu_device_assignment =
530       tensorflow::GetTPUCompilationAndExecutionDevices(
531           devices.device_names(), /*num_replicas=*/1,
532           /*num_cores_per_replica=*/1, topology_attr.getValue(),
533           status_or_device_coodinates.ConsumeValueOrDie());
534   if (!status_or_tpu_device_assignment.ok())
535     return cluster.emitError()
536            << "error in fetching TPU compilation/execution devices: "
537            << status_or_tpu_device_assignment.status().error_message();
538   auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
539 
540   *host_device = tpu_device_assignment.tpu_devices[0][0].host;
541   return mlir::success();
542 }
543 
IsTPUDevice(llvm::StringRef device)544 bool IsTPUDevice(llvm::StringRef device) {
545   Device parsed_device;
546   if (!DeviceNameUtils::ParseFullName(mlir::StringRefToView(device),
547                                       &parsed_device))
548     return false;
549   return parsed_device.has_type && parsed_device.type == kDeviceTPU;
550 }
551 
552 }  // namespace tensorflow
553