1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =================================================================== 15 16"""Helper library for handling infeed between hosts and TPUs. 17""" 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import itertools 24 25import numpy as np 26from six.moves import xrange # pylint: disable=redefined-builtin 27 28from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.ops import array_ops 33from tensorflow.python.tpu import tpu_name_util 34from tensorflow.python.tpu import tpu_sharding 35from tensorflow.python.tpu.ops import tpu_ops 36 37from tensorflow.python.util import nest 38 39 40def partition_or_replicate_on_host(tensor, dims): 41 """Partitions or replicates the input tensor. 42 43 The ops inside this function are placed on the host side. 44 45 Args: 46 tensor: The input tensor which will be partitioned or replicated. 47 dims: A list of integer describes how to partition the input tensor. 48 49 Returns: 50 An iterator of `Tensor`s or a list of partitioned tensors. 51 """ 52 if dims is None: 53 return itertools.repeat(tensor) 54 dims = np.array(dims) 55 output = [tensor] 56 shape_list = np.array(tensor.shape.as_list()) 57 quotients, remainders = np.divmod(shape_list, dims) 58 for axis, (quotient, remainder, dim, original_size) in enumerate( 59 zip(quotients, remainders, dims, shape_list)): 60 if dim <= 1: 61 continue 62 if remainder > 0: 63 # For each dimension, when it cannot be evenly partitioned, XLA assumes 64 # tensors are partitioned in a greedy manner by using 65 # ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims 66 # are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] => 67 # [[(3, 4), (3, 4), (2, 4), (2, 2)], 68 # [(2, 4), (2, 4), (2, 4), (2, 2)]] 69 ceil_ratio = quotient + 1 70 num_full_slots, left_over = np.divmod(original_size, ceil_ratio) 71 num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over] 72 if len(num_or_size_splits) < dim: 73 num_or_size_splits += [0] * (dim - len(num_or_size_splits)) 74 new_output = [] 75 for x in output: 76 new_output.append( 77 array_ops.split( 78 x, num_or_size_splits=num_or_size_splits, axis=axis)) 79 output = new_output 80 else: 81 output = [array_ops.split(x, int(dim), axis=axis) for x in output] 82 output = nest.flatten(output) 83 return output 84 85 86def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims): 87 """Tags appropriate XLA sharding attribute to the dequeued tensor. 88 89 The sharding attribute of the dequeued tensor will be a tuple. 90 91 Args: 92 tensor: The dequeued tensor on TPU. 93 dims: A list of integer describes how the tensor is partitioned. 94 95 Returns: 96 The same tensor with the xla_sharding attribute. 97 """ 98 if dims is None: 99 return xla_sharding.replicate(tensor, assign_tuple_sharding=True) 100 elif np.prod(dims) == 1: 101 return xla_sharding.assign_device(tensor, 0, assign_tuple_sharding=True) 102 else: 103 tile_assignment = np.arange(np.prod(dims)).reshape(dims) 104 return xla_sharding.tile( 105 tensor=tensor, 106 tile_assignment=tile_assignment, 107 assign_tuple_sharding=True) 108 109 110def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims): 111 """Tags appropriate XLA sharding attribute to the dequeued tensors. 112 113 Args: 114 dequeues: A list of dequeued tensors on TPU. 115 dims: A list of integer describes how the tensor is partitioned. 116 117 Returns: 118 The same dequeues with appropriate xla_sharding attribute. 119 """ 120 nest.assert_shallow_structure(dequeues, dims) 121 return nest.map_structure_up_to( 122 dequeues, _tag_sharding_attribute_for_dequeued_tensor, dequeues, dims) 123 124 125class InfeedQueue(object): 126 """A helper object to build a device infeed queue. 127 128 The InfeedQueue builds the host-side and device-side Ops to enqueue and 129 dequeue elements, respectively, and ensures that their types and 130 shapes match. 131 """ 132 133 def __init__(self, 134 number_of_tuple_elements=None, 135 tuple_types=None, 136 tuple_shapes=None, 137 shard_dimensions=None, 138 number_of_partitions=None, 139 name=None): 140 """Creates a new InfeedQueue with the given configuration. 141 142 The configuration need not be fully specified at creation since it 143 can be modified subsequently by methods that set the values 144 explicitly or infer them from the shapes of inputs. 145 146 Args: 147 number_of_tuple_elements: the number of Tensors fed atomically through the 148 queue, must be present unless it can be inferred from other arguments. 149 tuple_types: if not None, a list of types of the elements of the queue. 150 tuple_shapes: if not None, a list of shapes of the elements of the queue. 151 shard_dimensions: if not None, a list of dimensions on which the 152 elements of the queue should be sharded during automatic 153 parallelization. 154 number_of_partitions: if > 1, the infeed dequeue shape will contain 155 the full shape that includes all partitions and add corresponding XLA 156 annotation on the infeed dequeue op. In this case, the infeed is still 157 data parallel that feeds per-core batch size to each core while the XLA 158 computation may be partitioned. As XLA requires infeed dequeue shape to 159 be per-replica shape, thus we need number_of_partitions here to 160 calculate the per-replica unpartitioned shape. 161 name: the name of the queue. 162 163 Raises: 164 ValueError: if number_of_tuple_elements <= 0; or 165 number_of_tuple_arguments, tuple_types, tuple_shapes, and 166 shard_dimensions are all None; or the length of tuple_types, 167 tuple_shapes, or shard_dimensions is not equal to 168 number_of_tuple_elements; or any element of shard_dimensions 169 can't be converted to a Dimension. 170 TypeError: if any element of tuple_types or tuple_shapes can't 171 be converted to a dtype or TensorShape, respectively. 172 """ 173 self._frozen = False 174 self._generated_enqueue_ops = False 175 self._generated_dequeue_op = False 176 self._name = "InfeedQueue" if name is None else name 177 if number_of_partitions is None: 178 self._number_of_partitions = 1 179 else: 180 self._number_of_partitions = number_of_partitions 181 if number_of_tuple_elements is None: 182 if tuple_types is not None: 183 number_of_tuple_elements = len(tuple_types) 184 elif tuple_shapes is not None: 185 number_of_tuple_elements = len(tuple_shapes) 186 elif shard_dimensions is not None: 187 number_of_tuple_elements = len(shard_dimensions) 188 else: 189 raise ValueError( 190 "number of tuple elements cannot be inferred from InfeedQueue " 191 "constructor") 192 if number_of_tuple_elements <= 0: 193 raise ValueError("number_of_tuple_elements %d must be > 0" % 194 number_of_tuple_elements) 195 # Make an empty sharding policy for each tuple element. 196 self._sharding_policies = [ 197 tpu_sharding.ShardingPolicy() 198 for _ in xrange(number_of_tuple_elements) 199 ] 200 if tuple_types is not None: 201 self.set_tuple_types(tuple_types) 202 else: 203 self._tuple_types = None 204 if tuple_shapes is not None: 205 self.set_tuple_shapes(tuple_shapes) 206 else: 207 self._tuple_shapes = None 208 if shard_dimensions is not None: 209 self.set_shard_dimensions(shard_dimensions) 210 self._validate() 211 212 def _validate(self): 213 """Checks that the configuration is self-consistent. 214 215 Raises: 216 ValueError: if the shapes and sharding policies don't match. 217 """ 218 if self.tuple_shapes is not None: 219 for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes): 220 # Raise an error if the policy is incompatible with the shape. 221 _ = policy.get_sharded_shape(shape) 222 223 @property 224 def number_of_tuple_elements(self): 225 """Returns the number of InfeedQueue tuple elements.""" 226 return len(self._sharding_policies) 227 228 @property 229 def tuple_types(self): 230 """Returns the types of the InfeedQueue tuple elements.""" 231 return self._tuple_types 232 233 def set_tuple_types(self, tuple_types): 234 """Sets the type of each element of the queue. 235 236 tuple_types must be a list of length 237 self.number_of_tuple_elements, and each element must be 238 convertible to a dtype. 239 240 Args: 241 tuple_types: the types of each queue element. 242 243 Raises: 244 ValueError: if tuple_types is not of length 245 self.number_of_tuple_elements. 246 TypeError: if an element of tuple_types cannot be converted to a 247 dtype. 248 """ 249 if len(tuple_types) != self.number_of_tuple_elements: 250 raise ValueError("tuple_types is %s, but must be a list of length %d" % 251 (str(tuple_types), self.number_of_tuple_elements)) 252 if self._frozen: 253 for (frozen, updated) in zip(self._tuple_types, tuple_types): 254 if frozen != updated: 255 raise ValueError( 256 "Trying to update InfeedQueue with frozen configuration with an " 257 "incompatible type. Frozen types are %s, updated types are %s" % ( 258 str(self._tuple_types), str(tuple_types))) 259 else: 260 try: 261 self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types] 262 except (TypeError) as e: 263 raise TypeError( 264 "tuple_types is %s, but must be a list of elements each " 265 "convertible to dtype: got error %s" % (str(tuple_types), str(e))) 266 267 @property 268 def tuple_shapes(self): 269 """Returns the shapes of the InfeedQueue tuple elements.""" 270 return self._tuple_shapes 271 272 def set_tuple_shapes(self, tuple_shapes): 273 """Sets the shape of each element of the queue. 274 275 tuple_shapes must be a list of length 276 self.number_of_tuple_elements, and each element must be 277 convertible to a TensorShape. 278 279 Args: 280 tuple_shapes: the shapes of each queue element. 281 282 Raises: 283 ValueError: if tuple_shapes is not of length 284 self.number_of_tuple_elements. 285 TypeError: if an element of tuple_shapes cannot be converted to 286 a TensorShape. 287 """ 288 if len(tuple_shapes) != self.number_of_tuple_elements: 289 raise ValueError("tuple_shapes is %s, but must be a list of length %d" % 290 (str(tuple_shapes), self.number_of_tuple_elements)) 291 try: 292 tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes] 293 except (ValueError, TypeError) as e: 294 raise TypeError( 295 "tuple_shapes is %s, but must be a list of elements each " 296 "convertible to TensorShape: got error %s" % (str(tuple_shapes), 297 str(e))) 298 if self._frozen: 299 for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes): 300 if frozen != updated: 301 raise ValueError( 302 "Trying to update InfeedQueue with frozen configuration with an " 303 "incompatible shape. Frozen shapes are %s, updated shapes are %s" 304 % (str(self._tuple_shapes), str(tuple_shapes))) 305 else: 306 self._tuple_shapes = tuple_shapes 307 self._validate() 308 309 @property 310 def sharding_policies(self): 311 """Returns the sharding policies of the InfeedQueue tuple elements.""" 312 return self._sharding_policies 313 314 @property 315 def shard_dimensions(self): 316 """Gets the shard dimension of each tuple element. 317 318 Returns: 319 A list of length number_of_tuple_elements, where each list entry 320 is the shard dimension of that tuple element or None if the 321 shard dimension has not been set. 322 """ 323 # The number of shards is always the same for all the policies. 324 return [policy.shard_dimension for policy in self._sharding_policies] 325 326 def set_shard_dimensions(self, shard_dimensions): 327 """Sets the shard_dimension of each element of the queue. 328 329 shard_dimensions must be a list of length 330 self.number_of_tuple_elements, and each element must be 331 convertible to a Dimension compatible with self.tuple_shapes. 332 333 Args: 334 shard_dimensions: the dimensions of each queue element. 335 336 Raises: 337 ValueError: if shard_dimensions is not of length 338 self.number_of_tuple_elements; or an element of 339 shard_dimensions cannot be converted to a Dimension; or an 340 element of shard_dimensions is a Dimension that is out of 341 range for the corresponding tuple element shape. 342 """ 343 if len(shard_dimensions) != self.number_of_tuple_elements: 344 raise ValueError("shard_dimensions is %s, but must be a list of length %d" 345 % (str(shard_dimensions), 346 self.number_of_tuple_elements)) 347 for (policy, dimension) in zip(self._sharding_policies, shard_dimensions): 348 policy.set_shard_dimension(dimension) 349 self._validate() 350 351 @property 352 def number_of_shards(self): 353 """Gets the number of shards to use for the InfeedQueue. 354 355 Returns: 356 Number of shards or None if the number of shards has not been set. 357 """ 358 # The number of shards is always the same for all the policies. 359 return self._sharding_policies[0].number_of_shards 360 361 def set_number_of_shards(self, number_of_shards): 362 """Sets the number of shards to use for the InfeedQueue. 363 364 Args: 365 number_of_shards: number of ways to shard the InfeedQueue. 366 367 Raises: 368 ValueError: if number_of_shards is not > 0; or the policies have 369 been frozen and number_of_shards was already set to something 370 else. 371 """ 372 for policy in self._sharding_policies: 373 policy.set_number_of_shards(number_of_shards) 374 policy.set_number_of_partitions(self._number_of_partitions) 375 self._validate() 376 377 def set_configuration_from_input_tensors(self, input_tensors): 378 """Sets the shapes and types of the queue tuple elements. 379 380 input_tensors is a list of Tensors whose types and shapes are used 381 to set the queue configuration. 382 383 Args: 384 input_tensors: list of Tensors of the same types and shapes as 385 the desired queue Tuple. 386 387 Raises: 388 ValueError: if input_tensors is not a list of length 389 self.number_of_tuple_elements 390 """ 391 if len(input_tensors) != self.number_of_tuple_elements: 392 raise ValueError("input_tensors is %s, but should be a list of %d Tensors" 393 % (str(input_tensors), self.number_of_tuple_elements)) 394 self.set_tuple_shapes([t.shape for t in input_tensors]) 395 self.set_tuple_types([t.dtype for t in input_tensors]) 396 397 def set_configuration_from_sharded_input_tensors(self, input_tensors): 398 """Sets the shapes and types of the queue tuple elements. 399 400 input_tensors is a list of lists of Tensors whose types and shapes are used 401 to set the queue configuration. The length of the outer list is the number 402 of shards required, and each inner list is the tuple of Tensors to use to 403 determine the types and shapes of the corresponding shard. This method 404 depends on the shard dimension, and calling it freezes the shard policy. 405 406 Args: 407 input_tensors: list of lists of Tensors. The outer list length corresponds 408 to the desired number of shards, and each inner list is the size 409 and shape of the desired configuration of the corresponding shard. 410 411 Raises: 412 ValueError: if any inner list is not a list of length 413 self.number_of_tuple_elements; or the inner lists do not combine to 414 form a consistent unsharded shape. 415 TypeError: if the types of the Tensors in the inner lists do not match. 416 """ 417 if not self._frozen: 418 # Unset the tuple shapes in case the configuration becomes 419 # transiently inconsistent. 420 self._tuple_shapes = None 421 number_of_shards = len(input_tensors) 422 self.set_number_of_shards(number_of_shards) 423 for t in input_tensors: 424 if len(t) != self.number_of_tuple_elements: 425 raise ValueError( 426 "input_tensors is %s but must be a list of lists, where each inner" 427 " list has length number_of_tuple_elements=%d" % ( 428 str(input_tensors), self.number_of_tuple_elements)) 429 # Transpose the inputs to make a list of shard shapes for each tuple 430 # element. 431 sharded_shapes = [[t[i].shape for t in input_tensors] 432 for i in xrange(self.number_of_tuple_elements)] 433 # For each tuple, get the unsharded shape using that tuple's policy. 434 unsharded_shapes = [ 435 policy.get_unsharded_shape(s) 436 for (policy, s) in zip(self._sharding_policies, sharded_shapes) 437 ] 438 self.set_tuple_shapes(unsharded_shapes) 439 for i in xrange(1, self.number_of_shards): 440 for (t1, t2) in zip(input_tensors[0], input_tensors[i]): 441 if t1.dtype != t2.dtype: 442 raise TypeError( 443 "types of the tuple elements of input_tensors %s are not " 444 "consistent" % str(input_tensors)) 445 self.set_tuple_types([t.dtype for t in input_tensors[0]]) 446 447 def freeze(self): 448 """Freezes the InfeedQueue so it can no longer be modified. 449 450 The configuration is implicitly frozen before any host-side or 451 device-side Ops are generated. The configuration cannot be frozen 452 until the types and shapes of the tuple elements have been set. 453 454 Raises: 455 ValueError: if the types or shapes of the tuple elements have not been 456 set. 457 """ 458 self._frozen = True 459 if self._tuple_types is None: 460 raise ValueError( 461 "Can't freeze an InfeedQueue without setting all tuple types.") 462 if self._tuple_shapes is None: 463 raise ValueError( 464 "Can't freeze an InfeedQueue without setting all tuple shapes.") 465 for shape in self._tuple_shapes: 466 if shape.dims is None: 467 raise ValueError( 468 "Can't freeze an InfeedQueue without setting all tuple shapes.") 469 for policy in self._sharding_policies: 470 policy.freeze() 471 self._validate() 472 473 def generate_dequeue_op(self, tpu_device=0): 474 """Generates the device-side Op to dequeue a tuple from the queue. 475 476 Implicitly freezes the queue configuration if it is not already 477 frozen, which will raise errors if the shapes and types have not 478 been fully specified. 479 480 Args: 481 tpu_device: The TPU device ordinal where the infeed instruction should be 482 placed. If None, no explicit placement will be performed, and it is up 483 to the user to call this API from within a proper TPU device scope. 484 The XLA code will fail if the TPU dequeue instruction is not bound to 485 any device. 486 487 Returns: 488 A list of Outputs corresponding to a shard of infeed dequeued 489 into XLA, suitable for use within a replicated block. 490 491 Raises: 492 ValueError: if the types or shapes of the tuple elements have not been 493 set; or if a dequeue op has already been generated. 494 """ 495 self.freeze() 496 if self._generated_dequeue_op: 497 raise ValueError("Can't generate two dequeue Ops from the same queue") 498 self._generated_dequeue_op = True 499 full_name = "%s/dequeue" % self._name 500 sharded_shapes = [ 501 policy.get_unpartitioned_shape(policy.get_sharded_shape(shape)) 502 for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) 503 ] 504 if tpu_device is not None: 505 with ops.device(tpu_name_util.core(tpu_device)): 506 dequeue_op = tpu_ops.infeed_dequeue_tuple( 507 dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) 508 else: 509 dequeue_op = tpu_ops.infeed_dequeue_tuple( 510 dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) 511 if self._number_of_partitions <= 1: 512 return dequeue_op 513 partitions = [ 514 policy.get_unpartitioned_shape([1] * shape.ndims).as_list() 515 for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) 516 ] 517 return tag_sharding_attribute_for_dequeued_tensors(dequeue_op, partitions) 518 519 def _generate_enqueue_op(self, 520 inputs, 521 name_prefix, 522 index, 523 device=None, 524 tpu_ordinal=-1): 525 """Generate a host-side Op to enqueue a tuple to the queue. 526 527 If device is None the inputs are all required to have the same 528 device specification, and the enqueue Op is colocated with 529 inputs[0]. Otherwise the enqueue Op is placed on 'device'. 530 531 Args: 532 inputs: a list of Tensors with the types and shapes of the tuple elements. 533 name_prefix: the base name for the Op. 534 index: the shard index, used to uniquify the Op name. 535 device: device to place the Op on, or None if it should be 536 colocated with the inputs. 537 tpu_ordinal: ordinal of the TPU device on the host to use for 538 infeed if device is a CPU device. Should be set to -1 if device 539 is a TPU device. 540 541 Returns: 542 An Op corresponding to a shard of infeed enqueued at the host, 543 suitable for use within a replicated block. 544 545 Raises: 546 ValueError: if device is None and inputs do not all have the 547 same device specification. 548 """ 549 full_name = "%s/%d" % (name_prefix, index) 550 shapes = [t.shape for t in inputs] 551 if device is None: 552 devices = [t.device for t in inputs] 553 for i in xrange(1, self.number_of_tuple_elements): 554 if devices[0] != devices[i]: 555 raise ValueError( 556 "input devices for shard %d are %s, but should all be the same" % 557 (index, str(devices))) 558 with ops.colocate_with(inputs[0]): 559 return tpu_ops.infeed_enqueue_tuple( 560 inputs=inputs, 561 shapes=shapes, 562 name=full_name, 563 device_ordinal=tpu_ordinal) 564 else: 565 with ops.device(device): 566 return tpu_ops.infeed_enqueue_tuple( 567 inputs=inputs, 568 shapes=shapes, 569 name=full_name, 570 device_ordinal=tpu_ordinal) 571 572 def generate_enqueue_ops(self, 573 sharded_inputs, 574 tpu_ordinal_function=None, 575 placement_function=None): 576 """Generates the host-side Ops to enqueue the shards of a tuple. 577 578 sharded_inputs is a list, one for each shard, of lists of 579 Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed 580 shard i of the queue. Returns the host-side Ops that must be run to 581 enqueue the sharded tuple. The Op for shard i is colocated with the inputs 582 for shard i. 583 584 Implicitly freezes the queue configuration if it is not already 585 frozen. If the configuration has already been frozen, and is not 586 compatible with the types and shapes of sharded_inputs, an error 587 will be raised. 588 589 Args: 590 sharded_inputs: a list of lists of Tensors. The length of the outer list 591 determines the number of shards. Each inner list indicates the types 592 and shapes of the tuples in the corresponding shard. 593 tpu_ordinal_function: if not None, a function that takes the 594 shard index as input and returns the ordinal of the TPU device 595 the shard's infeed should be placed on. tpu_ordinal_function must be 596 set if the inputs are placed on CPU devices. 597 placement_function: if not None, a function that takes the shard index as 598 input and returns the host device where the enqueue op should be placed 599 on. 600 601 Returns: 602 A list of host-side Ops, one for each shard, that when executed together 603 will enqueue a full-size element of infeed. 604 605 Raises: 606 ValueError: if the queue configuration has previously been frozen and the 607 shapes of the elements of sharded_inputs are not compatible with the 608 frozen configuration; or if the shapes of the elements of sharded_inputs 609 don't form a consistent unsharded tuple; or if the elements of a tuple 610 have different device constraints. 611 TypeError: if the queue configuration has previously been frozen and the 612 types of the elements of sharded_inputs are not compatible with the 613 frozen configuration; or if the types of the elements of sharded_inputs 614 don't form a consistent unsharded tuple. 615 """ 616 self.set_configuration_from_sharded_input_tensors(sharded_inputs) 617 self.freeze() 618 if self._generated_enqueue_ops: 619 raise ValueError("Can't generate two enqueue Ops from the same queue") 620 self._generated_enqueue_ops = True 621 if tpu_ordinal_function is None: 622 tpu_ordinal_function = lambda index: -1 623 name_prefix = "%s/enqueue" % self._name 624 return [ 625 self._generate_enqueue_op( 626 shard, 627 name_prefix, 628 index, 629 tpu_ordinal=tpu_ordinal_function(index), 630 device=placement_function(index) if placement_function else None) 631 for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) 632 ] 633 634 # TODO(misard) Generalize this to the case of systems that don't 635 # have 8 devices per host, and figure out what to do with 636 # model-parallelism. 637 def _default_placement_function(self, index): 638 return "/task:%d/device:CPU:0" % (index / 8) 639 640 def _default_ordinal_function(self, index): 641 return index % 8 642 643 # TODO(b/36470756) remove this from tutorials once we have a better story 644 # for automatic placement of input pipelines. 645 def split_inputs_and_generate_enqueue_ops(self, 646 inputs, 647 device_assignment=None, 648 placement_function=None, 649 tpu_ordinal_function=None): 650 """POORLY-PERFORMING ON MULTI-HOST SYSTEMS. 651 652 Generates the host-side Ops to enqueue a tuple. 653 654 This method performs poorly because it takes an entire input on a single 655 host, splits it, and distributes it to all of the cores. It is present only 656 to simplify tutorial examples. 657 658 inputs is a list of Tensors to use to feed the queue. Each input is split 659 into self.number_of_shards shards. Returns an Op for each shard to enqueue 660 the shard. The Op for shard i is placed on device placement_function(i). 661 662 Implicitly freezes the queue configuration if it is not already 663 frozen. If the configuration has already been frozen, and is not 664 compatible with the types and shapes of inputs, an error 665 will be raised. 666 667 Args: 668 inputs: a list of Tensors which indicates the types and shapes of the 669 queue tuple. 670 device_assignment: if not `None`, a TPU `DeviceAssignment`. If 671 device_assignment is not `None`, but `placement_function` and 672 `ordinal_function` are None, then `device_assignment` will be used to 673 place infeeds on the first k TPU shards, where k is the number of shards 674 in the queue. If all three are `None`, then default placement and 675 ordinal functions are used. 676 placement_function: if not None, a function that takes the shard 677 index as input and returns a device string indicating which 678 device the shard's infeed should be placed on. If placement_function 679 and tpu_ordinal_function are None, inputs are sharded round-robin 680 across the devices in the system. 681 tpu_ordinal_function: if not None, a function that takes the 682 shard index as input and returns the ordinal of the TPU device 683 the shard's infeed should be placed on. If placement_function 684 and tpu_ordinal_function are None, inputs are sharded round-robin 685 across the devices in the system. 686 687 Returns: 688 A list of host-side Ops, one for each shard, that when executed together 689 will enqueue a full-size element of infeed. 690 691 Raises: 692 ValueError: if the queue configuration has previously been frozen and the 693 shapes of the elements of inputs are not compatible with the frozen 694 configuration. 695 TypeError: if the queue configuration has previously been frozen and the 696 types of the elements of inputs are not compatible with the frozen 697 configuration. 698 """ 699 if device_assignment is None: 700 if placement_function is None: 701 placement_function = self._default_placement_function 702 if tpu_ordinal_function is None: 703 tpu_ordinal_function = self._default_ordinal_function 704 else: 705 706 def _placement_function_from_map(index): 707 return device_assignment.host_device(replica=index) 708 709 def _ordinal_function_from_map(index): 710 return device_assignment.tpu_ordinal(replica=index) 711 712 if placement_function is None: 713 placement_function = _placement_function_from_map 714 if tpu_ordinal_function is None: 715 tpu_ordinal_function = _ordinal_function_from_map 716 self.set_configuration_from_input_tensors(inputs) 717 self.freeze() 718 if self._generated_enqueue_ops: 719 raise ValueError("Can't generate two enqueue Ops from the same queue") 720 self._generated_enqueue_ops = True 721 split_name_prefix = "%s/split" % self._name 722 if self.number_of_shards == 1: 723 transposed_sharded_inputs = [[inp] for inp in inputs] 724 else: 725 726 def split_fn(inp, num_shards, axis, name): 727 with ops.colocate_with(inp): 728 return array_ops.split(inp, num_shards, axis=axis, name=name) 729 730 transposed_sharded_inputs = [ 731 split_fn( 732 inp, 733 self.number_of_shards, 734 axis=policy.shard_dimension, 735 name="%s/%d" % (split_name_prefix, index)) 736 for (inp, policy, index) in zip(inputs, self._sharding_policies, 737 xrange(self.number_of_tuple_elements)) 738 ] 739 sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs] 740 for i in xrange(self.number_of_shards)] 741 name_prefix = "%s/enqueue" % self._name 742 return [ 743 self._generate_enqueue_op( 744 shard, 745 name_prefix, 746 index, 747 device=placement_function(index), 748 tpu_ordinal=tpu_ordinal_function(index)) 749 for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) 750 ] 751 752 753class _PartitionedInfeedQueue(InfeedQueue): 754 """A helper object to build a device infeed queue with input partition. 755 756 Args: 757 number_of_tuple_elements: the number of Tensors fed atomically through the 758 queue, must be present unless it can be inferred from other arguments. 759 device_assignment: A TPU `DeviceAssignment` which is used to place all the 760 partitions to different TPU infeed queues. 761 host_id: The id of the host machine. 762 input_partition_dims: A nested list/tuple of integers. Each inner 763 list/tuple describes how to partition the corresponding input tensor. 764 tuple_types: If not None, a list of types of the elements of the queue. 765 tuple_shapes: If not None, a list of shapes of the elements of the queue. 766 name: The name of the queue. 767 """ 768 769 def __init__(self, 770 number_of_tuple_elements, 771 device_assignment, 772 host_id, 773 input_partition_dims=None, 774 tuple_types=None, 775 tuple_shapes=None, 776 name=None): 777 super(_PartitionedInfeedQueue, self).__init__( 778 number_of_tuple_elements=number_of_tuple_elements, 779 tuple_types=tuple_types, 780 tuple_shapes=None, 781 shard_dimensions=None, 782 name="PartitionedInfeedQueue" if name is None else name) 783 self._input_partition_dims = input_partition_dims 784 self._host_id = host_id 785 self._device_assignment = device_assignment 786 787 def generate_dequeue_op(self, tpu_device=0): 788 """Generate TPU dequeue ops. 789 790 Args: 791 tpu_device: The TPU device ordinal where the infeed instruction should be 792 placed. 793 794 Returns: 795 A list of Outputs corresponding to a partition of infeed dequeued 796 into XLA, suitable for use within a replicated block. 797 798 Raises: 799 ValueError: if the types or shapes of the tuple elements have not been 800 set; or if a dequeue op has already been generated. 801 """ 802 self.freeze() 803 if self._generated_dequeue_op: 804 raise ValueError("Can't generate two dequeue Ops from the same queue") 805 self._generated_dequeue_op = True 806 full_name = "%s/dequeue" % self._name 807 sharded_shapes = [ 808 policy.get_sharded_shape(shape) 809 for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) 810 ] 811 with ops.device(tpu_name_util.core(tpu_device)): 812 values = tpu_ops.infeed_dequeue_tuple( 813 dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) 814 return tag_sharding_attribute_for_dequeued_tensors( 815 values, self._input_partition_dims) 816 817 def generate_enqueue_ops(self, sharded_inputs): 818 """Generates the host-side Ops to enqueue the partitioned inputs. 819 820 sharded_inputs is a list, one for each replica, of lists of 821 Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed 822 replica i. 823 sharded_inputs[i][j] is partitioned by self._input_partition_dims[j]. 824 825 For example, if sharded_inputs[i][j] is a 2-D Tensor: 826 [[A, B, C, D], 827 [E ,F, G, H]] 828 self._input_partition_dims[j] is [2, 4]. 829 830 sharded_inputs[i][j] will be partitioned and flattened into: 831 [A, B, C, D, E, F, G, H] and fed into the logical core ids: 832 [0, 1, 2, 3, 4, 5, 6, 7] respectively. 833 834 Args: 835 sharded_inputs: a list of lists of Tensors. The length of the 836 outer list determines the number of shards. Each inner list indicates 837 the types and shapes of the tuples in the corresponding shard. 838 839 Returns: 840 A list of host-side Ops, one for each shard, that when executed together 841 will enqueue a full-size element of infeed. 842 843 Raises: 844 ValueError: if the queue configuration has previously been frozen and the 845 shapes of the elements of sharded_inputs are not compatible with the 846 frozen configuration; or if the shapes of the elements of sharded_inputs 847 don't form a consistent unsharded tuple; or if the elements of a tuple 848 have different device constraints; or if the partition dims are invalid. 849 TypeError: if the queue configuration has previously been frozen and the 850 types of the elements of sharded_inputs are not compatible with the 851 frozen configuration; or if the types of the elements of sharded_inputs 852 don't form a consistent unsharded tuple. 853 """ 854 self.set_configuration_from_sharded_input_tensors(sharded_inputs) 855 number_of_replicas = len(sharded_inputs) 856 number_of_tuple_elements = len(sharded_inputs[0]) 857 858 assert len(self._input_partition_dims) == number_of_tuple_elements 859 enqueue_ops = [] 860 861 for replica_index in range(number_of_replicas): 862 flattened_inputs = sharded_inputs[replica_index] 863 inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs, 864 self._input_partition_dims) 865 inputs_parted_iters = [ 866 iter(self._check_dims_and_partition_or_replicate_on_host(x, dims)) 867 for x, dims in zip(sharded_inputs[replica_index], 868 inputs_part_dims_flat) 869 ] 870 871 # Find the replica_id of the host's logical core 0. 872 # The self._host_id is guaranteed to contain the logical core 0, 873 # even when num_cores_per_replica > num_cores_per_host -- the function 874 # caller makes sure that this host_id will must be receiving data (calls 875 # input_fn). 876 replica_id = self._device_assignment.lookup_replicas( 877 task_id=self._host_id, logical_core=0)[replica_index] 878 for logical_core in xrange(self._device_assignment.num_cores_per_replica): 879 # Places different partitions to different logic cores. 880 # Since there can be multiple hosts per replica, we need to find 881 # the actual host (device) of this logical core. 882 device = self._device_assignment.host_device( 883 replica=replica_id, logical_core=logical_core) 884 885 with ops.device(device): 886 ordinal = self._device_assignment.tpu_ordinal( 887 replica=replica_id, logical_core=logical_core) 888 infeed_inputs = [] 889 for it in inputs_parted_iters: 890 input_for_device = next(it, None) 891 if input_for_device is not None: 892 infeed_inputs.append(input_for_device) 893 894 if infeed_inputs: 895 enqueue_ops.append( 896 tpu_ops.infeed_enqueue_tuple( 897 inputs=infeed_inputs, 898 shapes=[x.shape for x in infeed_inputs], 899 name="enqueue/replica_{0}/input_{1}".format( 900 replica_index, logical_core), 901 device_ordinal=ordinal)) 902 return enqueue_ops 903 904 def _check_input_partition_dims(self, tensor, dims): 905 """Checks that input partition dims are valid for the `Tensor`. 906 907 Args: 908 tensor: Input tensor for partitioning. 909 dims: A list of integer describes how to partition the input tensor. 910 911 Raises: 912 ValueError: If the tensor can't be partitioned by dims or the 913 num_cores_per_replica doesn't match the number of 914 partitions(dims.prod()). 915 """ 916 # No partitioning specified, so don't perform further checks. 917 if dims is None: 918 return 919 920 dims = np.array(dims) 921 922 if (dims < 1).any(): 923 raise ValueError("All input partition dims must be >= 1.") 924 925 # No partitioning, so don't perform further checks. 926 if dims.prod() == 1: 927 return 928 929 if dims.prod() != self._device_assignment.num_cores_per_replica: 930 raise ValueError( 931 "The product of each input partition dim should equal to " 932 "num_cores_per_replica. (dim = {}, num_cores_per_replica " 933 "= {})".format(dims, self._device_assignment.num_cores_per_replica)) 934 if dims.shape[0] != tensor.shape.ndims: 935 raise ValueError( 936 "Input partition dims must have the same number of dimensions " 937 "as the `Tensor` to be partitioned. (tensor shape = {}, input " 938 "partition dims = {}).".format(tensor.shape.as_list(), dims)) 939 940 tensor.shape.assert_is_fully_defined() 941 942 def _check_dims_and_partition_or_replicate_on_host(self, tensor, dims): 943 """Checks dims and partitions or replicates the input tensor. 944 945 The ops inside this function are placed on the host side. 946 947 Args: 948 tensor: The input tensor which will be partitioned or replicated. 949 dims: A list of integer describes how to partition the input tensor. 950 951 Returns: 952 An iterator of `Tensor`s or a list of partitioned tensors. 953 """ 954 self._check_input_partition_dims(tensor, dims) 955 return partition_or_replicate_on_host(tensor, dims) 956