• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_
18 
19 #include <string>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/Optional.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "mlir/IR/Attributes.h"  // from @llvm-project
26 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/util/device_name_utils.h"
32 #include "tensorflow/stream_executor/lib/statusor.h"
33 
34 namespace tensorflow {
35 using stream_executor::port::StatusOr;
36 
37 extern const char* const kTPUReplicatedHost;
38 extern const char* const kNumCoresPerReplicaAttr;
39 extern const char* const kTopologyAttr;
40 extern const char* const kDeviceAssignmentAttr;
41 
42 // A TPU device for execution alongside its associated host CPU device.
43 struct TPUDeviceAndHost {
TPUDeviceAndHostTPUDeviceAndHost44   TPUDeviceAndHost() {}
TPUDeviceAndHostTPUDeviceAndHost45   TPUDeviceAndHost(llvm::StringRef device, llvm::StringRef host)
46       : device(device), host(host) {}
47 
48   std::string device;
49   std::string host;
50 };
51 
52 // TPU devices to be used for execution (e.g. devices for TPUExecute ops) and
53 // their associated host CPU devices (for outside compilation). They are ordered
54 // by `num_replicas` followed by `num_cores_per_replica`.
55 using TPUDevicesAndHosts =
56     llvm::SmallVector<llvm::SmallVector<TPUDeviceAndHost, 8>, 8>;
57 
58 // TPU compilation device, execution and associated host devices, and optionally
59 // execution device IDs. Execution device IDs are populated if `topology` and
60 // `device_assignment` are provided.
61 struct TPUDeviceAssignment {
TPUDeviceAssignmentTPUDeviceAssignment62   TPUDeviceAssignment(llvm::StringRef compilation_device,
63                       TPUDevicesAndHosts&& tpu_devices)
64       : compilation_device(compilation_device),
65         tpu_devices(std::move(tpu_devices)) {}
66 
TPUDeviceAssignmentTPUDeviceAssignment67   TPUDeviceAssignment(llvm::StringRef compilation_device,
68                       TPUDevicesAndHosts&& tpu_devices,
69                       xla::DeviceAssignmentProto&& xla_device_assignment)
70       : compilation_device(compilation_device),
71         tpu_devices(std::move(tpu_devices)),
72         xla_device_assignment(std::move(xla_device_assignment)) {}
73 
74   std::string compilation_device;
75   TPUDevicesAndHosts tpu_devices;
76   llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment;
77 };
78 
79 // Extracts device coordinates from a device assignment attribute on an op.
80 StatusOr<llvm::SmallVector<int64_t, 8>> GetDeviceCoordinates(
81     mlir::ArrayAttr device_assignment_attr);
82 
83 // Finds the TPU compilation device and execution devices from `devices` for a
84 // TPU computation subgraph. Compilation device is determined from looking up
85 // all TPU_SYSTEM:0 devices and choosing the CPU device associated to the first
86 // TPU_SYSTEM device sorted lexicographically by replica and task. Execution
87 // devices are determined by looking up all TPU devices associated with each
88 // TPU_SYSTEM:0 device found, alongside associated `topology_attr` and
89 // `device_assignment_attr`. If `topology_attr` not an empty string (parsable to
90 // TopologyProto), `device_assignment_attr` must not be empty also. When
91 // `topology_attr` and `device_assignment_attr` are not empty, a general device
92 // assignment based on those two attributes are used. Otherwise when
93 // `topology_attr` and `device_assignment_attr` are empty, a full mesh device
94 // assignment is used instead. A failure will be returned if it is not possible
95 // (e.g. invalid devices or invalid parameters).
96 //
97 //
98 // For example, for `devices`:
99 //   {
100 //     /job:localhost/replica:0/task:0/device:CPU:0,
101 //     /job:worker/replica:0/task:0/device:CPU:0,
102 //     /job:worker/replica:0/task:0/device:TPU_SYSTEM:0,
103 //     /job:worker/replica:0/task:0/device:TPU:0,
104 //     /job:worker/replica:0/task:0/device:TPU:1,
105 //     /job:worker/replica:0/task:0/device:TPU:2,
106 //     /job:worker/replica:0/task:0/device:TPU:3,
107 //     /job:worker/replica:0/task:1/device:CPU:0,
108 //     /job:worker/replica:0/task:1/device:TPU_SYSTEM:0,
109 //     /job:worker/replica:0/task:1/device:TPU:0,
110 //     /job:worker/replica:0/task:1/device:TPU:1,
111 //     /job:worker/replica:0/task:1/device:TPU:2,
112 //     /job:worker/replica:0/task:1/device:TPU:3
113 //   }
114 //
115 //
116 // With the following parameters (full mesh device assignment):
117 //   `num_replicas` = 8
118 //   `num_cores_per_replica` = 1
119 //   `topology_attr` = ""
120 //   `device_assignment_attr` = {}
121 //
122 // The `compilation_device` will be:
123 //   /job:worker/replica:0/task:0/device:CPU:0
124 //
125 // `execution_devices` will be:
126 //   {
127 //     {
128 //       /job:worker/replica:0/task:0/device:TPU:0
129 //     },
130 //     {
131 //       /job:worker/replica:0/task:0/device:TPU:1
132 //     },
133 //     {
134 //       /job:worker/replica:0/task:0/device:TPU:2
135 //     },
136 //     {
137 //       /job:worker/replica:0/task:0/device:TPU:3
138 //     },
139 //     {
140 //       /job:worker/replica:0/task:1/device:TPU:0
141 //     },
142 //     {
143 //       /job:worker/replica:0/task:1/device:TPU:1
144 //     },
145 //     {
146 //       /job:worker/replica:0/task:1/device:TPU:2
147 //     },
148 //     {
149 //       /job:worker/replica:0/task:1/device:TPU:3
150 //     }
151 //   }
152 //
153 // and `xla_device_assignment` will not be set.
154 //
155 //
156 // With the following parameters (general device assignment):
157 //   `num_replicas` = 4
158 //   `num_cores_per_replica` = 2
159 //   `topology_attr` (in proto debug string format) =
160 //     {
161 //       mesh_shape: 2
162 //       mesh_shape: 2
163 //       mesh_shape: 2
164 //       num_tasks: 2
165 //       num_tpu_devices_per_task: 4
166 //       device_coordinates: 0
167 //       device_coordinates: 0
168 //       device_coordinates: 0
169 //       device_coordinates: 0
170 //       device_coordinates: 1
171 //       device_coordinates: 0
172 //       device_coordinates: 1
173 //       device_coordinates: 1
174 //       device_coordinates: 0
175 //       device_coordinates: 1
176 //       device_coordinates: 0
177 //       device_coordinates: 0
178 //       device_coordinates: 1
179 //       device_coordinates: 0
180 //       device_coordinates: 1
181 //       device_coordinates: 1
182 //       device_coordinates: 1
183 //       device_coordinates: 1
184 //       device_coordinates: 0
185 //       device_coordinates: 1
186 //       device_coordinates: 1
187 //       device_coordinates: 0
188 //       device_coordinates: 0
189 //       device_coordinates: 1
190 //     }
191 //   `device_assignment` =
192 //     {0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1}
193 //
194 // The `compilation_device` will be:
195 //   /job:worker/replica:0/task:0/device:CPU:0
196 //
197 // `execution_devices` will be:
198 //   {
199 //     {
200 //       "/job:worker/replica:0/task:0/device:TPU:0",
201 //       "/job:worker/replica:0/task:1/device:TPU:3"
202 //     },
203 //     {
204 //       "/job:worker/replica:0/task:0/device:TPU:1",
205 //       "/job:worker/replica:0/task:1/device:TPU:2"
206 //     },
207 //     {
208 //       "/job:worker/replica:0/task:0/device:TPU:3",
209 //       "/job:worker/replica:0/task:1/device:TPU:0"
210 //     },
211 //     {
212 //       "/job:worker/replica:0/task:0/device:TPU:2",
213 //       "/job:worker/replica:0/task:1/device:TPU:1"
214 //     }
215 //   }
216 //
217 // and `xla_device_assignment` will be:
218 //   {
219 //     replica_count: 4
220 //     computation_count: 2
221 //     computation_devices {
222 //       replica_device_ids: 0
223 //       replica_device_ids: 4
224 //       replica_device_ids: 2
225 //       replica_device_ids: 6
226 //     }
227 //     computation_devices {
228 //       replica_device_ids: 1
229 //       replica_device_ids: 5
230 //       replica_device_ids: 3
231 //       replica_device_ids: 7
232 //     }
233 //   }
234 StatusOr<TPUDeviceAssignment> GetTPUCompilationAndExecutionDevices(
235     llvm::ArrayRef<DeviceNameUtils::ParsedName> devices, int num_replicas,
236     int num_cores_per_replica, llvm::StringRef topology_attr,
237     llvm::ArrayRef<int64_t> device_assignment_attr);
238 
239 // Virtual device is used for evice assignment for executing ops on a specified
240 // logical core.
241 std::string GetDeviceAliasForLogicalCore(int core_index);
242 
243 // Parses TPU compilation and execution devices from a TPU cluster and returns
244 // the host device for the head and tail computations. If the TPU computation is
245 // replicated, kTPUReplicatedHost is returned instead.
246 mlir::LogicalResult GetHostDeviceOutsideComputation(
247     mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster,
248     std::string* host_device);
249 
250 // Checks if a device string is a TPU device.
251 bool IsTPUDevice(llvm::StringRef device);
252 
253 }  // namespace tensorflow
254 
255 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_
256