• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15
16"""Helper library for handling infeed between hosts and TPUs.
17"""
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import itertools
24
25import numpy as np
26from six.moves import xrange  # pylint: disable=redefined-builtin
27
28from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.ops import array_ops
33from tensorflow.python.tpu import tpu_name_util
34from tensorflow.python.tpu import tpu_sharding
35from tensorflow.python.tpu.ops import tpu_ops
36
37from tensorflow.python.util import nest
38
39
40def partition_or_replicate_on_host(tensor, dims):
41  """Partitions or replicates the input tensor.
42
43    The ops inside this function are placed on the host side.
44
45  Args:
46    tensor: The input tensor which will be partitioned or replicated.
47    dims: A list of integer describes how to partition the input tensor.
48
49  Returns:
50    An iterator of `Tensor`s or a list of partitioned tensors.
51  """
52  if dims is None:
53    return itertools.repeat(tensor)
54  dims = np.array(dims)
55  output = [tensor]
56  shape_list = np.array(tensor.shape.as_list())
57  quotients, remainders = np.divmod(shape_list, dims)
58  for axis, (quotient, remainder, dim, original_size) in enumerate(
59      zip(quotients, remainders, dims, shape_list)):
60    if dim <= 1:
61      continue
62    if remainder > 0:
63      # For each dimension, when it cannot be evenly partitioned, XLA assumes
64      # tensors are partitioned in a greedy manner by using
65      # ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims
66      # are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] =>
67      # [[(3, 4), (3, 4), (2, 4), (2, 2)],
68      # [(2, 4), (2, 4), (2, 4), (2, 2)]]
69      ceil_ratio = quotient + 1
70      num_full_slots, left_over = np.divmod(original_size, ceil_ratio)
71      num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over]
72      if len(num_or_size_splits) < dim:
73        num_or_size_splits += [0] * (dim - len(num_or_size_splits))
74      new_output = []
75      for x in output:
76        new_output.append(
77            array_ops.split(
78                x, num_or_size_splits=num_or_size_splits, axis=axis))
79      output = new_output
80    else:
81      output = [array_ops.split(x, int(dim), axis=axis) for x in output]
82    output = nest.flatten(output)
83  return output
84
85
86def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
87  """Tags appropriate XLA sharding attribute to the dequeued tensor.
88
89  The sharding attribute of the dequeued tensor will be a tuple.
90
91  Args:
92    tensor: The dequeued tensor on TPU.
93    dims: A list of integer describes how the tensor is partitioned.
94
95  Returns:
96    The same tensor with the xla_sharding attribute.
97  """
98  if dims is None:
99    return xla_sharding.replicate(tensor, assign_tuple_sharding=True)
100  elif np.prod(dims) == 1:
101    return xla_sharding.assign_device(tensor, 0, assign_tuple_sharding=True)
102  else:
103    tile_assignment = np.arange(np.prod(dims)).reshape(dims)
104    return xla_sharding.tile(
105        tensor=tensor,
106        tile_assignment=tile_assignment,
107        assign_tuple_sharding=True)
108
109
110def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims):
111  """Tags appropriate XLA sharding attribute to the dequeued tensors.
112
113  Args:
114    dequeues: A list of dequeued tensors on TPU.
115    dims: A list of integer describes how the tensor is partitioned.
116
117  Returns:
118    The same dequeues with appropriate xla_sharding attribute.
119  """
120  nest.assert_shallow_structure(dequeues, dims)
121  return nest.map_structure_up_to(
122      dequeues, _tag_sharding_attribute_for_dequeued_tensor, dequeues, dims)
123
124
125class InfeedQueue(object):
126  """A helper object to build a device infeed queue.
127
128  The InfeedQueue builds the host-side and device-side Ops to enqueue and
129  dequeue elements, respectively, and ensures that their types and
130  shapes match.
131  """
132
133  def __init__(self,
134               number_of_tuple_elements=None,
135               tuple_types=None,
136               tuple_shapes=None,
137               shard_dimensions=None,
138               number_of_partitions=None,
139               name=None):
140    """Creates a new InfeedQueue with the given configuration.
141
142    The configuration need not be fully specified at creation since it
143    can be modified subsequently by methods that set the values
144    explicitly or infer them from the shapes of inputs.
145
146    Args:
147      number_of_tuple_elements: the number of Tensors fed atomically through the
148        queue, must be present unless it can be inferred from other arguments.
149      tuple_types: if not None, a list of types of the elements of the queue.
150      tuple_shapes: if not None, a list of shapes of the elements of the queue.
151      shard_dimensions: if not None, a list of dimensions on which the
152        elements of the queue should be sharded during automatic
153        parallelization.
154      number_of_partitions: if > 1, the infeed dequeue shape will contain
155        the full shape that includes all partitions and add corresponding XLA
156        annotation on the infeed dequeue op. In this case, the infeed is still
157        data parallel that feeds per-core batch size to each core while the XLA
158        computation may be partitioned. As XLA requires infeed dequeue shape to
159        be per-replica shape, thus we need number_of_partitions here to
160        calculate the per-replica unpartitioned shape.
161      name: the name of the queue.
162
163    Raises:
164      ValueError: if number_of_tuple_elements <= 0; or
165        number_of_tuple_arguments, tuple_types, tuple_shapes, and
166        shard_dimensions are all None; or the length of tuple_types,
167        tuple_shapes, or shard_dimensions is not equal to
168        number_of_tuple_elements; or any element of shard_dimensions
169        can't be converted to a Dimension.
170      TypeError: if any element of tuple_types or tuple_shapes can't
171        be converted to a dtype or TensorShape, respectively.
172    """
173    self._frozen = False
174    self._generated_enqueue_ops = False
175    self._generated_dequeue_op = False
176    self._name = "InfeedQueue" if name is None else name
177    if number_of_partitions is None:
178      self._number_of_partitions = 1
179    else:
180      self._number_of_partitions = number_of_partitions
181    if number_of_tuple_elements is None:
182      if tuple_types is not None:
183        number_of_tuple_elements = len(tuple_types)
184      elif tuple_shapes is not None:
185        number_of_tuple_elements = len(tuple_shapes)
186      elif shard_dimensions is not None:
187        number_of_tuple_elements = len(shard_dimensions)
188      else:
189        raise ValueError(
190            "number of tuple elements cannot be inferred from InfeedQueue "
191            "constructor")
192    if number_of_tuple_elements <= 0:
193      raise ValueError("number_of_tuple_elements %d must be > 0" %
194                       number_of_tuple_elements)
195    # Make an empty sharding policy for each tuple element.
196    self._sharding_policies = [
197        tpu_sharding.ShardingPolicy()
198        for _ in xrange(number_of_tuple_elements)
199    ]
200    if tuple_types is not None:
201      self.set_tuple_types(tuple_types)
202    else:
203      self._tuple_types = None
204    if tuple_shapes is not None:
205      self.set_tuple_shapes(tuple_shapes)
206    else:
207      self._tuple_shapes = None
208    if shard_dimensions is not None:
209      self.set_shard_dimensions(shard_dimensions)
210    self._validate()
211
212  def _validate(self):
213    """Checks that the configuration is self-consistent.
214
215    Raises:
216      ValueError: if the shapes and sharding policies don't match.
217    """
218    if self.tuple_shapes is not None:
219      for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes):
220        # Raise an error if the policy is incompatible with the shape.
221        _ = policy.get_sharded_shape(shape)
222
223  @property
224  def number_of_tuple_elements(self):
225    """Returns the number of InfeedQueue tuple elements."""
226    return len(self._sharding_policies)
227
228  @property
229  def tuple_types(self):
230    """Returns the types of the InfeedQueue tuple elements."""
231    return self._tuple_types
232
233  def set_tuple_types(self, tuple_types):
234    """Sets the type of each element of the queue.
235
236    tuple_types must be a list of length
237    self.number_of_tuple_elements, and each element must be
238    convertible to a dtype.
239
240    Args:
241      tuple_types: the types of each queue element.
242
243    Raises:
244      ValueError: if tuple_types is not of length
245        self.number_of_tuple_elements.
246      TypeError: if an element of tuple_types cannot be converted to a
247        dtype.
248    """
249    if len(tuple_types) != self.number_of_tuple_elements:
250      raise ValueError("tuple_types is %s, but must be a list of length %d" %
251                       (str(tuple_types), self.number_of_tuple_elements))
252    if self._frozen:
253      for (frozen, updated) in zip(self._tuple_types, tuple_types):
254        if frozen != updated:
255          raise ValueError(
256              "Trying to update InfeedQueue with frozen configuration with an "
257              "incompatible type. Frozen types are %s, updated types are %s" % (
258                  str(self._tuple_types), str(tuple_types)))
259    else:
260      try:
261        self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types]
262      except (TypeError) as e:
263        raise TypeError(
264            "tuple_types is %s, but must be a list of elements each "
265            "convertible to dtype: got error %s" % (str(tuple_types), str(e)))
266
267  @property
268  def tuple_shapes(self):
269    """Returns the shapes of the InfeedQueue tuple elements."""
270    return self._tuple_shapes
271
272  def set_tuple_shapes(self, tuple_shapes):
273    """Sets the shape of each element of the queue.
274
275    tuple_shapes must be a list of length
276    self.number_of_tuple_elements, and each element must be
277    convertible to a TensorShape.
278
279    Args:
280      tuple_shapes: the shapes of each queue element.
281
282    Raises:
283      ValueError: if tuple_shapes is not of length
284        self.number_of_tuple_elements.
285      TypeError: if an element of tuple_shapes cannot be converted to
286        a TensorShape.
287    """
288    if len(tuple_shapes) != self.number_of_tuple_elements:
289      raise ValueError("tuple_shapes is %s, but must be a list of length %d" %
290                       (str(tuple_shapes), self.number_of_tuple_elements))
291    try:
292      tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes]
293    except (ValueError, TypeError) as e:
294      raise TypeError(
295          "tuple_shapes is %s, but must be a list of elements each "
296          "convertible to TensorShape: got error %s" % (str(tuple_shapes),
297                                                        str(e)))
298    if self._frozen:
299      for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes):
300        if frozen != updated:
301          raise ValueError(
302              "Trying to update InfeedQueue with frozen configuration with an "
303              "incompatible shape. Frozen shapes are %s, updated shapes are %s"
304              % (str(self._tuple_shapes), str(tuple_shapes)))
305    else:
306      self._tuple_shapes = tuple_shapes
307    self._validate()
308
309  @property
310  def sharding_policies(self):
311    """Returns the sharding policies of the InfeedQueue tuple elements."""
312    return self._sharding_policies
313
314  @property
315  def shard_dimensions(self):
316    """Gets the shard dimension of each tuple element.
317
318    Returns:
319      A list of length number_of_tuple_elements, where each list entry
320      is the shard dimension of that tuple element or None if the
321      shard dimension has not been set.
322    """
323    # The number of shards is always the same for all the policies.
324    return [policy.shard_dimension for policy in self._sharding_policies]
325
326  def set_shard_dimensions(self, shard_dimensions):
327    """Sets the shard_dimension of each element of the queue.
328
329    shard_dimensions must be a list of length
330    self.number_of_tuple_elements, and each element must be
331    convertible to a Dimension compatible with self.tuple_shapes.
332
333    Args:
334      shard_dimensions: the dimensions of each queue element.
335
336    Raises:
337      ValueError: if shard_dimensions is not of length
338        self.number_of_tuple_elements; or an element of
339        shard_dimensions cannot be converted to a Dimension; or an
340        element of shard_dimensions is a Dimension that is out of
341        range for the corresponding tuple element shape.
342    """
343    if len(shard_dimensions) != self.number_of_tuple_elements:
344      raise ValueError("shard_dimensions is %s, but must be a list of length %d"
345                       % (str(shard_dimensions),
346                          self.number_of_tuple_elements))
347    for (policy, dimension) in zip(self._sharding_policies, shard_dimensions):
348      policy.set_shard_dimension(dimension)
349    self._validate()
350
351  @property
352  def number_of_shards(self):
353    """Gets the number of shards to use for the InfeedQueue.
354
355    Returns:
356      Number of shards or None if the number of shards has not been set.
357    """
358    # The number of shards is always the same for all the policies.
359    return self._sharding_policies[0].number_of_shards
360
361  def set_number_of_shards(self, number_of_shards):
362    """Sets the number of shards to use for the InfeedQueue.
363
364    Args:
365      number_of_shards: number of ways to shard the InfeedQueue.
366
367    Raises:
368      ValueError: if number_of_shards is not > 0; or the policies have
369        been frozen and number_of_shards was already set to something
370        else.
371    """
372    for policy in self._sharding_policies:
373      policy.set_number_of_shards(number_of_shards)
374      policy.set_number_of_partitions(self._number_of_partitions)
375    self._validate()
376
377  def set_configuration_from_input_tensors(self, input_tensors):
378    """Sets the shapes and types of the queue tuple elements.
379
380    input_tensors is a list of Tensors whose types and shapes are used
381    to set the queue configuration.
382
383    Args:
384      input_tensors: list of Tensors of the same types and shapes as
385        the desired queue Tuple.
386
387    Raises:
388      ValueError: if input_tensors is not a list of length
389        self.number_of_tuple_elements
390    """
391    if len(input_tensors) != self.number_of_tuple_elements:
392      raise ValueError("input_tensors is %s, but should be a list of %d Tensors"
393                       % (str(input_tensors), self.number_of_tuple_elements))
394    self.set_tuple_shapes([t.shape for t in input_tensors])
395    self.set_tuple_types([t.dtype for t in input_tensors])
396
397  def set_configuration_from_sharded_input_tensors(self, input_tensors):
398    """Sets the shapes and types of the queue tuple elements.
399
400    input_tensors is a list of lists of Tensors whose types and shapes are used
401    to set the queue configuration. The length of the outer list is the number
402    of shards required, and each inner list is the tuple of Tensors to use to
403    determine the types and shapes of the corresponding shard. This method
404    depends on the shard dimension, and calling it freezes the shard policy.
405
406    Args:
407      input_tensors: list of lists of Tensors. The outer list length corresponds
408        to the desired number of shards, and each inner list is the size
409        and shape of the desired configuration of the corresponding shard.
410
411    Raises:
412      ValueError: if any inner list is not a list of length
413        self.number_of_tuple_elements; or the inner lists do not combine to
414        form a consistent unsharded shape.
415      TypeError: if the types of the Tensors in the inner lists do not match.
416    """
417    if not self._frozen:
418      # Unset the tuple shapes in case the configuration becomes
419      # transiently inconsistent.
420      self._tuple_shapes = None
421    number_of_shards = len(input_tensors)
422    self.set_number_of_shards(number_of_shards)
423    for t in input_tensors:
424      if len(t) != self.number_of_tuple_elements:
425        raise ValueError(
426            "input_tensors is %s but must be a list of lists, where each inner"
427            " list has length number_of_tuple_elements=%d" % (
428                str(input_tensors), self.number_of_tuple_elements))
429    # Transpose the inputs to make a list of shard shapes for each tuple
430    # element.
431    sharded_shapes = [[t[i].shape for t in input_tensors]
432                      for i in xrange(self.number_of_tuple_elements)]
433    # For each tuple, get the unsharded shape using that tuple's policy.
434    unsharded_shapes = [
435        policy.get_unsharded_shape(s)
436        for (policy, s) in zip(self._sharding_policies, sharded_shapes)
437    ]
438    self.set_tuple_shapes(unsharded_shapes)
439    for i in xrange(1, self.number_of_shards):
440      for (t1, t2) in zip(input_tensors[0], input_tensors[i]):
441        if t1.dtype != t2.dtype:
442          raise TypeError(
443              "types of the tuple elements of input_tensors %s are not "
444              "consistent" % str(input_tensors))
445    self.set_tuple_types([t.dtype for t in input_tensors[0]])
446
447  def freeze(self):
448    """Freezes the InfeedQueue so it can no longer be modified.
449
450    The configuration is implicitly frozen before any host-side or
451    device-side Ops are generated. The configuration cannot be frozen
452    until the types and shapes of the tuple elements have been set.
453
454    Raises:
455      ValueError: if the types or shapes of the tuple elements have not been
456      set.
457    """
458    self._frozen = True
459    if self._tuple_types is None:
460      raise ValueError(
461          "Can't freeze an InfeedQueue without setting all tuple types.")
462    if self._tuple_shapes is None:
463      raise ValueError(
464          "Can't freeze an InfeedQueue without setting all tuple shapes.")
465    for shape in self._tuple_shapes:
466      if shape.dims is None:
467        raise ValueError(
468            "Can't freeze an InfeedQueue without setting all tuple shapes.")
469    for policy in self._sharding_policies:
470      policy.freeze()
471    self._validate()
472
473  def generate_dequeue_op(self, tpu_device=0):
474    """Generates the device-side Op to dequeue a tuple from the queue.
475
476    Implicitly freezes the queue configuration if it is not already
477    frozen, which will raise errors if the shapes and types have not
478    been fully specified.
479
480    Args:
481      tpu_device: The TPU device ordinal where the infeed instruction should be
482        placed. If None, no explicit placement will be performed, and it is up
483        to the user to call this API from within a proper TPU device scope.
484        The XLA code will fail if the TPU dequeue instruction is not bound to
485        any device.
486
487    Returns:
488      A list of Outputs corresponding to a shard of infeed dequeued
489      into XLA, suitable for use within a replicated block.
490
491    Raises:
492      ValueError: if the types or shapes of the tuple elements have not been
493      set; or if a dequeue op has already been generated.
494    """
495    self.freeze()
496    if self._generated_dequeue_op:
497      raise ValueError("Can't generate two dequeue Ops from the same queue")
498    self._generated_dequeue_op = True
499    full_name = "%s/dequeue" % self._name
500    sharded_shapes = [
501        policy.get_unpartitioned_shape(policy.get_sharded_shape(shape))
502        for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
503    ]
504    if tpu_device is not None:
505      with ops.device(tpu_name_util.core(tpu_device)):
506        dequeue_op = tpu_ops.infeed_dequeue_tuple(
507            dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
508    else:
509      dequeue_op = tpu_ops.infeed_dequeue_tuple(
510          dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
511    if self._number_of_partitions <= 1:
512      return dequeue_op
513    partitions = [
514        policy.get_unpartitioned_shape([1] * shape.ndims).as_list()
515        for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
516    ]
517    return tag_sharding_attribute_for_dequeued_tensors(dequeue_op, partitions)
518
519  def _generate_enqueue_op(self,
520                           inputs,
521                           name_prefix,
522                           index,
523                           device=None,
524                           tpu_ordinal=-1):
525    """Generate a host-side Op to enqueue a tuple to the queue.
526
527    If device is None the inputs are all required to have the same
528    device specification, and the enqueue Op is colocated with
529    inputs[0]. Otherwise the enqueue Op is placed on 'device'.
530
531    Args:
532      inputs: a list of Tensors with the types and shapes of the tuple elements.
533      name_prefix: the base name for the Op.
534      index: the shard index, used to uniquify the Op name.
535      device: device to place the Op on, or None if it should be
536        colocated with the inputs.
537      tpu_ordinal: ordinal of the TPU device on the host to use for
538      infeed if device is a CPU device. Should be set to -1 if device
539      is a TPU device.
540
541    Returns:
542      An Op corresponding to a shard of infeed enqueued at the host,
543      suitable for use within a replicated block.
544
545    Raises:
546      ValueError: if device is None and inputs do not all have the
547        same device specification.
548    """
549    full_name = "%s/%d" % (name_prefix, index)
550    shapes = [t.shape for t in inputs]
551    if device is None:
552      devices = [t.device for t in inputs]
553      for i in xrange(1, self.number_of_tuple_elements):
554        if devices[0] != devices[i]:
555          raise ValueError(
556              "input devices for shard %d are %s, but should all be the same" %
557              (index, str(devices)))
558      with ops.colocate_with(inputs[0]):
559        return tpu_ops.infeed_enqueue_tuple(
560            inputs=inputs,
561            shapes=shapes,
562            name=full_name,
563            device_ordinal=tpu_ordinal)
564    else:
565      with ops.device(device):
566        return tpu_ops.infeed_enqueue_tuple(
567            inputs=inputs,
568            shapes=shapes,
569            name=full_name,
570            device_ordinal=tpu_ordinal)
571
572  def generate_enqueue_ops(self,
573                           sharded_inputs,
574                           tpu_ordinal_function=None,
575                           placement_function=None):
576    """Generates the host-side Ops to enqueue the shards of a tuple.
577
578    sharded_inputs is a list, one for each shard, of lists of
579    Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed
580    shard i of the queue. Returns the host-side Ops that must be run to
581    enqueue the sharded tuple. The Op for shard i is colocated with the inputs
582    for shard i.
583
584    Implicitly freezes the queue configuration if it is not already
585    frozen. If the configuration has already been frozen, and is not
586    compatible with the types and shapes of sharded_inputs, an error
587    will be raised.
588
589    Args:
590      sharded_inputs: a list of lists of Tensors. The length of the outer list
591        determines the number of shards. Each inner list indicates the types
592        and shapes of the tuples in the corresponding shard.
593      tpu_ordinal_function: if not None, a function that takes the
594        shard index as input and returns the ordinal of the TPU device
595        the shard's infeed should be placed on. tpu_ordinal_function must be
596        set if the inputs are placed on CPU devices.
597      placement_function: if not None, a function that takes the shard index as
598        input and returns the host device where the enqueue op should be placed
599        on.
600
601    Returns:
602      A list of host-side Ops, one for each shard, that when executed together
603      will enqueue a full-size element of infeed.
604
605    Raises:
606      ValueError: if the queue configuration has previously been frozen and the
607        shapes of the elements of sharded_inputs are not compatible with the
608        frozen configuration; or if the shapes of the elements of sharded_inputs
609        don't form a consistent unsharded tuple; or if the elements of a tuple
610        have different device constraints.
611      TypeError: if the queue configuration has previously been frozen and the
612        types of the elements of sharded_inputs are not compatible with the
613        frozen configuration; or if the types of the elements of sharded_inputs
614        don't form a consistent unsharded tuple.
615    """
616    self.set_configuration_from_sharded_input_tensors(sharded_inputs)
617    self.freeze()
618    if self._generated_enqueue_ops:
619      raise ValueError("Can't generate two enqueue Ops from the same queue")
620    self._generated_enqueue_ops = True
621    if tpu_ordinal_function is None:
622      tpu_ordinal_function = lambda index: -1
623    name_prefix = "%s/enqueue" % self._name
624    return [
625        self._generate_enqueue_op(
626            shard,
627            name_prefix,
628            index,
629            tpu_ordinal=tpu_ordinal_function(index),
630            device=placement_function(index) if placement_function else None)
631        for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
632    ]
633
634  # TODO(misard) Generalize this to the case of systems that don't
635  # have 8 devices per host, and figure out what to do with
636  # model-parallelism.
637  def _default_placement_function(self, index):
638    return "/task:%d/device:CPU:0" % (index / 8)
639
640  def _default_ordinal_function(self, index):
641    return index % 8
642
643  # TODO(b/36470756) remove this from tutorials once we have a better story
644  # for automatic placement of input pipelines.
645  def split_inputs_and_generate_enqueue_ops(self,
646                                            inputs,
647                                            device_assignment=None,
648                                            placement_function=None,
649                                            tpu_ordinal_function=None):
650    """POORLY-PERFORMING ON MULTI-HOST SYSTEMS.
651
652    Generates the host-side Ops to enqueue a tuple.
653
654    This method performs poorly because it takes an entire input on a single
655    host, splits it, and distributes it to all of the cores. It is present only
656    to simplify tutorial examples.
657
658    inputs is a list of Tensors to use to feed the queue. Each input is split
659    into self.number_of_shards shards. Returns an Op for each shard to enqueue
660    the shard. The Op for shard i is placed on device placement_function(i).
661
662    Implicitly freezes the queue configuration if it is not already
663    frozen. If the configuration has already been frozen, and is not
664    compatible with the types and shapes of inputs, an error
665    will be raised.
666
667    Args:
668      inputs: a list of Tensors which indicates the types and shapes of the
669        queue tuple.
670     device_assignment: if not `None`, a TPU `DeviceAssignment`. If
671        device_assignment is not `None`, but `placement_function` and
672        `ordinal_function` are None, then `device_assignment` will be used to
673        place infeeds on the first k TPU shards, where k is the number of shards
674        in the queue. If all three are `None`, then default placement and
675        ordinal functions are used.
676      placement_function: if not None, a function that takes the shard
677        index as input and returns a device string indicating which
678        device the shard's infeed should be placed on. If placement_function
679        and tpu_ordinal_function are None, inputs are sharded round-robin
680        across the devices in the system.
681      tpu_ordinal_function: if not None, a function that takes the
682        shard index as input and returns the ordinal of the TPU device
683        the shard's infeed should be placed on. If placement_function
684        and tpu_ordinal_function are None, inputs are sharded round-robin
685        across the devices in the system.
686
687    Returns:
688      A list of host-side Ops, one for each shard, that when executed together
689      will enqueue a full-size element of infeed.
690
691    Raises:
692      ValueError: if the queue configuration has previously been frozen and the
693        shapes of the elements of inputs are not compatible with the frozen
694        configuration.
695      TypeError: if the queue configuration has previously been frozen and the
696        types of the elements of inputs are not compatible with the frozen
697        configuration.
698    """
699    if device_assignment is None:
700      if placement_function is None:
701        placement_function = self._default_placement_function
702      if tpu_ordinal_function is None:
703        tpu_ordinal_function = self._default_ordinal_function
704    else:
705
706      def _placement_function_from_map(index):
707        return device_assignment.host_device(replica=index)
708
709      def _ordinal_function_from_map(index):
710        return device_assignment.tpu_ordinal(replica=index)
711
712      if placement_function is None:
713        placement_function = _placement_function_from_map
714      if tpu_ordinal_function is None:
715        tpu_ordinal_function = _ordinal_function_from_map
716    self.set_configuration_from_input_tensors(inputs)
717    self.freeze()
718    if self._generated_enqueue_ops:
719      raise ValueError("Can't generate two enqueue Ops from the same queue")
720    self._generated_enqueue_ops = True
721    split_name_prefix = "%s/split" % self._name
722    if self.number_of_shards == 1:
723      transposed_sharded_inputs = [[inp] for inp in inputs]
724    else:
725
726      def split_fn(inp, num_shards, axis, name):
727        with ops.colocate_with(inp):
728          return array_ops.split(inp, num_shards, axis=axis, name=name)
729
730      transposed_sharded_inputs = [
731          split_fn(
732              inp,
733              self.number_of_shards,
734              axis=policy.shard_dimension,
735              name="%s/%d" % (split_name_prefix, index))
736          for (inp, policy, index) in zip(inputs, self._sharding_policies,
737                                          xrange(self.number_of_tuple_elements))
738      ]
739    sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs]
740                      for i in xrange(self.number_of_shards)]
741    name_prefix = "%s/enqueue" % self._name
742    return [
743        self._generate_enqueue_op(
744            shard,
745            name_prefix,
746            index,
747            device=placement_function(index),
748            tpu_ordinal=tpu_ordinal_function(index))
749        for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
750    ]
751
752
753class _PartitionedInfeedQueue(InfeedQueue):
754  """A helper object to build a device infeed queue with input partition.
755
756  Args:
757    number_of_tuple_elements: the number of Tensors fed atomically through the
758      queue, must be present unless it can be inferred from other arguments.
759    device_assignment: A TPU `DeviceAssignment` which is used to place all the
760      partitions to different TPU infeed queues.
761    host_id: The id of the host machine.
762    input_partition_dims: A nested list/tuple of integers. Each inner
763      list/tuple describes how to partition the corresponding input tensor.
764    tuple_types: If not None, a list of types of the elements of the queue.
765    tuple_shapes: If not None, a list of shapes of the elements of the queue.
766    name: The name of the queue.
767  """
768
769  def __init__(self,
770               number_of_tuple_elements,
771               device_assignment,
772               host_id,
773               input_partition_dims=None,
774               tuple_types=None,
775               tuple_shapes=None,
776               name=None):
777    super(_PartitionedInfeedQueue, self).__init__(
778        number_of_tuple_elements=number_of_tuple_elements,
779        tuple_types=tuple_types,
780        tuple_shapes=None,
781        shard_dimensions=None,
782        name="PartitionedInfeedQueue" if name is None else name)
783    self._input_partition_dims = input_partition_dims
784    self._host_id = host_id
785    self._device_assignment = device_assignment
786
787  def generate_dequeue_op(self, tpu_device=0):
788    """Generate TPU dequeue ops.
789
790    Args:
791      tpu_device: The TPU device ordinal where the infeed instruction should be
792        placed.
793
794    Returns:
795      A list of Outputs corresponding to a partition of infeed dequeued
796      into XLA, suitable for use within a replicated block.
797
798    Raises:
799      ValueError: if the types or shapes of the tuple elements have not been
800      set; or if a dequeue op has already been generated.
801    """
802    self.freeze()
803    if self._generated_dequeue_op:
804      raise ValueError("Can't generate two dequeue Ops from the same queue")
805    self._generated_dequeue_op = True
806    full_name = "%s/dequeue" % self._name
807    sharded_shapes = [
808        policy.get_sharded_shape(shape)
809        for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
810    ]
811    with ops.device(tpu_name_util.core(tpu_device)):
812      values = tpu_ops.infeed_dequeue_tuple(
813          dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
814    return tag_sharding_attribute_for_dequeued_tensors(
815        values, self._input_partition_dims)
816
817  def generate_enqueue_ops(self, sharded_inputs):
818    """Generates the host-side Ops to enqueue the partitioned inputs.
819
820    sharded_inputs is a list, one for each replica, of lists of
821    Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed
822    replica i.
823    sharded_inputs[i][j] is partitioned by self._input_partition_dims[j].
824
825    For example, if sharded_inputs[i][j] is a 2-D Tensor:
826    [[A, B, C, D],
827     [E ,F, G, H]]
828    self._input_partition_dims[j] is [2, 4].
829
830    sharded_inputs[i][j] will be partitioned and flattened into:
831    [A, B, C, D, E, F, G, H] and fed into the logical core ids:
832    [0, 1, 2, 3, 4, 5, 6, 7] respectively.
833
834    Args:
835      sharded_inputs: a list of lists of Tensors. The length of the
836        outer list determines the number of shards. Each inner list indicates
837        the types and shapes of the tuples in the corresponding shard.
838
839    Returns:
840      A list of host-side Ops, one for each shard, that when executed together
841      will enqueue a full-size element of infeed.
842
843    Raises:
844      ValueError: if the queue configuration has previously been frozen and the
845        shapes of the elements of sharded_inputs are not compatible with the
846        frozen configuration; or if the shapes of the elements of sharded_inputs
847        don't form a consistent unsharded tuple; or if the elements of a tuple
848        have different device constraints; or if the partition dims are invalid.
849      TypeError: if the queue configuration has previously been frozen and the
850        types of the elements of sharded_inputs are not compatible with the
851        frozen configuration; or if the types of the elements of sharded_inputs
852        don't form a consistent unsharded tuple.
853    """
854    self.set_configuration_from_sharded_input_tensors(sharded_inputs)
855    number_of_replicas = len(sharded_inputs)
856    number_of_tuple_elements = len(sharded_inputs[0])
857
858    assert len(self._input_partition_dims) == number_of_tuple_elements
859    enqueue_ops = []
860
861    for replica_index in range(number_of_replicas):
862      flattened_inputs = sharded_inputs[replica_index]
863      inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs,
864                                                 self._input_partition_dims)
865      inputs_parted_iters = [
866          iter(self._check_dims_and_partition_or_replicate_on_host(x, dims))
867          for x, dims in zip(sharded_inputs[replica_index],
868                             inputs_part_dims_flat)
869      ]
870
871      # Find the replica_id of the host's logical core 0.
872      # The self._host_id is guaranteed to contain the logical core 0,
873      # even when num_cores_per_replica > num_cores_per_host -- the function
874      # caller makes sure that this host_id will must be receiving data (calls
875      # input_fn).
876      replica_id = self._device_assignment.lookup_replicas(
877          task_id=self._host_id, logical_core=0)[replica_index]
878      for logical_core in xrange(self._device_assignment.num_cores_per_replica):
879        # Places different partitions to different logic cores.
880        # Since there can be multiple hosts per replica, we need to find
881        # the actual host (device) of this logical core.
882        device = self._device_assignment.host_device(
883            replica=replica_id, logical_core=logical_core)
884
885        with ops.device(device):
886          ordinal = self._device_assignment.tpu_ordinal(
887              replica=replica_id, logical_core=logical_core)
888          infeed_inputs = []
889          for it in inputs_parted_iters:
890            input_for_device = next(it, None)
891            if input_for_device is not None:
892              infeed_inputs.append(input_for_device)
893
894          if infeed_inputs:
895            enqueue_ops.append(
896                tpu_ops.infeed_enqueue_tuple(
897                    inputs=infeed_inputs,
898                    shapes=[x.shape for x in infeed_inputs],
899                    name="enqueue/replica_{0}/input_{1}".format(
900                        replica_index, logical_core),
901                    device_ordinal=ordinal))
902    return enqueue_ops
903
904  def _check_input_partition_dims(self, tensor, dims):
905    """Checks that input partition dims are valid for the `Tensor`.
906
907    Args:
908      tensor: Input tensor for partitioning.
909      dims: A list of integer describes how to partition the input tensor.
910
911    Raises:
912      ValueError: If the tensor can't be partitioned by dims or the
913        num_cores_per_replica doesn't match the number of
914        partitions(dims.prod()).
915    """
916    # No partitioning specified, so don't perform further checks.
917    if dims is None:
918      return
919
920    dims = np.array(dims)
921
922    if (dims < 1).any():
923      raise ValueError("All input partition dims must be >= 1.")
924
925    # No partitioning, so don't perform further checks.
926    if dims.prod() == 1:
927      return
928
929    if dims.prod() != self._device_assignment.num_cores_per_replica:
930      raise ValueError(
931          "The product of each input partition dim should equal to "
932          "num_cores_per_replica. (dim = {}, num_cores_per_replica "
933          "= {})".format(dims, self._device_assignment.num_cores_per_replica))
934    if dims.shape[0] != tensor.shape.ndims:
935      raise ValueError(
936          "Input partition dims must have the same number of dimensions "
937          "as the `Tensor` to be partitioned. (tensor shape = {}, input "
938          "partition dims = {}).".format(tensor.shape.as_list(), dims))
939
940    tensor.shape.assert_is_fully_defined()
941
942  def _check_dims_and_partition_or_replicate_on_host(self, tensor, dims):
943    """Checks dims and partitions or replicates the input tensor.
944
945      The ops inside this function are placed on the host side.
946
947    Args:
948      tensor: The input tensor which will be partitioned or replicated.
949      dims: A list of integer describes how to partition the input tensor.
950
951    Returns:
952      An iterator of `Tensor`s or a list of partitioned tensors.
953    """
954    self._check_input_partition_dims(tensor, dims)
955    return partition_or_replicate_on_host(tensor, dims)
956