• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
17 
18 #include <cstdint>
19 #include <tuple>
20 
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
29 #include "tensorflow/core/util/device_name_utils.h"
30 
31 namespace tensorflow {
32 namespace {
33 
34 using Device = DeviceNameUtils::ParsedName;
35 
DeviceNamesToParsedNames(llvm::ArrayRef<std::string> device_names,llvm::SmallVectorImpl<Device> * parsed_devices)36 bool DeviceNamesToParsedNames(llvm::ArrayRef<std::string> device_names,
37                               llvm::SmallVectorImpl<Device>* parsed_devices) {
38   parsed_devices->reserve(device_names.size());
39   for (const auto& device_name : device_names) {
40     Device parsed_name;
41     if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name))
42       return false;
43 
44     parsed_devices->push_back(parsed_name);
45   }
46   return true;
47 }
48 
49 using DeviceNames = llvm::SmallVector<std::string, 8>;
50 
51 struct ParameterizedDeviceSetTest
52     : ::testing::TestWithParam<std::tuple<DeviceNames, std::string>> {};
53 
TEST_P(ParameterizedDeviceSetTest,BadDeviceSet)54 TEST_P(ParameterizedDeviceSetTest, BadDeviceSet) {
55   llvm::SmallVector<Device, 8> devices;
56   ASSERT_TRUE(DeviceNamesToParsedNames(std::get<0>(GetParam()), &devices));
57   std::string topology_attr;
58   std::vector<int64_t> device_assignment_attr;
59 
60   auto status_or = GetTPUCompilationAndExecutionDevices(
61       devices, /*num_replicas=*/1, /*num_cores_per_replica=*/1, topology_attr,
62       device_assignment_attr);
63   ASSERT_FALSE(status_or.ok());
64   EXPECT_EQ(status_or.status().error_message(), std::get<1>(GetParam()));
65 }
66 
67 INSTANTIATE_TEST_SUITE_P(
68     BadDeviceSet, ParameterizedDeviceSetTest,
69     ::testing::Values(
70         std::make_tuple<DeviceNames, std::string>(
71             {"/job:localhost/replica:0/task:0/device:CPU:0"},
72             "no TPU_SYSTEM devices found"),
73         std::make_tuple<DeviceNames, std::string>(
74             {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
75              "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0"},
76             "found TPU_SYSTEM devices with conflicting jobs 'localhost' and "
77             "'worker'"),
78         std::make_tuple<DeviceNames, std::string>(
79             {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
80              "/job:localhost/replica:1/task:0/device:TPU_SYSTEM:0"},
81             "found TPU_SYSTEM devices with conflicting replicas '0' and '1'"),
82         std::make_tuple<DeviceNames, std::string>(
83             {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
84              "/job:localhost/replica:0/task:0/device:TPU:0",
85              "/job:localhost/replica:0/task:0/device:TPU:1",
86              "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0",
87              "/job:localhost/replica:0/task:1/device:TPU:0"},
88             "expected the number of TPU devices per host to be 2, got 1")));
89 
90 struct ParameterizedMetadataTest
91     : ::testing::TestWithParam<std::tuple<int, int, std::string,
92                                           std::vector<int64_t>, std::string>> {
93 };
94 
TEST_P(ParameterizedMetadataTest,BadMetadata)95 TEST_P(ParameterizedMetadataTest, BadMetadata) {
96   llvm::SmallVector<Device, 8> devices;
97   ASSERT_TRUE(DeviceNamesToParsedNames(
98       {"/job:worker/replica:0/task:0/device:TPU_SYSTEM:0",
99        "/job:worker/replica:0/task:0/device:TPU:0",
100        "/job:worker/replica:0/task:1/device:TPU_SYSTEM:0",
101        "/job:worker/replica:0/task:1/device:TPU:0"},
102       &devices));
103   std::string compilation_device;
104   llvm::SmallVector<llvm::SmallVector<std::string, 8>, 8> execution_devices;
105   llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment;
106 
107   auto status_or = GetTPUCompilationAndExecutionDevices(
108       devices, std::get<0>(GetParam()), std::get<1>(GetParam()),
109       std::get<2>(GetParam()), std::get<3>(GetParam()));
110   ASSERT_FALSE(status_or.ok());
111   EXPECT_EQ(status_or.status().error_message(), std::get<4>(GetParam()));
112 }
113 
TopologyWithMeshShape(llvm::ArrayRef<int> mesh_shape)114 std::string TopologyWithMeshShape(llvm::ArrayRef<int> mesh_shape) {
115   tpu::TopologyProto topology_proto;
116   for (int mesh_dim : mesh_shape) topology_proto.add_mesh_shape(mesh_dim);
117   return topology_proto.SerializeAsString();
118 }
119 
TopologyWithMeshShapeAndTasks(llvm::ArrayRef<int> mesh_shape,int num_tasks,int num_tpu_devices_per_task)120 std::string TopologyWithMeshShapeAndTasks(llvm::ArrayRef<int> mesh_shape,
121                                           int num_tasks,
122                                           int num_tpu_devices_per_task) {
123   tpu::TopologyProto topology_proto;
124   for (int mesh_dim : mesh_shape) topology_proto.add_mesh_shape(mesh_dim);
125   topology_proto.set_num_tasks(num_tasks);
126   topology_proto.set_num_tpu_devices_per_task(num_tpu_devices_per_task);
127   return topology_proto.SerializeAsString();
128 }
129 
TopologyWithDeviceCoordinates(llvm::ArrayRef<int> device_coordinates)130 std::string TopologyWithDeviceCoordinates(
131     llvm::ArrayRef<int> device_coordinates) {
132   tpu::TopologyProto topology_proto;
133   topology_proto.add_mesh_shape(2);
134   topology_proto.add_mesh_shape(1);
135   topology_proto.add_mesh_shape(1);
136   topology_proto.add_mesh_shape(1);
137   topology_proto.set_num_tasks(2);
138   topology_proto.set_num_tpu_devices_per_task(1);
139   for (int device_coordinate : device_coordinates)
140     topology_proto.add_device_coordinates(device_coordinate);
141   return topology_proto.SerializeAsString();
142 }
143 
144 INSTANTIATE_TEST_SUITE_P(
145     BadFullMeshMetadata, ParameterizedMetadataTest,
146     ::testing::Values(
147         std::make_tuple(
148             2, 1, "", std::vector<int64_t>{0},
149             "'device_assignment' must not be set when 'topology' is not set"),
150         std::make_tuple(8, 1, "", std::vector<int64_t>(),
151                         "'num_replicas' must be equal to 1 or 2, got 8"),
152         std::make_tuple(2, 2, "", std::vector<int64_t>(),
153                         "'num_cores_per_replica' must be equal to 1, got 2")));
154 
155 INSTANTIATE_TEST_SUITE_P(
156     BadGeneralTopologyMetadata, ParameterizedMetadataTest,
157     ::testing::Values(
158         std::make_tuple(
159             2, 1, "BAD_TOPOLOGY", std::vector<int64_t>(),
160             "failed to parse 'topology' attribute to TopologyProto"),
161         std::make_tuple(4, 2, TopologyWithMeshShape({0}),
162                         std::vector<int64_t>(),
163                         "'topology' 'mesh_shape' must be rank 4, got rank 1"),
164         std::make_tuple(
165             2, 1, TopologyWithMeshShape({2, 0, 1, 2}), std::vector<int64_t>(),
166             "'topology' 'mesh_shape' dimension 1 must be positive, got 0"),
167         std::make_tuple(2, 1, TopologyWithMeshShapeAndTasks({1, 1, 1, 1}, 1, 1),
168                         std::vector<int64_t>(),
169                         "number of tasks from available TPU devices must be "
170                         "'num_tasks' in 'topology' (1), got 2"),
171         std::make_tuple(2, 1, TopologyWithMeshShapeAndTasks({1, 1, 1, 1}, 2, 2),
172                         std::vector<int64_t>(),
173                         "number of TPU devices available per task must be "
174                         "'num_tpu_devices_per_task' in 'topology' (2), got 1"),
175         std::make_tuple(
176             2, 1, TopologyWithDeviceCoordinates({}), std::vector<int64_t>(),
177             "length of 'device_coordinates' in 'topology' must be 'num_tasks' "
178             "* 'num_tpus_per_task' * 4 (2 * 1 * 4), got 0"),
179         std::make_tuple(
180             2, 1, TopologyWithDeviceCoordinates({-1, 0, 0, 0, 1, 0, 0, 0}),
181             std::vector<int64_t>(),
182             "device coordinate (-1, 0, 0, 0) in 'topology' is outside "
183             "of mesh shape (2, 1, 1, 1)"),
184         std::make_tuple(
185             2, 1, TopologyWithDeviceCoordinates({2, 0, 0, 0, 1, 0, 0, 0}),
186             std::vector<int64_t>(),
187             "device coordinate (2, 0, 0, 0) in 'topology' is outside "
188             "of mesh shape (2, 1, 1, 1)"),
189         std::make_tuple(
190             2, 1, TopologyWithDeviceCoordinates({0, -1, 0, 0, 1, 0, 0, 0}),
191             std::vector<int64_t>(),
192             "device coordinate (0, -1, 0, 0) in 'topology' is outside "
193             "of mesh shape (2, 1, 1, 1)"),
194         std::make_tuple(
195             2, 1, TopologyWithDeviceCoordinates({0, 1, 0, 0, 1, 0, 0, 0}),
196             std::vector<int64_t>(),
197             "device coordinate (0, 1, 0, 0) in 'topology' is outside "
198             "of mesh shape (2, 1, 1, 1)"),
199         std::make_tuple(
200             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, -1, 1, 0, 0, 0}),
201             std::vector<int64_t>(),
202             "device coordinate (0, 0, 0, -1) in 'topology' is outside "
203             "of mesh shape (2, 1, 1, 1)"),
204         std::make_tuple(
205             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 1, 1, 0, 0, 0}),
206             std::vector<int64_t>(),
207             "device coordinate (0, 0, 0, 1) in 'topology' is outside "
208             "of mesh shape (2, 1, 1, 1)"),
209         std::make_tuple(
210             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 0, 0, 0, 0}),
211             std::vector<int64_t>(),
212             "'topology' has duplicate device coordinate (0, 0, 0, 0)")));
213 
214 INSTANTIATE_TEST_SUITE_P(
215     BadGeneralDeviceAssignmentMetadata, ParameterizedMetadataTest,
216     ::testing::Values(
217         std::make_tuple(2, 1,
218                         TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
219                         std::vector<int64_t>(),
220                         "length of 'device_assignment' must be 'num_replicas' "
221                         "* 'num_cores_per_replica' * 4 (2 * 1 * 4), got 0"),
222         std::make_tuple(
223             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
224             std::vector<int64_t>{-1, 0, 0, 0, 0, 0, 0, 0},
225             "device coordinate (-1, 0, 0, 0) in 'device_assignment' "
226             "is outside of mesh shape (2, 1, 1, 1)"),
227         std::make_tuple(
228             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
229             std::vector<int64_t>{2, 0, 0, 0, 0, 0, 0, 0},
230             "device coordinate (2, 0, 0, 0) in 'device_assignment' is "
231             "outside of mesh shape (2, 1, 1, 1)"),
232         std::make_tuple(
233             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
234             std::vector<int64_t>{0, -1, 0, 0, 0, 0, 0, 0},
235             "device coordinate (0, -1, 0, 0) in 'device_assignment' "
236             "is outside of mesh shape (2, 1, 1, 1)"),
237         std::make_tuple(
238             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
239             std::vector<int64_t>{0, 1, 0, 0, 0, 0, 0, 0},
240             "device coordinate (0, 1, 0, 0) in 'device_assignment' is "
241             "outside of mesh shape (2, 1, 1, 1)"),
242         std::make_tuple(
243             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
244             std::vector<int64_t>{0, 0, 0, -1, 0, 0, 0, 0},
245             "device coordinate (0, 0, 0, -1) in 'device_assignment' "
246             "is outside of mesh shape (2, 1, 1, 1)"),
247         std::make_tuple(
248             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
249             std::vector<int64_t>{0, 0, 0, 1, 0, 0, 0, 0},
250             "device coordinate (0, 0, 0, 1) in 'device_assignment' is "
251             "outside of mesh shape (2, 1, 1, 1)"),
252         std::make_tuple(2, 1,
253                         TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
254                         std::vector<int64_t>{0, 0, 0, 0, 0, 0, 0, 0},
255                         "'device_assignment' has duplicate device coordinate "
256                         "(0, 0, 0, 0)")));
257 
MakeDeviceSet(int num_tasks,int num_devices_per_task)258 std::vector<std::string> MakeDeviceSet(int num_tasks,
259                                        int num_devices_per_task) {
260   std::vector<std::string> devices{
261       "/job:localhost/replica:0/task:0/device:CPU:0"};
262   devices.reserve(num_tasks * num_devices_per_task + num_tasks + 1);
263 
264   for (int task = 0; task < num_tasks; ++task) {
265     devices.push_back(
266         llvm::formatv("/job:worker/replica:0/task:{0}/device:CPU:0", task)
267             .str());
268     devices.push_back(
269         llvm::formatv("/job:worker/replica:0/task:{0}/device:TPU_SYSTEM:0",
270                       task)
271             .str());
272     for (int device = 0; device < num_devices_per_task; ++device)
273       devices.push_back(
274           llvm::formatv("/job:worker/replica:0/task:{0}/device:TPU:{1}", task,
275                         device)
276               .str());
277   }
278 
279   return devices;
280 }
281 
TEST(TPURewriteDeviceUtilTest,BadGeneralDeviceAssignmentMetadataMissingDevice)282 TEST(TPURewriteDeviceUtilTest,
283      BadGeneralDeviceAssignmentMetadataMissingDevice) {
284   tpu::TopologyProto topology_proto;
285   {
286     topology_proto.add_mesh_shape(2);
287     topology_proto.add_mesh_shape(1);
288     topology_proto.add_mesh_shape(1);
289     topology_proto.add_mesh_shape(1);
290     topology_proto.set_num_tasks(1);
291     topology_proto.set_num_tpu_devices_per_task(1);
292     topology_proto.add_device_coordinates(0);
293     topology_proto.add_device_coordinates(0);
294     topology_proto.add_device_coordinates(0);
295     topology_proto.add_device_coordinates(0);
296   }
297 
298   std::string topology_attr = topology_proto.SerializeAsString();
299   std::vector<int64_t> device_assignment_attr{1, 0, 0, 0};
300 
301   llvm::SmallVector<Device, 8> devices;
302   std::vector<std::string> device_names =
303       MakeDeviceSet(/*num_tasks=*/1, /*num_devices_per_task=*/1);
304   ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
305 
306   auto status_or = GetTPUCompilationAndExecutionDevices(
307       devices, /*num_replicas=*/1, /*num_cores_per_replica=*/1, topology_attr,
308       device_assignment_attr);
309 
310   ASSERT_FALSE(status_or.ok());
311   EXPECT_EQ(status_or.status().error_message(),
312             "no TPU device found for 'device_assignment' device coordinate (1, "
313             "0, 0, 0)");
314 }
315 
TEST(TPURewriteDeviceUtilTest,ValidFullMeshDeviceAssignment)316 TEST(TPURewriteDeviceUtilTest, ValidFullMeshDeviceAssignment) {
317   llvm::SmallVector<Device, 8> devices;
318   std::vector<std::string> device_names =
319       MakeDeviceSet(/*num_tasks=*/2, /*num_devices_per_task=*/4);
320   ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
321   std::string topology_attr;
322   std::vector<int64_t> device_assignment_attr;
323 
324   auto status_or = GetTPUCompilationAndExecutionDevices(
325       devices, /*num_replicas=*/8, /*num_cores_per_replica=*/1, topology_attr,
326       device_assignment_attr);
327 
328   TF_ASSERT_OK(status_or.status());
329 
330   const auto& tpu_device_assignment = status_or.ValueOrDie();
331   EXPECT_EQ(tpu_device_assignment.compilation_device,
332             "/job:worker/replica:0/task:0/device:CPU:0");
333   const auto& tpu_devices = tpu_device_assignment.tpu_devices;
334   ASSERT_EQ(tpu_devices.size(), 8);
335   for (const auto& replica_tpu_devices : tpu_devices)
336     ASSERT_EQ(replica_tpu_devices.size(), 1);
337 
338   EXPECT_EQ(tpu_devices[0][0].device,
339             "/job:worker/replica:0/task:0/device:TPU:0");
340   EXPECT_EQ(tpu_devices[0][0].host,
341             "/job:worker/replica:0/task:0/device:CPU:0");
342   EXPECT_EQ(tpu_devices[1][0].device,
343             "/job:worker/replica:0/task:0/device:TPU:1");
344   EXPECT_EQ(tpu_devices[1][0].host,
345             "/job:worker/replica:0/task:0/device:CPU:0");
346   EXPECT_EQ(tpu_devices[2][0].device,
347             "/job:worker/replica:0/task:0/device:TPU:2");
348   EXPECT_EQ(tpu_devices[2][0].host,
349             "/job:worker/replica:0/task:0/device:CPU:0");
350   EXPECT_EQ(tpu_devices[3][0].device,
351             "/job:worker/replica:0/task:0/device:TPU:3");
352   EXPECT_EQ(tpu_devices[3][0].host,
353             "/job:worker/replica:0/task:0/device:CPU:0");
354   EXPECT_EQ(tpu_devices[4][0].device,
355             "/job:worker/replica:0/task:1/device:TPU:0");
356   EXPECT_EQ(tpu_devices[4][0].host,
357             "/job:worker/replica:0/task:1/device:CPU:0");
358   EXPECT_EQ(tpu_devices[5][0].device,
359             "/job:worker/replica:0/task:1/device:TPU:1");
360   EXPECT_EQ(tpu_devices[5][0].host,
361             "/job:worker/replica:0/task:1/device:CPU:0");
362   EXPECT_EQ(tpu_devices[6][0].device,
363             "/job:worker/replica:0/task:1/device:TPU:2");
364   EXPECT_EQ(tpu_devices[6][0].host,
365             "/job:worker/replica:0/task:1/device:CPU:0");
366   EXPECT_EQ(tpu_devices[7][0].device,
367             "/job:worker/replica:0/task:1/device:TPU:3");
368   EXPECT_EQ(tpu_devices[7][0].host,
369             "/job:worker/replica:0/task:1/device:CPU:0");
370 
371   EXPECT_FALSE(tpu_device_assignment.xla_device_assignment.hasValue());
372 }
373 
TEST(TPURewriteDeviceUtilTest,ValidGeneralDeviceAssignmentMesh2x2x2)374 TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh2x2x2) {
375   tpu::TopologyProto topology_proto;
376   {
377     topology_proto.add_mesh_shape(2);
378     topology_proto.add_mesh_shape(2);
379     topology_proto.add_mesh_shape(1);
380     topology_proto.add_mesh_shape(2);
381     topology_proto.set_num_tasks(2);
382     topology_proto.set_num_tpu_devices_per_task(4);
383     topology_proto.add_device_coordinates(0);
384     topology_proto.add_device_coordinates(0);
385     topology_proto.add_device_coordinates(0);
386     topology_proto.add_device_coordinates(0);
387     topology_proto.add_device_coordinates(0);
388     topology_proto.add_device_coordinates(1);
389     topology_proto.add_device_coordinates(0);
390     topology_proto.add_device_coordinates(0);
391     topology_proto.add_device_coordinates(1);
392     topology_proto.add_device_coordinates(1);
393     topology_proto.add_device_coordinates(0);
394     topology_proto.add_device_coordinates(0);
395     topology_proto.add_device_coordinates(1);
396     topology_proto.add_device_coordinates(0);
397     topology_proto.add_device_coordinates(0);
398     topology_proto.add_device_coordinates(0);
399     topology_proto.add_device_coordinates(1);
400     topology_proto.add_device_coordinates(0);
401     topology_proto.add_device_coordinates(0);
402     topology_proto.add_device_coordinates(1);
403     topology_proto.add_device_coordinates(1);
404     topology_proto.add_device_coordinates(1);
405     topology_proto.add_device_coordinates(0);
406     topology_proto.add_device_coordinates(1);
407     topology_proto.add_device_coordinates(0);
408     topology_proto.add_device_coordinates(1);
409     topology_proto.add_device_coordinates(0);
410     topology_proto.add_device_coordinates(1);
411     topology_proto.add_device_coordinates(0);
412     topology_proto.add_device_coordinates(0);
413     topology_proto.add_device_coordinates(0);
414     topology_proto.add_device_coordinates(1);
415   }
416 
417   std::string topology_attr = topology_proto.SerializeAsString();
418   std::vector<int64_t> device_assignment_attr{0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,
419                                               0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0,
420                                               0, 1, 1, 1, 0, 0, 1, 1, 0, 1};
421 
422   llvm::SmallVector<Device, 8> devices;
423   std::vector<std::string> device_names =
424       MakeDeviceSet(/*num_tasks=*/2, /*num_devices_per_task=*/4);
425   ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
426 
427   auto status_or = GetTPUCompilationAndExecutionDevices(
428       devices, /*num_replicas=*/4, /*num_cores_per_replica=*/2, topology_attr,
429       device_assignment_attr);
430 
431   TF_ASSERT_OK(status_or.status());
432 
433   const auto& tpu_device_assignment = status_or.ValueOrDie();
434   EXPECT_EQ(tpu_device_assignment.compilation_device,
435             "/job:worker/replica:0/task:0/device:CPU:0");
436   const auto& tpu_devices = tpu_device_assignment.tpu_devices;
437   ASSERT_EQ(tpu_devices.size(), 4);
438   for (const auto& replica_tpu_devices : tpu_devices)
439     ASSERT_EQ(replica_tpu_devices.size(), 2);
440 
441   EXPECT_EQ(tpu_devices[0][0].device,
442             "/job:worker/replica:0/task:0/device:TPU:0");
443   EXPECT_EQ(tpu_devices[0][0].host,
444             "/job:worker/replica:0/task:0/device:CPU:0");
445   EXPECT_EQ(tpu_devices[0][1].device,
446             "/job:worker/replica:0/task:1/device:TPU:3");
447   EXPECT_EQ(tpu_devices[0][1].host,
448             "/job:worker/replica:0/task:1/device:CPU:0");
449   EXPECT_EQ(tpu_devices[1][0].device,
450             "/job:worker/replica:0/task:0/device:TPU:1");
451   EXPECT_EQ(tpu_devices[1][0].host,
452             "/job:worker/replica:0/task:0/device:CPU:0");
453   EXPECT_EQ(tpu_devices[1][1].device,
454             "/job:worker/replica:0/task:1/device:TPU:2");
455   EXPECT_EQ(tpu_devices[1][1].host,
456             "/job:worker/replica:0/task:1/device:CPU:0");
457   EXPECT_EQ(tpu_devices[2][0].device,
458             "/job:worker/replica:0/task:0/device:TPU:3");
459   EXPECT_EQ(tpu_devices[2][0].host,
460             "/job:worker/replica:0/task:0/device:CPU:0");
461   EXPECT_EQ(tpu_devices[2][1].device,
462             "/job:worker/replica:0/task:1/device:TPU:0");
463   EXPECT_EQ(tpu_devices[2][1].host,
464             "/job:worker/replica:0/task:1/device:CPU:0");
465   EXPECT_EQ(tpu_devices[3][0].device,
466             "/job:worker/replica:0/task:0/device:TPU:2");
467   EXPECT_EQ(tpu_devices[3][0].host,
468             "/job:worker/replica:0/task:0/device:CPU:0");
469   EXPECT_EQ(tpu_devices[3][1].device,
470             "/job:worker/replica:0/task:1/device:TPU:1");
471   EXPECT_EQ(tpu_devices[3][1].host,
472             "/job:worker/replica:0/task:1/device:CPU:0");
473 
474   auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
475   ASSERT_TRUE(xla_device_assignment.hasValue());
476   EXPECT_EQ(xla_device_assignment->replica_count(), 4);
477   EXPECT_EQ(xla_device_assignment->computation_count(), 2);
478   ASSERT_EQ(xla_device_assignment->computation_devices_size(), 2);
479   const auto& computation_device_0 =
480       xla_device_assignment->computation_devices(0);
481   ASSERT_EQ(computation_device_0.replica_device_ids_size(), 4);
482   const auto& computation_device_1 =
483       xla_device_assignment->computation_devices(1);
484   ASSERT_EQ(computation_device_1.replica_device_ids_size(), 4);
485 
486   EXPECT_EQ(computation_device_0.replica_device_ids(0), 0);
487   EXPECT_EQ(computation_device_0.replica_device_ids(1), 4);
488   EXPECT_EQ(computation_device_0.replica_device_ids(2), 2);
489   EXPECT_EQ(computation_device_0.replica_device_ids(3), 6);
490   EXPECT_EQ(computation_device_1.replica_device_ids(0), 1);
491   EXPECT_EQ(computation_device_1.replica_device_ids(1), 5);
492   EXPECT_EQ(computation_device_1.replica_device_ids(2), 3);
493   EXPECT_EQ(computation_device_1.replica_device_ids(3), 7);
494 }
495 
TEST(TPURewriteDeviceUtilTest,ValidGeneralDeviceAssignmentMesh1x2x1x3)496 TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) {
497   tpu::TopologyProto topology_proto;
498   {
499     topology_proto.add_mesh_shape(1);
500     topology_proto.add_mesh_shape(2);
501     topology_proto.add_mesh_shape(1);
502     topology_proto.add_mesh_shape(3);
503     topology_proto.set_num_tasks(3);
504     topology_proto.set_num_tpu_devices_per_task(2);
505     topology_proto.add_device_coordinates(0);
506     topology_proto.add_device_coordinates(0);
507     topology_proto.add_device_coordinates(0);
508     topology_proto.add_device_coordinates(0);
509     topology_proto.add_device_coordinates(0);
510     topology_proto.add_device_coordinates(1);
511     topology_proto.add_device_coordinates(0);
512     topology_proto.add_device_coordinates(0);
513     topology_proto.add_device_coordinates(0);
514     topology_proto.add_device_coordinates(1);
515     topology_proto.add_device_coordinates(0);
516     topology_proto.add_device_coordinates(1);
517     topology_proto.add_device_coordinates(0);
518     topology_proto.add_device_coordinates(0);
519     topology_proto.add_device_coordinates(0);
520     topology_proto.add_device_coordinates(1);
521     topology_proto.add_device_coordinates(0);
522     topology_proto.add_device_coordinates(0);
523     topology_proto.add_device_coordinates(0);
524     topology_proto.add_device_coordinates(2);
525     topology_proto.add_device_coordinates(0);
526     topology_proto.add_device_coordinates(1);
527     topology_proto.add_device_coordinates(0);
528     topology_proto.add_device_coordinates(2);
529   }
530 
531   std::string topology_attr = topology_proto.SerializeAsString();
532   std::vector<int64_t> device_assignment_attr{
533       0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 2, 0, 1, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0};
534 
535   llvm::SmallVector<Device, 8> devices;
536   std::vector<std::string> device_names =
537       MakeDeviceSet(/*num_tasks=*/3, /*num_devices_per_task=*/2);
538   ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
539 
540   auto status_or = GetTPUCompilationAndExecutionDevices(
541       devices, /*num_replicas=*/2, /*num_cores_per_replica=*/3, topology_attr,
542       device_assignment_attr);
543 
544   TF_ASSERT_OK(status_or.status());
545 
546   auto& tpu_device_assignment = status_or.ValueOrDie();
547   EXPECT_EQ(tpu_device_assignment.compilation_device,
548             "/job:worker/replica:0/task:0/device:CPU:0");
549 
550   auto& tpu_devices = tpu_device_assignment.tpu_devices;
551   ASSERT_EQ(tpu_devices.size(), 2);
552   for (const auto& replica_tpu_devices : tpu_devices)
553     ASSERT_EQ(replica_tpu_devices.size(), 3);
554 
555   EXPECT_EQ(tpu_devices[0][0].device,
556             "/job:worker/replica:0/task:1/device:TPU:1");
557   EXPECT_EQ(tpu_devices[0][0].host,
558             "/job:worker/replica:0/task:1/device:CPU:0");
559   EXPECT_EQ(tpu_devices[0][1].device,
560             "/job:worker/replica:0/task:1/device:TPU:0");
561   EXPECT_EQ(tpu_devices[0][1].host,
562             "/job:worker/replica:0/task:1/device:CPU:0");
563   EXPECT_EQ(tpu_devices[0][2].device,
564             "/job:worker/replica:0/task:2/device:TPU:0");
565   EXPECT_EQ(tpu_devices[0][2].host,
566             "/job:worker/replica:0/task:2/device:CPU:0");
567   EXPECT_EQ(tpu_devices[1][0].device,
568             "/job:worker/replica:0/task:2/device:TPU:1");
569   EXPECT_EQ(tpu_devices[1][0].host,
570             "/job:worker/replica:0/task:2/device:CPU:0");
571   EXPECT_EQ(tpu_devices[1][1].device,
572             "/job:worker/replica:0/task:0/device:TPU:0");
573   EXPECT_EQ(tpu_devices[1][1].host,
574             "/job:worker/replica:0/task:0/device:CPU:0");
575   EXPECT_EQ(tpu_devices[1][2].device,
576             "/job:worker/replica:0/task:0/device:TPU:1");
577   EXPECT_EQ(tpu_devices[1][2].host,
578             "/job:worker/replica:0/task:0/device:CPU:0");
579 
580   auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
581   ASSERT_TRUE(xla_device_assignment.hasValue());
582   EXPECT_EQ(xla_device_assignment->replica_count(), 2);
583   EXPECT_EQ(xla_device_assignment->computation_count(), 3);
584   ASSERT_EQ(xla_device_assignment->computation_devices_size(), 3);
585   const auto& computation_device_0 =
586       xla_device_assignment->computation_devices(0);
587   ASSERT_EQ(computation_device_0.replica_device_ids_size(), 2);
588   const auto& computation_device_1 =
589       xla_device_assignment->computation_devices(1);
590   ASSERT_EQ(computation_device_1.replica_device_ids_size(), 2);
591   const auto& computation_device_2 =
592       xla_device_assignment->computation_devices(2);
593   ASSERT_EQ(computation_device_2.replica_device_ids_size(), 2);
594 
595   EXPECT_EQ(computation_device_0.replica_device_ids(0), 1);
596   EXPECT_EQ(computation_device_0.replica_device_ids(1), 5);
597   EXPECT_EQ(computation_device_1.replica_device_ids(0), 4);
598   EXPECT_EQ(computation_device_1.replica_device_ids(1), 0);
599   EXPECT_EQ(computation_device_2.replica_device_ids(0), 2);
600   EXPECT_EQ(computation_device_2.replica_device_ids(1), 3);
601 }
602 
TEST(TPURewriteDeviceUtilTest,TestGetDeviceCoordinates)603 TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) {
604   mlir::MLIRContext context;
605   mlir::Builder builder(&context);
606   auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3});
607   auto status_or_device_coodinates =
608       GetDeviceCoordinates(device_assignment_attr);
609   ASSERT_TRUE(status_or_device_coodinates.ok());
610   auto device_coordinates = status_or_device_coodinates.ConsumeValueOrDie();
611   EXPECT_EQ(device_coordinates[0], 1);
612   EXPECT_EQ(device_coordinates[1], 2);
613   EXPECT_EQ(device_coordinates[2], 3);
614 }
615 
TEST(TPURewriteDeviceUtilTest,TestInvalidAttrForDeviceAssignmentDisallowed)616 TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
617   mlir::MLIRContext context;
618   mlir::Builder builder(&context);
619   auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0});
620   auto status_or_device_coodinates =
621       GetDeviceCoordinates(device_assignment_attr);
622   ASSERT_TRUE(!status_or_device_coodinates.ok());
623   EXPECT_EQ(status_or_device_coodinates.status().error_message(),
624             "bad 'device_assignment' attribute at index 0, not an int");
625 }
626 
TEST(TPURewriteDeviceUtilTest,TestHasModelParallelismFalse)627 TEST(TPURewriteDeviceUtilTest, TestHasModelParallelismFalse) {
628   mlir::MLIRContext context;
629   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
630   mlir::OwningModuleRef module_ref =
631       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
632   mlir::OpBuilder builder(module_ref->getBodyRegion());
633 
634   llvm::SmallVector<mlir::Type, 8> result_types;
635   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
636       mlir::UnknownLoc::get(&context), result_types);
637   cluster->setAttr(kNumCoresPerReplicaAttr,
638                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
639   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
640   cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
641 
642   EXPECT_FALSE(HasModelParallelism(cluster));
643 }
644 
TEST(TPURewriteDeviceUtilTest,TestHasModelParallelismTrue)645 TEST(TPURewriteDeviceUtilTest, TestHasModelParallelismTrue) {
646   mlir::MLIRContext context;
647   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
648   mlir::OwningModuleRef module_ref =
649       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
650   mlir::OpBuilder builder(module_ref->getBodyRegion());
651 
652   llvm::SmallVector<mlir::Type, 8> result_types;
653   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
654       mlir::UnknownLoc::get(&context), result_types);
655   cluster->setAttr(kNumCoresPerReplicaAttr,
656                    builder.getIntegerAttr(builder.getIntegerType(64), 5));
657   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
658   cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
659 
660   EXPECT_TRUE(HasModelParallelism(cluster));
661 }
662 
TEST(TPURewriteDeviceUtilTest,TestHasModelParallelismFalseMissingCoresPerReplicaAttr)663 TEST(TPURewriteDeviceUtilTest,
664      TestHasModelParallelismFalseMissingCoresPerReplicaAttr) {
665   mlir::MLIRContext context;
666   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
667   mlir::OwningModuleRef module_ref =
668       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
669   mlir::OpBuilder builder(module_ref->getBodyRegion());
670 
671   llvm::SmallVector<mlir::Type, 8> result_types;
672   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
673       mlir::UnknownLoc::get(&context), result_types);
674   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
675   cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
676 
677   EXPECT_FALSE(HasModelParallelism(cluster));
678 }
679 
TEST(TPURewriteDeviceUtilTest,TestGetHostFailDeviceMissingAttributes)680 TEST(TPURewriteDeviceUtilTest, TestGetHostFailDeviceMissingAttributes) {
681   mlir::MLIRContext context;
682   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
683   mlir::OwningModuleRef module_ref =
684       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
685   mlir::OpBuilder builder(module_ref->getBodyRegion());
686   llvm::SmallVector<mlir::Type, 8> result_types;
687   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
688       mlir::UnknownLoc::get(&context), result_types);
689 
690   mlir::TF::RuntimeDevices devices;
691   std::string host_device;
692   EXPECT_TRUE(mlir::failed(
693       GetHostDeviceOutsideComputation(devices, cluster, &host_device)));
694 }
695 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailMissingTopology)696 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingTopology) {
697   mlir::MLIRContext context;
698   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
699   mlir::OwningModuleRef module_ref =
700       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
701   mlir::OpBuilder builder(module_ref->getBodyRegion());
702 
703   llvm::SmallVector<mlir::Type, 8> result_types;
704   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
705       mlir::UnknownLoc::get(&context), result_types);
706   cluster->setAttr(kNumCoresPerReplicaAttr,
707                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
708   cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
709 
710   mlir::TF::RuntimeDevices runtime_devices;
711   std::string host_device;
712   EXPECT_TRUE(mlir::failed(
713       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
714 }
715 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailMissingDeviceAssignment)716 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingDeviceAssignment) {
717   mlir::MLIRContext context;
718   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
719   mlir::OwningModuleRef module_ref =
720       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
721   mlir::OpBuilder builder(module_ref->getBodyRegion());
722 
723   llvm::SmallVector<mlir::Type, 8> result_types;
724   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
725       mlir::UnknownLoc::get(&context), result_types);
726   cluster->setAttr(kNumCoresPerReplicaAttr,
727                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
728   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
729 
730   mlir::TF::RuntimeDevices runtime_devices;
731   std::string host_device;
732   EXPECT_TRUE(mlir::failed(
733       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
734 }
735 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailBadDeviceAssignment)736 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceAssignment) {
737   mlir::MLIRContext context;
738   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
739   mlir::OwningModuleRef module_ref =
740       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
741   mlir::OpBuilder builder(module_ref->getBodyRegion());
742 
743   llvm::SmallVector<mlir::Type, 8> result_types;
744   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
745       mlir::UnknownLoc::get(&context), result_types);
746   cluster->setAttr(kNumCoresPerReplicaAttr,
747                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
748   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
749   cluster->setAttr(kDeviceAssignmentAttr,
750                    builder.getStrArrayAttr(llvm::ArrayRef<llvm::StringRef>(
751                        {"bad_device_assigment"})));
752 
753   mlir::TF::RuntimeDevices runtime_devices;
754   std::string host_device;
755   EXPECT_TRUE(mlir::failed(
756       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
757 }
758 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailBadDeviceName)759 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceName) {
760   mlir::MLIRContext context;
761   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
762   mlir::OwningModuleRef module_ref =
763       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
764   mlir::OpBuilder builder(module_ref->getBodyRegion());
765   (*module_ref)
766       ->setAttr("tf.devices",
767                 builder.getStrArrayAttr(
768                     llvm::ArrayRef<llvm::StringRef>({"bad_device_name"})));
769 
770   llvm::SmallVector<mlir::Type, 8> result_types;
771   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
772       mlir::UnknownLoc::get(&context), result_types);
773   cluster->setAttr(kNumCoresPerReplicaAttr,
774                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
775   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
776   cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
777 
778   mlir::TF::RuntimeDevices runtime_devices;
779   (void)GetDevicesFromOp(*module_ref, &runtime_devices);
780   std::string host_device;
781   EXPECT_TRUE(mlir::failed(
782       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
783 }
784 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceTPUReplicate)785 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) {
786   mlir::MLIRContext context;
787   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
788   mlir::OwningModuleRef module_ref =
789       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
790   mlir::OpBuilder builder(module_ref->getBodyRegion());
791 
792   llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<llvm::StringRef, 4>>
793       devices;
794   auto replicate = builder.create<mlir::tf_device::ReplicateOp>(
795       mlir::UnknownLoc::get(&context), /*num_replicas=*/2, devices,
796       llvm::ArrayRef<std::pair<mlir::ValueRange, mlir::Type>>{},
797       mlir::ValueRange{}, mlir::TypeRange{});
798   builder.setInsertionPoint(&replicate.body().front(),
799                             replicate.body().front().begin());
800 
801   llvm::SmallVector<mlir::Type, 8> result_types;
802   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
803       mlir::UnknownLoc::get(&context), result_types);
804 
805   mlir::TF::RuntimeDevices runtime_devices;
806   std::string host_device;
807   EXPECT_TRUE(mlir::succeeded(
808       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
809   EXPECT_EQ(host_device, kTPUReplicatedHost);
810 }
811 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceNotReplicated)812 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceNotReplicated) {
813   mlir::MLIRContext context;
814   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
815   mlir::OwningModuleRef module_ref =
816       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
817   mlir::OpBuilder builder(module_ref->getBodyRegion());
818   (*module_ref)
819       ->setAttr("tf.devices",
820                 builder.getStrArrayAttr(llvm::ArrayRef<llvm::StringRef>(
821                     {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
822                      "/job:localhost/replica:0/task:0/device:TPU:0",
823                      "/job:worker/replica:0/task:0/device:CPU:0"})));
824 
825   llvm::SmallVector<mlir::Type, 8> result_types;
826   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
827       mlir::UnknownLoc::get(&context), result_types);
828   cluster->setAttr(kNumCoresPerReplicaAttr,
829                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
830   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
831   cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
832 
833   mlir::TF::RuntimeDevices runtime_devices;
834   (void)GetDevicesFromOp(*module_ref, &runtime_devices);
835   std::string host_device;
836   EXPECT_TRUE(mlir::succeeded(
837       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
838   EXPECT_EQ(host_device, "/job:localhost/replica:0/task:0/device:CPU:0");
839 }
840 
TEST(TPURewriteDeviceUtilTest,TestIsTPUDevice)841 TEST(TPURewriteDeviceUtilTest, TestIsTPUDevice) {
842   EXPECT_TRUE(IsTPUDevice("/job:localhost/replica:0/task:0/device:TPU:0"));
843   EXPECT_FALSE(IsTPUDevice("/job:localhost/replica:0/task:0/device:CPU:0"));
844   EXPECT_FALSE(IsTPUDevice("INVALID_DEVICE"));
845 }
846 
847 }  // anonymous namespace
848 }  // namespace tensorflow
849