• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15
16"""Helper library for handling infeed between hosts and TPUs.
17"""
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import itertools
24
25import numpy as np
26from six.moves import xrange  # pylint: disable=redefined-builtin
27
28from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.ops import array_ops
33from tensorflow.python.tpu import tpu
34from tensorflow.python.tpu import tpu_sharding
35from tensorflow.python.tpu.ops import tpu_ops
36
37from tensorflow.python.util import nest
38
39
40def partition_or_replicate_on_host(tensor, dims):
41  """Partitions or replicates the input tensor.
42
43    The ops inside this function are placed on the host side.
44
45  Args:
46    tensor: The input tensor which will be partioned or replicated.
47    dims: A list of integer describes how to partition the input tensor.
48
49  Returns:
50    An iterator of `Tensor`s or a list of partioned tensors.
51  """
52  if dims is None:
53    return itertools.repeat(tensor)
54  dims = np.array(dims)
55  output = [tensor]
56  shape_list = np.array(tensor.shape.as_list())
57  quotients, remainders = np.divmod(shape_list, dims)
58  for axis, (quotient, remainder, dim, original_size) in enumerate(
59      zip(quotients, remainders, dims, shape_list)):
60    if dim <= 1:
61      continue
62    if remainder > 0:
63      # For each dimension, when it cannot be evenly partitioned, XLA assumes
64      # tensors are partitioned in a greedy manner by using
65      # ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims
66      # are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] =>
67      # [[(3, 4), (3, 4), (2, 4), (2, 2)],
68      # [(2, 4), (2, 4), (2, 4), (2, 2)]]
69      ceil_ratio = quotient + 1
70      num_full_slots, left_over = np.divmod(original_size, ceil_ratio)
71      num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over]
72      if len(num_or_size_splits) < dim:
73        num_or_size_splits += [0] * (dim - len(num_or_size_splits))
74      new_output = []
75      for x in output:
76        new_output.append(
77            array_ops.split(
78                x, num_or_size_splits=num_or_size_splits, axis=axis))
79      output = new_output
80    else:
81      output = [array_ops.split(x, dim, axis=axis) for x in output]
82    output = nest.flatten(output)
83  return output
84
85
86def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
87  """Tags appropriate XLA sharding attribute to the dequeued tensor.
88
89  Args:
90    tensor: The dequeued tensor on TPU.
91    dims: A list of integer describes how the tensor is partitioned.
92
93  Returns:
94    The same tensor with the xla_sharding attribute.
95  """
96  if dims is None:
97    return xla_sharding.replicate(tensor)
98  elif np.prod(dims) == 1:
99    return xla_sharding.assign_device(tensor, 0)
100  else:
101    tile_assignment = np.arange(np.prod(dims)).reshape(dims)
102    return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment)
103
104
105def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims):
106  """Tags appropriate XLA sharding attribute to the dequeued tensors.
107
108  Args:
109    dequeues: A list of dequeued tensors on TPU.
110    dims: A list of integer describes how the tensor is partitioned.
111
112  Returns:
113    The same dequeues with appropriate xla_sharding attribute.
114  """
115  nest.assert_shallow_structure(dequeues, dims)
116  return nest.map_structure_up_to(
117      dequeues, _tag_sharding_attribute_for_dequeued_tensor, dequeues, dims)
118
119
120class InfeedQueue(object):
121  """A helper object to build a device infeed queue.
122
123  The InfeedQueue builds the host-side and device-side Ops to enqueue and
124  dequeue elements, respectively, and ensures that their types and
125  shapes match.
126  """
127
128  def __init__(self,
129               number_of_tuple_elements=None,
130               tuple_types=None,
131               tuple_shapes=None,
132               shard_dimensions=None,
133               name=None):
134    """Creates a new InfeedQueue with the given configuration.
135
136    The configuration need not be fully specified at creation since it
137    can be modified subsequently by methods that set the values
138    explicitly or infer them from the shapes of inputs.
139
140    Args:
141      number_of_tuple_elements: the number of Tensors fed atomically through the
142        queue, must be present unless it can be inferred from other arguments.
143      tuple_types: if not None, a list of types of the elements of the queue.
144      tuple_shapes: if not None, a list of shapes of the elements of the queue.
145      shard_dimensions: if not None, a list of dimensions on which the
146        elements of the queue should be sharded during automatic
147        parallelization.
148      name: the name of the queue.
149
150    Raises:
151      ValueError: if number_of_tuple_elements <= 0; or
152        number_of_tuple_arguments, tuple_types, tuple_shapes, and
153        shard_dimensions are all None; or the length of tuple_types,
154        tuple_shapes, or shard_dimensions is not equal to
155        number_of_tuple_elements; or any element of shard_dimensions
156        can't be converted to a Dimension.
157      TypeError: if any element of tuple_types or tuple_shapes can't
158        be converted to a dtype or TensorShape, respectively.
159    """
160    self._frozen = False
161    self._generated_enqueue_ops = False
162    self._generated_dequeue_op = False
163    self._name = "InfeedQueue" if name is None else name
164    if number_of_tuple_elements is None:
165      if tuple_types is not None:
166        number_of_tuple_elements = len(tuple_types)
167      elif tuple_shapes is not None:
168        number_of_tuple_elements = len(tuple_shapes)
169      elif shard_dimensions is not None:
170        number_of_tuple_elements = len(shard_dimensions)
171      else:
172        raise ValueError(
173            "number of tuple elements cannot be inferred from InfeedQueue "
174            "constructor")
175    if number_of_tuple_elements <= 0:
176      raise ValueError("number_of_tuple_elements %d must be > 0" %
177                       number_of_tuple_elements)
178    # Make an empty sharding policy for each tuple element.
179    self._sharding_policies = [
180        tpu_sharding.ShardingPolicy()
181        for _ in xrange(number_of_tuple_elements)
182    ]
183    if tuple_types is not None:
184      self.set_tuple_types(tuple_types)
185    else:
186      self._tuple_types = None
187    if tuple_shapes is not None:
188      self.set_tuple_shapes(tuple_shapes)
189    else:
190      self._tuple_shapes = None
191    if shard_dimensions is not None:
192      self.set_shard_dimensions(shard_dimensions)
193    self._validate()
194
195  def _validate(self):
196    """Checks that the configuration is self-consistent.
197
198    Raises:
199      ValueError: if the shapes and sharding policies don't match.
200    """
201    if self.tuple_shapes is not None:
202      for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes):
203        # Raise an error if the policy is incompatible with the shape.
204        _ = policy.get_sharded_shape(shape)
205
206  @property
207  def number_of_tuple_elements(self):
208    """Returns the number of InfeedQueue tuple elements."""
209    return len(self._sharding_policies)
210
211  @property
212  def tuple_types(self):
213    """Returns the types of the InfeedQueue tuple elements."""
214    return self._tuple_types
215
216  def set_tuple_types(self, tuple_types):
217    """Sets the type of each element of the queue.
218
219    tuple_types must be a list of length
220    self.number_of_tuple_elements, and each element must be
221    convertible to a dtype.
222
223    Args:
224      tuple_types: the types of each queue element.
225
226    Raises:
227      ValueError: if tuple_types is not of length
228        self.number_of_tuple_elements.
229      TypeError: if an element of tuple_types cannot be converted to a
230        dtype.
231    """
232    if len(tuple_types) != self.number_of_tuple_elements:
233      raise ValueError("tuple_types is %s, but must be a list of length %d" %
234                       (str(tuple_types), self.number_of_tuple_elements))
235    if self._frozen:
236      for (frozen, updated) in zip(self._tuple_types, tuple_types):
237        if frozen != updated:
238          raise ValueError(
239              "Trying to update InfeedQueue with frozen configuration with an "
240              "incompatible type. Frozen types are %s, updated types are %s" % (
241                  str(self._tuple_types), str(tuple_types)))
242    else:
243      try:
244        self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types]
245      except (TypeError) as e:
246        raise TypeError(
247            "tuple_types is %s, but must be a list of elements each "
248            "convertible to dtype: got error %s" % (str(tuple_types), str(e)))
249
250  @property
251  def tuple_shapes(self):
252    """Returns the shapes of the InfeedQueue tuple elements."""
253    return self._tuple_shapes
254
255  def set_tuple_shapes(self, tuple_shapes):
256    """Sets the shape of each element of the queue.
257
258    tuple_shapes must be a list of length
259    self.number_of_tuple_elements, and each element must be
260    convertible to a TensorShape.
261
262    Args:
263      tuple_shapes: the shapes of each queue element.
264
265    Raises:
266      ValueError: if tuple_shapes is not of length
267        self.number_of_tuple_elements.
268      TypeError: if an element of tuple_shapes cannot be converted to
269        a TensorShape.
270    """
271    if len(tuple_shapes) != self.number_of_tuple_elements:
272      raise ValueError("tuple_shapes is %s, but must be a list of length %d" %
273                       (str(tuple_shapes), self.number_of_tuple_elements))
274    try:
275      tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes]
276    except (ValueError, TypeError) as e:
277      raise TypeError(
278          "tuple_shapes is %s, but must be a list of elements each "
279          "convertible to TensorShape: got error %s" % (str(tuple_shapes),
280                                                        str(e)))
281    if self._frozen:
282      for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes):
283        if frozen != updated:
284          raise ValueError(
285              "Trying to update InfeedQueue with frozen configuration with an "
286              "incompatible shape. Frozen shapes are %s, updated shapes are %s"
287              % (str(self._tuple_shapes), str(tuple_shapes)))
288    else:
289      self._tuple_shapes = tuple_shapes
290    self._validate()
291
292  @property
293  def sharding_policies(self):
294    """Returns the sharding policies of the InfeedQueue tuple elements."""
295    return self._sharding_policies
296
297  @property
298  def shard_dimensions(self):
299    """Gets the shard dimension of each tuple element.
300
301    Returns:
302      A list of length number_of_tuple_elements, where each list entry
303      is the shard dimension of that tuple element or None if the
304      shard dimension has not been set.
305    """
306    # The number of shards is always the same for all the policies.
307    return [policy.shard_dimension for policy in self._sharding_policies]
308
309  def set_shard_dimensions(self, shard_dimensions):
310    """Sets the shard_dimension of each element of the queue.
311
312    shard_dimensions must be a list of length
313    self.number_of_tuple_elements, and each element must be
314    convertible to a Dimension compatible with self.tuple_shapes.
315
316    Args:
317      shard_dimensions: the dimensions of each queue element.
318
319    Raises:
320      ValueError: if shard_dimensions is not of length
321        self.number_of_tuple_elements; or an element of
322        shard_dimensions cannot be converted to a Dimension; or an
323        element of shard_dimensions is a Dimension that is out of
324        range for the corresponding tuple element shape.
325    """
326    if len(shard_dimensions) != self.number_of_tuple_elements:
327      raise ValueError("shard_dimensions is %s, but must be a list of length %d"
328                       % (str(shard_dimensions),
329                          self.number_of_tuple_elements))
330    for (policy, dimension) in zip(self._sharding_policies, shard_dimensions):
331      policy.set_shard_dimension(dimension)
332    self._validate()
333
334  @property
335  def number_of_shards(self):
336    """Gets the number of shards to use for the InfeedQueue.
337
338    Returns:
339      Number of shards or None if the number of shards has not been set.
340    """
341    # The number of shards is always the same for all the policies.
342    return self._sharding_policies[0].number_of_shards
343
344  def set_number_of_shards(self, number_of_shards):
345    """Sets the number of shards to use for the InfeedQueue.
346
347    Args:
348      number_of_shards: number of ways to shard the InfeedQueue.
349
350    Raises:
351      ValueError: if number_of_shards is not > 0; or the policies have
352        been frozen and number_of_shards was already set to something
353        else.
354    """
355    for policy in self._sharding_policies:
356      policy.set_number_of_shards(number_of_shards)
357    self._validate()
358
359  def set_configuration_from_input_tensors(self, input_tensors):
360    """Sets the shapes and types of the queue tuple elements.
361
362    input_tensors is a list of Tensors whose types and shapes are used
363    to set the queue configuration.
364
365    Args:
366      input_tensors: list of Tensors of the same types and shapes as
367        the desired queue Tuple.
368
369    Raises:
370      ValueError: if input_tensors is not a list of length
371        self.number_of_tuple_elements
372    """
373    if len(input_tensors) != self.number_of_tuple_elements:
374      raise ValueError("input_tensors is %s, but should be a list of %d Tensors"
375                       % (str(input_tensors), self.number_of_tuple_elements))
376    self.set_tuple_shapes([t.shape for t in input_tensors])
377    self.set_tuple_types([t.dtype for t in input_tensors])
378
379  def set_configuration_from_sharded_input_tensors(self, input_tensors):
380    """Sets the shapes and types of the queue tuple elements.
381
382    input_tensors is a list of lists of Tensors whose types and shapes are used
383    to set the queue configuration. The length of the outer list is the number
384    of shards required, and each inner list is the tuple of Tensors to use to
385    determine the types and shapes of the corresponding shard. This method
386    depends on the shard dimension, and calling it freezes the shard policy.
387
388    Args:
389      input_tensors: list of lists of Tensors. The outer list length corresponds
390        to the desired number of shards, and each inner list is the size
391        and shape of the desired configuration of the corresponding shard.
392
393    Raises:
394      ValueError: if any inner list is not a list of length
395        self.number_of_tuple_elements; or the inner lists do not combine to
396        form a consistent unsharded shape.
397      TypeError: if the types of the Tensors in the inner lists do not match.
398    """
399    if not self._frozen:
400      # Unset the tuple shapes in case the configuration becomes
401      # transiently inconsistent.
402      self._tuple_shapes = None
403    number_of_shards = len(input_tensors)
404    self.set_number_of_shards(number_of_shards)
405    for t in input_tensors:
406      if len(t) != self.number_of_tuple_elements:
407        raise ValueError(
408            "input_tensors is %s but must be a list of lists, where each inner"
409            " list has length number_of_tuple_elements=%d" % (
410                str(input_tensors), self.number_of_tuple_elements))
411    # Transpose the inputs to make a list of shard shapes for each tuple
412    # element.
413    sharded_shapes = [[t[i].shape for t in input_tensors]
414                      for i in xrange(self.number_of_tuple_elements)]
415    # For each tuple, get the unsharded shape using that tuple's policy.
416    unsharded_shapes = [
417        policy.get_unsharded_shape(s)
418        for (policy, s) in zip(self._sharding_policies, sharded_shapes)
419    ]
420    self.set_tuple_shapes(unsharded_shapes)
421    for i in xrange(1, self.number_of_shards):
422      for (t1, t2) in zip(input_tensors[0], input_tensors[i]):
423        if t1.dtype != t2.dtype:
424          raise TypeError(
425              "types of the tuple elements of input_tensors %s are not "
426              "consistent" % str(input_tensors))
427    self.set_tuple_types([t.dtype for t in input_tensors[0]])
428
429  def freeze(self):
430    """Freezes the InfeedQueue so it can no longer be modified.
431
432    The configuration is implicitly frozen before any host-side or
433    device-side Ops are generated. The configuration cannot be frozen
434    until the types and shapes of the tuple elements have been set.
435
436    Raises:
437      ValueError: if the types or shapes of the tuple elements have not been
438      set.
439    """
440    self._frozen = True
441    if self._tuple_types is None:
442      raise ValueError(
443          "Can't freeze an InfeedQueue without setting all tuple types.")
444    if self._tuple_shapes is None:
445      raise ValueError(
446          "Can't freeze an InfeedQueue without setting all tuple shapes.")
447    for shape in self._tuple_shapes:
448      if shape.dims is None:
449        raise ValueError(
450            "Can't freeze an InfeedQueue without setting all tuple shapes.")
451    for policy in self._sharding_policies:
452      policy.freeze()
453    self._validate()
454
455  def generate_dequeue_op(self, tpu_device=0):
456    """Generates the device-side Op to dequeue a tuple from the queue.
457
458    Implicitly freezes the queue configuration if it is not already
459    frozen, which will raise errors if the shapes and types have not
460    been fully specified.
461
462    Args:
463      tpu_device: The TPU device ordinal where the infeed instruction should be
464        placed. If None, no explicit placement will be performed, and it is up
465        to the user to call this API from within a proper TPU device scope.
466        The XLA code will fail if the TPU dequeue instruction is not bound to
467        any device.
468
469    Returns:
470      A list of Outputs corresponding to a shard of infeed dequeued
471      into XLA, suitable for use within a replicated block.
472
473    Raises:
474      ValueError: if the types or shapes of the tuple elements have not been
475      set; or if a dequeue op has already been generated.
476    """
477    self.freeze()
478    if self._generated_dequeue_op:
479      raise ValueError("Can't generate two dequeue Ops from the same queue")
480    self._generated_dequeue_op = True
481    full_name = "%s/dequeue" % self._name
482    sharded_shapes = [
483        policy.get_sharded_shape(shape)
484        for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
485    ]
486    if tpu_device is not None:
487      with ops.device(tpu.core(tpu_device)):
488        return tpu_ops.infeed_dequeue_tuple(
489            dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
490    else:
491      return tpu_ops.infeed_dequeue_tuple(
492          dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
493
494  def _generate_enqueue_op(self,
495                           inputs,
496                           name_prefix,
497                           index,
498                           device=None,
499                           tpu_ordinal=-1):
500    """Generate a host-side Op to enqueue a tuple to the queue.
501
502    If device is None the inputs are all required to have the same
503    device specification, and the enqueue Op is colocated with
504    inputs[0]. Otherwise the enqueue Op is placed on 'device'.
505
506    Args:
507      inputs: a list of Tensors with the types and shapes of the tuple elements.
508      name_prefix: the base name for the Op.
509      index: the shard index, used to uniquify the Op name.
510      device: device to place the Op on, or None if it should be
511        colocated with the inputs.
512      tpu_ordinal: ordinal of the TPU device on the host to use for
513      infeed if device is a CPU device. Should be set to -1 if device
514      is a TPU device.
515
516    Returns:
517      An Op corresponding to a shard of infeed enqueued at the host,
518      suitable for use within a replicated block.
519
520    Raises:
521      ValueError: if device is None and inputs do not all have the
522        same device specification.
523    """
524    full_name = "%s/%d" % (name_prefix, index)
525    shapes = [t.shape for t in inputs]
526    if device is None:
527      devices = [t.device for t in inputs]
528      for i in xrange(1, self.number_of_tuple_elements):
529        if devices[0] != devices[i]:
530          raise ValueError(
531              "input devices for shard %d are %s, but should all be the same" %
532              (index, str(devices)))
533      with ops.colocate_with(inputs[0]):
534        return tpu_ops.infeed_enqueue_tuple(
535            inputs=inputs,
536            shapes=shapes,
537            name=full_name,
538            device_ordinal=tpu_ordinal)
539    else:
540      with ops.device(device):
541        return tpu_ops.infeed_enqueue_tuple(
542            inputs=inputs,
543            shapes=shapes,
544            name=full_name,
545            device_ordinal=tpu_ordinal)
546
547  def generate_enqueue_ops(self,
548                           sharded_inputs,
549                           tpu_ordinal_function=None,
550                           placement_function=None):
551    """Generates the host-side Ops to enqueue the shards of a tuple.
552
553    sharded_inputs is a list, one for each shard, of lists of
554    Tensors. sharded_inputs[0] is the tuple of Tensors to use to feed
555    shard 0 if the queue. Returns the host-side Ops that must be run to
556    enqueue the sharded tuple. The Op for shard i is colocated with the inputs
557    for shard i.
558
559    Implicitly freezes the queue configuration if it is not already
560    frozen. If the configuration has already been frozen, and is not
561    compatible with the types and shapes of sharded_inputs, an error
562    will be raised.
563
564    Args:
565      sharded_inputs: a list of lists of Tensors. The length of the outer list
566        determines the number of shards. Each inner list indicates the types
567        and shapes of the tuples in the corresponding shard.
568      tpu_ordinal_function: if not None, a function that takes the
569        shard index as input and returns the ordinal of the TPU device
570        the shard's infeed should be placed on. tpu_ordinal_function must be
571        set if the inputs are placed on CPU devices.
572      placement_function: if not None, a function that takes the shard index as
573        input and returns the host device where the enqueue op should be placed
574        on.
575
576    Returns:
577      A list of host-side Ops, one for each shard, that when executed together
578      will enqueue a full-size element of infeed.
579
580    Raises:
581      ValueError: if the queue configuration has previously been frozen and the
582        shapes of the elements of sharded_inputs are not compatible with the
583        frozen configuration; or if the shapes of the elements of sharded_inputs
584        don't form a consistent unsharded tuple; or if the elements of a tuple
585        have different device constraints.
586      TypeError: if the queue configuration has previously been frozen and the
587        types of the elements of sharded_inputs are not compatible with the
588        frozen configuration; or if the types of the elements of sharded_inputs
589        don't form a consistent unsharded tuple.
590    """
591    self.set_configuration_from_sharded_input_tensors(sharded_inputs)
592    self.freeze()
593    if self._generated_enqueue_ops:
594      raise ValueError("Can't generate two enqueue Ops from the same queue")
595    self._generated_enqueue_ops = True
596    if tpu_ordinal_function is None:
597      tpu_ordinal_function = lambda index: -1
598    name_prefix = "%s/enqueue" % self._name
599    return [
600        self._generate_enqueue_op(
601            shard,
602            name_prefix,
603            index,
604            tpu_ordinal=tpu_ordinal_function(index),
605            device=placement_function(index) if placement_function else None)
606        for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
607    ]
608
609  # TODO(misard) Generalize this to the case of systems that don't
610  # have 8 devices per host, and figure out what to do with
611  # model-parallelism.
612  def _default_placement_function(self, index):
613    return "/task:%d/device:CPU:0" % (index / 8)
614
615  def _default_ordinal_function(self, index):
616    return index % 8
617
618  # TODO(b/36470756) remove this from tutorials once we have a better story
619  # for automatic placement of input pipelines.
620  def split_inputs_and_generate_enqueue_ops(self,
621                                            inputs,
622                                            device_assignment=None,
623                                            placement_function=None,
624                                            tpu_ordinal_function=None):
625    """POORLY-PERFORMING ON MULTI-HOST SYSTEMS.
626
627    Generates the host-side Ops to enqueue a tuple.
628
629    This method performs poorly because it takes an entire input on a single
630    host, splits it, and distributes it to all of the cores. It is present only
631    to simplify tutorial examples.
632
633    inputs is a list of Tensors to use to feed the queue. Each input is split
634    into self.number_of_shards shards. Returns an Op for each shard to enqueue
635    the shard. The Op for shard i is placed on device placement_function(i).
636
637    Implicitly freezes the queue configuration if it is not already
638    frozen. If the configuration has already been frozen, and is not
639    compatible with the types and shapes of inputs, an error
640    will be raised.
641
642    Args:
643      inputs: a list of Tensors which indicates the types and shapes of the
644        queue tuple.
645     device_assignment: if not `None`, a TPU `DeviceAssignment`. If
646        device_assignment is not `None`, but `placement_function` and
647        `ordinal_function` are None, then `device_assignment` will be used to
648        place infeeds on the first k TPU shards, where k is the number of shards
649        in the queue. If all three are `None`, then default placement and
650        ordinal functions are used.
651      placement_function: if not None, a function that takes the shard
652        index as input and returns a device string indicating which
653        device the shard's infeed should be placed on. If placement_function
654        and tpu_ordinal_function are None, inputs are sharded round-robin
655        across the devices in the system.
656      tpu_ordinal_function: if not None, a function that takes the
657        shard index as input and returns the ordinal of the TPU device
658        the shard's infeed should be placed on. If placement_function
659        and tpu_ordinal_function are None, inputs are sharded round-robin
660        across the devices in the system.
661
662    Returns:
663      A list of host-side Ops, one for each shard, that when executed together
664      will enqueue a full-size element of infeed.
665
666    Raises:
667      ValueError: if the queue configuration has previously been frozen and the
668        shapes of the elements of inputs are not compatible with the frozen
669        configuration.
670      TypeError: if the queue configuration has previously been frozen and the
671        types of the elements of inputs are not compatible with the frozen
672        configuration.
673    """
674    if device_assignment is None:
675      if placement_function is None:
676        placement_function = self._default_placement_function
677      if tpu_ordinal_function is None:
678        tpu_ordinal_function = self._default_ordinal_function
679    else:
680
681      def _placement_function_from_map(index):
682        return device_assignment.host_device(replica=index)
683
684      def _ordinal_function_from_map(index):
685        return device_assignment.tpu_ordinal(replica=index)
686
687      if placement_function is None:
688        placement_function = _placement_function_from_map
689      if tpu_ordinal_function is None:
690        tpu_ordinal_function = _ordinal_function_from_map
691    self.set_configuration_from_input_tensors(inputs)
692    self.freeze()
693    if self._generated_enqueue_ops:
694      raise ValueError("Can't generate two enqueue Ops from the same queue")
695    self._generated_enqueue_ops = True
696    split_name_prefix = "%s/split" % self._name
697    if self.number_of_shards == 1:
698      transposed_sharded_inputs = [[inp] for inp in inputs]
699    else:
700
701      def split_fn(inp, num_shards, axis, name):
702        with ops.colocate_with(inp):
703          return array_ops.split(inp, num_shards, axis=axis, name=name)
704
705      transposed_sharded_inputs = [
706          split_fn(
707              inp,
708              self.number_of_shards,
709              axis=policy.shard_dimension,
710              name="%s/%d" % (split_name_prefix, index))
711          for (inp, policy, index) in zip(inputs, self._sharding_policies,
712                                          xrange(self.number_of_tuple_elements))
713      ]
714    sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs]
715                      for i in xrange(self.number_of_shards)]
716    name_prefix = "%s/enqueue" % self._name
717    return [
718        self._generate_enqueue_op(
719            shard,
720            name_prefix,
721            index,
722            device=placement_function(index),
723            tpu_ordinal=tpu_ordinal_function(index))
724        for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
725    ]
726
727
728class _PartitionedInfeedQueue(InfeedQueue):
729  """A helper object to build a device infeed queue with input partition.
730
731  Args:
732    number_of_tuple_elements: the number of Tensors fed atomically through the
733      queue, must be present unless it can be inferred from other arguments.
734    device_assignment: A TPU `DeviceAssignment` which is used to place all the
735      partitions to different TPU infeed queues.
736    host_id: The id of the host machine.
737    input_partition_dims: A nested list/tuple of integers. Each inner
738      list/tuple describes how to partition the corresponding input tensor.
739    tuple_types: If not None, a list of types of the elements of the queue.
740    tuple_shapes: If not None, a list of shapes of the elements of the queue.
741    name: The name of the queue.
742  """
743
744  def __init__(self,
745               number_of_tuple_elements,
746               device_assignment,
747               host_id,
748               input_partition_dims=None,
749               tuple_types=None,
750               tuple_shapes=None,
751               name=None):
752    super(_PartitionedInfeedQueue, self).__init__(
753        number_of_tuple_elements=number_of_tuple_elements,
754        tuple_types=tuple_types,
755        tuple_shapes=None,
756        shard_dimensions=None,
757        name="PartitionedInfeedQueue" if name is None else name)
758    self._input_partition_dims = input_partition_dims
759    self._host_id = host_id
760    self._device_assignment = device_assignment
761
762  def generate_dequeue_op(self, tpu_device=0):
763    """Generate TPU dequeue ops.
764
765    Args:
766      tpu_device: The TPU device ordinal where the infeed instruction should be
767        placed.
768
769    Returns:
770      A list of Outputs corresponding to a partition of infeed dequeued
771      into XLA, suitable for use within a replicated block.
772
773    Raises:
774      ValueError: if the types or shapes of the tuple elements have not been
775      set; or if a dequeue op has already been generated.
776    """
777    self.freeze()
778    if self._generated_dequeue_op:
779      raise ValueError("Can't generate two dequeue Ops from the same queue")
780    self._generated_dequeue_op = True
781    full_name = "%s/dequeue" % self._name
782    sharded_shapes = [
783        policy.get_sharded_shape(shape)
784        for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
785    ]
786    with ops.device(tpu.core(tpu_device)):
787      values = tpu_ops.infeed_dequeue_tuple(
788          dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
789    return tag_sharding_attribute_for_dequeued_tensors(
790        values, self._input_partition_dims)
791
792  def generate_enqueue_ops(self, per_host_sharded_inputs):
793    """Generates the host-side Ops to enqueue the partitioned inputs.
794
795    per_host_sharded_inputs is a list, one for each replica, of lists of
796    Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed
797    replica i.
798    sharded_inputs[i][j] is partitioned by self._input_partition_dims[j].
799
800    For example, if sharded_inputs[i][j] is a 2-D Tensor:
801    [[A, B, C, D],
802     [E ,F, G, H]]
803    self._input_partition_dims[j] is [2, 4].
804
805    sharded_inputs[i][j] will be partitioned and flattened into:
806    [A, B, C, D, E, F, G, H] and fed into the logical core ids:
807    [0, 1, 2, 3, 4, 5, 6, 7] respectively.
808
809    Args:
810      per_host_sharded_inputs: a list of lists of Tensors. The length of the
811        outer list determines the number of shards. Each inner list indicates
812        the types and shapes of the tuples in the corresponding shard.
813
814    Returns:
815      A list of host-side Ops, one for each shard, that when executed together
816      will enqueue a full-size element of infeed.
817
818    Raises:
819      ValueError: if the queue configuration has previously been frozen and the
820        shapes of the elements of sharded_inputs are not compatible with the
821        frozen configuration; or if the shapes of the elements of sharded_inputs
822        don't form a consistent unsharded tuple; or if the elements of a tuple
823        have different device constraints; or if the partition dims are invalid.
824      TypeError: if the queue configuration has previously been frozen and the
825        types of the elements of sharded_inputs are not compatible with the
826        frozen configuration; or if the types of the elements of sharded_inputs
827        don't form a consistent unsharded tuple.
828    """
829    self.set_configuration_from_sharded_input_tensors(per_host_sharded_inputs)
830    number_of_replicas_per_host = len(per_host_sharded_inputs)
831    number_of_tuple_elements = len(per_host_sharded_inputs[0])
832
833    assert len(self._input_partition_dims) == number_of_tuple_elements
834    per_host_enqueue_ops = []
835
836    for replica_index in range(number_of_replicas_per_host):
837      flattened_inputs = per_host_sharded_inputs[replica_index]
838      inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs,
839                                                 self._input_partition_dims)
840      inputs_parted_iters = [
841          iter(self._check_dims_and_partition_or_replicate_on_host(x, dims))
842          for x, dims in zip(per_host_sharded_inputs[replica_index],
843                             inputs_part_dims_flat)
844      ]
845
846      for logical_core in xrange(self._device_assignment.num_cores_per_replica):
847        # Places different partitions to different logic cores.
848        replica_id = self._device_assignment.lookup_replicas(
849            self._host_id, logical_core)[replica_index]
850        ordinal = self._device_assignment.tpu_ordinal(
851            replica=replica_id, logical_core=logical_core)
852        infeed_inputs = []
853        for it in inputs_parted_iters:
854          input_for_device = next(it, None)
855          if input_for_device is not None:
856            infeed_inputs.append(input_for_device)
857
858        if infeed_inputs:
859          per_host_enqueue_ops.append(
860              tpu_ops.infeed_enqueue_tuple(
861                  inputs=infeed_inputs,
862                  shapes=[x.shape for x in infeed_inputs],
863                  name="enqueue/replica_{0}/input_{1}".format(
864                      replica_index, logical_core),
865                  device_ordinal=ordinal))
866    return per_host_enqueue_ops
867
868  def _check_input_partition_dims(self, tensor, dims):
869    """Checks that input partition dims are valid for the `Tensor`.
870
871    Args:
872      tensor: Input tensor for partitioning.
873      dims: A list of integer describes how to partition the input tensor.
874
875    Raises:
876      ValueError: If the tensor can't be partitioned by dims or the
877        num_cores_per_replica doesn't match the number of
878        partitions(dims.prod()).
879    """
880    # No partitioning specified, so don't perform further checks.
881    if dims is None:
882      return
883
884    dims = np.array(dims)
885
886    if (dims < 1).any():
887      raise ValueError("All input partition dims must be >= 1.")
888
889    # No partitioning, so don't perform further checks.
890    if dims.prod() == 1:
891      return
892
893    if dims.prod() != self._device_assignment.num_cores_per_replica:
894      raise ValueError(
895          "The product of each input parition dim should equal to "
896          "num_cores_per_replica. (dim = {}, num_cores_per_replica "
897          "= {})".format(dims, self._device_assignment.num_cores_per_replica))
898    if dims.shape[0] != tensor.shape.ndims:
899      raise ValueError(
900          "Input partition dims must have the same number of dimensions "
901          "as the `Tensor` to be partitioned. (tensor shape = {}, input "
902          "partition dims = {}).".format(tensor.shape.as_list(), dims))
903
904    tensor.shape.assert_is_fully_defined()
905
906  def _check_dims_and_partition_or_replicate_on_host(self, tensor, dims):
907    """Checks dims and partitions or replicates the input tensor.
908
909      The ops inside this function are placed on the host side.
910
911    Args:
912      tensor: The input tensor which will be partioned or replicated.
913      dims: A list of integer describes how to partition the input tensor.
914
915    Returns:
916      An iterator of `Tensor`s or a list of partioned tensors.
917    """
918    self._check_input_partition_dims(tensor, dims)
919    return partition_or_replicate_on_host(tensor, dims)
920