1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =================================================================== 15 16"""Helper library for handling infeed between hosts and TPUs. 17""" 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import itertools 24 25import numpy as np 26from six.moves import xrange # pylint: disable=redefined-builtin 27 28from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.ops import array_ops 33from tensorflow.python.tpu import tpu 34from tensorflow.python.tpu import tpu_sharding 35from tensorflow.python.tpu.ops import tpu_ops 36 37from tensorflow.python.util import nest 38 39 40def partition_or_replicate_on_host(tensor, dims): 41 """Partitions or replicates the input tensor. 42 43 The ops inside this function are placed on the host side. 44 45 Args: 46 tensor: The input tensor which will be partioned or replicated. 47 dims: A list of integer describes how to partition the input tensor. 48 49 Returns: 50 An iterator of `Tensor`s or a list of partioned tensors. 51 """ 52 if dims is None: 53 return itertools.repeat(tensor) 54 dims = np.array(dims) 55 output = [tensor] 56 shape_list = np.array(tensor.shape.as_list()) 57 quotients, remainders = np.divmod(shape_list, dims) 58 for axis, (quotient, remainder, dim, original_size) in enumerate( 59 zip(quotients, remainders, dims, shape_list)): 60 if dim <= 1: 61 continue 62 if remainder > 0: 63 # For each dimension, when it cannot be evenly partitioned, XLA assumes 64 # tensors are partitioned in a greedy manner by using 65 # ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims 66 # are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] => 67 # [[(3, 4), (3, 4), (2, 4), (2, 2)], 68 # [(2, 4), (2, 4), (2, 4), (2, 2)]] 69 ceil_ratio = quotient + 1 70 num_full_slots, left_over = np.divmod(original_size, ceil_ratio) 71 num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over] 72 if len(num_or_size_splits) < dim: 73 num_or_size_splits += [0] * (dim - len(num_or_size_splits)) 74 new_output = [] 75 for x in output: 76 new_output.append( 77 array_ops.split( 78 x, num_or_size_splits=num_or_size_splits, axis=axis)) 79 output = new_output 80 else: 81 output = [array_ops.split(x, dim, axis=axis) for x in output] 82 output = nest.flatten(output) 83 return output 84 85 86def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims): 87 """Tags appropriate XLA sharding attribute to the dequeued tensor. 88 89 Args: 90 tensor: The dequeued tensor on TPU. 91 dims: A list of integer describes how the tensor is partitioned. 92 93 Returns: 94 The same tensor with the xla_sharding attribute. 95 """ 96 if dims is None: 97 return xla_sharding.replicate(tensor) 98 elif np.prod(dims) == 1: 99 return xla_sharding.assign_device(tensor, 0) 100 else: 101 tile_assignment = np.arange(np.prod(dims)).reshape(dims) 102 return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment) 103 104 105def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims): 106 """Tags appropriate XLA sharding attribute to the dequeued tensors. 107 108 Args: 109 dequeues: A list of dequeued tensors on TPU. 110 dims: A list of integer describes how the tensor is partitioned. 111 112 Returns: 113 The same dequeues with appropriate xla_sharding attribute. 114 """ 115 nest.assert_shallow_structure(dequeues, dims) 116 return nest.map_structure_up_to( 117 dequeues, _tag_sharding_attribute_for_dequeued_tensor, dequeues, dims) 118 119 120class InfeedQueue(object): 121 """A helper object to build a device infeed queue. 122 123 The InfeedQueue builds the host-side and device-side Ops to enqueue and 124 dequeue elements, respectively, and ensures that their types and 125 shapes match. 126 """ 127 128 def __init__(self, 129 number_of_tuple_elements=None, 130 tuple_types=None, 131 tuple_shapes=None, 132 shard_dimensions=None, 133 name=None): 134 """Creates a new InfeedQueue with the given configuration. 135 136 The configuration need not be fully specified at creation since it 137 can be modified subsequently by methods that set the values 138 explicitly or infer them from the shapes of inputs. 139 140 Args: 141 number_of_tuple_elements: the number of Tensors fed atomically through the 142 queue, must be present unless it can be inferred from other arguments. 143 tuple_types: if not None, a list of types of the elements of the queue. 144 tuple_shapes: if not None, a list of shapes of the elements of the queue. 145 shard_dimensions: if not None, a list of dimensions on which the 146 elements of the queue should be sharded during automatic 147 parallelization. 148 name: the name of the queue. 149 150 Raises: 151 ValueError: if number_of_tuple_elements <= 0; or 152 number_of_tuple_arguments, tuple_types, tuple_shapes, and 153 shard_dimensions are all None; or the length of tuple_types, 154 tuple_shapes, or shard_dimensions is not equal to 155 number_of_tuple_elements; or any element of shard_dimensions 156 can't be converted to a Dimension. 157 TypeError: if any element of tuple_types or tuple_shapes can't 158 be converted to a dtype or TensorShape, respectively. 159 """ 160 self._frozen = False 161 self._generated_enqueue_ops = False 162 self._generated_dequeue_op = False 163 self._name = "InfeedQueue" if name is None else name 164 if number_of_tuple_elements is None: 165 if tuple_types is not None: 166 number_of_tuple_elements = len(tuple_types) 167 elif tuple_shapes is not None: 168 number_of_tuple_elements = len(tuple_shapes) 169 elif shard_dimensions is not None: 170 number_of_tuple_elements = len(shard_dimensions) 171 else: 172 raise ValueError( 173 "number of tuple elements cannot be inferred from InfeedQueue " 174 "constructor") 175 if number_of_tuple_elements <= 0: 176 raise ValueError("number_of_tuple_elements %d must be > 0" % 177 number_of_tuple_elements) 178 # Make an empty sharding policy for each tuple element. 179 self._sharding_policies = [ 180 tpu_sharding.ShardingPolicy() 181 for _ in xrange(number_of_tuple_elements) 182 ] 183 if tuple_types is not None: 184 self.set_tuple_types(tuple_types) 185 else: 186 self._tuple_types = None 187 if tuple_shapes is not None: 188 self.set_tuple_shapes(tuple_shapes) 189 else: 190 self._tuple_shapes = None 191 if shard_dimensions is not None: 192 self.set_shard_dimensions(shard_dimensions) 193 self._validate() 194 195 def _validate(self): 196 """Checks that the configuration is self-consistent. 197 198 Raises: 199 ValueError: if the shapes and sharding policies don't match. 200 """ 201 if self.tuple_shapes is not None: 202 for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes): 203 # Raise an error if the policy is incompatible with the shape. 204 _ = policy.get_sharded_shape(shape) 205 206 @property 207 def number_of_tuple_elements(self): 208 """Returns the number of InfeedQueue tuple elements.""" 209 return len(self._sharding_policies) 210 211 @property 212 def tuple_types(self): 213 """Returns the types of the InfeedQueue tuple elements.""" 214 return self._tuple_types 215 216 def set_tuple_types(self, tuple_types): 217 """Sets the type of each element of the queue. 218 219 tuple_types must be a list of length 220 self.number_of_tuple_elements, and each element must be 221 convertible to a dtype. 222 223 Args: 224 tuple_types: the types of each queue element. 225 226 Raises: 227 ValueError: if tuple_types is not of length 228 self.number_of_tuple_elements. 229 TypeError: if an element of tuple_types cannot be converted to a 230 dtype. 231 """ 232 if len(tuple_types) != self.number_of_tuple_elements: 233 raise ValueError("tuple_types is %s, but must be a list of length %d" % 234 (str(tuple_types), self.number_of_tuple_elements)) 235 if self._frozen: 236 for (frozen, updated) in zip(self._tuple_types, tuple_types): 237 if frozen != updated: 238 raise ValueError( 239 "Trying to update InfeedQueue with frozen configuration with an " 240 "incompatible type. Frozen types are %s, updated types are %s" % ( 241 str(self._tuple_types), str(tuple_types))) 242 else: 243 try: 244 self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types] 245 except (TypeError) as e: 246 raise TypeError( 247 "tuple_types is %s, but must be a list of elements each " 248 "convertible to dtype: got error %s" % (str(tuple_types), str(e))) 249 250 @property 251 def tuple_shapes(self): 252 """Returns the shapes of the InfeedQueue tuple elements.""" 253 return self._tuple_shapes 254 255 def set_tuple_shapes(self, tuple_shapes): 256 """Sets the shape of each element of the queue. 257 258 tuple_shapes must be a list of length 259 self.number_of_tuple_elements, and each element must be 260 convertible to a TensorShape. 261 262 Args: 263 tuple_shapes: the shapes of each queue element. 264 265 Raises: 266 ValueError: if tuple_shapes is not of length 267 self.number_of_tuple_elements. 268 TypeError: if an element of tuple_shapes cannot be converted to 269 a TensorShape. 270 """ 271 if len(tuple_shapes) != self.number_of_tuple_elements: 272 raise ValueError("tuple_shapes is %s, but must be a list of length %d" % 273 (str(tuple_shapes), self.number_of_tuple_elements)) 274 try: 275 tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes] 276 except (ValueError, TypeError) as e: 277 raise TypeError( 278 "tuple_shapes is %s, but must be a list of elements each " 279 "convertible to TensorShape: got error %s" % (str(tuple_shapes), 280 str(e))) 281 if self._frozen: 282 for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes): 283 if frozen != updated: 284 raise ValueError( 285 "Trying to update InfeedQueue with frozen configuration with an " 286 "incompatible shape. Frozen shapes are %s, updated shapes are %s" 287 % (str(self._tuple_shapes), str(tuple_shapes))) 288 else: 289 self._tuple_shapes = tuple_shapes 290 self._validate() 291 292 @property 293 def sharding_policies(self): 294 """Returns the sharding policies of the InfeedQueue tuple elements.""" 295 return self._sharding_policies 296 297 @property 298 def shard_dimensions(self): 299 """Gets the shard dimension of each tuple element. 300 301 Returns: 302 A list of length number_of_tuple_elements, where each list entry 303 is the shard dimension of that tuple element or None if the 304 shard dimension has not been set. 305 """ 306 # The number of shards is always the same for all the policies. 307 return [policy.shard_dimension for policy in self._sharding_policies] 308 309 def set_shard_dimensions(self, shard_dimensions): 310 """Sets the shard_dimension of each element of the queue. 311 312 shard_dimensions must be a list of length 313 self.number_of_tuple_elements, and each element must be 314 convertible to a Dimension compatible with self.tuple_shapes. 315 316 Args: 317 shard_dimensions: the dimensions of each queue element. 318 319 Raises: 320 ValueError: if shard_dimensions is not of length 321 self.number_of_tuple_elements; or an element of 322 shard_dimensions cannot be converted to a Dimension; or an 323 element of shard_dimensions is a Dimension that is out of 324 range for the corresponding tuple element shape. 325 """ 326 if len(shard_dimensions) != self.number_of_tuple_elements: 327 raise ValueError("shard_dimensions is %s, but must be a list of length %d" 328 % (str(shard_dimensions), 329 self.number_of_tuple_elements)) 330 for (policy, dimension) in zip(self._sharding_policies, shard_dimensions): 331 policy.set_shard_dimension(dimension) 332 self._validate() 333 334 @property 335 def number_of_shards(self): 336 """Gets the number of shards to use for the InfeedQueue. 337 338 Returns: 339 Number of shards or None if the number of shards has not been set. 340 """ 341 # The number of shards is always the same for all the policies. 342 return self._sharding_policies[0].number_of_shards 343 344 def set_number_of_shards(self, number_of_shards): 345 """Sets the number of shards to use for the InfeedQueue. 346 347 Args: 348 number_of_shards: number of ways to shard the InfeedQueue. 349 350 Raises: 351 ValueError: if number_of_shards is not > 0; or the policies have 352 been frozen and number_of_shards was already set to something 353 else. 354 """ 355 for policy in self._sharding_policies: 356 policy.set_number_of_shards(number_of_shards) 357 self._validate() 358 359 def set_configuration_from_input_tensors(self, input_tensors): 360 """Sets the shapes and types of the queue tuple elements. 361 362 input_tensors is a list of Tensors whose types and shapes are used 363 to set the queue configuration. 364 365 Args: 366 input_tensors: list of Tensors of the same types and shapes as 367 the desired queue Tuple. 368 369 Raises: 370 ValueError: if input_tensors is not a list of length 371 self.number_of_tuple_elements 372 """ 373 if len(input_tensors) != self.number_of_tuple_elements: 374 raise ValueError("input_tensors is %s, but should be a list of %d Tensors" 375 % (str(input_tensors), self.number_of_tuple_elements)) 376 self.set_tuple_shapes([t.shape for t in input_tensors]) 377 self.set_tuple_types([t.dtype for t in input_tensors]) 378 379 def set_configuration_from_sharded_input_tensors(self, input_tensors): 380 """Sets the shapes and types of the queue tuple elements. 381 382 input_tensors is a list of lists of Tensors whose types and shapes are used 383 to set the queue configuration. The length of the outer list is the number 384 of shards required, and each inner list is the tuple of Tensors to use to 385 determine the types and shapes of the corresponding shard. This method 386 depends on the shard dimension, and calling it freezes the shard policy. 387 388 Args: 389 input_tensors: list of lists of Tensors. The outer list length corresponds 390 to the desired number of shards, and each inner list is the size 391 and shape of the desired configuration of the corresponding shard. 392 393 Raises: 394 ValueError: if any inner list is not a list of length 395 self.number_of_tuple_elements; or the inner lists do not combine to 396 form a consistent unsharded shape. 397 TypeError: if the types of the Tensors in the inner lists do not match. 398 """ 399 if not self._frozen: 400 # Unset the tuple shapes in case the configuration becomes 401 # transiently inconsistent. 402 self._tuple_shapes = None 403 number_of_shards = len(input_tensors) 404 self.set_number_of_shards(number_of_shards) 405 for t in input_tensors: 406 if len(t) != self.number_of_tuple_elements: 407 raise ValueError( 408 "input_tensors is %s but must be a list of lists, where each inner" 409 " list has length number_of_tuple_elements=%d" % ( 410 str(input_tensors), self.number_of_tuple_elements)) 411 # Transpose the inputs to make a list of shard shapes for each tuple 412 # element. 413 sharded_shapes = [[t[i].shape for t in input_tensors] 414 for i in xrange(self.number_of_tuple_elements)] 415 # For each tuple, get the unsharded shape using that tuple's policy. 416 unsharded_shapes = [ 417 policy.get_unsharded_shape(s) 418 for (policy, s) in zip(self._sharding_policies, sharded_shapes) 419 ] 420 self.set_tuple_shapes(unsharded_shapes) 421 for i in xrange(1, self.number_of_shards): 422 for (t1, t2) in zip(input_tensors[0], input_tensors[i]): 423 if t1.dtype != t2.dtype: 424 raise TypeError( 425 "types of the tuple elements of input_tensors %s are not " 426 "consistent" % str(input_tensors)) 427 self.set_tuple_types([t.dtype for t in input_tensors[0]]) 428 429 def freeze(self): 430 """Freezes the InfeedQueue so it can no longer be modified. 431 432 The configuration is implicitly frozen before any host-side or 433 device-side Ops are generated. The configuration cannot be frozen 434 until the types and shapes of the tuple elements have been set. 435 436 Raises: 437 ValueError: if the types or shapes of the tuple elements have not been 438 set. 439 """ 440 self._frozen = True 441 if self._tuple_types is None: 442 raise ValueError( 443 "Can't freeze an InfeedQueue without setting all tuple types.") 444 if self._tuple_shapes is None: 445 raise ValueError( 446 "Can't freeze an InfeedQueue without setting all tuple shapes.") 447 for shape in self._tuple_shapes: 448 if shape.dims is None: 449 raise ValueError( 450 "Can't freeze an InfeedQueue without setting all tuple shapes.") 451 for policy in self._sharding_policies: 452 policy.freeze() 453 self._validate() 454 455 def generate_dequeue_op(self, tpu_device=0): 456 """Generates the device-side Op to dequeue a tuple from the queue. 457 458 Implicitly freezes the queue configuration if it is not already 459 frozen, which will raise errors if the shapes and types have not 460 been fully specified. 461 462 Args: 463 tpu_device: The TPU device ordinal where the infeed instruction should be 464 placed. If None, no explicit placement will be performed, and it is up 465 to the user to call this API from within a proper TPU device scope. 466 The XLA code will fail if the TPU dequeue instruction is not bound to 467 any device. 468 469 Returns: 470 A list of Outputs corresponding to a shard of infeed dequeued 471 into XLA, suitable for use within a replicated block. 472 473 Raises: 474 ValueError: if the types or shapes of the tuple elements have not been 475 set; or if a dequeue op has already been generated. 476 """ 477 self.freeze() 478 if self._generated_dequeue_op: 479 raise ValueError("Can't generate two dequeue Ops from the same queue") 480 self._generated_dequeue_op = True 481 full_name = "%s/dequeue" % self._name 482 sharded_shapes = [ 483 policy.get_sharded_shape(shape) 484 for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) 485 ] 486 if tpu_device is not None: 487 with ops.device(tpu.core(tpu_device)): 488 return tpu_ops.infeed_dequeue_tuple( 489 dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) 490 else: 491 return tpu_ops.infeed_dequeue_tuple( 492 dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) 493 494 def _generate_enqueue_op(self, 495 inputs, 496 name_prefix, 497 index, 498 device=None, 499 tpu_ordinal=-1): 500 """Generate a host-side Op to enqueue a tuple to the queue. 501 502 If device is None the inputs are all required to have the same 503 device specification, and the enqueue Op is colocated with 504 inputs[0]. Otherwise the enqueue Op is placed on 'device'. 505 506 Args: 507 inputs: a list of Tensors with the types and shapes of the tuple elements. 508 name_prefix: the base name for the Op. 509 index: the shard index, used to uniquify the Op name. 510 device: device to place the Op on, or None if it should be 511 colocated with the inputs. 512 tpu_ordinal: ordinal of the TPU device on the host to use for 513 infeed if device is a CPU device. Should be set to -1 if device 514 is a TPU device. 515 516 Returns: 517 An Op corresponding to a shard of infeed enqueued at the host, 518 suitable for use within a replicated block. 519 520 Raises: 521 ValueError: if device is None and inputs do not all have the 522 same device specification. 523 """ 524 full_name = "%s/%d" % (name_prefix, index) 525 shapes = [t.shape for t in inputs] 526 if device is None: 527 devices = [t.device for t in inputs] 528 for i in xrange(1, self.number_of_tuple_elements): 529 if devices[0] != devices[i]: 530 raise ValueError( 531 "input devices for shard %d are %s, but should all be the same" % 532 (index, str(devices))) 533 with ops.colocate_with(inputs[0]): 534 return tpu_ops.infeed_enqueue_tuple( 535 inputs=inputs, 536 shapes=shapes, 537 name=full_name, 538 device_ordinal=tpu_ordinal) 539 else: 540 with ops.device(device): 541 return tpu_ops.infeed_enqueue_tuple( 542 inputs=inputs, 543 shapes=shapes, 544 name=full_name, 545 device_ordinal=tpu_ordinal) 546 547 def generate_enqueue_ops(self, 548 sharded_inputs, 549 tpu_ordinal_function=None, 550 placement_function=None): 551 """Generates the host-side Ops to enqueue the shards of a tuple. 552 553 sharded_inputs is a list, one for each shard, of lists of 554 Tensors. sharded_inputs[0] is the tuple of Tensors to use to feed 555 shard 0 if the queue. Returns the host-side Ops that must be run to 556 enqueue the sharded tuple. The Op for shard i is colocated with the inputs 557 for shard i. 558 559 Implicitly freezes the queue configuration if it is not already 560 frozen. If the configuration has already been frozen, and is not 561 compatible with the types and shapes of sharded_inputs, an error 562 will be raised. 563 564 Args: 565 sharded_inputs: a list of lists of Tensors. The length of the outer list 566 determines the number of shards. Each inner list indicates the types 567 and shapes of the tuples in the corresponding shard. 568 tpu_ordinal_function: if not None, a function that takes the 569 shard index as input and returns the ordinal of the TPU device 570 the shard's infeed should be placed on. tpu_ordinal_function must be 571 set if the inputs are placed on CPU devices. 572 placement_function: if not None, a function that takes the shard index as 573 input and returns the host device where the enqueue op should be placed 574 on. 575 576 Returns: 577 A list of host-side Ops, one for each shard, that when executed together 578 will enqueue a full-size element of infeed. 579 580 Raises: 581 ValueError: if the queue configuration has previously been frozen and the 582 shapes of the elements of sharded_inputs are not compatible with the 583 frozen configuration; or if the shapes of the elements of sharded_inputs 584 don't form a consistent unsharded tuple; or if the elements of a tuple 585 have different device constraints. 586 TypeError: if the queue configuration has previously been frozen and the 587 types of the elements of sharded_inputs are not compatible with the 588 frozen configuration; or if the types of the elements of sharded_inputs 589 don't form a consistent unsharded tuple. 590 """ 591 self.set_configuration_from_sharded_input_tensors(sharded_inputs) 592 self.freeze() 593 if self._generated_enqueue_ops: 594 raise ValueError("Can't generate two enqueue Ops from the same queue") 595 self._generated_enqueue_ops = True 596 if tpu_ordinal_function is None: 597 tpu_ordinal_function = lambda index: -1 598 name_prefix = "%s/enqueue" % self._name 599 return [ 600 self._generate_enqueue_op( 601 shard, 602 name_prefix, 603 index, 604 tpu_ordinal=tpu_ordinal_function(index), 605 device=placement_function(index) if placement_function else None) 606 for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) 607 ] 608 609 # TODO(misard) Generalize this to the case of systems that don't 610 # have 8 devices per host, and figure out what to do with 611 # model-parallelism. 612 def _default_placement_function(self, index): 613 return "/task:%d/device:CPU:0" % (index / 8) 614 615 def _default_ordinal_function(self, index): 616 return index % 8 617 618 # TODO(b/36470756) remove this from tutorials once we have a better story 619 # for automatic placement of input pipelines. 620 def split_inputs_and_generate_enqueue_ops(self, 621 inputs, 622 device_assignment=None, 623 placement_function=None, 624 tpu_ordinal_function=None): 625 """POORLY-PERFORMING ON MULTI-HOST SYSTEMS. 626 627 Generates the host-side Ops to enqueue a tuple. 628 629 This method performs poorly because it takes an entire input on a single 630 host, splits it, and distributes it to all of the cores. It is present only 631 to simplify tutorial examples. 632 633 inputs is a list of Tensors to use to feed the queue. Each input is split 634 into self.number_of_shards shards. Returns an Op for each shard to enqueue 635 the shard. The Op for shard i is placed on device placement_function(i). 636 637 Implicitly freezes the queue configuration if it is not already 638 frozen. If the configuration has already been frozen, and is not 639 compatible with the types and shapes of inputs, an error 640 will be raised. 641 642 Args: 643 inputs: a list of Tensors which indicates the types and shapes of the 644 queue tuple. 645 device_assignment: if not `None`, a TPU `DeviceAssignment`. If 646 device_assignment is not `None`, but `placement_function` and 647 `ordinal_function` are None, then `device_assignment` will be used to 648 place infeeds on the first k TPU shards, where k is the number of shards 649 in the queue. If all three are `None`, then default placement and 650 ordinal functions are used. 651 placement_function: if not None, a function that takes the shard 652 index as input and returns a device string indicating which 653 device the shard's infeed should be placed on. If placement_function 654 and tpu_ordinal_function are None, inputs are sharded round-robin 655 across the devices in the system. 656 tpu_ordinal_function: if not None, a function that takes the 657 shard index as input and returns the ordinal of the TPU device 658 the shard's infeed should be placed on. If placement_function 659 and tpu_ordinal_function are None, inputs are sharded round-robin 660 across the devices in the system. 661 662 Returns: 663 A list of host-side Ops, one for each shard, that when executed together 664 will enqueue a full-size element of infeed. 665 666 Raises: 667 ValueError: if the queue configuration has previously been frozen and the 668 shapes of the elements of inputs are not compatible with the frozen 669 configuration. 670 TypeError: if the queue configuration has previously been frozen and the 671 types of the elements of inputs are not compatible with the frozen 672 configuration. 673 """ 674 if device_assignment is None: 675 if placement_function is None: 676 placement_function = self._default_placement_function 677 if tpu_ordinal_function is None: 678 tpu_ordinal_function = self._default_ordinal_function 679 else: 680 681 def _placement_function_from_map(index): 682 return device_assignment.host_device(replica=index) 683 684 def _ordinal_function_from_map(index): 685 return device_assignment.tpu_ordinal(replica=index) 686 687 if placement_function is None: 688 placement_function = _placement_function_from_map 689 if tpu_ordinal_function is None: 690 tpu_ordinal_function = _ordinal_function_from_map 691 self.set_configuration_from_input_tensors(inputs) 692 self.freeze() 693 if self._generated_enqueue_ops: 694 raise ValueError("Can't generate two enqueue Ops from the same queue") 695 self._generated_enqueue_ops = True 696 split_name_prefix = "%s/split" % self._name 697 if self.number_of_shards == 1: 698 transposed_sharded_inputs = [[inp] for inp in inputs] 699 else: 700 701 def split_fn(inp, num_shards, axis, name): 702 with ops.colocate_with(inp): 703 return array_ops.split(inp, num_shards, axis=axis, name=name) 704 705 transposed_sharded_inputs = [ 706 split_fn( 707 inp, 708 self.number_of_shards, 709 axis=policy.shard_dimension, 710 name="%s/%d" % (split_name_prefix, index)) 711 for (inp, policy, index) in zip(inputs, self._sharding_policies, 712 xrange(self.number_of_tuple_elements)) 713 ] 714 sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs] 715 for i in xrange(self.number_of_shards)] 716 name_prefix = "%s/enqueue" % self._name 717 return [ 718 self._generate_enqueue_op( 719 shard, 720 name_prefix, 721 index, 722 device=placement_function(index), 723 tpu_ordinal=tpu_ordinal_function(index)) 724 for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) 725 ] 726 727 728class _PartitionedInfeedQueue(InfeedQueue): 729 """A helper object to build a device infeed queue with input partition. 730 731 Args: 732 number_of_tuple_elements: the number of Tensors fed atomically through the 733 queue, must be present unless it can be inferred from other arguments. 734 device_assignment: A TPU `DeviceAssignment` which is used to place all the 735 partitions to different TPU infeed queues. 736 host_id: The id of the host machine. 737 input_partition_dims: A nested list/tuple of integers. Each inner 738 list/tuple describes how to partition the corresponding input tensor. 739 tuple_types: If not None, a list of types of the elements of the queue. 740 tuple_shapes: If not None, a list of shapes of the elements of the queue. 741 name: The name of the queue. 742 """ 743 744 def __init__(self, 745 number_of_tuple_elements, 746 device_assignment, 747 host_id, 748 input_partition_dims=None, 749 tuple_types=None, 750 tuple_shapes=None, 751 name=None): 752 super(_PartitionedInfeedQueue, self).__init__( 753 number_of_tuple_elements=number_of_tuple_elements, 754 tuple_types=tuple_types, 755 tuple_shapes=None, 756 shard_dimensions=None, 757 name="PartitionedInfeedQueue" if name is None else name) 758 self._input_partition_dims = input_partition_dims 759 self._host_id = host_id 760 self._device_assignment = device_assignment 761 762 def generate_dequeue_op(self, tpu_device=0): 763 """Generate TPU dequeue ops. 764 765 Args: 766 tpu_device: The TPU device ordinal where the infeed instruction should be 767 placed. 768 769 Returns: 770 A list of Outputs corresponding to a partition of infeed dequeued 771 into XLA, suitable for use within a replicated block. 772 773 Raises: 774 ValueError: if the types or shapes of the tuple elements have not been 775 set; or if a dequeue op has already been generated. 776 """ 777 self.freeze() 778 if self._generated_dequeue_op: 779 raise ValueError("Can't generate two dequeue Ops from the same queue") 780 self._generated_dequeue_op = True 781 full_name = "%s/dequeue" % self._name 782 sharded_shapes = [ 783 policy.get_sharded_shape(shape) 784 for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) 785 ] 786 with ops.device(tpu.core(tpu_device)): 787 values = tpu_ops.infeed_dequeue_tuple( 788 dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) 789 return tag_sharding_attribute_for_dequeued_tensors( 790 values, self._input_partition_dims) 791 792 def generate_enqueue_ops(self, per_host_sharded_inputs): 793 """Generates the host-side Ops to enqueue the partitioned inputs. 794 795 per_host_sharded_inputs is a list, one for each replica, of lists of 796 Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed 797 replica i. 798 sharded_inputs[i][j] is partitioned by self._input_partition_dims[j]. 799 800 For example, if sharded_inputs[i][j] is a 2-D Tensor: 801 [[A, B, C, D], 802 [E ,F, G, H]] 803 self._input_partition_dims[j] is [2, 4]. 804 805 sharded_inputs[i][j] will be partitioned and flattened into: 806 [A, B, C, D, E, F, G, H] and fed into the logical core ids: 807 [0, 1, 2, 3, 4, 5, 6, 7] respectively. 808 809 Args: 810 per_host_sharded_inputs: a list of lists of Tensors. The length of the 811 outer list determines the number of shards. Each inner list indicates 812 the types and shapes of the tuples in the corresponding shard. 813 814 Returns: 815 A list of host-side Ops, one for each shard, that when executed together 816 will enqueue a full-size element of infeed. 817 818 Raises: 819 ValueError: if the queue configuration has previously been frozen and the 820 shapes of the elements of sharded_inputs are not compatible with the 821 frozen configuration; or if the shapes of the elements of sharded_inputs 822 don't form a consistent unsharded tuple; or if the elements of a tuple 823 have different device constraints; or if the partition dims are invalid. 824 TypeError: if the queue configuration has previously been frozen and the 825 types of the elements of sharded_inputs are not compatible with the 826 frozen configuration; or if the types of the elements of sharded_inputs 827 don't form a consistent unsharded tuple. 828 """ 829 self.set_configuration_from_sharded_input_tensors(per_host_sharded_inputs) 830 number_of_replicas_per_host = len(per_host_sharded_inputs) 831 number_of_tuple_elements = len(per_host_sharded_inputs[0]) 832 833 assert len(self._input_partition_dims) == number_of_tuple_elements 834 per_host_enqueue_ops = [] 835 836 for replica_index in range(number_of_replicas_per_host): 837 flattened_inputs = per_host_sharded_inputs[replica_index] 838 inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs, 839 self._input_partition_dims) 840 inputs_parted_iters = [ 841 iter(self._check_dims_and_partition_or_replicate_on_host(x, dims)) 842 for x, dims in zip(per_host_sharded_inputs[replica_index], 843 inputs_part_dims_flat) 844 ] 845 846 for logical_core in xrange(self._device_assignment.num_cores_per_replica): 847 # Places different partitions to different logic cores. 848 replica_id = self._device_assignment.lookup_replicas( 849 self._host_id, logical_core)[replica_index] 850 ordinal = self._device_assignment.tpu_ordinal( 851 replica=replica_id, logical_core=logical_core) 852 infeed_inputs = [] 853 for it in inputs_parted_iters: 854 input_for_device = next(it, None) 855 if input_for_device is not None: 856 infeed_inputs.append(input_for_device) 857 858 if infeed_inputs: 859 per_host_enqueue_ops.append( 860 tpu_ops.infeed_enqueue_tuple( 861 inputs=infeed_inputs, 862 shapes=[x.shape for x in infeed_inputs], 863 name="enqueue/replica_{0}/input_{1}".format( 864 replica_index, logical_core), 865 device_ordinal=ordinal)) 866 return per_host_enqueue_ops 867 868 def _check_input_partition_dims(self, tensor, dims): 869 """Checks that input partition dims are valid for the `Tensor`. 870 871 Args: 872 tensor: Input tensor for partitioning. 873 dims: A list of integer describes how to partition the input tensor. 874 875 Raises: 876 ValueError: If the tensor can't be partitioned by dims or the 877 num_cores_per_replica doesn't match the number of 878 partitions(dims.prod()). 879 """ 880 # No partitioning specified, so don't perform further checks. 881 if dims is None: 882 return 883 884 dims = np.array(dims) 885 886 if (dims < 1).any(): 887 raise ValueError("All input partition dims must be >= 1.") 888 889 # No partitioning, so don't perform further checks. 890 if dims.prod() == 1: 891 return 892 893 if dims.prod() != self._device_assignment.num_cores_per_replica: 894 raise ValueError( 895 "The product of each input parition dim should equal to " 896 "num_cores_per_replica. (dim = {}, num_cores_per_replica " 897 "= {})".format(dims, self._device_assignment.num_cores_per_replica)) 898 if dims.shape[0] != tensor.shape.ndims: 899 raise ValueError( 900 "Input partition dims must have the same number of dimensions " 901 "as the `Tensor` to be partitioned. (tensor shape = {}, input " 902 "partition dims = {}).".format(tensor.shape.as_list(), dims)) 903 904 tensor.shape.assert_is_fully_defined() 905 906 def _check_dims_and_partition_or_replicate_on_host(self, tensor, dims): 907 """Checks dims and partitions or replicates the input tensor. 908 909 The ops inside this function are placed on the host side. 910 911 Args: 912 tensor: The input tensor which will be partioned or replicated. 913 dims: A list of integer describes how to partition the input tensor. 914 915 Returns: 916 An iterator of `Tensor`s or a list of partioned tensors. 917 """ 918 self._check_input_partition_dims(tensor, dims) 919 return partition_or_replicate_on_host(tensor, dims) 920