1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the 'License'); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an 'AS IS' BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ====================================== 15"""Experimental support for defining XLA shardings.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as _np # Avoids becoming a part of public Tensorflow API. 22 23from tensorflow.compiler.tf2xla.python import xla as tf2xla 24from tensorflow.compiler.xla import xla_data_pb2 25from tensorflow.core.framework import attr_value_pb2 26 27 28class Sharding(object): 29 """A class to support adding sharding attributes to Ops. 30 31 Use the factory constructors and then call apply_to_tensor: 32 Sharding.replicate().apply_to_tensor(tensor) 33 """ 34 35 def __init__(self, proto=None): 36 """Do not use this constructor; use the factory functions below.""" 37 self._proto = proto 38 39 @classmethod 40 def replicate(cls): 41 """Returns a replicated sharding attribute. 42 43 This causes an op to be computed in its entirety independently on all 44 cores in the XLA device. 45 """ 46 return Sharding( 47 proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED)) 48 49 @classmethod 50 def manual(cls): 51 """Returns a manuall sharding attribute. 52 53 This means the op is manually partitioned by the user and XLA will not 54 change the shapes. 55 """ 56 return Sharding( 57 proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.MANUAL)) 58 59 @classmethod 60 def assign_device(cls, core): 61 """Returns an AssignDevice sharding attribute. 62 63 This causes an op to be computed in its entirety only on one core in 64 the XLA device. 65 Args: 66 core: The core to assign this Op to. 67 """ 68 return Sharding( 69 proto=xla_data_pb2.OpSharding( 70 type=xla_data_pb2.OpSharding.MAXIMAL, 71 tile_assignment_dimensions=[1], 72 tile_assignment_devices=[core])) 73 74 @classmethod 75 def tile(cls, tile_assignment): 76 """Returns a Tiled sharding attribute. 77 78 This causes an op to be partially computed on multiple cores in the 79 XLA device. 80 81 Args: 82 tile_assignment: An np.ndarray describing the topology of the tiling and 83 which device will compute which part of the topology. 84 85 Raises: 86 TypeError: tile_assignment was not of np.array type. 87 88 TODO(jmolloy): This concept is nefarious and is not 89 something we really want to expose to users (especially as the 90 contract for tile_assignment is very strict). 91 """ 92 if not isinstance(tile_assignment, _np.ndarray): 93 raise TypeError('Tile assignment must be of type np.ndarray') 94 dims = list(tile_assignment.shape) 95 flattened_devices = tile_assignment.reshape(-1, order='C') 96 return Sharding( 97 proto=xla_data_pb2.OpSharding( 98 type=xla_data_pb2.OpSharding.OTHER, 99 tile_assignment_dimensions=dims, 100 tile_assignment_devices=list(flattened_devices))) 101 102 @classmethod 103 def partial_tile(cls, tile_assignment): 104 """Returns a partially tiled sharding attribute. 105 106 This is similar to tile(), but tile_assignment has one more dimension than 107 the tensor, and tiles in the last dimension of tile_assignment are 108 replicated. 109 110 Args: 111 tile_assignment: An np.ndarray describing the topology of the tiling and 112 which device will compute which part of the topology. 113 114 Raises: 115 TypeError: tile_assignment was not of np.array type. 116 """ 117 if not isinstance(tile_assignment, _np.ndarray): 118 raise TypeError('PartialTile assignment must be of type np.ndarray') 119 dims = list(tile_assignment.shape) 120 flattened_devices = tile_assignment.reshape(-1, order='C') 121 return Sharding( 122 proto=xla_data_pb2.OpSharding( 123 type=xla_data_pb2.OpSharding.OTHER, 124 tile_assignment_dimensions=dims, 125 tile_assignment_devices=list(flattened_devices), 126 replicate_on_last_tile_dim=True)) 127 128 @classmethod 129 def split(cls, tensor, split_dimension, num_devices, input_shape=None): 130 """Returns a Sharding that splits a tensor across a dimension. 131 132 This creates a Tiled attribute, similar to tile(), but easier to use for the 133 common case of tiling a tensor N ways in one dimension. 134 135 Args: 136 tensor: A tf.Tensor to split. 137 split_dimension: The dimension number to split. 138 num_devices: The number of cores to split `tensor` over. 139 input_shape: The shape of the original tensor. 140 141 Raises: 142 ValueError: The tensor to split was smaller in the split dimension than 143 the number of devices to split over. 144 """ 145 if input_shape: 146 shape = input_shape 147 else: 148 shape = tensor.shape.as_list() 149 if (shape[split_dimension] is not None and 150 shape[split_dimension] < num_devices): 151 raise ValueError('Split dimension was smaller than the required number ' 152 'of splits: shape=%r, dimension=%r, num_devices=%r' % 153 (shape, split_dimension, num_devices)) 154 155 tile_assignment_dims = [1] * len(shape) 156 tile_assignment_dims[split_dimension] = num_devices 157 158 return Sharding( 159 proto=xla_data_pb2.OpSharding( 160 type=xla_data_pb2.OpSharding.OTHER, 161 tile_assignment_dimensions=tile_assignment_dims, 162 tile_assignment_devices=range(num_devices))) 163 164 def apply_to_tensor(self, 165 tensor, 166 assign_tuple_sharding=False, 167 use_sharding_op=False): 168 """Applies this Sharding attribute to `tensor`. 169 170 Args: 171 tensor: A tf.Tensor to split. 172 assign_tuple_sharding: If the sharding type should be a tuple. 173 use_sharding_op: whether to create a sharding op on `tensor`. 174 175 Returns: 176 The tensor with Sharding attribute. 177 """ 178 proto = self._proto 179 if use_sharding_op: 180 if assign_tuple_sharding: 181 proto = self._create_tuple_proto(num_outputs=1) 182 tensor = tf2xla.sharding(tensor, sharding=proto.SerializeToString()) 183 else: 184 tensor = tf2xla.sharding( 185 tensor, sharding=proto.SerializeToString()) 186 elif assign_tuple_sharding or len(tensor.op.outputs) > 1: 187 proto = self._get_or_create_tuple_proto(tensor.op) 188 # We can't mutate an element of old_proto.tuple_shardings, so create 189 # a new proto. 190 tuple_shardings = list(proto.tuple_shardings) 191 tuple_shardings[tensor.value_index] = self._proto 192 proto = xla_data_pb2.OpSharding( 193 type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings) 194 195 # TODO(jmolloy): This need to be seriously revisited before declaring this 196 # API available for public use. 197 # pylint: disable=protected-access 198 tensor.op._set_attr('_XlaSharding', 199 attr_value_pb2.AttrValue(s=proto.SerializeToString())) 200 return tensor 201 202 def apply_to_operation(self, operation): 203 """Applies this Sharding attribute to `operation`. 204 205 Args: 206 operation: A tf.Operation to add sharding annotation. 207 """ 208 attr_value = attr_value_pb2.AttrValue(s=self._proto.SerializeToString()) 209 # pylint: disable=protected-access 210 operation._set_attr('_XlaSharding', attr_value) 211 212 @property 213 def proto(self): 214 """Return the sharding protobuf of type xla_data_pb2.OpSharding.""" 215 return self._proto 216 217 def _get_or_create_tuple_proto(self, op): 218 try: 219 attr = op.get_attr('_XlaSharding') 220 proto = xla_data_pb2.OpSharding() 221 proto.ParseFromString(attr) 222 return proto 223 except ValueError: 224 return self._create_tuple_proto(len(op.outputs)) 225 226 def _create_tuple_proto(self, num_outputs): 227 shardings = [ 228 xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED) 229 ] * num_outputs 230 return xla_data_pb2.OpSharding( 231 type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=shardings) 232 233 234def copy_sharding(from_tensor, to_tensor, use_sharding_op=False): 235 """Copies the a tensor's sharding to another. 236 237 Args: 238 from_tensor: Source tensor. Must be the sole output of an op. 239 to_tensor: the tensor the annotate with the copy. 240 use_sharding_op: whether to create a sharding op on `to_tensor`. 241 242 Returns: 243 A tensor with sharding annotation copied from `from_tensor`. 244 """ 245 sharding = get_tensor_sharding(from_tensor) 246 if sharding is None: 247 return to_tensor 248 249 if use_sharding_op: 250 to_tensor = tf2xla.sharding(to_tensor, sharding=sharding) 251 attr_value = attr_value_pb2.AttrValue(s=sharding) 252 # pylint: disable=protected-access 253 to_tensor.op._set_attr('_XlaSharding', attr_value) 254 return to_tensor 255 256# Helpers for the above factory functions that allow easy application of 257# shardings, for example: 258# tensor = xla_sharding.replicate(tensor) 259 260 261def replicate(tensor, assign_tuple_sharding=False, use_sharding_op=False): 262 return Sharding.replicate().apply_to_tensor( 263 tensor, 264 assign_tuple_sharding=assign_tuple_sharding, 265 use_sharding_op=use_sharding_op) 266 267 268def assign_device(tensor, 269 device, 270 assign_tuple_sharding=False, 271 use_sharding_op=False): 272 """Returns a tensor that has AssignDevice sharding attribute.""" 273 return Sharding.assign_device(device).apply_to_tensor( 274 tensor, 275 assign_tuple_sharding=assign_tuple_sharding, 276 use_sharding_op=use_sharding_op) 277 278 279def tile(tensor, 280 tile_assignment, 281 assign_tuple_sharding=False, 282 use_sharding_op=False): 283 """Returns a tensor that has tiled sharding. 284 285 Args: 286 tensor: A tf.Tensor to shard. 287 tile_assignment: An np.ndarray describing the topology of the tiling and 288 which device will compute which part of the topology. 289 assign_tuple_sharding: If the sharding type should be a tuple. 290 use_sharding_op: If true, adds a sharding op to set the sharding. 291 """ 292 return Sharding.tile(tile_assignment).apply_to_tensor( 293 tensor, 294 assign_tuple_sharding=assign_tuple_sharding, 295 use_sharding_op=use_sharding_op) 296 297 298def split(tensor, 299 split_dimension, 300 num_devices, 301 assign_tuple_sharding=False, 302 use_sharding_op=False, 303 input_shape=None): 304 """Returns a tensor that is split along the given dimension. 305 306 Args: 307 tensor: A tf.Tensor to split. 308 split_dimension: The dimension to split. 309 num_devices: The number of devices to partition the dimension. 310 assign_tuple_sharding: If the sharding type should be a tuple. 311 use_sharding_op: If true, adds a sharding op to set the sharding. 312 input_shape: The full shape of the input tensor. 313 """ 314 return Sharding.split(tensor, split_dimension, num_devices, 315 input_shape).apply_to_tensor( 316 tensor, 317 assign_tuple_sharding=assign_tuple_sharding, 318 use_sharding_op=use_sharding_op) 319 320 321def partial_tile(tensor, tile_assignment, use_sharding_op=False): 322 """Returns a tensor that has tiled sharding. 323 324 Args: 325 tensor: A tf.Tensor to shard. 326 tile_assignment: An np.ndarray describing the topology of the tiling and 327 which device will compute which part of the topology. It must have one 328 more dimension than tensor, and the last dimension represents partially 329 replicated tiles. 330 use_sharding_op: If true, adds a sharding op to set the sharding. 331 """ 332 return Sharding.partial_tile(tile_assignment).apply_to_tensor( 333 tensor, use_sharding_op=use_sharding_op) 334 335 336def get_op_sharding(op): 337 """Returns sharding attribute of an op. 338 339 Args: 340 op: a TensorFlow op. 341 342 Returns: 343 The attribute representing XLA sharding on this op. 344 """ 345 try: 346 return op.get_attr('_XlaSharding') 347 except ValueError: 348 return None 349 except AttributeError: 350 # AttributeError: 'DistributedVarOp' object has no attribute 'get_attr'. 351 return None 352 353 354def get_tensor_sharding(tensor): 355 """Returns sharding attribute of a Tensor. 356 357 Args: 358 tensor: a Tensor. 359 360 Returns: 361 The attribute representing XLA sharding on tensor's op. 362 """ 363 try: 364 return get_op_sharding(tensor.op) 365 except AttributeError: 366 # AttributeError: Tensor.op is meaningless when eager execution is enabled. 367 return None 368 369 370def get_sharding_tile_shape(sharding): 371 """Returns the tile assignment shape for a sharded Tensor. 372 373 Args: 374 sharding: a serialized OpSharding message describing the layout of a 375 sharded Tensor. 376 377 Returns: 378 A list, for each dimension of the sharded Tensor, of the number of shards 379 into which it has been split. Returns None if the input indicates no tile 380 assignments. 381 """ 382 if sharding is None: 383 return None 384 sharding_message = xla_data_pb2.OpSharding() 385 sharding_message.ParseFromString(sharding) 386 if sharding_message.tile_assignment_dimensions: 387 return sharding_message.tile_assignment_dimensions 388 else: 389 return None 390 391 392def auto_to_manual_spmd_partition(tensor, manual_sharding): 393 """Switches from automatic SPMD partitioning to manual partitioning. 394 395 Converts a full-shaped tensor (to be automatically partitioned by SPMD 396 partitioner) to a shard-shaped tensor to be consumed by manually partitioned 397 ops. 398 399 Args: 400 tensor: A tf.Tensor in full shape. 401 manual_sharding: a serialized string of OpSharding to be used in manual 402 partitioning. 403 404 Returns: 405 A shard-shaped tensor to be consumed by manually partitioned ops. 406 """ 407 return tf2xla.spmd_full_to_shard_shape( 408 tensor, manual_sharding=manual_sharding) 409 410 411def manual_to_auto_spmd_partition(tensor, manual_sharding, full_shape): 412 """Switches from manual partitioning to automatic SPMD partitioning. 413 414 Converts a shard-shaped tensor (manually partitioned in SPMD-style) to a 415 full-shaped tensor to be partitioned automatically by the SPMD partitioner. 416 417 Args: 418 tensor: A tf.Tensor in shard shape. 419 manual_sharding: a serialized string of OpSharding to be used in manual 420 partitioning. 421 full_shape: the shape of tensor before partitioning. 422 423 Returns: 424 A full-shaped tensor to be partitioned automatically by the SPMD 425 partitioner. 426 """ 427 return tf2xla.spmd_shard_to_full_shape( 428 tensor, manual_sharding=manual_sharding, full_shape=full_shape) 429 430 431def mesh_split_sharding(device_mesh, tensor_split_dims_mapping): 432 """Returns a Sharding object representing sharding along multiple dimensions. 433 434 Args: 435 device_mesh: An np.ndarray describing the topology of the device mesh and 436 each element is the ID of the device in the topology. 437 tensor_split_dims_mapping: A list of integers that map each tensor axis to 438 the device mesh axis along which it is sharded. Its length is the tensor 439 rank, and tensor_split_dims_mapping[i] is device mesh axis for tensor 440 dimension i. Use -1 for tensor dimensions that are not sharded. 441 442 Raises: 443 ValueError: The number of tensor split dimensions is larger than device mesh 444 rank. 445 """ 446 permutation = [d for d in tensor_split_dims_mapping if d >= 0] 447 if len(permutation) > len(device_mesh.shape): 448 raise ValueError( 449 'Number of tensor split dimensions (%r) is larger than device mesh ' 450 'rank (%r). tensor_split_dims_mapping: %r, device_mesh.shape: %r' % 451 (len(permutation), len( 452 device_mesh.shape), tensor_split_dims_mapping, device_mesh.shape)) 453 # Append replicated dimensions to the end. 454 transpose_permutation = permutation + [ 455 d for d in range(len(device_mesh.shape)) if d not in permutation 456 ] 457 tile_assignment = _np.transpose(device_mesh, transpose_permutation) 458 tile_shape = [ 459 1 if d < 0 else device_mesh.shape[d] for d in tensor_split_dims_mapping 460 ] 461 partial = len(permutation) < len(device_mesh.shape) 462 if partial: 463 tile_shape.append(_np.prod(device_mesh.shape) // _np.prod(tile_shape)) 464 tile_assignment = _np.reshape(tile_assignment, tile_shape) 465 466 if partial: 467 return Sharding.partial_tile(tile_assignment) 468 return Sharding.tile(tile_assignment) 469 470 471def mesh_split(tensor, 472 device_mesh, 473 tensor_split_dims_mapping, 474 use_sharding_op=False): 475 """Returns a tensor that is split along multiple dimensions in a device mesh. 476 477 Args: 478 tensor: A tf.Tensor to split. 479 device_mesh: An np.ndarray describing the topology of the device mesh and 480 each element is the ID of the device in the topology. 481 tensor_split_dims_mapping: A list of integers that map each tensor axis to 482 the device mesh axis along which it is sharded. Its length is the tensor 483 rank, and tensor_split_dims_mapping[i] is device mesh axis for tensor 484 dimension i. Use -1 for tensor dimensions that are not sharded. 485 use_sharding_op: If true, adds a sharding op to set the sharding. 486 487 Raises: 488 ValueError: The number of tensor split dimensions is larger than device mesh 489 rank. 490 """ 491 sharding = mesh_split_sharding(device_mesh, tensor_split_dims_mapping) 492 return sharding.apply_to_tensor(tensor, use_sharding_op=use_sharding_op) 493