1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_ 18 19 #include <string> 20 21 #include "absl/strings/string_view.h" 22 #include "llvm/ADT/ArrayRef.h" 23 #include "llvm/ADT/Optional.h" 24 #include "llvm/ADT/SmallVector.h" 25 #include "llvm/ADT/StringRef.h" 26 #include "mlir/IR/Attributes.h" // from @llvm-project 27 #include "mlir/Support/LogicalResult.h" // from @llvm-project 28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" 29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" 30 #include "tensorflow/compiler/xla/xla_data.pb.h" 31 #include "tensorflow/core/lib/core/status.h" 32 #include "tensorflow/core/util/device_name_utils.h" 33 #include "tensorflow/stream_executor/lib/statusor.h" 34 35 namespace tensorflow { 36 using stream_executor::port::StatusOr; 37 38 inline constexpr absl::string_view kTPUReplicatedHost = "TPU_REPLICATED_HOST"; 39 inline constexpr absl::string_view kNumCoresPerReplicaAttr = 40 "num_cores_per_replica"; 41 inline constexpr absl::string_view kTopologyAttr = "topology"; 42 inline constexpr absl::string_view kDeviceAssignmentAttr = "device_assignment"; 43 44 // A TPU device for execution alongside its associated host CPU device. 45 struct TPUDeviceAndHost { TPUDeviceAndHostTPUDeviceAndHost46 TPUDeviceAndHost() {} TPUDeviceAndHostTPUDeviceAndHost47 TPUDeviceAndHost(llvm::StringRef device, llvm::StringRef host) 48 : device(device), host(host) {} 49 50 std::string device; 51 std::string host; 52 }; 53 54 // TPU devices to be used for execution (e.g. devices for TPUExecute ops) and 55 // their associated host CPU devices (for outside compilation). They are ordered 56 // by `num_replicas` followed by `num_cores_per_replica`. 57 using TPUDevicesAndHosts = 58 llvm::SmallVector<llvm::SmallVector<TPUDeviceAndHost, 8>, 8>; 59 60 // TPU compilation device, execution and associated host devices, and optionally 61 // execution device IDs. Execution device IDs are populated if `topology` and 62 // `device_assignment` are provided. 63 struct TPUDeviceAssignment { TPUDeviceAssignmentTPUDeviceAssignment64 TPUDeviceAssignment(llvm::StringRef compilation_device, 65 TPUDevicesAndHosts&& tpu_devices) 66 : compilation_device(compilation_device), 67 tpu_devices(std::move(tpu_devices)) {} 68 TPUDeviceAssignmentTPUDeviceAssignment69 TPUDeviceAssignment(llvm::StringRef compilation_device, 70 TPUDevicesAndHosts&& tpu_devices, 71 xla::DeviceAssignmentProto&& xla_device_assignment) 72 : compilation_device(compilation_device), 73 tpu_devices(std::move(tpu_devices)), 74 xla_device_assignment(std::move(xla_device_assignment)) {} 75 76 std::string compilation_device; 77 TPUDevicesAndHosts tpu_devices; 78 llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment; 79 }; 80 81 // Extracts device coordinates from a device assignment attribute on an op. 82 StatusOr<llvm::SmallVector<int64_t, 8>> GetDeviceCoordinates( 83 mlir::ArrayAttr device_assignment_attr); 84 85 // Finds the TPU compilation device and execution devices from `devices` for a 86 // TPU computation subgraph. Compilation device is determined from looking up 87 // all TPU_SYSTEM:0 devices and choosing the CPU device associated to the first 88 // TPU_SYSTEM device sorted lexicographically by replica and task. Execution 89 // devices are determined by looking up all TPU devices associated with each 90 // TPU_SYSTEM:0 device found, alongside associated `topology_attr` and 91 // `device_assignment_attr`. If `topology_attr` not an empty string (parsable to 92 // TopologyProto), `device_assignment_attr` must not be empty also. When 93 // `topology_attr` and `device_assignment_attr` are not empty, a general device 94 // assignment based on those two attributes are used. Otherwise when 95 // `topology_attr` and `device_assignment_attr` are empty, a full mesh device 96 // assignment is used instead. A failure will be returned if it is not possible 97 // (e.g. invalid devices or invalid parameters). 98 // 99 // 100 // For example, for `devices`: 101 // { 102 // /job:localhost/replica:0/task:0/device:CPU:0, 103 // /job:worker/replica:0/task:0/device:CPU:0, 104 // /job:worker/replica:0/task:0/device:TPU_SYSTEM:0, 105 // /job:worker/replica:0/task:0/device:TPU:0, 106 // /job:worker/replica:0/task:0/device:TPU:1, 107 // /job:worker/replica:0/task:0/device:TPU:2, 108 // /job:worker/replica:0/task:0/device:TPU:3, 109 // /job:worker/replica:0/task:1/device:CPU:0, 110 // /job:worker/replica:0/task:1/device:TPU_SYSTEM:0, 111 // /job:worker/replica:0/task:1/device:TPU:0, 112 // /job:worker/replica:0/task:1/device:TPU:1, 113 // /job:worker/replica:0/task:1/device:TPU:2, 114 // /job:worker/replica:0/task:1/device:TPU:3 115 // } 116 // 117 // 118 // With the following parameters (full mesh device assignment): 119 // `num_replicas` = 8 120 // `num_cores_per_replica` = 1 121 // `topology_attr` = "" 122 // `device_assignment_attr` = {} 123 // 124 // The `compilation_device` will be: 125 // /job:worker/replica:0/task:0/device:CPU:0 126 // 127 // `execution_devices` will be: 128 // { 129 // { 130 // /job:worker/replica:0/task:0/device:TPU:0 131 // }, 132 // { 133 // /job:worker/replica:0/task:0/device:TPU:1 134 // }, 135 // { 136 // /job:worker/replica:0/task:0/device:TPU:2 137 // }, 138 // { 139 // /job:worker/replica:0/task:0/device:TPU:3 140 // }, 141 // { 142 // /job:worker/replica:0/task:1/device:TPU:0 143 // }, 144 // { 145 // /job:worker/replica:0/task:1/device:TPU:1 146 // }, 147 // { 148 // /job:worker/replica:0/task:1/device:TPU:2 149 // }, 150 // { 151 // /job:worker/replica:0/task:1/device:TPU:3 152 // } 153 // } 154 // 155 // and `xla_device_assignment` will not be set. 156 // 157 // 158 // With the following parameters (general device assignment): 159 // `num_replicas` = 4 160 // `num_cores_per_replica` = 2 161 // `topology_attr` (in proto debug string format) = 162 // { 163 // mesh_shape: 2 164 // mesh_shape: 2 165 // mesh_shape: 2 166 // num_tasks: 2 167 // num_tpu_devices_per_task: 4 168 // device_coordinates: 0 169 // device_coordinates: 0 170 // device_coordinates: 0 171 // device_coordinates: 0 172 // device_coordinates: 1 173 // device_coordinates: 0 174 // device_coordinates: 1 175 // device_coordinates: 1 176 // device_coordinates: 0 177 // device_coordinates: 1 178 // device_coordinates: 0 179 // device_coordinates: 0 180 // device_coordinates: 1 181 // device_coordinates: 0 182 // device_coordinates: 1 183 // device_coordinates: 1 184 // device_coordinates: 1 185 // device_coordinates: 1 186 // device_coordinates: 0 187 // device_coordinates: 1 188 // device_coordinates: 1 189 // device_coordinates: 0 190 // device_coordinates: 0 191 // device_coordinates: 1 192 // } 193 // `device_assignment` = 194 // {0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1} 195 // 196 // The `compilation_device` will be: 197 // /job:worker/replica:0/task:0/device:CPU:0 198 // 199 // `execution_devices` will be: 200 // { 201 // { 202 // "/job:worker/replica:0/task:0/device:TPU:0", 203 // "/job:worker/replica:0/task:1/device:TPU:3" 204 // }, 205 // { 206 // "/job:worker/replica:0/task:0/device:TPU:1", 207 // "/job:worker/replica:0/task:1/device:TPU:2" 208 // }, 209 // { 210 // "/job:worker/replica:0/task:0/device:TPU:3", 211 // "/job:worker/replica:0/task:1/device:TPU:0" 212 // }, 213 // { 214 // "/job:worker/replica:0/task:0/device:TPU:2", 215 // "/job:worker/replica:0/task:1/device:TPU:1" 216 // } 217 // } 218 // 219 // and `xla_device_assignment` will be: 220 // { 221 // replica_count: 4 222 // computation_count: 2 223 // computation_devices { 224 // replica_device_ids: 0 225 // replica_device_ids: 4 226 // replica_device_ids: 2 227 // replica_device_ids: 6 228 // } 229 // computation_devices { 230 // replica_device_ids: 1 231 // replica_device_ids: 5 232 // replica_device_ids: 3 233 // replica_device_ids: 7 234 // } 235 // } 236 StatusOr<TPUDeviceAssignment> GetTPUCompilationAndExecutionDevices( 237 llvm::ArrayRef<DeviceNameUtils::ParsedName> devices, int num_replicas, 238 int num_cores_per_replica, llvm::StringRef topology_attr, 239 llvm::ArrayRef<int64_t> device_assignment_attr); 240 241 // Virtual device is used for evice assignment for executing ops on a specified 242 // logical core. 243 std::string GetDeviceAliasForLogicalCore(int core_index); 244 245 // Returns true if cluster contains model parallelism based on 246 // `num_cores_per_replica_attribute`. Otherwise returns false. 247 bool HasModelParallelism(mlir::tf_device::ClusterOp cluster); 248 249 // Parses TPU compilation and execution devices from a TPU cluster and returns 250 // the host device for the head and tail computations. If the TPU computation is 251 // replicated, kTPUReplicatedHost is returned instead. 252 mlir::LogicalResult GetHostDeviceOutsideComputation( 253 mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster, 254 std::string* host_device); 255 256 // Checks if a device string is a TPU device. 257 bool IsTPUDevice(llvm::StringRef device); 258 259 // Checks if a device string is a TPU replicated core device. 260 bool IsTPUReplicatedCore(llvm::StringRef device); 261 262 } // namespace tensorflow 263 264 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_ 265