1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
17
18 #include <cstdint>
19 #include <tuple>
20
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/IR/Builders.h" // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
24 #include "mlir/IR/MLIRContext.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
29 #include "tensorflow/core/util/device_name_utils.h"
30
31 namespace tensorflow {
32 namespace {
33
34 using Device = DeviceNameUtils::ParsedName;
35
DeviceNamesToParsedNames(llvm::ArrayRef<std::string> device_names,llvm::SmallVectorImpl<Device> * parsed_devices)36 bool DeviceNamesToParsedNames(llvm::ArrayRef<std::string> device_names,
37 llvm::SmallVectorImpl<Device>* parsed_devices) {
38 parsed_devices->reserve(device_names.size());
39 for (const auto& device_name : device_names) {
40 Device parsed_name;
41 if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name))
42 return false;
43
44 parsed_devices->push_back(parsed_name);
45 }
46 return true;
47 }
48
49 using DeviceNames = llvm::SmallVector<std::string, 8>;
50
51 struct ParameterizedDeviceSetTest
52 : ::testing::TestWithParam<std::tuple<DeviceNames, std::string>> {};
53
TEST_P(ParameterizedDeviceSetTest,BadDeviceSet)54 TEST_P(ParameterizedDeviceSetTest, BadDeviceSet) {
55 llvm::SmallVector<Device, 8> devices;
56 ASSERT_TRUE(DeviceNamesToParsedNames(std::get<0>(GetParam()), &devices));
57 std::string topology_attr;
58 std::vector<int64_t> device_assignment_attr;
59
60 auto status_or = GetTPUCompilationAndExecutionDevices(
61 devices, /*num_replicas=*/1, /*num_cores_per_replica=*/1, topology_attr,
62 device_assignment_attr);
63 ASSERT_FALSE(status_or.ok());
64 EXPECT_EQ(status_or.status().error_message(), std::get<1>(GetParam()));
65 }
66
67 INSTANTIATE_TEST_SUITE_P(
68 BadDeviceSet, ParameterizedDeviceSetTest,
69 ::testing::Values(
70 std::make_tuple<DeviceNames, std::string>(
71 {"/job:localhost/replica:0/task:0/device:CPU:0"},
72 "no TPU_SYSTEM devices found"),
73 std::make_tuple<DeviceNames, std::string>(
74 {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
75 "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0"},
76 "found TPU_SYSTEM devices with conflicting jobs 'localhost' and "
77 "'worker'"),
78 std::make_tuple<DeviceNames, std::string>(
79 {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
80 "/job:localhost/replica:1/task:0/device:TPU_SYSTEM:0"},
81 "found TPU_SYSTEM devices with conflicting replicas '0' and '1'"),
82 std::make_tuple<DeviceNames, std::string>(
83 {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
84 "/job:localhost/replica:0/task:0/device:TPU:0",
85 "/job:localhost/replica:0/task:0/device:TPU:1",
86 "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0",
87 "/job:localhost/replica:0/task:1/device:TPU:0"},
88 "expected the number of TPU devices per host to be 2, got 1")));
89
90 struct ParameterizedMetadataTest
91 : ::testing::TestWithParam<std::tuple<int, int, std::string,
92 std::vector<int64_t>, std::string>> {
93 };
94
TEST_P(ParameterizedMetadataTest,BadMetadata)95 TEST_P(ParameterizedMetadataTest, BadMetadata) {
96 llvm::SmallVector<Device, 8> devices;
97 ASSERT_TRUE(DeviceNamesToParsedNames(
98 {"/job:worker/replica:0/task:0/device:TPU_SYSTEM:0",
99 "/job:worker/replica:0/task:0/device:TPU:0",
100 "/job:worker/replica:0/task:1/device:TPU_SYSTEM:0",
101 "/job:worker/replica:0/task:1/device:TPU:0"},
102 &devices));
103 std::string compilation_device;
104 llvm::SmallVector<llvm::SmallVector<std::string, 8>, 8> execution_devices;
105 llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment;
106
107 auto status_or = GetTPUCompilationAndExecutionDevices(
108 devices, std::get<0>(GetParam()), std::get<1>(GetParam()),
109 std::get<2>(GetParam()), std::get<3>(GetParam()));
110 ASSERT_FALSE(status_or.ok());
111 EXPECT_EQ(status_or.status().error_message(), std::get<4>(GetParam()));
112 }
113
TopologyWithMeshShape(llvm::ArrayRef<int> mesh_shape)114 std::string TopologyWithMeshShape(llvm::ArrayRef<int> mesh_shape) {
115 tpu::TopologyProto topology_proto;
116 for (int mesh_dim : mesh_shape) topology_proto.add_mesh_shape(mesh_dim);
117 return topology_proto.SerializeAsString();
118 }
119
TopologyWithMeshShapeAndTasks(llvm::ArrayRef<int> mesh_shape,int num_tasks,int num_tpu_devices_per_task)120 std::string TopologyWithMeshShapeAndTasks(llvm::ArrayRef<int> mesh_shape,
121 int num_tasks,
122 int num_tpu_devices_per_task) {
123 tpu::TopologyProto topology_proto;
124 for (int mesh_dim : mesh_shape) topology_proto.add_mesh_shape(mesh_dim);
125 topology_proto.set_num_tasks(num_tasks);
126 topology_proto.set_num_tpu_devices_per_task(num_tpu_devices_per_task);
127 return topology_proto.SerializeAsString();
128 }
129
TopologyWithDeviceCoordinates(llvm::ArrayRef<int> device_coordinates)130 std::string TopologyWithDeviceCoordinates(
131 llvm::ArrayRef<int> device_coordinates) {
132 tpu::TopologyProto topology_proto;
133 topology_proto.add_mesh_shape(2);
134 topology_proto.add_mesh_shape(1);
135 topology_proto.add_mesh_shape(1);
136 topology_proto.add_mesh_shape(1);
137 topology_proto.set_num_tasks(2);
138 topology_proto.set_num_tpu_devices_per_task(1);
139 for (int device_coordinate : device_coordinates)
140 topology_proto.add_device_coordinates(device_coordinate);
141 return topology_proto.SerializeAsString();
142 }
143
144 INSTANTIATE_TEST_SUITE_P(
145 BadFullMeshMetadata, ParameterizedMetadataTest,
146 ::testing::Values(
147 std::make_tuple(
148 2, 1, "", std::vector<int64_t>{0},
149 "'device_assignment' must not be set when 'topology' is not set"),
150 std::make_tuple(8, 1, "", std::vector<int64_t>(),
151 "'num_replicas' must be equal to 1 or 2, got 8"),
152 std::make_tuple(2, 2, "", std::vector<int64_t>(),
153 "'num_cores_per_replica' must be equal to 1, got 2")));
154
155 INSTANTIATE_TEST_SUITE_P(
156 BadGeneralTopologyMetadata, ParameterizedMetadataTest,
157 ::testing::Values(
158 std::make_tuple(
159 2, 1, "BAD_TOPOLOGY", std::vector<int64_t>(),
160 "failed to parse 'topology' attribute to TopologyProto"),
161 std::make_tuple(4, 2, TopologyWithMeshShape({0}),
162 std::vector<int64_t>(),
163 "'topology' 'mesh_shape' must be rank 4, got rank 1"),
164 std::make_tuple(
165 2, 1, TopologyWithMeshShape({2, 0, 1, 2}), std::vector<int64_t>(),
166 "'topology' 'mesh_shape' dimension 1 must be positive, got 0"),
167 std::make_tuple(2, 1, TopologyWithMeshShapeAndTasks({1, 1, 1, 1}, 1, 1),
168 std::vector<int64_t>(),
169 "number of tasks from available TPU devices must be "
170 "'num_tasks' in 'topology' (1), got 2"),
171 std::make_tuple(2, 1, TopologyWithMeshShapeAndTasks({1, 1, 1, 1}, 2, 2),
172 std::vector<int64_t>(),
173 "number of TPU devices available per task must be "
174 "'num_tpu_devices_per_task' in 'topology' (2), got 1"),
175 std::make_tuple(
176 2, 1, TopologyWithDeviceCoordinates({}), std::vector<int64_t>(),
177 "length of 'device_coordinates' in 'topology' must be 'num_tasks' "
178 "* 'num_tpus_per_task' * 4 (2 * 1 * 4), got 0"),
179 std::make_tuple(
180 2, 1, TopologyWithDeviceCoordinates({-1, 0, 0, 0, 1, 0, 0, 0}),
181 std::vector<int64_t>(),
182 "device coordinate (-1, 0, 0, 0) in 'topology' is outside "
183 "of mesh shape (2, 1, 1, 1)"),
184 std::make_tuple(
185 2, 1, TopologyWithDeviceCoordinates({2, 0, 0, 0, 1, 0, 0, 0}),
186 std::vector<int64_t>(),
187 "device coordinate (2, 0, 0, 0) in 'topology' is outside "
188 "of mesh shape (2, 1, 1, 1)"),
189 std::make_tuple(
190 2, 1, TopologyWithDeviceCoordinates({0, -1, 0, 0, 1, 0, 0, 0}),
191 std::vector<int64_t>(),
192 "device coordinate (0, -1, 0, 0) in 'topology' is outside "
193 "of mesh shape (2, 1, 1, 1)"),
194 std::make_tuple(
195 2, 1, TopologyWithDeviceCoordinates({0, 1, 0, 0, 1, 0, 0, 0}),
196 std::vector<int64_t>(),
197 "device coordinate (0, 1, 0, 0) in 'topology' is outside "
198 "of mesh shape (2, 1, 1, 1)"),
199 std::make_tuple(
200 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, -1, 1, 0, 0, 0}),
201 std::vector<int64_t>(),
202 "device coordinate (0, 0, 0, -1) in 'topology' is outside "
203 "of mesh shape (2, 1, 1, 1)"),
204 std::make_tuple(
205 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 1, 1, 0, 0, 0}),
206 std::vector<int64_t>(),
207 "device coordinate (0, 0, 0, 1) in 'topology' is outside "
208 "of mesh shape (2, 1, 1, 1)"),
209 std::make_tuple(
210 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 0, 0, 0, 0}),
211 std::vector<int64_t>(),
212 "'topology' has duplicate device coordinate (0, 0, 0, 0)")));
213
214 INSTANTIATE_TEST_SUITE_P(
215 BadGeneralDeviceAssignmentMetadata, ParameterizedMetadataTest,
216 ::testing::Values(
217 std::make_tuple(2, 1,
218 TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
219 std::vector<int64_t>(),
220 "length of 'device_assignment' must be 'num_replicas' "
221 "* 'num_cores_per_replica' * 4 (2 * 1 * 4), got 0"),
222 std::make_tuple(
223 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
224 std::vector<int64_t>{-1, 0, 0, 0, 0, 0, 0, 0},
225 "device coordinate (-1, 0, 0, 0) in 'device_assignment' "
226 "is outside of mesh shape (2, 1, 1, 1)"),
227 std::make_tuple(
228 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
229 std::vector<int64_t>{2, 0, 0, 0, 0, 0, 0, 0},
230 "device coordinate (2, 0, 0, 0) in 'device_assignment' is "
231 "outside of mesh shape (2, 1, 1, 1)"),
232 std::make_tuple(
233 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
234 std::vector<int64_t>{0, -1, 0, 0, 0, 0, 0, 0},
235 "device coordinate (0, -1, 0, 0) in 'device_assignment' "
236 "is outside of mesh shape (2, 1, 1, 1)"),
237 std::make_tuple(
238 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
239 std::vector<int64_t>{0, 1, 0, 0, 0, 0, 0, 0},
240 "device coordinate (0, 1, 0, 0) in 'device_assignment' is "
241 "outside of mesh shape (2, 1, 1, 1)"),
242 std::make_tuple(
243 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
244 std::vector<int64_t>{0, 0, 0, -1, 0, 0, 0, 0},
245 "device coordinate (0, 0, 0, -1) in 'device_assignment' "
246 "is outside of mesh shape (2, 1, 1, 1)"),
247 std::make_tuple(
248 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
249 std::vector<int64_t>{0, 0, 0, 1, 0, 0, 0, 0},
250 "device coordinate (0, 0, 0, 1) in 'device_assignment' is "
251 "outside of mesh shape (2, 1, 1, 1)"),
252 std::make_tuple(2, 1,
253 TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
254 std::vector<int64_t>{0, 0, 0, 0, 0, 0, 0, 0},
255 "'device_assignment' has duplicate device coordinate "
256 "(0, 0, 0, 0)")));
257
MakeDeviceSet(int num_tasks,int num_devices_per_task)258 std::vector<std::string> MakeDeviceSet(int num_tasks,
259 int num_devices_per_task) {
260 std::vector<std::string> devices{
261 "/job:localhost/replica:0/task:0/device:CPU:0"};
262 devices.reserve(num_tasks * num_devices_per_task + num_tasks + 1);
263
264 for (int task = 0; task < num_tasks; ++task) {
265 devices.push_back(
266 llvm::formatv("/job:worker/replica:0/task:{0}/device:CPU:0", task)
267 .str());
268 devices.push_back(
269 llvm::formatv("/job:worker/replica:0/task:{0}/device:TPU_SYSTEM:0",
270 task)
271 .str());
272 for (int device = 0; device < num_devices_per_task; ++device)
273 devices.push_back(
274 llvm::formatv("/job:worker/replica:0/task:{0}/device:TPU:{1}", task,
275 device)
276 .str());
277 }
278
279 return devices;
280 }
281
TEST(TPURewriteDeviceUtilTest,BadGeneralDeviceAssignmentMetadataMissingDevice)282 TEST(TPURewriteDeviceUtilTest,
283 BadGeneralDeviceAssignmentMetadataMissingDevice) {
284 tpu::TopologyProto topology_proto;
285 {
286 topology_proto.add_mesh_shape(2);
287 topology_proto.add_mesh_shape(1);
288 topology_proto.add_mesh_shape(1);
289 topology_proto.add_mesh_shape(1);
290 topology_proto.set_num_tasks(1);
291 topology_proto.set_num_tpu_devices_per_task(1);
292 topology_proto.add_device_coordinates(0);
293 topology_proto.add_device_coordinates(0);
294 topology_proto.add_device_coordinates(0);
295 topology_proto.add_device_coordinates(0);
296 }
297
298 std::string topology_attr = topology_proto.SerializeAsString();
299 std::vector<int64_t> device_assignment_attr{1, 0, 0, 0};
300
301 llvm::SmallVector<Device, 8> devices;
302 std::vector<std::string> device_names =
303 MakeDeviceSet(/*num_tasks=*/1, /*num_devices_per_task=*/1);
304 ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
305
306 auto status_or = GetTPUCompilationAndExecutionDevices(
307 devices, /*num_replicas=*/1, /*num_cores_per_replica=*/1, topology_attr,
308 device_assignment_attr);
309
310 ASSERT_FALSE(status_or.ok());
311 EXPECT_EQ(status_or.status().error_message(),
312 "no TPU device found for 'device_assignment' device coordinate (1, "
313 "0, 0, 0)");
314 }
315
TEST(TPURewriteDeviceUtilTest,ValidFullMeshDeviceAssignment)316 TEST(TPURewriteDeviceUtilTest, ValidFullMeshDeviceAssignment) {
317 llvm::SmallVector<Device, 8> devices;
318 std::vector<std::string> device_names =
319 MakeDeviceSet(/*num_tasks=*/2, /*num_devices_per_task=*/4);
320 ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
321 std::string topology_attr;
322 std::vector<int64_t> device_assignment_attr;
323
324 auto status_or = GetTPUCompilationAndExecutionDevices(
325 devices, /*num_replicas=*/8, /*num_cores_per_replica=*/1, topology_attr,
326 device_assignment_attr);
327
328 TF_ASSERT_OK(status_or.status());
329
330 const auto& tpu_device_assignment = status_or.ValueOrDie();
331 EXPECT_EQ(tpu_device_assignment.compilation_device,
332 "/job:worker/replica:0/task:0/device:CPU:0");
333 const auto& tpu_devices = tpu_device_assignment.tpu_devices;
334 ASSERT_EQ(tpu_devices.size(), 8);
335 for (const auto& replica_tpu_devices : tpu_devices)
336 ASSERT_EQ(replica_tpu_devices.size(), 1);
337
338 EXPECT_EQ(tpu_devices[0][0].device,
339 "/job:worker/replica:0/task:0/device:TPU:0");
340 EXPECT_EQ(tpu_devices[0][0].host,
341 "/job:worker/replica:0/task:0/device:CPU:0");
342 EXPECT_EQ(tpu_devices[1][0].device,
343 "/job:worker/replica:0/task:0/device:TPU:1");
344 EXPECT_EQ(tpu_devices[1][0].host,
345 "/job:worker/replica:0/task:0/device:CPU:0");
346 EXPECT_EQ(tpu_devices[2][0].device,
347 "/job:worker/replica:0/task:0/device:TPU:2");
348 EXPECT_EQ(tpu_devices[2][0].host,
349 "/job:worker/replica:0/task:0/device:CPU:0");
350 EXPECT_EQ(tpu_devices[3][0].device,
351 "/job:worker/replica:0/task:0/device:TPU:3");
352 EXPECT_EQ(tpu_devices[3][0].host,
353 "/job:worker/replica:0/task:0/device:CPU:0");
354 EXPECT_EQ(tpu_devices[4][0].device,
355 "/job:worker/replica:0/task:1/device:TPU:0");
356 EXPECT_EQ(tpu_devices[4][0].host,
357 "/job:worker/replica:0/task:1/device:CPU:0");
358 EXPECT_EQ(tpu_devices[5][0].device,
359 "/job:worker/replica:0/task:1/device:TPU:1");
360 EXPECT_EQ(tpu_devices[5][0].host,
361 "/job:worker/replica:0/task:1/device:CPU:0");
362 EXPECT_EQ(tpu_devices[6][0].device,
363 "/job:worker/replica:0/task:1/device:TPU:2");
364 EXPECT_EQ(tpu_devices[6][0].host,
365 "/job:worker/replica:0/task:1/device:CPU:0");
366 EXPECT_EQ(tpu_devices[7][0].device,
367 "/job:worker/replica:0/task:1/device:TPU:3");
368 EXPECT_EQ(tpu_devices[7][0].host,
369 "/job:worker/replica:0/task:1/device:CPU:0");
370
371 EXPECT_FALSE(tpu_device_assignment.xla_device_assignment.hasValue());
372 }
373
TEST(TPURewriteDeviceUtilTest,ValidGeneralDeviceAssignmentMesh2x2x2)374 TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh2x2x2) {
375 tpu::TopologyProto topology_proto;
376 {
377 topology_proto.add_mesh_shape(2);
378 topology_proto.add_mesh_shape(2);
379 topology_proto.add_mesh_shape(1);
380 topology_proto.add_mesh_shape(2);
381 topology_proto.set_num_tasks(2);
382 topology_proto.set_num_tpu_devices_per_task(4);
383 topology_proto.add_device_coordinates(0);
384 topology_proto.add_device_coordinates(0);
385 topology_proto.add_device_coordinates(0);
386 topology_proto.add_device_coordinates(0);
387 topology_proto.add_device_coordinates(0);
388 topology_proto.add_device_coordinates(1);
389 topology_proto.add_device_coordinates(0);
390 topology_proto.add_device_coordinates(0);
391 topology_proto.add_device_coordinates(1);
392 topology_proto.add_device_coordinates(1);
393 topology_proto.add_device_coordinates(0);
394 topology_proto.add_device_coordinates(0);
395 topology_proto.add_device_coordinates(1);
396 topology_proto.add_device_coordinates(0);
397 topology_proto.add_device_coordinates(0);
398 topology_proto.add_device_coordinates(0);
399 topology_proto.add_device_coordinates(1);
400 topology_proto.add_device_coordinates(0);
401 topology_proto.add_device_coordinates(0);
402 topology_proto.add_device_coordinates(1);
403 topology_proto.add_device_coordinates(1);
404 topology_proto.add_device_coordinates(1);
405 topology_proto.add_device_coordinates(0);
406 topology_proto.add_device_coordinates(1);
407 topology_proto.add_device_coordinates(0);
408 topology_proto.add_device_coordinates(1);
409 topology_proto.add_device_coordinates(0);
410 topology_proto.add_device_coordinates(1);
411 topology_proto.add_device_coordinates(0);
412 topology_proto.add_device_coordinates(0);
413 topology_proto.add_device_coordinates(0);
414 topology_proto.add_device_coordinates(1);
415 }
416
417 std::string topology_attr = topology_proto.SerializeAsString();
418 std::vector<int64_t> device_assignment_attr{0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,
419 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0,
420 0, 1, 1, 1, 0, 0, 1, 1, 0, 1};
421
422 llvm::SmallVector<Device, 8> devices;
423 std::vector<std::string> device_names =
424 MakeDeviceSet(/*num_tasks=*/2, /*num_devices_per_task=*/4);
425 ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
426
427 auto status_or = GetTPUCompilationAndExecutionDevices(
428 devices, /*num_replicas=*/4, /*num_cores_per_replica=*/2, topology_attr,
429 device_assignment_attr);
430
431 TF_ASSERT_OK(status_or.status());
432
433 const auto& tpu_device_assignment = status_or.ValueOrDie();
434 EXPECT_EQ(tpu_device_assignment.compilation_device,
435 "/job:worker/replica:0/task:0/device:CPU:0");
436 const auto& tpu_devices = tpu_device_assignment.tpu_devices;
437 ASSERT_EQ(tpu_devices.size(), 4);
438 for (const auto& replica_tpu_devices : tpu_devices)
439 ASSERT_EQ(replica_tpu_devices.size(), 2);
440
441 EXPECT_EQ(tpu_devices[0][0].device,
442 "/job:worker/replica:0/task:0/device:TPU:0");
443 EXPECT_EQ(tpu_devices[0][0].host,
444 "/job:worker/replica:0/task:0/device:CPU:0");
445 EXPECT_EQ(tpu_devices[0][1].device,
446 "/job:worker/replica:0/task:1/device:TPU:3");
447 EXPECT_EQ(tpu_devices[0][1].host,
448 "/job:worker/replica:0/task:1/device:CPU:0");
449 EXPECT_EQ(tpu_devices[1][0].device,
450 "/job:worker/replica:0/task:0/device:TPU:1");
451 EXPECT_EQ(tpu_devices[1][0].host,
452 "/job:worker/replica:0/task:0/device:CPU:0");
453 EXPECT_EQ(tpu_devices[1][1].device,
454 "/job:worker/replica:0/task:1/device:TPU:2");
455 EXPECT_EQ(tpu_devices[1][1].host,
456 "/job:worker/replica:0/task:1/device:CPU:0");
457 EXPECT_EQ(tpu_devices[2][0].device,
458 "/job:worker/replica:0/task:0/device:TPU:3");
459 EXPECT_EQ(tpu_devices[2][0].host,
460 "/job:worker/replica:0/task:0/device:CPU:0");
461 EXPECT_EQ(tpu_devices[2][1].device,
462 "/job:worker/replica:0/task:1/device:TPU:0");
463 EXPECT_EQ(tpu_devices[2][1].host,
464 "/job:worker/replica:0/task:1/device:CPU:0");
465 EXPECT_EQ(tpu_devices[3][0].device,
466 "/job:worker/replica:0/task:0/device:TPU:2");
467 EXPECT_EQ(tpu_devices[3][0].host,
468 "/job:worker/replica:0/task:0/device:CPU:0");
469 EXPECT_EQ(tpu_devices[3][1].device,
470 "/job:worker/replica:0/task:1/device:TPU:1");
471 EXPECT_EQ(tpu_devices[3][1].host,
472 "/job:worker/replica:0/task:1/device:CPU:0");
473
474 auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
475 ASSERT_TRUE(xla_device_assignment.hasValue());
476 EXPECT_EQ(xla_device_assignment->replica_count(), 4);
477 EXPECT_EQ(xla_device_assignment->computation_count(), 2);
478 ASSERT_EQ(xla_device_assignment->computation_devices_size(), 2);
479 const auto& computation_device_0 =
480 xla_device_assignment->computation_devices(0);
481 ASSERT_EQ(computation_device_0.replica_device_ids_size(), 4);
482 const auto& computation_device_1 =
483 xla_device_assignment->computation_devices(1);
484 ASSERT_EQ(computation_device_1.replica_device_ids_size(), 4);
485
486 EXPECT_EQ(computation_device_0.replica_device_ids(0), 0);
487 EXPECT_EQ(computation_device_0.replica_device_ids(1), 4);
488 EXPECT_EQ(computation_device_0.replica_device_ids(2), 2);
489 EXPECT_EQ(computation_device_0.replica_device_ids(3), 6);
490 EXPECT_EQ(computation_device_1.replica_device_ids(0), 1);
491 EXPECT_EQ(computation_device_1.replica_device_ids(1), 5);
492 EXPECT_EQ(computation_device_1.replica_device_ids(2), 3);
493 EXPECT_EQ(computation_device_1.replica_device_ids(3), 7);
494 }
495
TEST(TPURewriteDeviceUtilTest,ValidGeneralDeviceAssignmentMesh1x2x1x3)496 TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) {
497 tpu::TopologyProto topology_proto;
498 {
499 topology_proto.add_mesh_shape(1);
500 topology_proto.add_mesh_shape(2);
501 topology_proto.add_mesh_shape(1);
502 topology_proto.add_mesh_shape(3);
503 topology_proto.set_num_tasks(3);
504 topology_proto.set_num_tpu_devices_per_task(2);
505 topology_proto.add_device_coordinates(0);
506 topology_proto.add_device_coordinates(0);
507 topology_proto.add_device_coordinates(0);
508 topology_proto.add_device_coordinates(0);
509 topology_proto.add_device_coordinates(0);
510 topology_proto.add_device_coordinates(1);
511 topology_proto.add_device_coordinates(0);
512 topology_proto.add_device_coordinates(0);
513 topology_proto.add_device_coordinates(0);
514 topology_proto.add_device_coordinates(1);
515 topology_proto.add_device_coordinates(0);
516 topology_proto.add_device_coordinates(1);
517 topology_proto.add_device_coordinates(0);
518 topology_proto.add_device_coordinates(0);
519 topology_proto.add_device_coordinates(0);
520 topology_proto.add_device_coordinates(1);
521 topology_proto.add_device_coordinates(0);
522 topology_proto.add_device_coordinates(0);
523 topology_proto.add_device_coordinates(0);
524 topology_proto.add_device_coordinates(2);
525 topology_proto.add_device_coordinates(0);
526 topology_proto.add_device_coordinates(1);
527 topology_proto.add_device_coordinates(0);
528 topology_proto.add_device_coordinates(2);
529 }
530
531 std::string topology_attr = topology_proto.SerializeAsString();
532 std::vector<int64_t> device_assignment_attr{
533 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 2, 0, 1, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0};
534
535 llvm::SmallVector<Device, 8> devices;
536 std::vector<std::string> device_names =
537 MakeDeviceSet(/*num_tasks=*/3, /*num_devices_per_task=*/2);
538 ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
539
540 auto status_or = GetTPUCompilationAndExecutionDevices(
541 devices, /*num_replicas=*/2, /*num_cores_per_replica=*/3, topology_attr,
542 device_assignment_attr);
543
544 TF_ASSERT_OK(status_or.status());
545
546 auto& tpu_device_assignment = status_or.ValueOrDie();
547 EXPECT_EQ(tpu_device_assignment.compilation_device,
548 "/job:worker/replica:0/task:0/device:CPU:0");
549
550 auto& tpu_devices = tpu_device_assignment.tpu_devices;
551 ASSERT_EQ(tpu_devices.size(), 2);
552 for (const auto& replica_tpu_devices : tpu_devices)
553 ASSERT_EQ(replica_tpu_devices.size(), 3);
554
555 EXPECT_EQ(tpu_devices[0][0].device,
556 "/job:worker/replica:0/task:1/device:TPU:1");
557 EXPECT_EQ(tpu_devices[0][0].host,
558 "/job:worker/replica:0/task:1/device:CPU:0");
559 EXPECT_EQ(tpu_devices[0][1].device,
560 "/job:worker/replica:0/task:1/device:TPU:0");
561 EXPECT_EQ(tpu_devices[0][1].host,
562 "/job:worker/replica:0/task:1/device:CPU:0");
563 EXPECT_EQ(tpu_devices[0][2].device,
564 "/job:worker/replica:0/task:2/device:TPU:0");
565 EXPECT_EQ(tpu_devices[0][2].host,
566 "/job:worker/replica:0/task:2/device:CPU:0");
567 EXPECT_EQ(tpu_devices[1][0].device,
568 "/job:worker/replica:0/task:2/device:TPU:1");
569 EXPECT_EQ(tpu_devices[1][0].host,
570 "/job:worker/replica:0/task:2/device:CPU:0");
571 EXPECT_EQ(tpu_devices[1][1].device,
572 "/job:worker/replica:0/task:0/device:TPU:0");
573 EXPECT_EQ(tpu_devices[1][1].host,
574 "/job:worker/replica:0/task:0/device:CPU:0");
575 EXPECT_EQ(tpu_devices[1][2].device,
576 "/job:worker/replica:0/task:0/device:TPU:1");
577 EXPECT_EQ(tpu_devices[1][2].host,
578 "/job:worker/replica:0/task:0/device:CPU:0");
579
580 auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
581 ASSERT_TRUE(xla_device_assignment.hasValue());
582 EXPECT_EQ(xla_device_assignment->replica_count(), 2);
583 EXPECT_EQ(xla_device_assignment->computation_count(), 3);
584 ASSERT_EQ(xla_device_assignment->computation_devices_size(), 3);
585 const auto& computation_device_0 =
586 xla_device_assignment->computation_devices(0);
587 ASSERT_EQ(computation_device_0.replica_device_ids_size(), 2);
588 const auto& computation_device_1 =
589 xla_device_assignment->computation_devices(1);
590 ASSERT_EQ(computation_device_1.replica_device_ids_size(), 2);
591 const auto& computation_device_2 =
592 xla_device_assignment->computation_devices(2);
593 ASSERT_EQ(computation_device_2.replica_device_ids_size(), 2);
594
595 EXPECT_EQ(computation_device_0.replica_device_ids(0), 1);
596 EXPECT_EQ(computation_device_0.replica_device_ids(1), 5);
597 EXPECT_EQ(computation_device_1.replica_device_ids(0), 4);
598 EXPECT_EQ(computation_device_1.replica_device_ids(1), 0);
599 EXPECT_EQ(computation_device_2.replica_device_ids(0), 2);
600 EXPECT_EQ(computation_device_2.replica_device_ids(1), 3);
601 }
602
TEST(TPURewriteDeviceUtilTest,TestGetDeviceCoordinates)603 TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) {
604 mlir::MLIRContext context;
605 mlir::Builder builder(&context);
606 auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3});
607 auto status_or_device_coodinates =
608 GetDeviceCoordinates(device_assignment_attr);
609 ASSERT_TRUE(status_or_device_coodinates.ok());
610 auto device_coordinates = status_or_device_coodinates.ConsumeValueOrDie();
611 EXPECT_EQ(device_coordinates[0], 1);
612 EXPECT_EQ(device_coordinates[1], 2);
613 EXPECT_EQ(device_coordinates[2], 3);
614 }
615
TEST(TPURewriteDeviceUtilTest,TestInvalidAttrForDeviceAssignmentDisallowed)616 TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
617 mlir::MLIRContext context;
618 mlir::Builder builder(&context);
619 auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0});
620 auto status_or_device_coodinates =
621 GetDeviceCoordinates(device_assignment_attr);
622 ASSERT_TRUE(!status_or_device_coodinates.ok());
623 EXPECT_EQ(status_or_device_coodinates.status().error_message(),
624 "bad 'device_assignment' attribute at index 0, not an int");
625 }
626
TEST(TPURewriteDeviceUtilTest,TestHasModelParallelismFalse)627 TEST(TPURewriteDeviceUtilTest, TestHasModelParallelismFalse) {
628 mlir::MLIRContext context;
629 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
630 mlir::OwningModuleRef module_ref =
631 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
632 mlir::OpBuilder builder(module_ref->getBodyRegion());
633
634 llvm::SmallVector<mlir::Type, 8> result_types;
635 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
636 mlir::UnknownLoc::get(&context), result_types);
637 cluster->setAttr(kNumCoresPerReplicaAttr,
638 builder.getIntegerAttr(builder.getIntegerType(64), 1));
639 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
640 cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
641
642 EXPECT_FALSE(HasModelParallelism(cluster));
643 }
644
TEST(TPURewriteDeviceUtilTest,TestHasModelParallelismTrue)645 TEST(TPURewriteDeviceUtilTest, TestHasModelParallelismTrue) {
646 mlir::MLIRContext context;
647 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
648 mlir::OwningModuleRef module_ref =
649 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
650 mlir::OpBuilder builder(module_ref->getBodyRegion());
651
652 llvm::SmallVector<mlir::Type, 8> result_types;
653 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
654 mlir::UnknownLoc::get(&context), result_types);
655 cluster->setAttr(kNumCoresPerReplicaAttr,
656 builder.getIntegerAttr(builder.getIntegerType(64), 5));
657 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
658 cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
659
660 EXPECT_TRUE(HasModelParallelism(cluster));
661 }
662
TEST(TPURewriteDeviceUtilTest,TestHasModelParallelismFalseMissingCoresPerReplicaAttr)663 TEST(TPURewriteDeviceUtilTest,
664 TestHasModelParallelismFalseMissingCoresPerReplicaAttr) {
665 mlir::MLIRContext context;
666 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
667 mlir::OwningModuleRef module_ref =
668 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
669 mlir::OpBuilder builder(module_ref->getBodyRegion());
670
671 llvm::SmallVector<mlir::Type, 8> result_types;
672 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
673 mlir::UnknownLoc::get(&context), result_types);
674 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
675 cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
676
677 EXPECT_FALSE(HasModelParallelism(cluster));
678 }
679
TEST(TPURewriteDeviceUtilTest,TestGetHostFailDeviceMissingAttributes)680 TEST(TPURewriteDeviceUtilTest, TestGetHostFailDeviceMissingAttributes) {
681 mlir::MLIRContext context;
682 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
683 mlir::OwningModuleRef module_ref =
684 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
685 mlir::OpBuilder builder(module_ref->getBodyRegion());
686 llvm::SmallVector<mlir::Type, 8> result_types;
687 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
688 mlir::UnknownLoc::get(&context), result_types);
689
690 mlir::TF::RuntimeDevices devices;
691 std::string host_device;
692 EXPECT_TRUE(mlir::failed(
693 GetHostDeviceOutsideComputation(devices, cluster, &host_device)));
694 }
695
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailMissingTopology)696 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingTopology) {
697 mlir::MLIRContext context;
698 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
699 mlir::OwningModuleRef module_ref =
700 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
701 mlir::OpBuilder builder(module_ref->getBodyRegion());
702
703 llvm::SmallVector<mlir::Type, 8> result_types;
704 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
705 mlir::UnknownLoc::get(&context), result_types);
706 cluster->setAttr(kNumCoresPerReplicaAttr,
707 builder.getIntegerAttr(builder.getIntegerType(64), 1));
708 cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
709
710 mlir::TF::RuntimeDevices runtime_devices;
711 std::string host_device;
712 EXPECT_TRUE(mlir::failed(
713 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
714 }
715
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailMissingDeviceAssignment)716 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingDeviceAssignment) {
717 mlir::MLIRContext context;
718 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
719 mlir::OwningModuleRef module_ref =
720 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
721 mlir::OpBuilder builder(module_ref->getBodyRegion());
722
723 llvm::SmallVector<mlir::Type, 8> result_types;
724 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
725 mlir::UnknownLoc::get(&context), result_types);
726 cluster->setAttr(kNumCoresPerReplicaAttr,
727 builder.getIntegerAttr(builder.getIntegerType(64), 1));
728 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
729
730 mlir::TF::RuntimeDevices runtime_devices;
731 std::string host_device;
732 EXPECT_TRUE(mlir::failed(
733 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
734 }
735
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailBadDeviceAssignment)736 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceAssignment) {
737 mlir::MLIRContext context;
738 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
739 mlir::OwningModuleRef module_ref =
740 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
741 mlir::OpBuilder builder(module_ref->getBodyRegion());
742
743 llvm::SmallVector<mlir::Type, 8> result_types;
744 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
745 mlir::UnknownLoc::get(&context), result_types);
746 cluster->setAttr(kNumCoresPerReplicaAttr,
747 builder.getIntegerAttr(builder.getIntegerType(64), 1));
748 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
749 cluster->setAttr(kDeviceAssignmentAttr,
750 builder.getStrArrayAttr(llvm::ArrayRef<llvm::StringRef>(
751 {"bad_device_assigment"})));
752
753 mlir::TF::RuntimeDevices runtime_devices;
754 std::string host_device;
755 EXPECT_TRUE(mlir::failed(
756 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
757 }
758
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailBadDeviceName)759 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceName) {
760 mlir::MLIRContext context;
761 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
762 mlir::OwningModuleRef module_ref =
763 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
764 mlir::OpBuilder builder(module_ref->getBodyRegion());
765 (*module_ref)
766 ->setAttr("tf.devices",
767 builder.getStrArrayAttr(
768 llvm::ArrayRef<llvm::StringRef>({"bad_device_name"})));
769
770 llvm::SmallVector<mlir::Type, 8> result_types;
771 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
772 mlir::UnknownLoc::get(&context), result_types);
773 cluster->setAttr(kNumCoresPerReplicaAttr,
774 builder.getIntegerAttr(builder.getIntegerType(64), 1));
775 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
776 cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
777
778 mlir::TF::RuntimeDevices runtime_devices;
779 (void)GetDevicesFromOp(*module_ref, &runtime_devices);
780 std::string host_device;
781 EXPECT_TRUE(mlir::failed(
782 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
783 }
784
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceTPUReplicate)785 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) {
786 mlir::MLIRContext context;
787 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
788 mlir::OwningModuleRef module_ref =
789 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
790 mlir::OpBuilder builder(module_ref->getBodyRegion());
791
792 llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<llvm::StringRef, 4>>
793 devices;
794 auto replicate = builder.create<mlir::tf_device::ReplicateOp>(
795 mlir::UnknownLoc::get(&context), /*num_replicas=*/2, devices,
796 llvm::ArrayRef<std::pair<mlir::ValueRange, mlir::Type>>{},
797 mlir::ValueRange{}, mlir::TypeRange{});
798 builder.setInsertionPoint(&replicate.body().front(),
799 replicate.body().front().begin());
800
801 llvm::SmallVector<mlir::Type, 8> result_types;
802 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
803 mlir::UnknownLoc::get(&context), result_types);
804
805 mlir::TF::RuntimeDevices runtime_devices;
806 std::string host_device;
807 EXPECT_TRUE(mlir::succeeded(
808 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
809 EXPECT_EQ(host_device, kTPUReplicatedHost);
810 }
811
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceNotReplicated)812 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceNotReplicated) {
813 mlir::MLIRContext context;
814 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
815 mlir::OwningModuleRef module_ref =
816 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
817 mlir::OpBuilder builder(module_ref->getBodyRegion());
818 (*module_ref)
819 ->setAttr("tf.devices",
820 builder.getStrArrayAttr(llvm::ArrayRef<llvm::StringRef>(
821 {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
822 "/job:localhost/replica:0/task:0/device:TPU:0",
823 "/job:worker/replica:0/task:0/device:CPU:0"})));
824
825 llvm::SmallVector<mlir::Type, 8> result_types;
826 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
827 mlir::UnknownLoc::get(&context), result_types);
828 cluster->setAttr(kNumCoresPerReplicaAttr,
829 builder.getIntegerAttr(builder.getIntegerType(64), 1));
830 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
831 cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
832
833 mlir::TF::RuntimeDevices runtime_devices;
834 (void)GetDevicesFromOp(*module_ref, &runtime_devices);
835 std::string host_device;
836 EXPECT_TRUE(mlir::succeeded(
837 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
838 EXPECT_EQ(host_device, "/job:localhost/replica:0/task:0/device:CPU:0");
839 }
840
TEST(TPURewriteDeviceUtilTest,TestIsTPUDevice)841 TEST(TPURewriteDeviceUtilTest, TestIsTPUDevice) {
842 EXPECT_TRUE(IsTPUDevice("/job:localhost/replica:0/task:0/device:TPU:0"));
843 EXPECT_FALSE(IsTPUDevice("/job:localhost/replica:0/task:0/device:CPU:0"));
844 EXPECT_FALSE(IsTPUDevice("INVALID_DEVICE"));
845 }
846
847 } // anonymous namespace
848 } // namespace tensorflow
849