• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15
16"""Helper library for handling infeed between hosts and TPUs.
17"""
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import itertools
24
25import numpy as np
26from six.moves import xrange  # pylint: disable=redefined-builtin
27
28from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.ops import array_ops
33from tensorflow.python.tpu import tpu
34from tensorflow.python.tpu import tpu_sharding
35from tensorflow.python.tpu.ops import tpu_ops
36
37from tensorflow.python.util import nest
38
39
40def partition_or_replicate_on_host(tensor, dims):
41  """Partitions or replicates the input tensor.
42
43    The ops inside this function are placed on the host side.
44
45  Args:
46    tensor: The input tensor which will be partitioned or replicated.
47    dims: A list of integer describes how to partition the input tensor.
48
49  Returns:
50    An iterator of `Tensor`s or a list of partitioned tensors.
51  """
52  if dims is None:
53    return itertools.repeat(tensor)
54  dims = np.array(dims)
55  output = [tensor]
56  shape_list = np.array(tensor.shape.as_list())
57  quotients, remainders = np.divmod(shape_list, dims)
58  for axis, (quotient, remainder, dim, original_size) in enumerate(
59      zip(quotients, remainders, dims, shape_list)):
60    if dim <= 1:
61      continue
62    if remainder > 0:
63      # For each dimension, when it cannot be evenly partitioned, XLA assumes
64      # tensors are partitioned in a greedy manner by using
65      # ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims
66      # are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] =>
67      # [[(3, 4), (3, 4), (2, 4), (2, 2)],
68      # [(2, 4), (2, 4), (2, 4), (2, 2)]]
69      ceil_ratio = quotient + 1
70      num_full_slots, left_over = np.divmod(original_size, ceil_ratio)
71      num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over]
72      if len(num_or_size_splits) < dim:
73        num_or_size_splits += [0] * (dim - len(num_or_size_splits))
74      new_output = []
75      for x in output:
76        new_output.append(
77            array_ops.split(
78                x, num_or_size_splits=num_or_size_splits, axis=axis))
79      output = new_output
80    else:
81      output = [array_ops.split(x, int(dim), axis=axis) for x in output]
82    output = nest.flatten(output)
83  return output
84
85
86def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
87  """Tags appropriate XLA sharding attribute to the dequeued tensor.
88
89  The sharding attribute of the dequeued tensor will be a tuple.
90
91  Args:
92    tensor: The dequeued tensor on TPU.
93    dims: A list of integer describes how the tensor is partitioned.
94
95  Returns:
96    The same tensor with the xla_sharding attribute.
97  """
98  if dims is None:
99    return xla_sharding.replicate(tensor, assign_tuple_sharding=True)
100  elif np.prod(dims) == 1:
101    return xla_sharding.assign_device(tensor, 0, assign_tuple_sharding=True)
102  else:
103    tile_assignment = np.arange(np.prod(dims)).reshape(dims)
104    return xla_sharding.tile(
105        tensor=tensor,
106        tile_assignment=tile_assignment,
107        assign_tuple_sharding=True)
108
109
110def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims):
111  """Tags appropriate XLA sharding attribute to the dequeued tensors.
112
113  Args:
114    dequeues: A list of dequeued tensors on TPU.
115    dims: A list of integer describes how the tensor is partitioned.
116
117  Returns:
118    The same dequeues with appropriate xla_sharding attribute.
119  """
120  nest.assert_shallow_structure(dequeues, dims)
121  return nest.map_structure_up_to(
122      dequeues, _tag_sharding_attribute_for_dequeued_tensor, dequeues, dims)
123
124
125class InfeedQueue(object):
126  """A helper object to build a device infeed queue.
127
128  The InfeedQueue builds the host-side and device-side Ops to enqueue and
129  dequeue elements, respectively, and ensures that their types and
130  shapes match.
131  """
132
133  def __init__(self,
134               number_of_tuple_elements=None,
135               tuple_types=None,
136               tuple_shapes=None,
137               shard_dimensions=None,
138               name=None):
139    """Creates a new InfeedQueue with the given configuration.
140
141    The configuration need not be fully specified at creation since it
142    can be modified subsequently by methods that set the values
143    explicitly or infer them from the shapes of inputs.
144
145    Args:
146      number_of_tuple_elements: the number of Tensors fed atomically through the
147        queue, must be present unless it can be inferred from other arguments.
148      tuple_types: if not None, a list of types of the elements of the queue.
149      tuple_shapes: if not None, a list of shapes of the elements of the queue.
150      shard_dimensions: if not None, a list of dimensions on which the
151        elements of the queue should be sharded during automatic
152        parallelization.
153      name: the name of the queue.
154
155    Raises:
156      ValueError: if number_of_tuple_elements <= 0; or
157        number_of_tuple_arguments, tuple_types, tuple_shapes, and
158        shard_dimensions are all None; or the length of tuple_types,
159        tuple_shapes, or shard_dimensions is not equal to
160        number_of_tuple_elements; or any element of shard_dimensions
161        can't be converted to a Dimension.
162      TypeError: if any element of tuple_types or tuple_shapes can't
163        be converted to a dtype or TensorShape, respectively.
164    """
165    self._frozen = False
166    self._generated_enqueue_ops = False
167    self._generated_dequeue_op = False
168    self._name = "InfeedQueue" if name is None else name
169    if number_of_tuple_elements is None:
170      if tuple_types is not None:
171        number_of_tuple_elements = len(tuple_types)
172      elif tuple_shapes is not None:
173        number_of_tuple_elements = len(tuple_shapes)
174      elif shard_dimensions is not None:
175        number_of_tuple_elements = len(shard_dimensions)
176      else:
177        raise ValueError(
178            "number of tuple elements cannot be inferred from InfeedQueue "
179            "constructor")
180    if number_of_tuple_elements <= 0:
181      raise ValueError("number_of_tuple_elements %d must be > 0" %
182                       number_of_tuple_elements)
183    # Make an empty sharding policy for each tuple element.
184    self._sharding_policies = [
185        tpu_sharding.ShardingPolicy()
186        for _ in xrange(number_of_tuple_elements)
187    ]
188    if tuple_types is not None:
189      self.set_tuple_types(tuple_types)
190    else:
191      self._tuple_types = None
192    if tuple_shapes is not None:
193      self.set_tuple_shapes(tuple_shapes)
194    else:
195      self._tuple_shapes = None
196    if shard_dimensions is not None:
197      self.set_shard_dimensions(shard_dimensions)
198    self._validate()
199
200  def _validate(self):
201    """Checks that the configuration is self-consistent.
202
203    Raises:
204      ValueError: if the shapes and sharding policies don't match.
205    """
206    if self.tuple_shapes is not None:
207      for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes):
208        # Raise an error if the policy is incompatible with the shape.
209        _ = policy.get_sharded_shape(shape)
210
211  @property
212  def number_of_tuple_elements(self):
213    """Returns the number of InfeedQueue tuple elements."""
214    return len(self._sharding_policies)
215
216  @property
217  def tuple_types(self):
218    """Returns the types of the InfeedQueue tuple elements."""
219    return self._tuple_types
220
221  def set_tuple_types(self, tuple_types):
222    """Sets the type of each element of the queue.
223
224    tuple_types must be a list of length
225    self.number_of_tuple_elements, and each element must be
226    convertible to a dtype.
227
228    Args:
229      tuple_types: the types of each queue element.
230
231    Raises:
232      ValueError: if tuple_types is not of length
233        self.number_of_tuple_elements.
234      TypeError: if an element of tuple_types cannot be converted to a
235        dtype.
236    """
237    if len(tuple_types) != self.number_of_tuple_elements:
238      raise ValueError("tuple_types is %s, but must be a list of length %d" %
239                       (str(tuple_types), self.number_of_tuple_elements))
240    if self._frozen:
241      for (frozen, updated) in zip(self._tuple_types, tuple_types):
242        if frozen != updated:
243          raise ValueError(
244              "Trying to update InfeedQueue with frozen configuration with an "
245              "incompatible type. Frozen types are %s, updated types are %s" % (
246                  str(self._tuple_types), str(tuple_types)))
247    else:
248      try:
249        self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types]
250      except (TypeError) as e:
251        raise TypeError(
252            "tuple_types is %s, but must be a list of elements each "
253            "convertible to dtype: got error %s" % (str(tuple_types), str(e)))
254
255  @property
256  def tuple_shapes(self):
257    """Returns the shapes of the InfeedQueue tuple elements."""
258    return self._tuple_shapes
259
260  def set_tuple_shapes(self, tuple_shapes):
261    """Sets the shape of each element of the queue.
262
263    tuple_shapes must be a list of length
264    self.number_of_tuple_elements, and each element must be
265    convertible to a TensorShape.
266
267    Args:
268      tuple_shapes: the shapes of each queue element.
269
270    Raises:
271      ValueError: if tuple_shapes is not of length
272        self.number_of_tuple_elements.
273      TypeError: if an element of tuple_shapes cannot be converted to
274        a TensorShape.
275    """
276    if len(tuple_shapes) != self.number_of_tuple_elements:
277      raise ValueError("tuple_shapes is %s, but must be a list of length %d" %
278                       (str(tuple_shapes), self.number_of_tuple_elements))
279    try:
280      tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes]
281    except (ValueError, TypeError) as e:
282      raise TypeError(
283          "tuple_shapes is %s, but must be a list of elements each "
284          "convertible to TensorShape: got error %s" % (str(tuple_shapes),
285                                                        str(e)))
286    if self._frozen:
287      for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes):
288        if frozen != updated:
289          raise ValueError(
290              "Trying to update InfeedQueue with frozen configuration with an "
291              "incompatible shape. Frozen shapes are %s, updated shapes are %s"
292              % (str(self._tuple_shapes), str(tuple_shapes)))
293    else:
294      self._tuple_shapes = tuple_shapes
295    self._validate()
296
297  @property
298  def sharding_policies(self):
299    """Returns the sharding policies of the InfeedQueue tuple elements."""
300    return self._sharding_policies
301
302  @property
303  def shard_dimensions(self):
304    """Gets the shard dimension of each tuple element.
305
306    Returns:
307      A list of length number_of_tuple_elements, where each list entry
308      is the shard dimension of that tuple element or None if the
309      shard dimension has not been set.
310    """
311    # The number of shards is always the same for all the policies.
312    return [policy.shard_dimension for policy in self._sharding_policies]
313
314  def set_shard_dimensions(self, shard_dimensions):
315    """Sets the shard_dimension of each element of the queue.
316
317    shard_dimensions must be a list of length
318    self.number_of_tuple_elements, and each element must be
319    convertible to a Dimension compatible with self.tuple_shapes.
320
321    Args:
322      shard_dimensions: the dimensions of each queue element.
323
324    Raises:
325      ValueError: if shard_dimensions is not of length
326        self.number_of_tuple_elements; or an element of
327        shard_dimensions cannot be converted to a Dimension; or an
328        element of shard_dimensions is a Dimension that is out of
329        range for the corresponding tuple element shape.
330    """
331    if len(shard_dimensions) != self.number_of_tuple_elements:
332      raise ValueError("shard_dimensions is %s, but must be a list of length %d"
333                       % (str(shard_dimensions),
334                          self.number_of_tuple_elements))
335    for (policy, dimension) in zip(self._sharding_policies, shard_dimensions):
336      policy.set_shard_dimension(dimension)
337    self._validate()
338
339  @property
340  def number_of_shards(self):
341    """Gets the number of shards to use for the InfeedQueue.
342
343    Returns:
344      Number of shards or None if the number of shards has not been set.
345    """
346    # The number of shards is always the same for all the policies.
347    return self._sharding_policies[0].number_of_shards
348
349  def set_number_of_shards(self, number_of_shards):
350    """Sets the number of shards to use for the InfeedQueue.
351
352    Args:
353      number_of_shards: number of ways to shard the InfeedQueue.
354
355    Raises:
356      ValueError: if number_of_shards is not > 0; or the policies have
357        been frozen and number_of_shards was already set to something
358        else.
359    """
360    for policy in self._sharding_policies:
361      policy.set_number_of_shards(number_of_shards)
362    self._validate()
363
364  def set_configuration_from_input_tensors(self, input_tensors):
365    """Sets the shapes and types of the queue tuple elements.
366
367    input_tensors is a list of Tensors whose types and shapes are used
368    to set the queue configuration.
369
370    Args:
371      input_tensors: list of Tensors of the same types and shapes as
372        the desired queue Tuple.
373
374    Raises:
375      ValueError: if input_tensors is not a list of length
376        self.number_of_tuple_elements
377    """
378    if len(input_tensors) != self.number_of_tuple_elements:
379      raise ValueError("input_tensors is %s, but should be a list of %d Tensors"
380                       % (str(input_tensors), self.number_of_tuple_elements))
381    self.set_tuple_shapes([t.shape for t in input_tensors])
382    self.set_tuple_types([t.dtype for t in input_tensors])
383
384  def set_configuration_from_sharded_input_tensors(self, input_tensors):
385    """Sets the shapes and types of the queue tuple elements.
386
387    input_tensors is a list of lists of Tensors whose types and shapes are used
388    to set the queue configuration. The length of the outer list is the number
389    of shards required, and each inner list is the tuple of Tensors to use to
390    determine the types and shapes of the corresponding shard. This method
391    depends on the shard dimension, and calling it freezes the shard policy.
392
393    Args:
394      input_tensors: list of lists of Tensors. The outer list length corresponds
395        to the desired number of shards, and each inner list is the size
396        and shape of the desired configuration of the corresponding shard.
397
398    Raises:
399      ValueError: if any inner list is not a list of length
400        self.number_of_tuple_elements; or the inner lists do not combine to
401        form a consistent unsharded shape.
402      TypeError: if the types of the Tensors in the inner lists do not match.
403    """
404    if not self._frozen:
405      # Unset the tuple shapes in case the configuration becomes
406      # transiently inconsistent.
407      self._tuple_shapes = None
408    number_of_shards = len(input_tensors)
409    self.set_number_of_shards(number_of_shards)
410    for t in input_tensors:
411      if len(t) != self.number_of_tuple_elements:
412        raise ValueError(
413            "input_tensors is %s but must be a list of lists, where each inner"
414            " list has length number_of_tuple_elements=%d" % (
415                str(input_tensors), self.number_of_tuple_elements))
416    # Transpose the inputs to make a list of shard shapes for each tuple
417    # element.
418    sharded_shapes = [[t[i].shape for t in input_tensors]
419                      for i in xrange(self.number_of_tuple_elements)]
420    # For each tuple, get the unsharded shape using that tuple's policy.
421    unsharded_shapes = [
422        policy.get_unsharded_shape(s)
423        for (policy, s) in zip(self._sharding_policies, sharded_shapes)
424    ]
425    self.set_tuple_shapes(unsharded_shapes)
426    for i in xrange(1, self.number_of_shards):
427      for (t1, t2) in zip(input_tensors[0], input_tensors[i]):
428        if t1.dtype != t2.dtype:
429          raise TypeError(
430              "types of the tuple elements of input_tensors %s are not "
431              "consistent" % str(input_tensors))
432    self.set_tuple_types([t.dtype for t in input_tensors[0]])
433
434  def freeze(self):
435    """Freezes the InfeedQueue so it can no longer be modified.
436
437    The configuration is implicitly frozen before any host-side or
438    device-side Ops are generated. The configuration cannot be frozen
439    until the types and shapes of the tuple elements have been set.
440
441    Raises:
442      ValueError: if the types or shapes of the tuple elements have not been
443      set.
444    """
445    self._frozen = True
446    if self._tuple_types is None:
447      raise ValueError(
448          "Can't freeze an InfeedQueue without setting all tuple types.")
449    if self._tuple_shapes is None:
450      raise ValueError(
451          "Can't freeze an InfeedQueue without setting all tuple shapes.")
452    for shape in self._tuple_shapes:
453      if shape.dims is None:
454        raise ValueError(
455            "Can't freeze an InfeedQueue without setting all tuple shapes.")
456    for policy in self._sharding_policies:
457      policy.freeze()
458    self._validate()
459
460  def generate_dequeue_op(self, tpu_device=0):
461    """Generates the device-side Op to dequeue a tuple from the queue.
462
463    Implicitly freezes the queue configuration if it is not already
464    frozen, which will raise errors if the shapes and types have not
465    been fully specified.
466
467    Args:
468      tpu_device: The TPU device ordinal where the infeed instruction should be
469        placed. If None, no explicit placement will be performed, and it is up
470        to the user to call this API from within a proper TPU device scope.
471        The XLA code will fail if the TPU dequeue instruction is not bound to
472        any device.
473
474    Returns:
475      A list of Outputs corresponding to a shard of infeed dequeued
476      into XLA, suitable for use within a replicated block.
477
478    Raises:
479      ValueError: if the types or shapes of the tuple elements have not been
480      set; or if a dequeue op has already been generated.
481    """
482    self.freeze()
483    if self._generated_dequeue_op:
484      raise ValueError("Can't generate two dequeue Ops from the same queue")
485    self._generated_dequeue_op = True
486    full_name = "%s/dequeue" % self._name
487    sharded_shapes = [
488        policy.get_sharded_shape(shape)
489        for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
490    ]
491    if tpu_device is not None:
492      with ops.device(tpu.core(tpu_device)):
493        return tpu_ops.infeed_dequeue_tuple(
494            dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
495    else:
496      return tpu_ops.infeed_dequeue_tuple(
497          dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
498
499  def _generate_enqueue_op(self,
500                           inputs,
501                           name_prefix,
502                           index,
503                           device=None,
504                           tpu_ordinal=-1):
505    """Generate a host-side Op to enqueue a tuple to the queue.
506
507    If device is None the inputs are all required to have the same
508    device specification, and the enqueue Op is colocated with
509    inputs[0]. Otherwise the enqueue Op is placed on 'device'.
510
511    Args:
512      inputs: a list of Tensors with the types and shapes of the tuple elements.
513      name_prefix: the base name for the Op.
514      index: the shard index, used to uniquify the Op name.
515      device: device to place the Op on, or None if it should be
516        colocated with the inputs.
517      tpu_ordinal: ordinal of the TPU device on the host to use for
518      infeed if device is a CPU device. Should be set to -1 if device
519      is a TPU device.
520
521    Returns:
522      An Op corresponding to a shard of infeed enqueued at the host,
523      suitable for use within a replicated block.
524
525    Raises:
526      ValueError: if device is None and inputs do not all have the
527        same device specification.
528    """
529    full_name = "%s/%d" % (name_prefix, index)
530    shapes = [t.shape for t in inputs]
531    if device is None:
532      devices = [t.device for t in inputs]
533      for i in xrange(1, self.number_of_tuple_elements):
534        if devices[0] != devices[i]:
535          raise ValueError(
536              "input devices for shard %d are %s, but should all be the same" %
537              (index, str(devices)))
538      with ops.colocate_with(inputs[0]):
539        return tpu_ops.infeed_enqueue_tuple(
540            inputs=inputs,
541            shapes=shapes,
542            name=full_name,
543            device_ordinal=tpu_ordinal)
544    else:
545      with ops.device(device):
546        return tpu_ops.infeed_enqueue_tuple(
547            inputs=inputs,
548            shapes=shapes,
549            name=full_name,
550            device_ordinal=tpu_ordinal)
551
552  def generate_enqueue_ops(self,
553                           sharded_inputs,
554                           tpu_ordinal_function=None,
555                           placement_function=None):
556    """Generates the host-side Ops to enqueue the shards of a tuple.
557
558    sharded_inputs is a list, one for each shard, of lists of
559    Tensors. sharded_inputs[0] is the tuple of Tensors to use to feed
560    shard 0 if the queue. Returns the host-side Ops that must be run to
561    enqueue the sharded tuple. The Op for shard i is colocated with the inputs
562    for shard i.
563
564    Implicitly freezes the queue configuration if it is not already
565    frozen. If the configuration has already been frozen, and is not
566    compatible with the types and shapes of sharded_inputs, an error
567    will be raised.
568
569    Args:
570      sharded_inputs: a list of lists of Tensors. The length of the outer list
571        determines the number of shards. Each inner list indicates the types
572        and shapes of the tuples in the corresponding shard.
573      tpu_ordinal_function: if not None, a function that takes the
574        shard index as input and returns the ordinal of the TPU device
575        the shard's infeed should be placed on. tpu_ordinal_function must be
576        set if the inputs are placed on CPU devices.
577      placement_function: if not None, a function that takes the shard index as
578        input and returns the host device where the enqueue op should be placed
579        on.
580
581    Returns:
582      A list of host-side Ops, one for each shard, that when executed together
583      will enqueue a full-size element of infeed.
584
585    Raises:
586      ValueError: if the queue configuration has previously been frozen and the
587        shapes of the elements of sharded_inputs are not compatible with the
588        frozen configuration; or if the shapes of the elements of sharded_inputs
589        don't form a consistent unsharded tuple; or if the elements of a tuple
590        have different device constraints.
591      TypeError: if the queue configuration has previously been frozen and the
592        types of the elements of sharded_inputs are not compatible with the
593        frozen configuration; or if the types of the elements of sharded_inputs
594        don't form a consistent unsharded tuple.
595    """
596    self.set_configuration_from_sharded_input_tensors(sharded_inputs)
597    self.freeze()
598    if self._generated_enqueue_ops:
599      raise ValueError("Can't generate two enqueue Ops from the same queue")
600    self._generated_enqueue_ops = True
601    if tpu_ordinal_function is None:
602      tpu_ordinal_function = lambda index: -1
603    name_prefix = "%s/enqueue" % self._name
604    return [
605        self._generate_enqueue_op(
606            shard,
607            name_prefix,
608            index,
609            tpu_ordinal=tpu_ordinal_function(index),
610            device=placement_function(index) if placement_function else None)
611        for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
612    ]
613
614  # TODO(misard) Generalize this to the case of systems that don't
615  # have 8 devices per host, and figure out what to do with
616  # model-parallelism.
617  def _default_placement_function(self, index):
618    return "/task:%d/device:CPU:0" % (index / 8)
619
620  def _default_ordinal_function(self, index):
621    return index % 8
622
623  # TODO(b/36470756) remove this from tutorials once we have a better story
624  # for automatic placement of input pipelines.
625  def split_inputs_and_generate_enqueue_ops(self,
626                                            inputs,
627                                            device_assignment=None,
628                                            placement_function=None,
629                                            tpu_ordinal_function=None):
630    """POORLY-PERFORMING ON MULTI-HOST SYSTEMS.
631
632    Generates the host-side Ops to enqueue a tuple.
633
634    This method performs poorly because it takes an entire input on a single
635    host, splits it, and distributes it to all of the cores. It is present only
636    to simplify tutorial examples.
637
638    inputs is a list of Tensors to use to feed the queue. Each input is split
639    into self.number_of_shards shards. Returns an Op for each shard to enqueue
640    the shard. The Op for shard i is placed on device placement_function(i).
641
642    Implicitly freezes the queue configuration if it is not already
643    frozen. If the configuration has already been frozen, and is not
644    compatible with the types and shapes of inputs, an error
645    will be raised.
646
647    Args:
648      inputs: a list of Tensors which indicates the types and shapes of the
649        queue tuple.
650     device_assignment: if not `None`, a TPU `DeviceAssignment`. If
651        device_assignment is not `None`, but `placement_function` and
652        `ordinal_function` are None, then `device_assignment` will be used to
653        place infeeds on the first k TPU shards, where k is the number of shards
654        in the queue. If all three are `None`, then default placement and
655        ordinal functions are used.
656      placement_function: if not None, a function that takes the shard
657        index as input and returns a device string indicating which
658        device the shard's infeed should be placed on. If placement_function
659        and tpu_ordinal_function are None, inputs are sharded round-robin
660        across the devices in the system.
661      tpu_ordinal_function: if not None, a function that takes the
662        shard index as input and returns the ordinal of the TPU device
663        the shard's infeed should be placed on. If placement_function
664        and tpu_ordinal_function are None, inputs are sharded round-robin
665        across the devices in the system.
666
667    Returns:
668      A list of host-side Ops, one for each shard, that when executed together
669      will enqueue a full-size element of infeed.
670
671    Raises:
672      ValueError: if the queue configuration has previously been frozen and the
673        shapes of the elements of inputs are not compatible with the frozen
674        configuration.
675      TypeError: if the queue configuration has previously been frozen and the
676        types of the elements of inputs are not compatible with the frozen
677        configuration.
678    """
679    if device_assignment is None:
680      if placement_function is None:
681        placement_function = self._default_placement_function
682      if tpu_ordinal_function is None:
683        tpu_ordinal_function = self._default_ordinal_function
684    else:
685
686      def _placement_function_from_map(index):
687        return device_assignment.host_device(replica=index)
688
689      def _ordinal_function_from_map(index):
690        return device_assignment.tpu_ordinal(replica=index)
691
692      if placement_function is None:
693        placement_function = _placement_function_from_map
694      if tpu_ordinal_function is None:
695        tpu_ordinal_function = _ordinal_function_from_map
696    self.set_configuration_from_input_tensors(inputs)
697    self.freeze()
698    if self._generated_enqueue_ops:
699      raise ValueError("Can't generate two enqueue Ops from the same queue")
700    self._generated_enqueue_ops = True
701    split_name_prefix = "%s/split" % self._name
702    if self.number_of_shards == 1:
703      transposed_sharded_inputs = [[inp] for inp in inputs]
704    else:
705
706      def split_fn(inp, num_shards, axis, name):
707        with ops.colocate_with(inp):
708          return array_ops.split(inp, num_shards, axis=axis, name=name)
709
710      transposed_sharded_inputs = [
711          split_fn(
712              inp,
713              self.number_of_shards,
714              axis=policy.shard_dimension,
715              name="%s/%d" % (split_name_prefix, index))
716          for (inp, policy, index) in zip(inputs, self._sharding_policies,
717                                          xrange(self.number_of_tuple_elements))
718      ]
719    sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs]
720                      for i in xrange(self.number_of_shards)]
721    name_prefix = "%s/enqueue" % self._name
722    return [
723        self._generate_enqueue_op(
724            shard,
725            name_prefix,
726            index,
727            device=placement_function(index),
728            tpu_ordinal=tpu_ordinal_function(index))
729        for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
730    ]
731
732
733class _PartitionedInfeedQueue(InfeedQueue):
734  """A helper object to build a device infeed queue with input partition.
735
736  Args:
737    number_of_tuple_elements: the number of Tensors fed atomically through the
738      queue, must be present unless it can be inferred from other arguments.
739    device_assignment: A TPU `DeviceAssignment` which is used to place all the
740      partitions to different TPU infeed queues.
741    host_id: The id of the host machine.
742    input_partition_dims: A nested list/tuple of integers. Each inner
743      list/tuple describes how to partition the corresponding input tensor.
744    tuple_types: If not None, a list of types of the elements of the queue.
745    tuple_shapes: If not None, a list of shapes of the elements of the queue.
746    name: The name of the queue.
747  """
748
749  def __init__(self,
750               number_of_tuple_elements,
751               device_assignment,
752               host_id,
753               input_partition_dims=None,
754               tuple_types=None,
755               tuple_shapes=None,
756               name=None):
757    super(_PartitionedInfeedQueue, self).__init__(
758        number_of_tuple_elements=number_of_tuple_elements,
759        tuple_types=tuple_types,
760        tuple_shapes=None,
761        shard_dimensions=None,
762        name="PartitionedInfeedQueue" if name is None else name)
763    self._input_partition_dims = input_partition_dims
764    self._host_id = host_id
765    self._device_assignment = device_assignment
766
767  def generate_dequeue_op(self, tpu_device=0):
768    """Generate TPU dequeue ops.
769
770    Args:
771      tpu_device: The TPU device ordinal where the infeed instruction should be
772        placed.
773
774    Returns:
775      A list of Outputs corresponding to a partition of infeed dequeued
776      into XLA, suitable for use within a replicated block.
777
778    Raises:
779      ValueError: if the types or shapes of the tuple elements have not been
780      set; or if a dequeue op has already been generated.
781    """
782    self.freeze()
783    if self._generated_dequeue_op:
784      raise ValueError("Can't generate two dequeue Ops from the same queue")
785    self._generated_dequeue_op = True
786    full_name = "%s/dequeue" % self._name
787    sharded_shapes = [
788        policy.get_sharded_shape(shape)
789        for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
790    ]
791    with ops.device(tpu.core(tpu_device)):
792      values = tpu_ops.infeed_dequeue_tuple(
793          dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
794    return tag_sharding_attribute_for_dequeued_tensors(
795        values, self._input_partition_dims)
796
797  def generate_enqueue_ops(self, sharded_inputs):
798    """Generates the host-side Ops to enqueue the partitioned inputs.
799
800    sharded_inputs is a list, one for each replica, of lists of
801    Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed
802    replica i.
803    sharded_inputs[i][j] is partitioned by self._input_partition_dims[j].
804
805    For example, if sharded_inputs[i][j] is a 2-D Tensor:
806    [[A, B, C, D],
807     [E ,F, G, H]]
808    self._input_partition_dims[j] is [2, 4].
809
810    sharded_inputs[i][j] will be partitioned and flattened into:
811    [A, B, C, D, E, F, G, H] and fed into the logical core ids:
812    [0, 1, 2, 3, 4, 5, 6, 7] respectively.
813
814    Args:
815      sharded_inputs: a list of lists of Tensors. The length of the
816        outer list determines the number of shards. Each inner list indicates
817        the types and shapes of the tuples in the corresponding shard.
818
819    Returns:
820      A list of host-side Ops, one for each shard, that when executed together
821      will enqueue a full-size element of infeed.
822
823    Raises:
824      ValueError: if the queue configuration has previously been frozen and the
825        shapes of the elements of sharded_inputs are not compatible with the
826        frozen configuration; or if the shapes of the elements of sharded_inputs
827        don't form a consistent unsharded tuple; or if the elements of a tuple
828        have different device constraints; or if the partition dims are invalid.
829      TypeError: if the queue configuration has previously been frozen and the
830        types of the elements of sharded_inputs are not compatible with the
831        frozen configuration; or if the types of the elements of sharded_inputs
832        don't form a consistent unsharded tuple.
833    """
834    self.set_configuration_from_sharded_input_tensors(sharded_inputs)
835    number_of_replicas = len(sharded_inputs)
836    number_of_tuple_elements = len(sharded_inputs[0])
837
838    assert len(self._input_partition_dims) == number_of_tuple_elements
839    enqueue_ops = []
840
841    for replica_index in range(number_of_replicas):
842      flattened_inputs = sharded_inputs[replica_index]
843      inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs,
844                                                 self._input_partition_dims)
845      inputs_parted_iters = [
846          iter(self._check_dims_and_partition_or_replicate_on_host(x, dims))
847          for x, dims in zip(sharded_inputs[replica_index],
848                             inputs_part_dims_flat)
849      ]
850
851      # Find the replica_id of the host's logical core 0.
852      # The self._host_id is guaranteed to contain the logical core 0,
853      # even when num_cores_per_replica > num_cores_per_host -- the function
854      # caller makes sure that this host_id will must be receiving data (calls
855      # input_fn).
856      replica_id = self._device_assignment.lookup_replicas(
857          task_id=self._host_id, logical_core=0)[replica_index]
858      for logical_core in xrange(self._device_assignment.num_cores_per_replica):
859        # Places different partitions to different logic cores.
860        # Since there can be multiple hosts per replica, we need to find
861        # the actual host (device) of this logical core.
862        device = self._device_assignment.host_device(
863            replica=replica_id, logical_core=logical_core)
864
865        with ops.device(device):
866          ordinal = self._device_assignment.tpu_ordinal(
867              replica=replica_id, logical_core=logical_core)
868          infeed_inputs = []
869          for it in inputs_parted_iters:
870            input_for_device = next(it, None)
871            if input_for_device is not None:
872              infeed_inputs.append(input_for_device)
873
874          if infeed_inputs:
875            enqueue_ops.append(
876                tpu_ops.infeed_enqueue_tuple(
877                    inputs=infeed_inputs,
878                    shapes=[x.shape for x in infeed_inputs],
879                    name="enqueue/replica_{0}/input_{1}".format(
880                        replica_index, logical_core),
881                    device_ordinal=ordinal))
882    return enqueue_ops
883
884  def _check_input_partition_dims(self, tensor, dims):
885    """Checks that input partition dims are valid for the `Tensor`.
886
887    Args:
888      tensor: Input tensor for partitioning.
889      dims: A list of integer describes how to partition the input tensor.
890
891    Raises:
892      ValueError: If the tensor can't be partitioned by dims or the
893        num_cores_per_replica doesn't match the number of
894        partitions(dims.prod()).
895    """
896    # No partitioning specified, so don't perform further checks.
897    if dims is None:
898      return
899
900    dims = np.array(dims)
901
902    if (dims < 1).any():
903      raise ValueError("All input partition dims must be >= 1.")
904
905    # No partitioning, so don't perform further checks.
906    if dims.prod() == 1:
907      return
908
909    if dims.prod() != self._device_assignment.num_cores_per_replica:
910      raise ValueError(
911          "The product of each input partition dim should equal to "
912          "num_cores_per_replica. (dim = {}, num_cores_per_replica "
913          "= {})".format(dims, self._device_assignment.num_cores_per_replica))
914    if dims.shape[0] != tensor.shape.ndims:
915      raise ValueError(
916          "Input partition dims must have the same number of dimensions "
917          "as the `Tensor` to be partitioned. (tensor shape = {}, input "
918          "partition dims = {}).".format(tensor.shape.as_list(), dims))
919
920    tensor.shape.assert_is_fully_defined()
921
922  def _check_dims_and_partition_or_replicate_on_host(self, tensor, dims):
923    """Checks dims and partitions or replicates the input tensor.
924
925      The ops inside this function are placed on the host side.
926
927    Args:
928      tensor: The input tensor which will be partitioned or replicated.
929      dims: A list of integer describes how to partition the input tensor.
930
931    Returns:
932      An iterator of `Tensor`s or a list of partitioned tensors.
933    """
934    self._check_input_partition_dims(tensor, dims)
935    return partition_or_replicate_on_host(tensor, dims)
936