• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
17 
18 #include <cstdint>
19 #include <tuple>
20 
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
29 #include "tensorflow/core/util/device_name_utils.h"
30 
31 namespace tensorflow {
32 namespace {
33 
34 using Device = DeviceNameUtils::ParsedName;
35 
DeviceNamesToParsedNames(llvm::ArrayRef<std::string> device_names,llvm::SmallVectorImpl<Device> * parsed_devices)36 bool DeviceNamesToParsedNames(llvm::ArrayRef<std::string> device_names,
37                               llvm::SmallVectorImpl<Device>* parsed_devices) {
38   parsed_devices->reserve(device_names.size());
39   for (const auto& device_name : device_names) {
40     Device parsed_name;
41     if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name))
42       return false;
43 
44     parsed_devices->push_back(parsed_name);
45   }
46   return true;
47 }
48 
49 using DeviceNames = llvm::SmallVector<std::string, 8>;
50 
51 struct ParameterizedDeviceSetTest
52     : ::testing::TestWithParam<std::tuple<DeviceNames, std::string>> {};
53 
TEST_P(ParameterizedDeviceSetTest,BadDeviceSet)54 TEST_P(ParameterizedDeviceSetTest, BadDeviceSet) {
55   llvm::SmallVector<Device, 8> devices;
56   ASSERT_TRUE(DeviceNamesToParsedNames(std::get<0>(GetParam()), &devices));
57   std::string topology_attr;
58   std::vector<int64_t> device_assignment_attr;
59 
60   auto status_or = GetTPUCompilationAndExecutionDevices(
61       devices, /*num_replicas=*/1, /*num_cores_per_replica=*/1, topology_attr,
62       device_assignment_attr);
63   ASSERT_FALSE(status_or.ok());
64   EXPECT_EQ(status_or.status().error_message(), std::get<1>(GetParam()));
65 }
66 
67 INSTANTIATE_TEST_SUITE_P(
68     BadDeviceSet, ParameterizedDeviceSetTest,
69     ::testing::Values(
70         std::make_tuple<DeviceNames, std::string>(
71             {"/job:localhost/replica:0/task:0/device:CPU:0"},
72             "no TPU_SYSTEM devices found"),
73         std::make_tuple<DeviceNames, std::string>(
74             {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
75              "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0"},
76             "found TPU_SYSTEM devices with conflicting jobs 'localhost' and "
77             "'worker'"),
78         std::make_tuple<DeviceNames, std::string>(
79             {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
80              "/job:localhost/replica:1/task:0/device:TPU_SYSTEM:0"},
81             "found TPU_SYSTEM devices with conflicting replicas '0' and '1'"),
82         std::make_tuple<DeviceNames, std::string>(
83             {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
84              "/job:localhost/replica:0/task:0/device:TPU:0",
85              "/job:localhost/replica:0/task:0/device:TPU:1",
86              "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0",
87              "/job:localhost/replica:0/task:1/device:TPU:0"},
88             "expected the number of TPU devices per host to be 2, got 1")));
89 
90 struct ParameterizedMetadataTest
91     : ::testing::TestWithParam<std::tuple<int, int, std::string,
92                                           std::vector<int64_t>, std::string>> {
93 };
94 
TEST_P(ParameterizedMetadataTest,BadMetadata)95 TEST_P(ParameterizedMetadataTest, BadMetadata) {
96   llvm::SmallVector<Device, 8> devices;
97   ASSERT_TRUE(DeviceNamesToParsedNames(
98       {"/job:worker/replica:0/task:0/device:TPU_SYSTEM:0",
99        "/job:worker/replica:0/task:0/device:TPU:0",
100        "/job:worker/replica:0/task:1/device:TPU_SYSTEM:0",
101        "/job:worker/replica:0/task:1/device:TPU:0"},
102       &devices));
103   std::string compilation_device;
104   llvm::SmallVector<llvm::SmallVector<std::string, 8>, 8> execution_devices;
105   llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment;
106 
107   auto status_or = GetTPUCompilationAndExecutionDevices(
108       devices, std::get<0>(GetParam()), std::get<1>(GetParam()),
109       std::get<2>(GetParam()), std::get<3>(GetParam()));
110   ASSERT_FALSE(status_or.ok());
111   EXPECT_EQ(status_or.status().error_message(), std::get<4>(GetParam()));
112 }
113 
TopologyWithMeshShape(llvm::ArrayRef<int> mesh_shape)114 std::string TopologyWithMeshShape(llvm::ArrayRef<int> mesh_shape) {
115   tpu::TopologyProto topology_proto;
116   for (int mesh_dim : mesh_shape) topology_proto.add_mesh_shape(mesh_dim);
117   return topology_proto.SerializeAsString();
118 }
119 
TopologyWithMeshShapeAndTasks(llvm::ArrayRef<int> mesh_shape,int num_tasks,int num_tpu_devices_per_task)120 std::string TopologyWithMeshShapeAndTasks(llvm::ArrayRef<int> mesh_shape,
121                                           int num_tasks,
122                                           int num_tpu_devices_per_task) {
123   tpu::TopologyProto topology_proto;
124   for (int mesh_dim : mesh_shape) topology_proto.add_mesh_shape(mesh_dim);
125   topology_proto.set_num_tasks(num_tasks);
126   topology_proto.set_num_tpu_devices_per_task(num_tpu_devices_per_task);
127   return topology_proto.SerializeAsString();
128 }
129 
TopologyWithDeviceCoordinates(llvm::ArrayRef<int> device_coordinates)130 std::string TopologyWithDeviceCoordinates(
131     llvm::ArrayRef<int> device_coordinates) {
132   tpu::TopologyProto topology_proto;
133   topology_proto.add_mesh_shape(2);
134   topology_proto.add_mesh_shape(1);
135   topology_proto.add_mesh_shape(1);
136   topology_proto.add_mesh_shape(1);
137   topology_proto.set_num_tasks(2);
138   topology_proto.set_num_tpu_devices_per_task(1);
139   for (int device_coordinate : device_coordinates)
140     topology_proto.add_device_coordinates(device_coordinate);
141   return topology_proto.SerializeAsString();
142 }
143 
144 INSTANTIATE_TEST_SUITE_P(
145     BadFullMeshMetadata, ParameterizedMetadataTest,
146     ::testing::Values(
147         std::make_tuple(
148             2, 1, "", std::vector<int64_t>{0},
149             "'device_assignment' must not be set when 'topology' is not set"),
150         std::make_tuple(8, 1, "", std::vector<int64_t>(),
151                         "'num_replicas' must be equal to 1 or 2, got 8"),
152         std::make_tuple(2, 2, "", std::vector<int64_t>(),
153                         "'num_cores_per_replica' must be equal to 1, got 2")));
154 
155 INSTANTIATE_TEST_SUITE_P(
156     BadGeneralTopologyMetadata, ParameterizedMetadataTest,
157     ::testing::Values(
158         std::make_tuple(
159             2, 1, "BAD_TOPOLOGY", std::vector<int64_t>(),
160             "failed to parse 'topology' attribute to TopologyProto"),
161         std::make_tuple(4, 2, TopologyWithMeshShape({0}),
162                         std::vector<int64_t>(),
163                         "'topology' 'mesh_shape' must be rank 4, got rank 1"),
164         std::make_tuple(
165             2, 1, TopologyWithMeshShape({2, 0, 1, 2}), std::vector<int64_t>(),
166             "'topology' 'mesh_shape' dimension 1 must be positive, got 0"),
167         std::make_tuple(2, 1, TopologyWithMeshShapeAndTasks({1, 1, 1, 1}, 1, 1),
168                         std::vector<int64_t>(),
169                         "number of tasks from available TPU devices must be "
170                         "'num_tasks' in 'topology' (1), got 2"),
171         std::make_tuple(2, 1, TopologyWithMeshShapeAndTasks({1, 1, 1, 1}, 2, 2),
172                         std::vector<int64_t>(),
173                         "number of TPU devices available per task must be "
174                         "'num_tpu_devices_per_task' in 'topology' (2), got 1"),
175         std::make_tuple(
176             2, 1, TopologyWithDeviceCoordinates({}), std::vector<int64_t>(),
177             "length of 'device_coordinates' in 'topology' must be 'num_tasks' "
178             "* 'num_tpus_per_task' * 4 (2 * 1 * 4), got 0"),
179         std::make_tuple(
180             2, 1, TopologyWithDeviceCoordinates({-1, 0, 0, 0, 1, 0, 0, 0}),
181             std::vector<int64_t>(),
182             "device coordinate (-1, 0, 0, 0) in 'topology' is outside "
183             "of mesh shape (2, 1, 1, 1)"),
184         std::make_tuple(
185             2, 1, TopologyWithDeviceCoordinates({2, 0, 0, 0, 1, 0, 0, 0}),
186             std::vector<int64_t>(),
187             "device coordinate (2, 0, 0, 0) in 'topology' is outside "
188             "of mesh shape (2, 1, 1, 1)"),
189         std::make_tuple(
190             2, 1, TopologyWithDeviceCoordinates({0, -1, 0, 0, 1, 0, 0, 0}),
191             std::vector<int64_t>(),
192             "device coordinate (0, -1, 0, 0) in 'topology' is outside "
193             "of mesh shape (2, 1, 1, 1)"),
194         std::make_tuple(
195             2, 1, TopologyWithDeviceCoordinates({0, 1, 0, 0, 1, 0, 0, 0}),
196             std::vector<int64_t>(),
197             "device coordinate (0, 1, 0, 0) in 'topology' is outside "
198             "of mesh shape (2, 1, 1, 1)"),
199         std::make_tuple(
200             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, -1, 1, 0, 0, 0}),
201             std::vector<int64_t>(),
202             "device coordinate (0, 0, 0, -1) in 'topology' is outside "
203             "of mesh shape (2, 1, 1, 1)"),
204         std::make_tuple(
205             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 1, 1, 0, 0, 0}),
206             std::vector<int64_t>(),
207             "device coordinate (0, 0, 0, 1) in 'topology' is outside "
208             "of mesh shape (2, 1, 1, 1)"),
209         std::make_tuple(
210             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 0, 0, 0, 0}),
211             std::vector<int64_t>(),
212             "'topology' has duplicate device coordinate (0, 0, 0, 0)")));
213 
214 INSTANTIATE_TEST_SUITE_P(
215     BadGeneralDeviceAssignmentMetadata, ParameterizedMetadataTest,
216     ::testing::Values(
217         std::make_tuple(2, 1,
218                         TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
219                         std::vector<int64_t>(),
220                         "length of 'device_assignment' must be 'num_replicas' "
221                         "* 'num_cores_per_replica' * 4 (2 * 1 * 4), got 0"),
222         std::make_tuple(
223             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
224             std::vector<int64_t>{-1, 0, 0, 0, 0, 0, 0, 0},
225             "device coordinate (-1, 0, 0, 0) in 'device_assignment' "
226             "is outside of mesh shape (2, 1, 1, 1)"),
227         std::make_tuple(
228             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
229             std::vector<int64_t>{2, 0, 0, 0, 0, 0, 0, 0},
230             "device coordinate (2, 0, 0, 0) in 'device_assignment' is "
231             "outside of mesh shape (2, 1, 1, 1)"),
232         std::make_tuple(
233             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
234             std::vector<int64_t>{0, -1, 0, 0, 0, 0, 0, 0},
235             "device coordinate (0, -1, 0, 0) in 'device_assignment' "
236             "is outside of mesh shape (2, 1, 1, 1)"),
237         std::make_tuple(
238             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
239             std::vector<int64_t>{0, 1, 0, 0, 0, 0, 0, 0},
240             "device coordinate (0, 1, 0, 0) in 'device_assignment' is "
241             "outside of mesh shape (2, 1, 1, 1)"),
242         std::make_tuple(
243             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
244             std::vector<int64_t>{0, 0, 0, -1, 0, 0, 0, 0},
245             "device coordinate (0, 0, 0, -1) in 'device_assignment' "
246             "is outside of mesh shape (2, 1, 1, 1)"),
247         std::make_tuple(
248             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
249             std::vector<int64_t>{0, 0, 0, 1, 0, 0, 0, 0},
250             "device coordinate (0, 0, 0, 1) in 'device_assignment' is "
251             "outside of mesh shape (2, 1, 1, 1)"),
252         std::make_tuple(2, 1,
253                         TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
254                         std::vector<int64_t>{0, 0, 0, 0, 0, 0, 0, 0},
255                         "'device_assignment' has duplicate device coordinate "
256                         "(0, 0, 0, 0)")));
257 
MakeDeviceSet(int num_tasks,int num_devices_per_task)258 std::vector<std::string> MakeDeviceSet(int num_tasks,
259                                        int num_devices_per_task) {
260   std::vector<std::string> devices{
261       "/job:localhost/replica:0/task:0/device:CPU:0"};
262   devices.reserve(num_tasks * num_devices_per_task + num_tasks + 1);
263 
264   for (int task = 0; task < num_tasks; ++task) {
265     devices.push_back(
266         llvm::formatv("/job:worker/replica:0/task:{0}/device:CPU:0", task)
267             .str());
268     devices.push_back(
269         llvm::formatv("/job:worker/replica:0/task:{0}/device:TPU_SYSTEM:0",
270                       task)
271             .str());
272     for (int device = 0; device < num_devices_per_task; ++device)
273       devices.push_back(
274           llvm::formatv("/job:worker/replica:0/task:{0}/device:TPU:{1}", task,
275                         device)
276               .str());
277   }
278 
279   return devices;
280 }
281 
TEST(TPURewriteDeviceUtilTest,BadGeneralDeviceAssignmentMetadataMissingDevice)282 TEST(TPURewriteDeviceUtilTest,
283      BadGeneralDeviceAssignmentMetadataMissingDevice) {
284   tpu::TopologyProto topology_proto;
285   {
286     topology_proto.add_mesh_shape(2);
287     topology_proto.add_mesh_shape(1);
288     topology_proto.add_mesh_shape(1);
289     topology_proto.add_mesh_shape(1);
290     topology_proto.set_num_tasks(1);
291     topology_proto.set_num_tpu_devices_per_task(1);
292     topology_proto.add_device_coordinates(0);
293     topology_proto.add_device_coordinates(0);
294     topology_proto.add_device_coordinates(0);
295     topology_proto.add_device_coordinates(0);
296   }
297 
298   std::string topology_attr = topology_proto.SerializeAsString();
299   std::vector<int64_t> device_assignment_attr{1, 0, 0, 0};
300 
301   llvm::SmallVector<Device, 8> devices;
302   std::vector<std::string> device_names =
303       MakeDeviceSet(/*num_tasks=*/1, /*num_devices_per_task=*/1);
304   ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
305 
306   auto status_or = GetTPUCompilationAndExecutionDevices(
307       devices, /*num_replicas=*/1, /*num_cores_per_replica=*/1, topology_attr,
308       device_assignment_attr);
309 
310   ASSERT_FALSE(status_or.ok());
311   EXPECT_EQ(status_or.status().error_message(),
312             "no TPU device found for 'device_assignment' device coordinate (1, "
313             "0, 0, 0)");
314 }
315 
TEST(TPURewriteDeviceUtilTest,ValidFullMeshDeviceAssignment)316 TEST(TPURewriteDeviceUtilTest, ValidFullMeshDeviceAssignment) {
317   llvm::SmallVector<Device, 8> devices;
318   std::vector<std::string> device_names =
319       MakeDeviceSet(/*num_tasks=*/2, /*num_devices_per_task=*/4);
320   ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
321   std::string topology_attr;
322   std::vector<int64_t> device_assignment_attr;
323 
324   auto status_or = GetTPUCompilationAndExecutionDevices(
325       devices, /*num_replicas=*/8, /*num_cores_per_replica=*/1, topology_attr,
326       device_assignment_attr);
327 
328   TF_ASSERT_OK(status_or.status());
329 
330   const auto& tpu_device_assignment = status_or.ValueOrDie();
331   EXPECT_EQ(tpu_device_assignment.compilation_device,
332             "/job:worker/replica:0/task:0/device:CPU:0");
333   const auto& tpu_devices = tpu_device_assignment.tpu_devices;
334   ASSERT_EQ(tpu_devices.size(), 8);
335   for (const auto& replica_tpu_devices : tpu_devices)
336     ASSERT_EQ(replica_tpu_devices.size(), 1);
337 
338   EXPECT_EQ(tpu_devices[0][0].device,
339             "/job:worker/replica:0/task:0/device:TPU:0");
340   EXPECT_EQ(tpu_devices[0][0].host,
341             "/job:worker/replica:0/task:0/device:CPU:0");
342   EXPECT_EQ(tpu_devices[1][0].device,
343             "/job:worker/replica:0/task:0/device:TPU:1");
344   EXPECT_EQ(tpu_devices[1][0].host,
345             "/job:worker/replica:0/task:0/device:CPU:0");
346   EXPECT_EQ(tpu_devices[2][0].device,
347             "/job:worker/replica:0/task:0/device:TPU:2");
348   EXPECT_EQ(tpu_devices[2][0].host,
349             "/job:worker/replica:0/task:0/device:CPU:0");
350   EXPECT_EQ(tpu_devices[3][0].device,
351             "/job:worker/replica:0/task:0/device:TPU:3");
352   EXPECT_EQ(tpu_devices[3][0].host,
353             "/job:worker/replica:0/task:0/device:CPU:0");
354   EXPECT_EQ(tpu_devices[4][0].device,
355             "/job:worker/replica:0/task:1/device:TPU:0");
356   EXPECT_EQ(tpu_devices[4][0].host,
357             "/job:worker/replica:0/task:1/device:CPU:0");
358   EXPECT_EQ(tpu_devices[5][0].device,
359             "/job:worker/replica:0/task:1/device:TPU:1");
360   EXPECT_EQ(tpu_devices[5][0].host,
361             "/job:worker/replica:0/task:1/device:CPU:0");
362   EXPECT_EQ(tpu_devices[6][0].device,
363             "/job:worker/replica:0/task:1/device:TPU:2");
364   EXPECT_EQ(tpu_devices[6][0].host,
365             "/job:worker/replica:0/task:1/device:CPU:0");
366   EXPECT_EQ(tpu_devices[7][0].device,
367             "/job:worker/replica:0/task:1/device:TPU:3");
368   EXPECT_EQ(tpu_devices[7][0].host,
369             "/job:worker/replica:0/task:1/device:CPU:0");
370 
371   EXPECT_FALSE(tpu_device_assignment.xla_device_assignment.hasValue());
372 }
373 
TEST(TPURewriteDeviceUtilTest,ValidGeneralDeviceAssignmentMesh2x2x2)374 TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh2x2x2) {
375   tpu::TopologyProto topology_proto;
376   {
377     topology_proto.add_mesh_shape(2);
378     topology_proto.add_mesh_shape(2);
379     topology_proto.add_mesh_shape(1);
380     topology_proto.add_mesh_shape(2);
381     topology_proto.set_num_tasks(2);
382     topology_proto.set_num_tpu_devices_per_task(4);
383     topology_proto.add_device_coordinates(0);
384     topology_proto.add_device_coordinates(0);
385     topology_proto.add_device_coordinates(0);
386     topology_proto.add_device_coordinates(0);
387     topology_proto.add_device_coordinates(0);
388     topology_proto.add_device_coordinates(1);
389     topology_proto.add_device_coordinates(0);
390     topology_proto.add_device_coordinates(0);
391     topology_proto.add_device_coordinates(1);
392     topology_proto.add_device_coordinates(1);
393     topology_proto.add_device_coordinates(0);
394     topology_proto.add_device_coordinates(0);
395     topology_proto.add_device_coordinates(1);
396     topology_proto.add_device_coordinates(0);
397     topology_proto.add_device_coordinates(0);
398     topology_proto.add_device_coordinates(0);
399     topology_proto.add_device_coordinates(1);
400     topology_proto.add_device_coordinates(0);
401     topology_proto.add_device_coordinates(0);
402     topology_proto.add_device_coordinates(1);
403     topology_proto.add_device_coordinates(1);
404     topology_proto.add_device_coordinates(1);
405     topology_proto.add_device_coordinates(0);
406     topology_proto.add_device_coordinates(1);
407     topology_proto.add_device_coordinates(0);
408     topology_proto.add_device_coordinates(1);
409     topology_proto.add_device_coordinates(0);
410     topology_proto.add_device_coordinates(1);
411     topology_proto.add_device_coordinates(0);
412     topology_proto.add_device_coordinates(0);
413     topology_proto.add_device_coordinates(0);
414     topology_proto.add_device_coordinates(1);
415   }
416 
417   std::string topology_attr = topology_proto.SerializeAsString();
418   std::vector<int64_t> device_assignment_attr{0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,
419                                               0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0,
420                                               0, 1, 1, 1, 0, 0, 1, 1, 0, 1};
421 
422   llvm::SmallVector<Device, 8> devices;
423   std::vector<std::string> device_names =
424       MakeDeviceSet(/*num_tasks=*/2, /*num_devices_per_task=*/4);
425   ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
426 
427   auto status_or = GetTPUCompilationAndExecutionDevices(
428       devices, /*num_replicas=*/4, /*num_cores_per_replica=*/2, topology_attr,
429       device_assignment_attr);
430 
431   TF_ASSERT_OK(status_or.status());
432 
433   const auto& tpu_device_assignment = status_or.ValueOrDie();
434   EXPECT_EQ(tpu_device_assignment.compilation_device,
435             "/job:worker/replica:0/task:0/device:CPU:0");
436   const auto& tpu_devices = tpu_device_assignment.tpu_devices;
437   ASSERT_EQ(tpu_devices.size(), 4);
438   for (const auto& replica_tpu_devices : tpu_devices)
439     ASSERT_EQ(replica_tpu_devices.size(), 2);
440 
441   EXPECT_EQ(tpu_devices[0][0].device,
442             "/job:worker/replica:0/task:0/device:TPU:0");
443   EXPECT_EQ(tpu_devices[0][0].host,
444             "/job:worker/replica:0/task:0/device:CPU:0");
445   EXPECT_EQ(tpu_devices[0][1].device,
446             "/job:worker/replica:0/task:1/device:TPU:3");
447   EXPECT_EQ(tpu_devices[0][1].host,
448             "/job:worker/replica:0/task:1/device:CPU:0");
449   EXPECT_EQ(tpu_devices[1][0].device,
450             "/job:worker/replica:0/task:0/device:TPU:1");
451   EXPECT_EQ(tpu_devices[1][0].host,
452             "/job:worker/replica:0/task:0/device:CPU:0");
453   EXPECT_EQ(tpu_devices[1][1].device,
454             "/job:worker/replica:0/task:1/device:TPU:2");
455   EXPECT_EQ(tpu_devices[1][1].host,
456             "/job:worker/replica:0/task:1/device:CPU:0");
457   EXPECT_EQ(tpu_devices[2][0].device,
458             "/job:worker/replica:0/task:0/device:TPU:3");
459   EXPECT_EQ(tpu_devices[2][0].host,
460             "/job:worker/replica:0/task:0/device:CPU:0");
461   EXPECT_EQ(tpu_devices[2][1].device,
462             "/job:worker/replica:0/task:1/device:TPU:0");
463   EXPECT_EQ(tpu_devices[2][1].host,
464             "/job:worker/replica:0/task:1/device:CPU:0");
465   EXPECT_EQ(tpu_devices[3][0].device,
466             "/job:worker/replica:0/task:0/device:TPU:2");
467   EXPECT_EQ(tpu_devices[3][0].host,
468             "/job:worker/replica:0/task:0/device:CPU:0");
469   EXPECT_EQ(tpu_devices[3][1].device,
470             "/job:worker/replica:0/task:1/device:TPU:1");
471   EXPECT_EQ(tpu_devices[3][1].host,
472             "/job:worker/replica:0/task:1/device:CPU:0");
473 
474   auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
475   ASSERT_TRUE(xla_device_assignment.hasValue());
476   EXPECT_EQ(xla_device_assignment->replica_count(), 4);
477   EXPECT_EQ(xla_device_assignment->computation_count(), 2);
478   ASSERT_EQ(xla_device_assignment->computation_devices_size(), 2);
479   const auto& computation_device_0 =
480       xla_device_assignment->computation_devices(0);
481   ASSERT_EQ(computation_device_0.replica_device_ids_size(), 4);
482   const auto& computation_device_1 =
483       xla_device_assignment->computation_devices(1);
484   ASSERT_EQ(computation_device_1.replica_device_ids_size(), 4);
485 
486   EXPECT_EQ(computation_device_0.replica_device_ids(0), 0);
487   EXPECT_EQ(computation_device_0.replica_device_ids(1), 4);
488   EXPECT_EQ(computation_device_0.replica_device_ids(2), 2);
489   EXPECT_EQ(computation_device_0.replica_device_ids(3), 6);
490   EXPECT_EQ(computation_device_1.replica_device_ids(0), 1);
491   EXPECT_EQ(computation_device_1.replica_device_ids(1), 5);
492   EXPECT_EQ(computation_device_1.replica_device_ids(2), 3);
493   EXPECT_EQ(computation_device_1.replica_device_ids(3), 7);
494 }
495 
TEST(TPURewriteDeviceUtilTest,ValidGeneralDeviceAssignmentMesh1x2x1x3)496 TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) {
497   tpu::TopologyProto topology_proto;
498   {
499     topology_proto.add_mesh_shape(1);
500     topology_proto.add_mesh_shape(2);
501     topology_proto.add_mesh_shape(1);
502     topology_proto.add_mesh_shape(3);
503     topology_proto.set_num_tasks(3);
504     topology_proto.set_num_tpu_devices_per_task(2);
505     topology_proto.add_device_coordinates(0);
506     topology_proto.add_device_coordinates(0);
507     topology_proto.add_device_coordinates(0);
508     topology_proto.add_device_coordinates(0);
509     topology_proto.add_device_coordinates(0);
510     topology_proto.add_device_coordinates(1);
511     topology_proto.add_device_coordinates(0);
512     topology_proto.add_device_coordinates(0);
513     topology_proto.add_device_coordinates(0);
514     topology_proto.add_device_coordinates(1);
515     topology_proto.add_device_coordinates(0);
516     topology_proto.add_device_coordinates(1);
517     topology_proto.add_device_coordinates(0);
518     topology_proto.add_device_coordinates(0);
519     topology_proto.add_device_coordinates(0);
520     topology_proto.add_device_coordinates(1);
521     topology_proto.add_device_coordinates(0);
522     topology_proto.add_device_coordinates(0);
523     topology_proto.add_device_coordinates(0);
524     topology_proto.add_device_coordinates(2);
525     topology_proto.add_device_coordinates(0);
526     topology_proto.add_device_coordinates(1);
527     topology_proto.add_device_coordinates(0);
528     topology_proto.add_device_coordinates(2);
529   }
530 
531   std::string topology_attr = topology_proto.SerializeAsString();
532   std::vector<int64_t> device_assignment_attr{
533       0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 2, 0, 1, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0};
534 
535   llvm::SmallVector<Device, 8> devices;
536   std::vector<std::string> device_names =
537       MakeDeviceSet(/*num_tasks=*/3, /*num_devices_per_task=*/2);
538   ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
539 
540   auto status_or = GetTPUCompilationAndExecutionDevices(
541       devices, /*num_replicas=*/2, /*num_cores_per_replica=*/3, topology_attr,
542       device_assignment_attr);
543 
544   TF_ASSERT_OK(status_or.status());
545 
546   auto& tpu_device_assignment = status_or.ValueOrDie();
547   EXPECT_EQ(tpu_device_assignment.compilation_device,
548             "/job:worker/replica:0/task:0/device:CPU:0");
549 
550   auto& tpu_devices = tpu_device_assignment.tpu_devices;
551   ASSERT_EQ(tpu_devices.size(), 2);
552   for (const auto& replica_tpu_devices : tpu_devices)
553     ASSERT_EQ(replica_tpu_devices.size(), 3);
554 
555   EXPECT_EQ(tpu_devices[0][0].device,
556             "/job:worker/replica:0/task:1/device:TPU:1");
557   EXPECT_EQ(tpu_devices[0][0].host,
558             "/job:worker/replica:0/task:1/device:CPU:0");
559   EXPECT_EQ(tpu_devices[0][1].device,
560             "/job:worker/replica:0/task:1/device:TPU:0");
561   EXPECT_EQ(tpu_devices[0][1].host,
562             "/job:worker/replica:0/task:1/device:CPU:0");
563   EXPECT_EQ(tpu_devices[0][2].device,
564             "/job:worker/replica:0/task:2/device:TPU:0");
565   EXPECT_EQ(tpu_devices[0][2].host,
566             "/job:worker/replica:0/task:2/device:CPU:0");
567   EXPECT_EQ(tpu_devices[1][0].device,
568             "/job:worker/replica:0/task:2/device:TPU:1");
569   EXPECT_EQ(tpu_devices[1][0].host,
570             "/job:worker/replica:0/task:2/device:CPU:0");
571   EXPECT_EQ(tpu_devices[1][1].device,
572             "/job:worker/replica:0/task:0/device:TPU:0");
573   EXPECT_EQ(tpu_devices[1][1].host,
574             "/job:worker/replica:0/task:0/device:CPU:0");
575   EXPECT_EQ(tpu_devices[1][2].device,
576             "/job:worker/replica:0/task:0/device:TPU:1");
577   EXPECT_EQ(tpu_devices[1][2].host,
578             "/job:worker/replica:0/task:0/device:CPU:0");
579 
580   auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
581   ASSERT_TRUE(xla_device_assignment.hasValue());
582   EXPECT_EQ(xla_device_assignment->replica_count(), 2);
583   EXPECT_EQ(xla_device_assignment->computation_count(), 3);
584   ASSERT_EQ(xla_device_assignment->computation_devices_size(), 3);
585   const auto& computation_device_0 =
586       xla_device_assignment->computation_devices(0);
587   ASSERT_EQ(computation_device_0.replica_device_ids_size(), 2);
588   const auto& computation_device_1 =
589       xla_device_assignment->computation_devices(1);
590   ASSERT_EQ(computation_device_1.replica_device_ids_size(), 2);
591   const auto& computation_device_2 =
592       xla_device_assignment->computation_devices(2);
593   ASSERT_EQ(computation_device_2.replica_device_ids_size(), 2);
594 
595   EXPECT_EQ(computation_device_0.replica_device_ids(0), 1);
596   EXPECT_EQ(computation_device_0.replica_device_ids(1), 5);
597   EXPECT_EQ(computation_device_1.replica_device_ids(0), 4);
598   EXPECT_EQ(computation_device_1.replica_device_ids(1), 0);
599   EXPECT_EQ(computation_device_2.replica_device_ids(0), 2);
600   EXPECT_EQ(computation_device_2.replica_device_ids(1), 3);
601 }
602 
TEST(TPURewriteDeviceUtilTest,TestGetDeviceCoordinates)603 TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) {
604   mlir::MLIRContext context;
605   mlir::Builder builder(&context);
606   auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3});
607   auto status_or_device_coodinates =
608       GetDeviceCoordinates(device_assignment_attr);
609   ASSERT_TRUE(status_or_device_coodinates.ok());
610   auto device_coordinates = status_or_device_coodinates.ConsumeValueOrDie();
611   EXPECT_EQ(device_coordinates[0], 1);
612   EXPECT_EQ(device_coordinates[1], 2);
613   EXPECT_EQ(device_coordinates[2], 3);
614 }
615 
TEST(TPURewriteDeviceUtilTest,TestInvalidAttrForDeviceAssignmentDisallowed)616 TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
617   mlir::MLIRContext context;
618   mlir::Builder builder(&context);
619   auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0});
620   auto status_or_device_coodinates =
621       GetDeviceCoordinates(device_assignment_attr);
622   ASSERT_TRUE(!status_or_device_coodinates.ok());
623   EXPECT_EQ(status_or_device_coodinates.status().error_message(),
624             "bad 'device_assignment' attribute at index 0, not an int");
625 }
626 
TEST(TPURewriteDeviceUtilTest,TestGetHostFailDeviceMissingAttributes)627 TEST(TPURewriteDeviceUtilTest, TestGetHostFailDeviceMissingAttributes) {
628   mlir::MLIRContext context;
629   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
630   mlir::OwningModuleRef module_ref =
631       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
632   mlir::OpBuilder builder(module_ref->getBodyRegion());
633   llvm::SmallVector<mlir::Type, 8> result_types;
634   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
635       mlir::UnknownLoc::get(&context), result_types);
636 
637   mlir::TF::RuntimeDevices devices;
638   std::string host_device;
639   EXPECT_TRUE(mlir::failed(
640       GetHostDeviceOutsideComputation(devices, cluster, &host_device)));
641 }
642 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailModelParallelism)643 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailModelParallelism) {
644   mlir::MLIRContext context;
645   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
646   mlir::OwningModuleRef module_ref =
647       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
648   mlir::OpBuilder builder(module_ref->getBodyRegion());
649 
650   llvm::SmallVector<mlir::Type, 8> result_types;
651   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
652       mlir::UnknownLoc::get(&context), result_types);
653   cluster->setAttr(kNumCoresPerReplicaAttr,
654                    builder.getIntegerAttr(builder.getIntegerType(64), 5));
655   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
656   cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
657 
658   mlir::TF::RuntimeDevices runtime_devices;
659   std::string host_device;
660   EXPECT_TRUE(mlir::failed(
661       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
662 }
663 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailMissingTopology)664 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingTopology) {
665   mlir::MLIRContext context;
666   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
667   mlir::OwningModuleRef module_ref =
668       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
669   mlir::OpBuilder builder(module_ref->getBodyRegion());
670 
671   llvm::SmallVector<mlir::Type, 8> result_types;
672   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
673       mlir::UnknownLoc::get(&context), result_types);
674   cluster->setAttr(kNumCoresPerReplicaAttr,
675                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
676   cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
677 
678   mlir::TF::RuntimeDevices runtime_devices;
679   std::string host_device;
680   EXPECT_TRUE(mlir::failed(
681       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
682 }
683 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailMissingDeviceAssignment)684 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingDeviceAssignment) {
685   mlir::MLIRContext context;
686   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
687   mlir::OwningModuleRef module_ref =
688       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
689   mlir::OpBuilder builder(module_ref->getBodyRegion());
690 
691   llvm::SmallVector<mlir::Type, 8> result_types;
692   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
693       mlir::UnknownLoc::get(&context), result_types);
694   cluster->setAttr(kNumCoresPerReplicaAttr,
695                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
696   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
697 
698   mlir::TF::RuntimeDevices runtime_devices;
699   std::string host_device;
700   EXPECT_TRUE(mlir::failed(
701       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
702 }
703 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailBadDeviceAssignment)704 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceAssignment) {
705   mlir::MLIRContext context;
706   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
707   mlir::OwningModuleRef module_ref =
708       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
709   mlir::OpBuilder builder(module_ref->getBodyRegion());
710 
711   llvm::SmallVector<mlir::Type, 8> result_types;
712   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
713       mlir::UnknownLoc::get(&context), result_types);
714   cluster->setAttr(kNumCoresPerReplicaAttr,
715                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
716   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
717   cluster->setAttr(kDeviceAssignmentAttr,
718                    builder.getStrArrayAttr(llvm::ArrayRef<llvm::StringRef>(
719                        {"bad_device_assigment"})));
720 
721   mlir::TF::RuntimeDevices runtime_devices;
722   std::string host_device;
723   EXPECT_TRUE(mlir::failed(
724       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
725 }
726 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailBadDeviceName)727 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceName) {
728   mlir::MLIRContext context;
729   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
730   mlir::OwningModuleRef module_ref =
731       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
732   mlir::OpBuilder builder(module_ref->getBodyRegion());
733   (*module_ref)
734       ->setAttr("tf.devices",
735                 builder.getStrArrayAttr(
736                     llvm::ArrayRef<llvm::StringRef>({"bad_device_name"})));
737 
738   llvm::SmallVector<mlir::Type, 8> result_types;
739   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
740       mlir::UnknownLoc::get(&context), result_types);
741   cluster->setAttr(kNumCoresPerReplicaAttr,
742                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
743   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
744   cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
745 
746   mlir::TF::RuntimeDevices runtime_devices;
747   (void)GetDevicesFromOp(*module_ref, &runtime_devices);
748   std::string host_device;
749   EXPECT_TRUE(mlir::failed(
750       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
751 }
752 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceTPUReplicate)753 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) {
754   mlir::MLIRContext context;
755   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
756   mlir::OwningModuleRef module_ref =
757       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
758   mlir::OpBuilder builder(module_ref->getBodyRegion());
759 
760   llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<llvm::StringRef, 4>>
761       devices;
762   auto replicate = builder.create<mlir::tf_device::ReplicateOp>(
763       mlir::UnknownLoc::get(&context), /*num_replicas=*/2, devices,
764       llvm::ArrayRef<std::pair<mlir::ValueRange, mlir::Type>>{},
765       mlir::ValueRange{}, mlir::TypeRange{});
766   builder.setInsertionPoint(&replicate.body().front(),
767                             replicate.body().front().begin());
768 
769   llvm::SmallVector<mlir::Type, 8> result_types;
770   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
771       mlir::UnknownLoc::get(&context), result_types);
772 
773   mlir::TF::RuntimeDevices runtime_devices;
774   std::string host_device;
775   EXPECT_TRUE(mlir::succeeded(
776       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
777   EXPECT_EQ(host_device, kTPUReplicatedHost);
778 }
779 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceNotReplicated)780 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceNotReplicated) {
781   mlir::MLIRContext context;
782   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
783   mlir::OwningModuleRef module_ref =
784       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
785   mlir::OpBuilder builder(module_ref->getBodyRegion());
786   (*module_ref)
787       ->setAttr("tf.devices",
788                 builder.getStrArrayAttr(llvm::ArrayRef<llvm::StringRef>(
789                     {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
790                      "/job:localhost/replica:0/task:0/device:TPU:0",
791                      "/job:worker/replica:0/task:0/device:CPU:0"})));
792 
793   llvm::SmallVector<mlir::Type, 8> result_types;
794   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
795       mlir::UnknownLoc::get(&context), result_types);
796   cluster->setAttr(kNumCoresPerReplicaAttr,
797                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
798   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
799   cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
800 
801   mlir::TF::RuntimeDevices runtime_devices;
802   (void)GetDevicesFromOp(*module_ref, &runtime_devices);
803   std::string host_device;
804   EXPECT_TRUE(mlir::succeeded(
805       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
806   EXPECT_EQ(host_device, "/job:localhost/replica:0/task:0/device:CPU:0");
807 }
808 
TEST(TPURewriteDeviceUtilTest,TestIsTPUDevice)809 TEST(TPURewriteDeviceUtilTest, TestIsTPUDevice) {
810   EXPECT_TRUE(IsTPUDevice("/job:localhost/replica:0/task:0/device:TPU:0"));
811   EXPECT_FALSE(IsTPUDevice("/job:localhost/replica:0/task:0/device:CPU:0"));
812   EXPECT_FALSE(IsTPUDevice("INVALID_DEVICE"));
813 }
814 
815 }  // anonymous namespace
816 }  // namespace tensorflow
817