1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
17
18 #include <cstdint>
19 #include <tuple>
20
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/IR/Builders.h" // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
24 #include "mlir/IR/MLIRContext.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
29 #include "tensorflow/core/util/device_name_utils.h"
30
31 namespace tensorflow {
32 namespace {
33
34 using Device = DeviceNameUtils::ParsedName;
35
DeviceNamesToParsedNames(llvm::ArrayRef<std::string> device_names,llvm::SmallVectorImpl<Device> * parsed_devices)36 bool DeviceNamesToParsedNames(llvm::ArrayRef<std::string> device_names,
37 llvm::SmallVectorImpl<Device>* parsed_devices) {
38 parsed_devices->reserve(device_names.size());
39 for (const auto& device_name : device_names) {
40 Device parsed_name;
41 if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name))
42 return false;
43
44 parsed_devices->push_back(parsed_name);
45 }
46 return true;
47 }
48
49 using DeviceNames = llvm::SmallVector<std::string, 8>;
50
51 struct ParameterizedDeviceSetTest
52 : ::testing::TestWithParam<std::tuple<DeviceNames, std::string>> {};
53
TEST_P(ParameterizedDeviceSetTest,BadDeviceSet)54 TEST_P(ParameterizedDeviceSetTest, BadDeviceSet) {
55 llvm::SmallVector<Device, 8> devices;
56 ASSERT_TRUE(DeviceNamesToParsedNames(std::get<0>(GetParam()), &devices));
57 std::string topology_attr;
58 std::vector<int64_t> device_assignment_attr;
59
60 auto status_or = GetTPUCompilationAndExecutionDevices(
61 devices, /*num_replicas=*/1, /*num_cores_per_replica=*/1, topology_attr,
62 device_assignment_attr);
63 ASSERT_FALSE(status_or.ok());
64 EXPECT_EQ(status_or.status().error_message(), std::get<1>(GetParam()));
65 }
66
67 INSTANTIATE_TEST_SUITE_P(
68 BadDeviceSet, ParameterizedDeviceSetTest,
69 ::testing::Values(
70 std::make_tuple<DeviceNames, std::string>(
71 {"/job:localhost/replica:0/task:0/device:CPU:0"},
72 "no TPU_SYSTEM devices found"),
73 std::make_tuple<DeviceNames, std::string>(
74 {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
75 "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0"},
76 "found TPU_SYSTEM devices with conflicting jobs 'localhost' and "
77 "'worker'"),
78 std::make_tuple<DeviceNames, std::string>(
79 {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
80 "/job:localhost/replica:1/task:0/device:TPU_SYSTEM:0"},
81 "found TPU_SYSTEM devices with conflicting replicas '0' and '1'"),
82 std::make_tuple<DeviceNames, std::string>(
83 {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
84 "/job:localhost/replica:0/task:0/device:TPU:0",
85 "/job:localhost/replica:0/task:0/device:TPU:1",
86 "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0",
87 "/job:localhost/replica:0/task:1/device:TPU:0"},
88 "expected the number of TPU devices per host to be 2, got 1")));
89
90 struct ParameterizedMetadataTest
91 : ::testing::TestWithParam<std::tuple<int, int, std::string,
92 std::vector<int64_t>, std::string>> {
93 };
94
TEST_P(ParameterizedMetadataTest,BadMetadata)95 TEST_P(ParameterizedMetadataTest, BadMetadata) {
96 llvm::SmallVector<Device, 8> devices;
97 ASSERT_TRUE(DeviceNamesToParsedNames(
98 {"/job:worker/replica:0/task:0/device:TPU_SYSTEM:0",
99 "/job:worker/replica:0/task:0/device:TPU:0",
100 "/job:worker/replica:0/task:1/device:TPU_SYSTEM:0",
101 "/job:worker/replica:0/task:1/device:TPU:0"},
102 &devices));
103 std::string compilation_device;
104 llvm::SmallVector<llvm::SmallVector<std::string, 8>, 8> execution_devices;
105 llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment;
106
107 auto status_or = GetTPUCompilationAndExecutionDevices(
108 devices, std::get<0>(GetParam()), std::get<1>(GetParam()),
109 std::get<2>(GetParam()), std::get<3>(GetParam()));
110 ASSERT_FALSE(status_or.ok());
111 EXPECT_EQ(status_or.status().error_message(), std::get<4>(GetParam()));
112 }
113
TopologyWithMeshShape(llvm::ArrayRef<int> mesh_shape)114 std::string TopologyWithMeshShape(llvm::ArrayRef<int> mesh_shape) {
115 tpu::TopologyProto topology_proto;
116 for (int mesh_dim : mesh_shape) topology_proto.add_mesh_shape(mesh_dim);
117 return topology_proto.SerializeAsString();
118 }
119
TopologyWithMeshShapeAndTasks(llvm::ArrayRef<int> mesh_shape,int num_tasks,int num_tpu_devices_per_task)120 std::string TopologyWithMeshShapeAndTasks(llvm::ArrayRef<int> mesh_shape,
121 int num_tasks,
122 int num_tpu_devices_per_task) {
123 tpu::TopologyProto topology_proto;
124 for (int mesh_dim : mesh_shape) topology_proto.add_mesh_shape(mesh_dim);
125 topology_proto.set_num_tasks(num_tasks);
126 topology_proto.set_num_tpu_devices_per_task(num_tpu_devices_per_task);
127 return topology_proto.SerializeAsString();
128 }
129
TopologyWithDeviceCoordinates(llvm::ArrayRef<int> device_coordinates)130 std::string TopologyWithDeviceCoordinates(
131 llvm::ArrayRef<int> device_coordinates) {
132 tpu::TopologyProto topology_proto;
133 topology_proto.add_mesh_shape(2);
134 topology_proto.add_mesh_shape(1);
135 topology_proto.add_mesh_shape(1);
136 topology_proto.add_mesh_shape(1);
137 topology_proto.set_num_tasks(2);
138 topology_proto.set_num_tpu_devices_per_task(1);
139 for (int device_coordinate : device_coordinates)
140 topology_proto.add_device_coordinates(device_coordinate);
141 return topology_proto.SerializeAsString();
142 }
143
144 INSTANTIATE_TEST_SUITE_P(
145 BadFullMeshMetadata, ParameterizedMetadataTest,
146 ::testing::Values(
147 std::make_tuple(
148 2, 1, "", std::vector<int64_t>{0},
149 "'device_assignment' must not be set when 'topology' is not set"),
150 std::make_tuple(8, 1, "", std::vector<int64_t>(),
151 "'num_replicas' must be equal to 1 or 2, got 8"),
152 std::make_tuple(2, 2, "", std::vector<int64_t>(),
153 "'num_cores_per_replica' must be equal to 1, got 2")));
154
155 INSTANTIATE_TEST_SUITE_P(
156 BadGeneralTopologyMetadata, ParameterizedMetadataTest,
157 ::testing::Values(
158 std::make_tuple(
159 2, 1, "BAD_TOPOLOGY", std::vector<int64_t>(),
160 "failed to parse 'topology' attribute to TopologyProto"),
161 std::make_tuple(4, 2, TopologyWithMeshShape({0}),
162 std::vector<int64_t>(),
163 "'topology' 'mesh_shape' must be rank 4, got rank 1"),
164 std::make_tuple(
165 2, 1, TopologyWithMeshShape({2, 0, 1, 2}), std::vector<int64_t>(),
166 "'topology' 'mesh_shape' dimension 1 must be positive, got 0"),
167 std::make_tuple(2, 1, TopologyWithMeshShapeAndTasks({1, 1, 1, 1}, 1, 1),
168 std::vector<int64_t>(),
169 "number of tasks from available TPU devices must be "
170 "'num_tasks' in 'topology' (1), got 2"),
171 std::make_tuple(2, 1, TopologyWithMeshShapeAndTasks({1, 1, 1, 1}, 2, 2),
172 std::vector<int64_t>(),
173 "number of TPU devices available per task must be "
174 "'num_tpu_devices_per_task' in 'topology' (2), got 1"),
175 std::make_tuple(
176 2, 1, TopologyWithDeviceCoordinates({}), std::vector<int64_t>(),
177 "length of 'device_coordinates' in 'topology' must be 'num_tasks' "
178 "* 'num_tpus_per_task' * 4 (2 * 1 * 4), got 0"),
179 std::make_tuple(
180 2, 1, TopologyWithDeviceCoordinates({-1, 0, 0, 0, 1, 0, 0, 0}),
181 std::vector<int64_t>(),
182 "device coordinate (-1, 0, 0, 0) in 'topology' is outside "
183 "of mesh shape (2, 1, 1, 1)"),
184 std::make_tuple(
185 2, 1, TopologyWithDeviceCoordinates({2, 0, 0, 0, 1, 0, 0, 0}),
186 std::vector<int64_t>(),
187 "device coordinate (2, 0, 0, 0) in 'topology' is outside "
188 "of mesh shape (2, 1, 1, 1)"),
189 std::make_tuple(
190 2, 1, TopologyWithDeviceCoordinates({0, -1, 0, 0, 1, 0, 0, 0}),
191 std::vector<int64_t>(),
192 "device coordinate (0, -1, 0, 0) in 'topology' is outside "
193 "of mesh shape (2, 1, 1, 1)"),
194 std::make_tuple(
195 2, 1, TopologyWithDeviceCoordinates({0, 1, 0, 0, 1, 0, 0, 0}),
196 std::vector<int64_t>(),
197 "device coordinate (0, 1, 0, 0) in 'topology' is outside "
198 "of mesh shape (2, 1, 1, 1)"),
199 std::make_tuple(
200 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, -1, 1, 0, 0, 0}),
201 std::vector<int64_t>(),
202 "device coordinate (0, 0, 0, -1) in 'topology' is outside "
203 "of mesh shape (2, 1, 1, 1)"),
204 std::make_tuple(
205 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 1, 1, 0, 0, 0}),
206 std::vector<int64_t>(),
207 "device coordinate (0, 0, 0, 1) in 'topology' is outside "
208 "of mesh shape (2, 1, 1, 1)"),
209 std::make_tuple(
210 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 0, 0, 0, 0}),
211 std::vector<int64_t>(),
212 "'topology' has duplicate device coordinate (0, 0, 0, 0)")));
213
214 INSTANTIATE_TEST_SUITE_P(
215 BadGeneralDeviceAssignmentMetadata, ParameterizedMetadataTest,
216 ::testing::Values(
217 std::make_tuple(2, 1,
218 TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
219 std::vector<int64_t>(),
220 "length of 'device_assignment' must be 'num_replicas' "
221 "* 'num_cores_per_replica' * 4 (2 * 1 * 4), got 0"),
222 std::make_tuple(
223 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
224 std::vector<int64_t>{-1, 0, 0, 0, 0, 0, 0, 0},
225 "device coordinate (-1, 0, 0, 0) in 'device_assignment' "
226 "is outside of mesh shape (2, 1, 1, 1)"),
227 std::make_tuple(
228 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
229 std::vector<int64_t>{2, 0, 0, 0, 0, 0, 0, 0},
230 "device coordinate (2, 0, 0, 0) in 'device_assignment' is "
231 "outside of mesh shape (2, 1, 1, 1)"),
232 std::make_tuple(
233 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
234 std::vector<int64_t>{0, -1, 0, 0, 0, 0, 0, 0},
235 "device coordinate (0, -1, 0, 0) in 'device_assignment' "
236 "is outside of mesh shape (2, 1, 1, 1)"),
237 std::make_tuple(
238 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
239 std::vector<int64_t>{0, 1, 0, 0, 0, 0, 0, 0},
240 "device coordinate (0, 1, 0, 0) in 'device_assignment' is "
241 "outside of mesh shape (2, 1, 1, 1)"),
242 std::make_tuple(
243 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
244 std::vector<int64_t>{0, 0, 0, -1, 0, 0, 0, 0},
245 "device coordinate (0, 0, 0, -1) in 'device_assignment' "
246 "is outside of mesh shape (2, 1, 1, 1)"),
247 std::make_tuple(
248 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
249 std::vector<int64_t>{0, 0, 0, 1, 0, 0, 0, 0},
250 "device coordinate (0, 0, 0, 1) in 'device_assignment' is "
251 "outside of mesh shape (2, 1, 1, 1)"),
252 std::make_tuple(2, 1,
253 TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
254 std::vector<int64_t>{0, 0, 0, 0, 0, 0, 0, 0},
255 "'device_assignment' has duplicate device coordinate "
256 "(0, 0, 0, 0)")));
257
MakeDeviceSet(int num_tasks,int num_devices_per_task)258 std::vector<std::string> MakeDeviceSet(int num_tasks,
259 int num_devices_per_task) {
260 std::vector<std::string> devices{
261 "/job:localhost/replica:0/task:0/device:CPU:0"};
262 devices.reserve(num_tasks * num_devices_per_task + num_tasks + 1);
263
264 for (int task = 0; task < num_tasks; ++task) {
265 devices.push_back(
266 llvm::formatv("/job:worker/replica:0/task:{0}/device:CPU:0", task)
267 .str());
268 devices.push_back(
269 llvm::formatv("/job:worker/replica:0/task:{0}/device:TPU_SYSTEM:0",
270 task)
271 .str());
272 for (int device = 0; device < num_devices_per_task; ++device)
273 devices.push_back(
274 llvm::formatv("/job:worker/replica:0/task:{0}/device:TPU:{1}", task,
275 device)
276 .str());
277 }
278
279 return devices;
280 }
281
TEST(TPURewriteDeviceUtilTest,BadGeneralDeviceAssignmentMetadataMissingDevice)282 TEST(TPURewriteDeviceUtilTest,
283 BadGeneralDeviceAssignmentMetadataMissingDevice) {
284 tpu::TopologyProto topology_proto;
285 {
286 topology_proto.add_mesh_shape(2);
287 topology_proto.add_mesh_shape(1);
288 topology_proto.add_mesh_shape(1);
289 topology_proto.add_mesh_shape(1);
290 topology_proto.set_num_tasks(1);
291 topology_proto.set_num_tpu_devices_per_task(1);
292 topology_proto.add_device_coordinates(0);
293 topology_proto.add_device_coordinates(0);
294 topology_proto.add_device_coordinates(0);
295 topology_proto.add_device_coordinates(0);
296 }
297
298 std::string topology_attr = topology_proto.SerializeAsString();
299 std::vector<int64_t> device_assignment_attr{1, 0, 0, 0};
300
301 llvm::SmallVector<Device, 8> devices;
302 std::vector<std::string> device_names =
303 MakeDeviceSet(/*num_tasks=*/1, /*num_devices_per_task=*/1);
304 ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
305
306 auto status_or = GetTPUCompilationAndExecutionDevices(
307 devices, /*num_replicas=*/1, /*num_cores_per_replica=*/1, topology_attr,
308 device_assignment_attr);
309
310 ASSERT_FALSE(status_or.ok());
311 EXPECT_EQ(status_or.status().error_message(),
312 "no TPU device found for 'device_assignment' device coordinate (1, "
313 "0, 0, 0)");
314 }
315
TEST(TPURewriteDeviceUtilTest,ValidFullMeshDeviceAssignment)316 TEST(TPURewriteDeviceUtilTest, ValidFullMeshDeviceAssignment) {
317 llvm::SmallVector<Device, 8> devices;
318 std::vector<std::string> device_names =
319 MakeDeviceSet(/*num_tasks=*/2, /*num_devices_per_task=*/4);
320 ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
321 std::string topology_attr;
322 std::vector<int64_t> device_assignment_attr;
323
324 auto status_or = GetTPUCompilationAndExecutionDevices(
325 devices, /*num_replicas=*/8, /*num_cores_per_replica=*/1, topology_attr,
326 device_assignment_attr);
327
328 TF_ASSERT_OK(status_or.status());
329
330 const auto& tpu_device_assignment = status_or.ValueOrDie();
331 EXPECT_EQ(tpu_device_assignment.compilation_device,
332 "/job:worker/replica:0/task:0/device:CPU:0");
333 const auto& tpu_devices = tpu_device_assignment.tpu_devices;
334 ASSERT_EQ(tpu_devices.size(), 8);
335 for (const auto& replica_tpu_devices : tpu_devices)
336 ASSERT_EQ(replica_tpu_devices.size(), 1);
337
338 EXPECT_EQ(tpu_devices[0][0].device,
339 "/job:worker/replica:0/task:0/device:TPU:0");
340 EXPECT_EQ(tpu_devices[0][0].host,
341 "/job:worker/replica:0/task:0/device:CPU:0");
342 EXPECT_EQ(tpu_devices[1][0].device,
343 "/job:worker/replica:0/task:0/device:TPU:1");
344 EXPECT_EQ(tpu_devices[1][0].host,
345 "/job:worker/replica:0/task:0/device:CPU:0");
346 EXPECT_EQ(tpu_devices[2][0].device,
347 "/job:worker/replica:0/task:0/device:TPU:2");
348 EXPECT_EQ(tpu_devices[2][0].host,
349 "/job:worker/replica:0/task:0/device:CPU:0");
350 EXPECT_EQ(tpu_devices[3][0].device,
351 "/job:worker/replica:0/task:0/device:TPU:3");
352 EXPECT_EQ(tpu_devices[3][0].host,
353 "/job:worker/replica:0/task:0/device:CPU:0");
354 EXPECT_EQ(tpu_devices[4][0].device,
355 "/job:worker/replica:0/task:1/device:TPU:0");
356 EXPECT_EQ(tpu_devices[4][0].host,
357 "/job:worker/replica:0/task:1/device:CPU:0");
358 EXPECT_EQ(tpu_devices[5][0].device,
359 "/job:worker/replica:0/task:1/device:TPU:1");
360 EXPECT_EQ(tpu_devices[5][0].host,
361 "/job:worker/replica:0/task:1/device:CPU:0");
362 EXPECT_EQ(tpu_devices[6][0].device,
363 "/job:worker/replica:0/task:1/device:TPU:2");
364 EXPECT_EQ(tpu_devices[6][0].host,
365 "/job:worker/replica:0/task:1/device:CPU:0");
366 EXPECT_EQ(tpu_devices[7][0].device,
367 "/job:worker/replica:0/task:1/device:TPU:3");
368 EXPECT_EQ(tpu_devices[7][0].host,
369 "/job:worker/replica:0/task:1/device:CPU:0");
370
371 EXPECT_FALSE(tpu_device_assignment.xla_device_assignment.hasValue());
372 }
373
TEST(TPURewriteDeviceUtilTest,ValidGeneralDeviceAssignmentMesh2x2x2)374 TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh2x2x2) {
375 tpu::TopologyProto topology_proto;
376 {
377 topology_proto.add_mesh_shape(2);
378 topology_proto.add_mesh_shape(2);
379 topology_proto.add_mesh_shape(1);
380 topology_proto.add_mesh_shape(2);
381 topology_proto.set_num_tasks(2);
382 topology_proto.set_num_tpu_devices_per_task(4);
383 topology_proto.add_device_coordinates(0);
384 topology_proto.add_device_coordinates(0);
385 topology_proto.add_device_coordinates(0);
386 topology_proto.add_device_coordinates(0);
387 topology_proto.add_device_coordinates(0);
388 topology_proto.add_device_coordinates(1);
389 topology_proto.add_device_coordinates(0);
390 topology_proto.add_device_coordinates(0);
391 topology_proto.add_device_coordinates(1);
392 topology_proto.add_device_coordinates(1);
393 topology_proto.add_device_coordinates(0);
394 topology_proto.add_device_coordinates(0);
395 topology_proto.add_device_coordinates(1);
396 topology_proto.add_device_coordinates(0);
397 topology_proto.add_device_coordinates(0);
398 topology_proto.add_device_coordinates(0);
399 topology_proto.add_device_coordinates(1);
400 topology_proto.add_device_coordinates(0);
401 topology_proto.add_device_coordinates(0);
402 topology_proto.add_device_coordinates(1);
403 topology_proto.add_device_coordinates(1);
404 topology_proto.add_device_coordinates(1);
405 topology_proto.add_device_coordinates(0);
406 topology_proto.add_device_coordinates(1);
407 topology_proto.add_device_coordinates(0);
408 topology_proto.add_device_coordinates(1);
409 topology_proto.add_device_coordinates(0);
410 topology_proto.add_device_coordinates(1);
411 topology_proto.add_device_coordinates(0);
412 topology_proto.add_device_coordinates(0);
413 topology_proto.add_device_coordinates(0);
414 topology_proto.add_device_coordinates(1);
415 }
416
417 std::string topology_attr = topology_proto.SerializeAsString();
418 std::vector<int64_t> device_assignment_attr{0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,
419 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0,
420 0, 1, 1, 1, 0, 0, 1, 1, 0, 1};
421
422 llvm::SmallVector<Device, 8> devices;
423 std::vector<std::string> device_names =
424 MakeDeviceSet(/*num_tasks=*/2, /*num_devices_per_task=*/4);
425 ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
426
427 auto status_or = GetTPUCompilationAndExecutionDevices(
428 devices, /*num_replicas=*/4, /*num_cores_per_replica=*/2, topology_attr,
429 device_assignment_attr);
430
431 TF_ASSERT_OK(status_or.status());
432
433 const auto& tpu_device_assignment = status_or.ValueOrDie();
434 EXPECT_EQ(tpu_device_assignment.compilation_device,
435 "/job:worker/replica:0/task:0/device:CPU:0");
436 const auto& tpu_devices = tpu_device_assignment.tpu_devices;
437 ASSERT_EQ(tpu_devices.size(), 4);
438 for (const auto& replica_tpu_devices : tpu_devices)
439 ASSERT_EQ(replica_tpu_devices.size(), 2);
440
441 EXPECT_EQ(tpu_devices[0][0].device,
442 "/job:worker/replica:0/task:0/device:TPU:0");
443 EXPECT_EQ(tpu_devices[0][0].host,
444 "/job:worker/replica:0/task:0/device:CPU:0");
445 EXPECT_EQ(tpu_devices[0][1].device,
446 "/job:worker/replica:0/task:1/device:TPU:3");
447 EXPECT_EQ(tpu_devices[0][1].host,
448 "/job:worker/replica:0/task:1/device:CPU:0");
449 EXPECT_EQ(tpu_devices[1][0].device,
450 "/job:worker/replica:0/task:0/device:TPU:1");
451 EXPECT_EQ(tpu_devices[1][0].host,
452 "/job:worker/replica:0/task:0/device:CPU:0");
453 EXPECT_EQ(tpu_devices[1][1].device,
454 "/job:worker/replica:0/task:1/device:TPU:2");
455 EXPECT_EQ(tpu_devices[1][1].host,
456 "/job:worker/replica:0/task:1/device:CPU:0");
457 EXPECT_EQ(tpu_devices[2][0].device,
458 "/job:worker/replica:0/task:0/device:TPU:3");
459 EXPECT_EQ(tpu_devices[2][0].host,
460 "/job:worker/replica:0/task:0/device:CPU:0");
461 EXPECT_EQ(tpu_devices[2][1].device,
462 "/job:worker/replica:0/task:1/device:TPU:0");
463 EXPECT_EQ(tpu_devices[2][1].host,
464 "/job:worker/replica:0/task:1/device:CPU:0");
465 EXPECT_EQ(tpu_devices[3][0].device,
466 "/job:worker/replica:0/task:0/device:TPU:2");
467 EXPECT_EQ(tpu_devices[3][0].host,
468 "/job:worker/replica:0/task:0/device:CPU:0");
469 EXPECT_EQ(tpu_devices[3][1].device,
470 "/job:worker/replica:0/task:1/device:TPU:1");
471 EXPECT_EQ(tpu_devices[3][1].host,
472 "/job:worker/replica:0/task:1/device:CPU:0");
473
474 auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
475 ASSERT_TRUE(xla_device_assignment.hasValue());
476 EXPECT_EQ(xla_device_assignment->replica_count(), 4);
477 EXPECT_EQ(xla_device_assignment->computation_count(), 2);
478 ASSERT_EQ(xla_device_assignment->computation_devices_size(), 2);
479 const auto& computation_device_0 =
480 xla_device_assignment->computation_devices(0);
481 ASSERT_EQ(computation_device_0.replica_device_ids_size(), 4);
482 const auto& computation_device_1 =
483 xla_device_assignment->computation_devices(1);
484 ASSERT_EQ(computation_device_1.replica_device_ids_size(), 4);
485
486 EXPECT_EQ(computation_device_0.replica_device_ids(0), 0);
487 EXPECT_EQ(computation_device_0.replica_device_ids(1), 4);
488 EXPECT_EQ(computation_device_0.replica_device_ids(2), 2);
489 EXPECT_EQ(computation_device_0.replica_device_ids(3), 6);
490 EXPECT_EQ(computation_device_1.replica_device_ids(0), 1);
491 EXPECT_EQ(computation_device_1.replica_device_ids(1), 5);
492 EXPECT_EQ(computation_device_1.replica_device_ids(2), 3);
493 EXPECT_EQ(computation_device_1.replica_device_ids(3), 7);
494 }
495
TEST(TPURewriteDeviceUtilTest,ValidGeneralDeviceAssignmentMesh1x2x1x3)496 TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) {
497 tpu::TopologyProto topology_proto;
498 {
499 topology_proto.add_mesh_shape(1);
500 topology_proto.add_mesh_shape(2);
501 topology_proto.add_mesh_shape(1);
502 topology_proto.add_mesh_shape(3);
503 topology_proto.set_num_tasks(3);
504 topology_proto.set_num_tpu_devices_per_task(2);
505 topology_proto.add_device_coordinates(0);
506 topology_proto.add_device_coordinates(0);
507 topology_proto.add_device_coordinates(0);
508 topology_proto.add_device_coordinates(0);
509 topology_proto.add_device_coordinates(0);
510 topology_proto.add_device_coordinates(1);
511 topology_proto.add_device_coordinates(0);
512 topology_proto.add_device_coordinates(0);
513 topology_proto.add_device_coordinates(0);
514 topology_proto.add_device_coordinates(1);
515 topology_proto.add_device_coordinates(0);
516 topology_proto.add_device_coordinates(1);
517 topology_proto.add_device_coordinates(0);
518 topology_proto.add_device_coordinates(0);
519 topology_proto.add_device_coordinates(0);
520 topology_proto.add_device_coordinates(1);
521 topology_proto.add_device_coordinates(0);
522 topology_proto.add_device_coordinates(0);
523 topology_proto.add_device_coordinates(0);
524 topology_proto.add_device_coordinates(2);
525 topology_proto.add_device_coordinates(0);
526 topology_proto.add_device_coordinates(1);
527 topology_proto.add_device_coordinates(0);
528 topology_proto.add_device_coordinates(2);
529 }
530
531 std::string topology_attr = topology_proto.SerializeAsString();
532 std::vector<int64_t> device_assignment_attr{
533 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 2, 0, 1, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0};
534
535 llvm::SmallVector<Device, 8> devices;
536 std::vector<std::string> device_names =
537 MakeDeviceSet(/*num_tasks=*/3, /*num_devices_per_task=*/2);
538 ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
539
540 auto status_or = GetTPUCompilationAndExecutionDevices(
541 devices, /*num_replicas=*/2, /*num_cores_per_replica=*/3, topology_attr,
542 device_assignment_attr);
543
544 TF_ASSERT_OK(status_or.status());
545
546 auto& tpu_device_assignment = status_or.ValueOrDie();
547 EXPECT_EQ(tpu_device_assignment.compilation_device,
548 "/job:worker/replica:0/task:0/device:CPU:0");
549
550 auto& tpu_devices = tpu_device_assignment.tpu_devices;
551 ASSERT_EQ(tpu_devices.size(), 2);
552 for (const auto& replica_tpu_devices : tpu_devices)
553 ASSERT_EQ(replica_tpu_devices.size(), 3);
554
555 EXPECT_EQ(tpu_devices[0][0].device,
556 "/job:worker/replica:0/task:1/device:TPU:1");
557 EXPECT_EQ(tpu_devices[0][0].host,
558 "/job:worker/replica:0/task:1/device:CPU:0");
559 EXPECT_EQ(tpu_devices[0][1].device,
560 "/job:worker/replica:0/task:1/device:TPU:0");
561 EXPECT_EQ(tpu_devices[0][1].host,
562 "/job:worker/replica:0/task:1/device:CPU:0");
563 EXPECT_EQ(tpu_devices[0][2].device,
564 "/job:worker/replica:0/task:2/device:TPU:0");
565 EXPECT_EQ(tpu_devices[0][2].host,
566 "/job:worker/replica:0/task:2/device:CPU:0");
567 EXPECT_EQ(tpu_devices[1][0].device,
568 "/job:worker/replica:0/task:2/device:TPU:1");
569 EXPECT_EQ(tpu_devices[1][0].host,
570 "/job:worker/replica:0/task:2/device:CPU:0");
571 EXPECT_EQ(tpu_devices[1][1].device,
572 "/job:worker/replica:0/task:0/device:TPU:0");
573 EXPECT_EQ(tpu_devices[1][1].host,
574 "/job:worker/replica:0/task:0/device:CPU:0");
575 EXPECT_EQ(tpu_devices[1][2].device,
576 "/job:worker/replica:0/task:0/device:TPU:1");
577 EXPECT_EQ(tpu_devices[1][2].host,
578 "/job:worker/replica:0/task:0/device:CPU:0");
579
580 auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
581 ASSERT_TRUE(xla_device_assignment.hasValue());
582 EXPECT_EQ(xla_device_assignment->replica_count(), 2);
583 EXPECT_EQ(xla_device_assignment->computation_count(), 3);
584 ASSERT_EQ(xla_device_assignment->computation_devices_size(), 3);
585 const auto& computation_device_0 =
586 xla_device_assignment->computation_devices(0);
587 ASSERT_EQ(computation_device_0.replica_device_ids_size(), 2);
588 const auto& computation_device_1 =
589 xla_device_assignment->computation_devices(1);
590 ASSERT_EQ(computation_device_1.replica_device_ids_size(), 2);
591 const auto& computation_device_2 =
592 xla_device_assignment->computation_devices(2);
593 ASSERT_EQ(computation_device_2.replica_device_ids_size(), 2);
594
595 EXPECT_EQ(computation_device_0.replica_device_ids(0), 1);
596 EXPECT_EQ(computation_device_0.replica_device_ids(1), 5);
597 EXPECT_EQ(computation_device_1.replica_device_ids(0), 4);
598 EXPECT_EQ(computation_device_1.replica_device_ids(1), 0);
599 EXPECT_EQ(computation_device_2.replica_device_ids(0), 2);
600 EXPECT_EQ(computation_device_2.replica_device_ids(1), 3);
601 }
602
TEST(TPURewriteDeviceUtilTest,TestGetDeviceCoordinates)603 TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) {
604 mlir::MLIRContext context;
605 mlir::Builder builder(&context);
606 auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3});
607 auto status_or_device_coodinates =
608 GetDeviceCoordinates(device_assignment_attr);
609 ASSERT_TRUE(status_or_device_coodinates.ok());
610 auto device_coordinates = status_or_device_coodinates.ConsumeValueOrDie();
611 EXPECT_EQ(device_coordinates[0], 1);
612 EXPECT_EQ(device_coordinates[1], 2);
613 EXPECT_EQ(device_coordinates[2], 3);
614 }
615
TEST(TPURewriteDeviceUtilTest,TestInvalidAttrForDeviceAssignmentDisallowed)616 TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
617 mlir::MLIRContext context;
618 mlir::Builder builder(&context);
619 auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0});
620 auto status_or_device_coodinates =
621 GetDeviceCoordinates(device_assignment_attr);
622 ASSERT_TRUE(!status_or_device_coodinates.ok());
623 EXPECT_EQ(status_or_device_coodinates.status().error_message(),
624 "bad 'device_assignment' attribute at index 0, not an int");
625 }
626
TEST(TPURewriteDeviceUtilTest,TestGetHostFailDeviceMissingAttributes)627 TEST(TPURewriteDeviceUtilTest, TestGetHostFailDeviceMissingAttributes) {
628 mlir::MLIRContext context;
629 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
630 mlir::OwningModuleRef module_ref =
631 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
632 mlir::OpBuilder builder(module_ref->getBodyRegion());
633 llvm::SmallVector<mlir::Type, 8> result_types;
634 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
635 mlir::UnknownLoc::get(&context), result_types);
636
637 mlir::TF::RuntimeDevices devices;
638 std::string host_device;
639 EXPECT_TRUE(mlir::failed(
640 GetHostDeviceOutsideComputation(devices, cluster, &host_device)));
641 }
642
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailModelParallelism)643 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailModelParallelism) {
644 mlir::MLIRContext context;
645 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
646 mlir::OwningModuleRef module_ref =
647 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
648 mlir::OpBuilder builder(module_ref->getBodyRegion());
649
650 llvm::SmallVector<mlir::Type, 8> result_types;
651 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
652 mlir::UnknownLoc::get(&context), result_types);
653 cluster->setAttr(kNumCoresPerReplicaAttr,
654 builder.getIntegerAttr(builder.getIntegerType(64), 5));
655 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
656 cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
657
658 mlir::TF::RuntimeDevices runtime_devices;
659 std::string host_device;
660 EXPECT_TRUE(mlir::failed(
661 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
662 }
663
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailMissingTopology)664 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingTopology) {
665 mlir::MLIRContext context;
666 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
667 mlir::OwningModuleRef module_ref =
668 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
669 mlir::OpBuilder builder(module_ref->getBodyRegion());
670
671 llvm::SmallVector<mlir::Type, 8> result_types;
672 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
673 mlir::UnknownLoc::get(&context), result_types);
674 cluster->setAttr(kNumCoresPerReplicaAttr,
675 builder.getIntegerAttr(builder.getIntegerType(64), 1));
676 cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
677
678 mlir::TF::RuntimeDevices runtime_devices;
679 std::string host_device;
680 EXPECT_TRUE(mlir::failed(
681 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
682 }
683
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailMissingDeviceAssignment)684 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingDeviceAssignment) {
685 mlir::MLIRContext context;
686 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
687 mlir::OwningModuleRef module_ref =
688 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
689 mlir::OpBuilder builder(module_ref->getBodyRegion());
690
691 llvm::SmallVector<mlir::Type, 8> result_types;
692 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
693 mlir::UnknownLoc::get(&context), result_types);
694 cluster->setAttr(kNumCoresPerReplicaAttr,
695 builder.getIntegerAttr(builder.getIntegerType(64), 1));
696 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
697
698 mlir::TF::RuntimeDevices runtime_devices;
699 std::string host_device;
700 EXPECT_TRUE(mlir::failed(
701 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
702 }
703
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailBadDeviceAssignment)704 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceAssignment) {
705 mlir::MLIRContext context;
706 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
707 mlir::OwningModuleRef module_ref =
708 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
709 mlir::OpBuilder builder(module_ref->getBodyRegion());
710
711 llvm::SmallVector<mlir::Type, 8> result_types;
712 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
713 mlir::UnknownLoc::get(&context), result_types);
714 cluster->setAttr(kNumCoresPerReplicaAttr,
715 builder.getIntegerAttr(builder.getIntegerType(64), 1));
716 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
717 cluster->setAttr(kDeviceAssignmentAttr,
718 builder.getStrArrayAttr(llvm::ArrayRef<llvm::StringRef>(
719 {"bad_device_assigment"})));
720
721 mlir::TF::RuntimeDevices runtime_devices;
722 std::string host_device;
723 EXPECT_TRUE(mlir::failed(
724 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
725 }
726
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailBadDeviceName)727 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceName) {
728 mlir::MLIRContext context;
729 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
730 mlir::OwningModuleRef module_ref =
731 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
732 mlir::OpBuilder builder(module_ref->getBodyRegion());
733 (*module_ref)
734 ->setAttr("tf.devices",
735 builder.getStrArrayAttr(
736 llvm::ArrayRef<llvm::StringRef>({"bad_device_name"})));
737
738 llvm::SmallVector<mlir::Type, 8> result_types;
739 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
740 mlir::UnknownLoc::get(&context), result_types);
741 cluster->setAttr(kNumCoresPerReplicaAttr,
742 builder.getIntegerAttr(builder.getIntegerType(64), 1));
743 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
744 cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
745
746 mlir::TF::RuntimeDevices runtime_devices;
747 (void)GetDevicesFromOp(*module_ref, &runtime_devices);
748 std::string host_device;
749 EXPECT_TRUE(mlir::failed(
750 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
751 }
752
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceTPUReplicate)753 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) {
754 mlir::MLIRContext context;
755 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
756 mlir::OwningModuleRef module_ref =
757 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
758 mlir::OpBuilder builder(module_ref->getBodyRegion());
759
760 llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<llvm::StringRef, 4>>
761 devices;
762 auto replicate = builder.create<mlir::tf_device::ReplicateOp>(
763 mlir::UnknownLoc::get(&context), /*num_replicas=*/2, devices,
764 llvm::ArrayRef<std::pair<mlir::ValueRange, mlir::Type>>{},
765 mlir::ValueRange{}, mlir::TypeRange{});
766 builder.setInsertionPoint(&replicate.body().front(),
767 replicate.body().front().begin());
768
769 llvm::SmallVector<mlir::Type, 8> result_types;
770 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
771 mlir::UnknownLoc::get(&context), result_types);
772
773 mlir::TF::RuntimeDevices runtime_devices;
774 std::string host_device;
775 EXPECT_TRUE(mlir::succeeded(
776 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
777 EXPECT_EQ(host_device, kTPUReplicatedHost);
778 }
779
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceNotReplicated)780 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceNotReplicated) {
781 mlir::MLIRContext context;
782 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
783 mlir::OwningModuleRef module_ref =
784 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
785 mlir::OpBuilder builder(module_ref->getBodyRegion());
786 (*module_ref)
787 ->setAttr("tf.devices",
788 builder.getStrArrayAttr(llvm::ArrayRef<llvm::StringRef>(
789 {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
790 "/job:localhost/replica:0/task:0/device:TPU:0",
791 "/job:worker/replica:0/task:0/device:CPU:0"})));
792
793 llvm::SmallVector<mlir::Type, 8> result_types;
794 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
795 mlir::UnknownLoc::get(&context), result_types);
796 cluster->setAttr(kNumCoresPerReplicaAttr,
797 builder.getIntegerAttr(builder.getIntegerType(64), 1));
798 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
799 cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
800
801 mlir::TF::RuntimeDevices runtime_devices;
802 (void)GetDevicesFromOp(*module_ref, &runtime_devices);
803 std::string host_device;
804 EXPECT_TRUE(mlir::succeeded(
805 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
806 EXPECT_EQ(host_device, "/job:localhost/replica:0/task:0/device:CPU:0");
807 }
808
TEST(TPURewriteDeviceUtilTest,TestIsTPUDevice)809 TEST(TPURewriteDeviceUtilTest, TestIsTPUDevice) {
810 EXPECT_TRUE(IsTPUDevice("/job:localhost/replica:0/task:0/device:TPU:0"));
811 EXPECT_FALSE(IsTPUDevice("/job:localhost/replica:0/task:0/device:CPU:0"));
812 EXPECT_FALSE(IsTPUDevice("INVALID_DEVICE"));
813 }
814
815 } // anonymous namespace
816 } // namespace tensorflow
817