1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_ 18 19 #include <string> 20 21 #include "llvm/ADT/ArrayRef.h" 22 #include "llvm/ADT/Optional.h" 23 #include "llvm/ADT/SmallVector.h" 24 #include "llvm/ADT/StringRef.h" 25 #include "mlir/IR/Attributes.h" // from @llvm-project 26 #include "mlir/Support/LogicalResult.h" // from @llvm-project 27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" 28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" 29 #include "tensorflow/compiler/xla/xla_data.pb.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/util/device_name_utils.h" 32 #include "tensorflow/stream_executor/lib/statusor.h" 33 34 namespace tensorflow { 35 using stream_executor::port::StatusOr; 36 37 extern const char* const kTPUReplicatedHost; 38 extern const char* const kNumCoresPerReplicaAttr; 39 extern const char* const kTopologyAttr; 40 extern const char* const kDeviceAssignmentAttr; 41 42 // A TPU device for execution alongside its associated host CPU device. 43 struct TPUDeviceAndHost { TPUDeviceAndHostTPUDeviceAndHost44 TPUDeviceAndHost() {} TPUDeviceAndHostTPUDeviceAndHost45 TPUDeviceAndHost(llvm::StringRef device, llvm::StringRef host) 46 : device(device), host(host) {} 47 48 std::string device; 49 std::string host; 50 }; 51 52 // TPU devices to be used for execution (e.g. devices for TPUExecute ops) and 53 // their associated host CPU devices (for outside compilation). They are ordered 54 // by `num_replicas` followed by `num_cores_per_replica`. 55 using TPUDevicesAndHosts = 56 llvm::SmallVector<llvm::SmallVector<TPUDeviceAndHost, 8>, 8>; 57 58 // TPU compilation device, execution and associated host devices, and optionally 59 // execution device IDs. Execution device IDs are populated if `topology` and 60 // `device_assignment` are provided. 61 struct TPUDeviceAssignment { TPUDeviceAssignmentTPUDeviceAssignment62 TPUDeviceAssignment(llvm::StringRef compilation_device, 63 TPUDevicesAndHosts&& tpu_devices) 64 : compilation_device(compilation_device), 65 tpu_devices(std::move(tpu_devices)) {} 66 TPUDeviceAssignmentTPUDeviceAssignment67 TPUDeviceAssignment(llvm::StringRef compilation_device, 68 TPUDevicesAndHosts&& tpu_devices, 69 xla::DeviceAssignmentProto&& xla_device_assignment) 70 : compilation_device(compilation_device), 71 tpu_devices(std::move(tpu_devices)), 72 xla_device_assignment(std::move(xla_device_assignment)) {} 73 74 std::string compilation_device; 75 TPUDevicesAndHosts tpu_devices; 76 llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment; 77 }; 78 79 // Extracts device coordinates from a device assignment attribute on an op. 80 StatusOr<llvm::SmallVector<int64_t, 8>> GetDeviceCoordinates( 81 mlir::ArrayAttr device_assignment_attr); 82 83 // Finds the TPU compilation device and execution devices from `devices` for a 84 // TPU computation subgraph. Compilation device is determined from looking up 85 // all TPU_SYSTEM:0 devices and choosing the CPU device associated to the first 86 // TPU_SYSTEM device sorted lexicographically by replica and task. Execution 87 // devices are determined by looking up all TPU devices associated with each 88 // TPU_SYSTEM:0 device found, alongside associated `topology_attr` and 89 // `device_assignment_attr`. If `topology_attr` not an empty string (parsable to 90 // TopologyProto), `device_assignment_attr` must not be empty also. When 91 // `topology_attr` and `device_assignment_attr` are not empty, a general device 92 // assignment based on those two attributes are used. Otherwise when 93 // `topology_attr` and `device_assignment_attr` are empty, a full mesh device 94 // assignment is used instead. A failure will be returned if it is not possible 95 // (e.g. invalid devices or invalid parameters). 96 // 97 // 98 // For example, for `devices`: 99 // { 100 // /job:localhost/replica:0/task:0/device:CPU:0, 101 // /job:worker/replica:0/task:0/device:CPU:0, 102 // /job:worker/replica:0/task:0/device:TPU_SYSTEM:0, 103 // /job:worker/replica:0/task:0/device:TPU:0, 104 // /job:worker/replica:0/task:0/device:TPU:1, 105 // /job:worker/replica:0/task:0/device:TPU:2, 106 // /job:worker/replica:0/task:0/device:TPU:3, 107 // /job:worker/replica:0/task:1/device:CPU:0, 108 // /job:worker/replica:0/task:1/device:TPU_SYSTEM:0, 109 // /job:worker/replica:0/task:1/device:TPU:0, 110 // /job:worker/replica:0/task:1/device:TPU:1, 111 // /job:worker/replica:0/task:1/device:TPU:2, 112 // /job:worker/replica:0/task:1/device:TPU:3 113 // } 114 // 115 // 116 // With the following parameters (full mesh device assignment): 117 // `num_replicas` = 8 118 // `num_cores_per_replica` = 1 119 // `topology_attr` = "" 120 // `device_assignment_attr` = {} 121 // 122 // The `compilation_device` will be: 123 // /job:worker/replica:0/task:0/device:CPU:0 124 // 125 // `execution_devices` will be: 126 // { 127 // { 128 // /job:worker/replica:0/task:0/device:TPU:0 129 // }, 130 // { 131 // /job:worker/replica:0/task:0/device:TPU:1 132 // }, 133 // { 134 // /job:worker/replica:0/task:0/device:TPU:2 135 // }, 136 // { 137 // /job:worker/replica:0/task:0/device:TPU:3 138 // }, 139 // { 140 // /job:worker/replica:0/task:1/device:TPU:0 141 // }, 142 // { 143 // /job:worker/replica:0/task:1/device:TPU:1 144 // }, 145 // { 146 // /job:worker/replica:0/task:1/device:TPU:2 147 // }, 148 // { 149 // /job:worker/replica:0/task:1/device:TPU:3 150 // } 151 // } 152 // 153 // and `xla_device_assignment` will not be set. 154 // 155 // 156 // With the following parameters (general device assignment): 157 // `num_replicas` = 4 158 // `num_cores_per_replica` = 2 159 // `topology_attr` (in proto debug string format) = 160 // { 161 // mesh_shape: 2 162 // mesh_shape: 2 163 // mesh_shape: 2 164 // num_tasks: 2 165 // num_tpu_devices_per_task: 4 166 // device_coordinates: 0 167 // device_coordinates: 0 168 // device_coordinates: 0 169 // device_coordinates: 0 170 // device_coordinates: 1 171 // device_coordinates: 0 172 // device_coordinates: 1 173 // device_coordinates: 1 174 // device_coordinates: 0 175 // device_coordinates: 1 176 // device_coordinates: 0 177 // device_coordinates: 0 178 // device_coordinates: 1 179 // device_coordinates: 0 180 // device_coordinates: 1 181 // device_coordinates: 1 182 // device_coordinates: 1 183 // device_coordinates: 1 184 // device_coordinates: 0 185 // device_coordinates: 1 186 // device_coordinates: 1 187 // device_coordinates: 0 188 // device_coordinates: 0 189 // device_coordinates: 1 190 // } 191 // `device_assignment` = 192 // {0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1} 193 // 194 // The `compilation_device` will be: 195 // /job:worker/replica:0/task:0/device:CPU:0 196 // 197 // `execution_devices` will be: 198 // { 199 // { 200 // "/job:worker/replica:0/task:0/device:TPU:0", 201 // "/job:worker/replica:0/task:1/device:TPU:3" 202 // }, 203 // { 204 // "/job:worker/replica:0/task:0/device:TPU:1", 205 // "/job:worker/replica:0/task:1/device:TPU:2" 206 // }, 207 // { 208 // "/job:worker/replica:0/task:0/device:TPU:3", 209 // "/job:worker/replica:0/task:1/device:TPU:0" 210 // }, 211 // { 212 // "/job:worker/replica:0/task:0/device:TPU:2", 213 // "/job:worker/replica:0/task:1/device:TPU:1" 214 // } 215 // } 216 // 217 // and `xla_device_assignment` will be: 218 // { 219 // replica_count: 4 220 // computation_count: 2 221 // computation_devices { 222 // replica_device_ids: 0 223 // replica_device_ids: 4 224 // replica_device_ids: 2 225 // replica_device_ids: 6 226 // } 227 // computation_devices { 228 // replica_device_ids: 1 229 // replica_device_ids: 5 230 // replica_device_ids: 3 231 // replica_device_ids: 7 232 // } 233 // } 234 StatusOr<TPUDeviceAssignment> GetTPUCompilationAndExecutionDevices( 235 llvm::ArrayRef<DeviceNameUtils::ParsedName> devices, int num_replicas, 236 int num_cores_per_replica, llvm::StringRef topology_attr, 237 llvm::ArrayRef<int64_t> device_assignment_attr); 238 239 // Virtual device is used for evice assignment for executing ops on a specified 240 // logical core. 241 std::string GetDeviceAliasForLogicalCore(int core_index); 242 243 // Returns true if cluster contains model parallelism based on 244 // `num_cores_per_replica_attribute`. Otherwise returns false. 245 bool HasModelParallelism(mlir::tf_device::ClusterOp cluster); 246 247 // Parses TPU compilation and execution devices from a TPU cluster and returns 248 // the host device for the head and tail computations. If the TPU computation is 249 // replicated, kTPUReplicatedHost is returned instead. 250 mlir::LogicalResult GetHostDeviceOutsideComputation( 251 mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster, 252 std::string* host_device); 253 254 // Checks if a device string is a TPU device. 255 bool IsTPUDevice(llvm::StringRef device); 256 257 // Checks if a device string is a TPU replicated core device. 258 bool IsTPUReplicatedCore(llvm::StringRef device); 259 260 } // namespace tensorflow 261 262 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_ 263