1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =================================================================== 15 16"""Helper library for handling infeed between hosts and TPUs. 17""" 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import itertools 24 25import numpy as np 26from six.moves import xrange # pylint: disable=redefined-builtin 27 28from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.ops import array_ops 33from tensorflow.python.tpu import tpu 34from tensorflow.python.tpu import tpu_sharding 35from tensorflow.python.tpu.ops import tpu_ops 36 37from tensorflow.python.util import nest 38 39 40def partition_or_replicate_on_host(tensor, dims): 41 """Partitions or replicates the input tensor. 42 43 The ops inside this function are placed on the host side. 44 45 Args: 46 tensor: The input tensor which will be partitioned or replicated. 47 dims: A list of integer describes how to partition the input tensor. 48 49 Returns: 50 An iterator of `Tensor`s or a list of partitioned tensors. 51 """ 52 if dims is None: 53 return itertools.repeat(tensor) 54 dims = np.array(dims) 55 output = [tensor] 56 shape_list = np.array(tensor.shape.as_list()) 57 quotients, remainders = np.divmod(shape_list, dims) 58 for axis, (quotient, remainder, dim, original_size) in enumerate( 59 zip(quotients, remainders, dims, shape_list)): 60 if dim <= 1: 61 continue 62 if remainder > 0: 63 # For each dimension, when it cannot be evenly partitioned, XLA assumes 64 # tensors are partitioned in a greedy manner by using 65 # ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims 66 # are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] => 67 # [[(3, 4), (3, 4), (2, 4), (2, 2)], 68 # [(2, 4), (2, 4), (2, 4), (2, 2)]] 69 ceil_ratio = quotient + 1 70 num_full_slots, left_over = np.divmod(original_size, ceil_ratio) 71 num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over] 72 if len(num_or_size_splits) < dim: 73 num_or_size_splits += [0] * (dim - len(num_or_size_splits)) 74 new_output = [] 75 for x in output: 76 new_output.append( 77 array_ops.split( 78 x, num_or_size_splits=num_or_size_splits, axis=axis)) 79 output = new_output 80 else: 81 output = [array_ops.split(x, int(dim), axis=axis) for x in output] 82 output = nest.flatten(output) 83 return output 84 85 86def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims): 87 """Tags appropriate XLA sharding attribute to the dequeued tensor. 88 89 The sharding attribute of the dequeued tensor will be a tuple. 90 91 Args: 92 tensor: The dequeued tensor on TPU. 93 dims: A list of integer describes how the tensor is partitioned. 94 95 Returns: 96 The same tensor with the xla_sharding attribute. 97 """ 98 if dims is None: 99 return xla_sharding.replicate(tensor, assign_tuple_sharding=True) 100 elif np.prod(dims) == 1: 101 return xla_sharding.assign_device(tensor, 0, assign_tuple_sharding=True) 102 else: 103 tile_assignment = np.arange(np.prod(dims)).reshape(dims) 104 return xla_sharding.tile( 105 tensor=tensor, 106 tile_assignment=tile_assignment, 107 assign_tuple_sharding=True) 108 109 110def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims): 111 """Tags appropriate XLA sharding attribute to the dequeued tensors. 112 113 Args: 114 dequeues: A list of dequeued tensors on TPU. 115 dims: A list of integer describes how the tensor is partitioned. 116 117 Returns: 118 The same dequeues with appropriate xla_sharding attribute. 119 """ 120 nest.assert_shallow_structure(dequeues, dims) 121 return nest.map_structure_up_to( 122 dequeues, _tag_sharding_attribute_for_dequeued_tensor, dequeues, dims) 123 124 125class InfeedQueue(object): 126 """A helper object to build a device infeed queue. 127 128 The InfeedQueue builds the host-side and device-side Ops to enqueue and 129 dequeue elements, respectively, and ensures that their types and 130 shapes match. 131 """ 132 133 def __init__(self, 134 number_of_tuple_elements=None, 135 tuple_types=None, 136 tuple_shapes=None, 137 shard_dimensions=None, 138 name=None): 139 """Creates a new InfeedQueue with the given configuration. 140 141 The configuration need not be fully specified at creation since it 142 can be modified subsequently by methods that set the values 143 explicitly or infer them from the shapes of inputs. 144 145 Args: 146 number_of_tuple_elements: the number of Tensors fed atomically through the 147 queue, must be present unless it can be inferred from other arguments. 148 tuple_types: if not None, a list of types of the elements of the queue. 149 tuple_shapes: if not None, a list of shapes of the elements of the queue. 150 shard_dimensions: if not None, a list of dimensions on which the 151 elements of the queue should be sharded during automatic 152 parallelization. 153 name: the name of the queue. 154 155 Raises: 156 ValueError: if number_of_tuple_elements <= 0; or 157 number_of_tuple_arguments, tuple_types, tuple_shapes, and 158 shard_dimensions are all None; or the length of tuple_types, 159 tuple_shapes, or shard_dimensions is not equal to 160 number_of_tuple_elements; or any element of shard_dimensions 161 can't be converted to a Dimension. 162 TypeError: if any element of tuple_types or tuple_shapes can't 163 be converted to a dtype or TensorShape, respectively. 164 """ 165 self._frozen = False 166 self._generated_enqueue_ops = False 167 self._generated_dequeue_op = False 168 self._name = "InfeedQueue" if name is None else name 169 if number_of_tuple_elements is None: 170 if tuple_types is not None: 171 number_of_tuple_elements = len(tuple_types) 172 elif tuple_shapes is not None: 173 number_of_tuple_elements = len(tuple_shapes) 174 elif shard_dimensions is not None: 175 number_of_tuple_elements = len(shard_dimensions) 176 else: 177 raise ValueError( 178 "number of tuple elements cannot be inferred from InfeedQueue " 179 "constructor") 180 if number_of_tuple_elements <= 0: 181 raise ValueError("number_of_tuple_elements %d must be > 0" % 182 number_of_tuple_elements) 183 # Make an empty sharding policy for each tuple element. 184 self._sharding_policies = [ 185 tpu_sharding.ShardingPolicy() 186 for _ in xrange(number_of_tuple_elements) 187 ] 188 if tuple_types is not None: 189 self.set_tuple_types(tuple_types) 190 else: 191 self._tuple_types = None 192 if tuple_shapes is not None: 193 self.set_tuple_shapes(tuple_shapes) 194 else: 195 self._tuple_shapes = None 196 if shard_dimensions is not None: 197 self.set_shard_dimensions(shard_dimensions) 198 self._validate() 199 200 def _validate(self): 201 """Checks that the configuration is self-consistent. 202 203 Raises: 204 ValueError: if the shapes and sharding policies don't match. 205 """ 206 if self.tuple_shapes is not None: 207 for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes): 208 # Raise an error if the policy is incompatible with the shape. 209 _ = policy.get_sharded_shape(shape) 210 211 @property 212 def number_of_tuple_elements(self): 213 """Returns the number of InfeedQueue tuple elements.""" 214 return len(self._sharding_policies) 215 216 @property 217 def tuple_types(self): 218 """Returns the types of the InfeedQueue tuple elements.""" 219 return self._tuple_types 220 221 def set_tuple_types(self, tuple_types): 222 """Sets the type of each element of the queue. 223 224 tuple_types must be a list of length 225 self.number_of_tuple_elements, and each element must be 226 convertible to a dtype. 227 228 Args: 229 tuple_types: the types of each queue element. 230 231 Raises: 232 ValueError: if tuple_types is not of length 233 self.number_of_tuple_elements. 234 TypeError: if an element of tuple_types cannot be converted to a 235 dtype. 236 """ 237 if len(tuple_types) != self.number_of_tuple_elements: 238 raise ValueError("tuple_types is %s, but must be a list of length %d" % 239 (str(tuple_types), self.number_of_tuple_elements)) 240 if self._frozen: 241 for (frozen, updated) in zip(self._tuple_types, tuple_types): 242 if frozen != updated: 243 raise ValueError( 244 "Trying to update InfeedQueue with frozen configuration with an " 245 "incompatible type. Frozen types are %s, updated types are %s" % ( 246 str(self._tuple_types), str(tuple_types))) 247 else: 248 try: 249 self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types] 250 except (TypeError) as e: 251 raise TypeError( 252 "tuple_types is %s, but must be a list of elements each " 253 "convertible to dtype: got error %s" % (str(tuple_types), str(e))) 254 255 @property 256 def tuple_shapes(self): 257 """Returns the shapes of the InfeedQueue tuple elements.""" 258 return self._tuple_shapes 259 260 def set_tuple_shapes(self, tuple_shapes): 261 """Sets the shape of each element of the queue. 262 263 tuple_shapes must be a list of length 264 self.number_of_tuple_elements, and each element must be 265 convertible to a TensorShape. 266 267 Args: 268 tuple_shapes: the shapes of each queue element. 269 270 Raises: 271 ValueError: if tuple_shapes is not of length 272 self.number_of_tuple_elements. 273 TypeError: if an element of tuple_shapes cannot be converted to 274 a TensorShape. 275 """ 276 if len(tuple_shapes) != self.number_of_tuple_elements: 277 raise ValueError("tuple_shapes is %s, but must be a list of length %d" % 278 (str(tuple_shapes), self.number_of_tuple_elements)) 279 try: 280 tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes] 281 except (ValueError, TypeError) as e: 282 raise TypeError( 283 "tuple_shapes is %s, but must be a list of elements each " 284 "convertible to TensorShape: got error %s" % (str(tuple_shapes), 285 str(e))) 286 if self._frozen: 287 for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes): 288 if frozen != updated: 289 raise ValueError( 290 "Trying to update InfeedQueue with frozen configuration with an " 291 "incompatible shape. Frozen shapes are %s, updated shapes are %s" 292 % (str(self._tuple_shapes), str(tuple_shapes))) 293 else: 294 self._tuple_shapes = tuple_shapes 295 self._validate() 296 297 @property 298 def sharding_policies(self): 299 """Returns the sharding policies of the InfeedQueue tuple elements.""" 300 return self._sharding_policies 301 302 @property 303 def shard_dimensions(self): 304 """Gets the shard dimension of each tuple element. 305 306 Returns: 307 A list of length number_of_tuple_elements, where each list entry 308 is the shard dimension of that tuple element or None if the 309 shard dimension has not been set. 310 """ 311 # The number of shards is always the same for all the policies. 312 return [policy.shard_dimension for policy in self._sharding_policies] 313 314 def set_shard_dimensions(self, shard_dimensions): 315 """Sets the shard_dimension of each element of the queue. 316 317 shard_dimensions must be a list of length 318 self.number_of_tuple_elements, and each element must be 319 convertible to a Dimension compatible with self.tuple_shapes. 320 321 Args: 322 shard_dimensions: the dimensions of each queue element. 323 324 Raises: 325 ValueError: if shard_dimensions is not of length 326 self.number_of_tuple_elements; or an element of 327 shard_dimensions cannot be converted to a Dimension; or an 328 element of shard_dimensions is a Dimension that is out of 329 range for the corresponding tuple element shape. 330 """ 331 if len(shard_dimensions) != self.number_of_tuple_elements: 332 raise ValueError("shard_dimensions is %s, but must be a list of length %d" 333 % (str(shard_dimensions), 334 self.number_of_tuple_elements)) 335 for (policy, dimension) in zip(self._sharding_policies, shard_dimensions): 336 policy.set_shard_dimension(dimension) 337 self._validate() 338 339 @property 340 def number_of_shards(self): 341 """Gets the number of shards to use for the InfeedQueue. 342 343 Returns: 344 Number of shards or None if the number of shards has not been set. 345 """ 346 # The number of shards is always the same for all the policies. 347 return self._sharding_policies[0].number_of_shards 348 349 def set_number_of_shards(self, number_of_shards): 350 """Sets the number of shards to use for the InfeedQueue. 351 352 Args: 353 number_of_shards: number of ways to shard the InfeedQueue. 354 355 Raises: 356 ValueError: if number_of_shards is not > 0; or the policies have 357 been frozen and number_of_shards was already set to something 358 else. 359 """ 360 for policy in self._sharding_policies: 361 policy.set_number_of_shards(number_of_shards) 362 self._validate() 363 364 def set_configuration_from_input_tensors(self, input_tensors): 365 """Sets the shapes and types of the queue tuple elements. 366 367 input_tensors is a list of Tensors whose types and shapes are used 368 to set the queue configuration. 369 370 Args: 371 input_tensors: list of Tensors of the same types and shapes as 372 the desired queue Tuple. 373 374 Raises: 375 ValueError: if input_tensors is not a list of length 376 self.number_of_tuple_elements 377 """ 378 if len(input_tensors) != self.number_of_tuple_elements: 379 raise ValueError("input_tensors is %s, but should be a list of %d Tensors" 380 % (str(input_tensors), self.number_of_tuple_elements)) 381 self.set_tuple_shapes([t.shape for t in input_tensors]) 382 self.set_tuple_types([t.dtype for t in input_tensors]) 383 384 def set_configuration_from_sharded_input_tensors(self, input_tensors): 385 """Sets the shapes and types of the queue tuple elements. 386 387 input_tensors is a list of lists of Tensors whose types and shapes are used 388 to set the queue configuration. The length of the outer list is the number 389 of shards required, and each inner list is the tuple of Tensors to use to 390 determine the types and shapes of the corresponding shard. This method 391 depends on the shard dimension, and calling it freezes the shard policy. 392 393 Args: 394 input_tensors: list of lists of Tensors. The outer list length corresponds 395 to the desired number of shards, and each inner list is the size 396 and shape of the desired configuration of the corresponding shard. 397 398 Raises: 399 ValueError: if any inner list is not a list of length 400 self.number_of_tuple_elements; or the inner lists do not combine to 401 form a consistent unsharded shape. 402 TypeError: if the types of the Tensors in the inner lists do not match. 403 """ 404 if not self._frozen: 405 # Unset the tuple shapes in case the configuration becomes 406 # transiently inconsistent. 407 self._tuple_shapes = None 408 number_of_shards = len(input_tensors) 409 self.set_number_of_shards(number_of_shards) 410 for t in input_tensors: 411 if len(t) != self.number_of_tuple_elements: 412 raise ValueError( 413 "input_tensors is %s but must be a list of lists, where each inner" 414 " list has length number_of_tuple_elements=%d" % ( 415 str(input_tensors), self.number_of_tuple_elements)) 416 # Transpose the inputs to make a list of shard shapes for each tuple 417 # element. 418 sharded_shapes = [[t[i].shape for t in input_tensors] 419 for i in xrange(self.number_of_tuple_elements)] 420 # For each tuple, get the unsharded shape using that tuple's policy. 421 unsharded_shapes = [ 422 policy.get_unsharded_shape(s) 423 for (policy, s) in zip(self._sharding_policies, sharded_shapes) 424 ] 425 self.set_tuple_shapes(unsharded_shapes) 426 for i in xrange(1, self.number_of_shards): 427 for (t1, t2) in zip(input_tensors[0], input_tensors[i]): 428 if t1.dtype != t2.dtype: 429 raise TypeError( 430 "types of the tuple elements of input_tensors %s are not " 431 "consistent" % str(input_tensors)) 432 self.set_tuple_types([t.dtype for t in input_tensors[0]]) 433 434 def freeze(self): 435 """Freezes the InfeedQueue so it can no longer be modified. 436 437 The configuration is implicitly frozen before any host-side or 438 device-side Ops are generated. The configuration cannot be frozen 439 until the types and shapes of the tuple elements have been set. 440 441 Raises: 442 ValueError: if the types or shapes of the tuple elements have not been 443 set. 444 """ 445 self._frozen = True 446 if self._tuple_types is None: 447 raise ValueError( 448 "Can't freeze an InfeedQueue without setting all tuple types.") 449 if self._tuple_shapes is None: 450 raise ValueError( 451 "Can't freeze an InfeedQueue without setting all tuple shapes.") 452 for shape in self._tuple_shapes: 453 if shape.dims is None: 454 raise ValueError( 455 "Can't freeze an InfeedQueue without setting all tuple shapes.") 456 for policy in self._sharding_policies: 457 policy.freeze() 458 self._validate() 459 460 def generate_dequeue_op(self, tpu_device=0): 461 """Generates the device-side Op to dequeue a tuple from the queue. 462 463 Implicitly freezes the queue configuration if it is not already 464 frozen, which will raise errors if the shapes and types have not 465 been fully specified. 466 467 Args: 468 tpu_device: The TPU device ordinal where the infeed instruction should be 469 placed. If None, no explicit placement will be performed, and it is up 470 to the user to call this API from within a proper TPU device scope. 471 The XLA code will fail if the TPU dequeue instruction is not bound to 472 any device. 473 474 Returns: 475 A list of Outputs corresponding to a shard of infeed dequeued 476 into XLA, suitable for use within a replicated block. 477 478 Raises: 479 ValueError: if the types or shapes of the tuple elements have not been 480 set; or if a dequeue op has already been generated. 481 """ 482 self.freeze() 483 if self._generated_dequeue_op: 484 raise ValueError("Can't generate two dequeue Ops from the same queue") 485 self._generated_dequeue_op = True 486 full_name = "%s/dequeue" % self._name 487 sharded_shapes = [ 488 policy.get_sharded_shape(shape) 489 for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) 490 ] 491 if tpu_device is not None: 492 with ops.device(tpu.core(tpu_device)): 493 return tpu_ops.infeed_dequeue_tuple( 494 dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) 495 else: 496 return tpu_ops.infeed_dequeue_tuple( 497 dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) 498 499 def _generate_enqueue_op(self, 500 inputs, 501 name_prefix, 502 index, 503 device=None, 504 tpu_ordinal=-1): 505 """Generate a host-side Op to enqueue a tuple to the queue. 506 507 If device is None the inputs are all required to have the same 508 device specification, and the enqueue Op is colocated with 509 inputs[0]. Otherwise the enqueue Op is placed on 'device'. 510 511 Args: 512 inputs: a list of Tensors with the types and shapes of the tuple elements. 513 name_prefix: the base name for the Op. 514 index: the shard index, used to uniquify the Op name. 515 device: device to place the Op on, or None if it should be 516 colocated with the inputs. 517 tpu_ordinal: ordinal of the TPU device on the host to use for 518 infeed if device is a CPU device. Should be set to -1 if device 519 is a TPU device. 520 521 Returns: 522 An Op corresponding to a shard of infeed enqueued at the host, 523 suitable for use within a replicated block. 524 525 Raises: 526 ValueError: if device is None and inputs do not all have the 527 same device specification. 528 """ 529 full_name = "%s/%d" % (name_prefix, index) 530 shapes = [t.shape for t in inputs] 531 if device is None: 532 devices = [t.device for t in inputs] 533 for i in xrange(1, self.number_of_tuple_elements): 534 if devices[0] != devices[i]: 535 raise ValueError( 536 "input devices for shard %d are %s, but should all be the same" % 537 (index, str(devices))) 538 with ops.colocate_with(inputs[0]): 539 return tpu_ops.infeed_enqueue_tuple( 540 inputs=inputs, 541 shapes=shapes, 542 name=full_name, 543 device_ordinal=tpu_ordinal) 544 else: 545 with ops.device(device): 546 return tpu_ops.infeed_enqueue_tuple( 547 inputs=inputs, 548 shapes=shapes, 549 name=full_name, 550 device_ordinal=tpu_ordinal) 551 552 def generate_enqueue_ops(self, 553 sharded_inputs, 554 tpu_ordinal_function=None, 555 placement_function=None): 556 """Generates the host-side Ops to enqueue the shards of a tuple. 557 558 sharded_inputs is a list, one for each shard, of lists of 559 Tensors. sharded_inputs[0] is the tuple of Tensors to use to feed 560 shard 0 if the queue. Returns the host-side Ops that must be run to 561 enqueue the sharded tuple. The Op for shard i is colocated with the inputs 562 for shard i. 563 564 Implicitly freezes the queue configuration if it is not already 565 frozen. If the configuration has already been frozen, and is not 566 compatible with the types and shapes of sharded_inputs, an error 567 will be raised. 568 569 Args: 570 sharded_inputs: a list of lists of Tensors. The length of the outer list 571 determines the number of shards. Each inner list indicates the types 572 and shapes of the tuples in the corresponding shard. 573 tpu_ordinal_function: if not None, a function that takes the 574 shard index as input and returns the ordinal of the TPU device 575 the shard's infeed should be placed on. tpu_ordinal_function must be 576 set if the inputs are placed on CPU devices. 577 placement_function: if not None, a function that takes the shard index as 578 input and returns the host device where the enqueue op should be placed 579 on. 580 581 Returns: 582 A list of host-side Ops, one for each shard, that when executed together 583 will enqueue a full-size element of infeed. 584 585 Raises: 586 ValueError: if the queue configuration has previously been frozen and the 587 shapes of the elements of sharded_inputs are not compatible with the 588 frozen configuration; or if the shapes of the elements of sharded_inputs 589 don't form a consistent unsharded tuple; or if the elements of a tuple 590 have different device constraints. 591 TypeError: if the queue configuration has previously been frozen and the 592 types of the elements of sharded_inputs are not compatible with the 593 frozen configuration; or if the types of the elements of sharded_inputs 594 don't form a consistent unsharded tuple. 595 """ 596 self.set_configuration_from_sharded_input_tensors(sharded_inputs) 597 self.freeze() 598 if self._generated_enqueue_ops: 599 raise ValueError("Can't generate two enqueue Ops from the same queue") 600 self._generated_enqueue_ops = True 601 if tpu_ordinal_function is None: 602 tpu_ordinal_function = lambda index: -1 603 name_prefix = "%s/enqueue" % self._name 604 return [ 605 self._generate_enqueue_op( 606 shard, 607 name_prefix, 608 index, 609 tpu_ordinal=tpu_ordinal_function(index), 610 device=placement_function(index) if placement_function else None) 611 for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) 612 ] 613 614 # TODO(misard) Generalize this to the case of systems that don't 615 # have 8 devices per host, and figure out what to do with 616 # model-parallelism. 617 def _default_placement_function(self, index): 618 return "/task:%d/device:CPU:0" % (index / 8) 619 620 def _default_ordinal_function(self, index): 621 return index % 8 622 623 # TODO(b/36470756) remove this from tutorials once we have a better story 624 # for automatic placement of input pipelines. 625 def split_inputs_and_generate_enqueue_ops(self, 626 inputs, 627 device_assignment=None, 628 placement_function=None, 629 tpu_ordinal_function=None): 630 """POORLY-PERFORMING ON MULTI-HOST SYSTEMS. 631 632 Generates the host-side Ops to enqueue a tuple. 633 634 This method performs poorly because it takes an entire input on a single 635 host, splits it, and distributes it to all of the cores. It is present only 636 to simplify tutorial examples. 637 638 inputs is a list of Tensors to use to feed the queue. Each input is split 639 into self.number_of_shards shards. Returns an Op for each shard to enqueue 640 the shard. The Op for shard i is placed on device placement_function(i). 641 642 Implicitly freezes the queue configuration if it is not already 643 frozen. If the configuration has already been frozen, and is not 644 compatible with the types and shapes of inputs, an error 645 will be raised. 646 647 Args: 648 inputs: a list of Tensors which indicates the types and shapes of the 649 queue tuple. 650 device_assignment: if not `None`, a TPU `DeviceAssignment`. If 651 device_assignment is not `None`, but `placement_function` and 652 `ordinal_function` are None, then `device_assignment` will be used to 653 place infeeds on the first k TPU shards, where k is the number of shards 654 in the queue. If all three are `None`, then default placement and 655 ordinal functions are used. 656 placement_function: if not None, a function that takes the shard 657 index as input and returns a device string indicating which 658 device the shard's infeed should be placed on. If placement_function 659 and tpu_ordinal_function are None, inputs are sharded round-robin 660 across the devices in the system. 661 tpu_ordinal_function: if not None, a function that takes the 662 shard index as input and returns the ordinal of the TPU device 663 the shard's infeed should be placed on. If placement_function 664 and tpu_ordinal_function are None, inputs are sharded round-robin 665 across the devices in the system. 666 667 Returns: 668 A list of host-side Ops, one for each shard, that when executed together 669 will enqueue a full-size element of infeed. 670 671 Raises: 672 ValueError: if the queue configuration has previously been frozen and the 673 shapes of the elements of inputs are not compatible with the frozen 674 configuration. 675 TypeError: if the queue configuration has previously been frozen and the 676 types of the elements of inputs are not compatible with the frozen 677 configuration. 678 """ 679 if device_assignment is None: 680 if placement_function is None: 681 placement_function = self._default_placement_function 682 if tpu_ordinal_function is None: 683 tpu_ordinal_function = self._default_ordinal_function 684 else: 685 686 def _placement_function_from_map(index): 687 return device_assignment.host_device(replica=index) 688 689 def _ordinal_function_from_map(index): 690 return device_assignment.tpu_ordinal(replica=index) 691 692 if placement_function is None: 693 placement_function = _placement_function_from_map 694 if tpu_ordinal_function is None: 695 tpu_ordinal_function = _ordinal_function_from_map 696 self.set_configuration_from_input_tensors(inputs) 697 self.freeze() 698 if self._generated_enqueue_ops: 699 raise ValueError("Can't generate two enqueue Ops from the same queue") 700 self._generated_enqueue_ops = True 701 split_name_prefix = "%s/split" % self._name 702 if self.number_of_shards == 1: 703 transposed_sharded_inputs = [[inp] for inp in inputs] 704 else: 705 706 def split_fn(inp, num_shards, axis, name): 707 with ops.colocate_with(inp): 708 return array_ops.split(inp, num_shards, axis=axis, name=name) 709 710 transposed_sharded_inputs = [ 711 split_fn( 712 inp, 713 self.number_of_shards, 714 axis=policy.shard_dimension, 715 name="%s/%d" % (split_name_prefix, index)) 716 for (inp, policy, index) in zip(inputs, self._sharding_policies, 717 xrange(self.number_of_tuple_elements)) 718 ] 719 sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs] 720 for i in xrange(self.number_of_shards)] 721 name_prefix = "%s/enqueue" % self._name 722 return [ 723 self._generate_enqueue_op( 724 shard, 725 name_prefix, 726 index, 727 device=placement_function(index), 728 tpu_ordinal=tpu_ordinal_function(index)) 729 for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) 730 ] 731 732 733class _PartitionedInfeedQueue(InfeedQueue): 734 """A helper object to build a device infeed queue with input partition. 735 736 Args: 737 number_of_tuple_elements: the number of Tensors fed atomically through the 738 queue, must be present unless it can be inferred from other arguments. 739 device_assignment: A TPU `DeviceAssignment` which is used to place all the 740 partitions to different TPU infeed queues. 741 host_id: The id of the host machine. 742 input_partition_dims: A nested list/tuple of integers. Each inner 743 list/tuple describes how to partition the corresponding input tensor. 744 tuple_types: If not None, a list of types of the elements of the queue. 745 tuple_shapes: If not None, a list of shapes of the elements of the queue. 746 name: The name of the queue. 747 """ 748 749 def __init__(self, 750 number_of_tuple_elements, 751 device_assignment, 752 host_id, 753 input_partition_dims=None, 754 tuple_types=None, 755 tuple_shapes=None, 756 name=None): 757 super(_PartitionedInfeedQueue, self).__init__( 758 number_of_tuple_elements=number_of_tuple_elements, 759 tuple_types=tuple_types, 760 tuple_shapes=None, 761 shard_dimensions=None, 762 name="PartitionedInfeedQueue" if name is None else name) 763 self._input_partition_dims = input_partition_dims 764 self._host_id = host_id 765 self._device_assignment = device_assignment 766 767 def generate_dequeue_op(self, tpu_device=0): 768 """Generate TPU dequeue ops. 769 770 Args: 771 tpu_device: The TPU device ordinal where the infeed instruction should be 772 placed. 773 774 Returns: 775 A list of Outputs corresponding to a partition of infeed dequeued 776 into XLA, suitable for use within a replicated block. 777 778 Raises: 779 ValueError: if the types or shapes of the tuple elements have not been 780 set; or if a dequeue op has already been generated. 781 """ 782 self.freeze() 783 if self._generated_dequeue_op: 784 raise ValueError("Can't generate two dequeue Ops from the same queue") 785 self._generated_dequeue_op = True 786 full_name = "%s/dequeue" % self._name 787 sharded_shapes = [ 788 policy.get_sharded_shape(shape) 789 for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) 790 ] 791 with ops.device(tpu.core(tpu_device)): 792 values = tpu_ops.infeed_dequeue_tuple( 793 dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) 794 return tag_sharding_attribute_for_dequeued_tensors( 795 values, self._input_partition_dims) 796 797 def generate_enqueue_ops(self, sharded_inputs): 798 """Generates the host-side Ops to enqueue the partitioned inputs. 799 800 sharded_inputs is a list, one for each replica, of lists of 801 Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed 802 replica i. 803 sharded_inputs[i][j] is partitioned by self._input_partition_dims[j]. 804 805 For example, if sharded_inputs[i][j] is a 2-D Tensor: 806 [[A, B, C, D], 807 [E ,F, G, H]] 808 self._input_partition_dims[j] is [2, 4]. 809 810 sharded_inputs[i][j] will be partitioned and flattened into: 811 [A, B, C, D, E, F, G, H] and fed into the logical core ids: 812 [0, 1, 2, 3, 4, 5, 6, 7] respectively. 813 814 Args: 815 sharded_inputs: a list of lists of Tensors. The length of the 816 outer list determines the number of shards. Each inner list indicates 817 the types and shapes of the tuples in the corresponding shard. 818 819 Returns: 820 A list of host-side Ops, one for each shard, that when executed together 821 will enqueue a full-size element of infeed. 822 823 Raises: 824 ValueError: if the queue configuration has previously been frozen and the 825 shapes of the elements of sharded_inputs are not compatible with the 826 frozen configuration; or if the shapes of the elements of sharded_inputs 827 don't form a consistent unsharded tuple; or if the elements of a tuple 828 have different device constraints; or if the partition dims are invalid. 829 TypeError: if the queue configuration has previously been frozen and the 830 types of the elements of sharded_inputs are not compatible with the 831 frozen configuration; or if the types of the elements of sharded_inputs 832 don't form a consistent unsharded tuple. 833 """ 834 self.set_configuration_from_sharded_input_tensors(sharded_inputs) 835 number_of_replicas = len(sharded_inputs) 836 number_of_tuple_elements = len(sharded_inputs[0]) 837 838 assert len(self._input_partition_dims) == number_of_tuple_elements 839 enqueue_ops = [] 840 841 for replica_index in range(number_of_replicas): 842 flattened_inputs = sharded_inputs[replica_index] 843 inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs, 844 self._input_partition_dims) 845 inputs_parted_iters = [ 846 iter(self._check_dims_and_partition_or_replicate_on_host(x, dims)) 847 for x, dims in zip(sharded_inputs[replica_index], 848 inputs_part_dims_flat) 849 ] 850 851 # Find the replica_id of the host's logical core 0. 852 # The self._host_id is guaranteed to contain the logical core 0, 853 # even when num_cores_per_replica > num_cores_per_host -- the function 854 # caller makes sure that this host_id will must be receiving data (calls 855 # input_fn). 856 replica_id = self._device_assignment.lookup_replicas( 857 task_id=self._host_id, logical_core=0)[replica_index] 858 for logical_core in xrange(self._device_assignment.num_cores_per_replica): 859 # Places different partitions to different logic cores. 860 # Since there can be multiple hosts per replica, we need to find 861 # the actual host (device) of this logical core. 862 device = self._device_assignment.host_device( 863 replica=replica_id, logical_core=logical_core) 864 865 with ops.device(device): 866 ordinal = self._device_assignment.tpu_ordinal( 867 replica=replica_id, logical_core=logical_core) 868 infeed_inputs = [] 869 for it in inputs_parted_iters: 870 input_for_device = next(it, None) 871 if input_for_device is not None: 872 infeed_inputs.append(input_for_device) 873 874 if infeed_inputs: 875 enqueue_ops.append( 876 tpu_ops.infeed_enqueue_tuple( 877 inputs=infeed_inputs, 878 shapes=[x.shape for x in infeed_inputs], 879 name="enqueue/replica_{0}/input_{1}".format( 880 replica_index, logical_core), 881 device_ordinal=ordinal)) 882 return enqueue_ops 883 884 def _check_input_partition_dims(self, tensor, dims): 885 """Checks that input partition dims are valid for the `Tensor`. 886 887 Args: 888 tensor: Input tensor for partitioning. 889 dims: A list of integer describes how to partition the input tensor. 890 891 Raises: 892 ValueError: If the tensor can't be partitioned by dims or the 893 num_cores_per_replica doesn't match the number of 894 partitions(dims.prod()). 895 """ 896 # No partitioning specified, so don't perform further checks. 897 if dims is None: 898 return 899 900 dims = np.array(dims) 901 902 if (dims < 1).any(): 903 raise ValueError("All input partition dims must be >= 1.") 904 905 # No partitioning, so don't perform further checks. 906 if dims.prod() == 1: 907 return 908 909 if dims.prod() != self._device_assignment.num_cores_per_replica: 910 raise ValueError( 911 "The product of each input partition dim should equal to " 912 "num_cores_per_replica. (dim = {}, num_cores_per_replica " 913 "= {})".format(dims, self._device_assignment.num_cores_per_replica)) 914 if dims.shape[0] != tensor.shape.ndims: 915 raise ValueError( 916 "Input partition dims must have the same number of dimensions " 917 "as the `Tensor` to be partitioned. (tensor shape = {}, input " 918 "partition dims = {}).".format(tensor.shape.as_list(), dims)) 919 920 tensor.shape.assert_is_fully_defined() 921 922 def _check_dims_and_partition_or_replicate_on_host(self, tensor, dims): 923 """Checks dims and partitions or replicates the input tensor. 924 925 The ops inside this function are placed on the host side. 926 927 Args: 928 tensor: The input tensor which will be partitioned or replicated. 929 dims: A list of integer describes how to partition the input tensor. 930 931 Returns: 932 An iterator of `Tensor`s or a list of partitioned tensors. 933 """ 934 self._check_input_partition_dims(tensor, dims) 935 return partition_or_replicate_on_host(tensor, dims) 936