• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for cross_device_ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import threading
23
24from tensorflow.python.distribute import values as value_lib
25from tensorflow.python.eager import backprop
26from tensorflow.python.eager import context
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_spec
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import collective_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import nccl_ops
35from tensorflow.python.ops import resource_variable_ops
36from tensorflow.python.platform import tf_logging as logging
37
38INSTANCE_KEY_START_NUMBER = 100
39
40
41def aggregate_gradients_using_nccl(replica_grads):
42  """Aggregate gradients using nccl allreduce."""
43  agg_all_g_and_v = []
44  for single_g_and_v in zip(*replica_grads):
45    single_grads = [g for g, _ in single_g_and_v]
46    agg_grads = nccl_ops.all_sum(single_grads)
47    agg_all_g_and_v.append(
48        [(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)])
49
50  agg_all_g_and_v = list(zip(*agg_all_g_and_v))
51
52  return agg_all_g_and_v
53
54
55def aggregate_gradients_using_hierarchical_copy(avail_devices, replica_grads):
56  """Aggregate gradients using hierarchical copies.
57
58  Args:
59    avail_devices: available GPU devices.
60    replica_grads: List of lists of (gradient, variable) tuples. The outer list
61      is over replicas. The inner list is over individual gradients.
62
63  Returns:
64    The list of (aggregated_gradient, variable), where the gradient has been
65      summed across all replicas and the variable is chosen from the first
66      replica.
67  """
68  # This only works for DGX-1 type of machine topology
69  # Device peer to peer matrix
70  # DMA: 0 1 2 3 4 5 6 7
71  # 0:   Y Y Y Y Y N N N
72  # 1:   Y Y Y Y N Y N N
73  # 2:   Y Y Y Y N N Y N
74  # 3:   Y Y Y Y N N N Y
75  # 4:   Y N N N Y Y Y Y
76  # 5:   N Y N N Y Y Y Y
77  # 6:   N N Y N Y Y Y Y
78  # 7:   N N N Y Y Y Y Y
79  agg_grads = []
80  num_devices = len(avail_devices)
81  # In the special case of DGX-1 machine topology, the two groups have equal
82  # size.
83  group_size = num_devices // 2
84  for i, single_grads in enumerate(zip(*replica_grads)):
85    group_0_main_device = i % num_devices
86    group_1_main_device = (group_0_main_device + group_size) % num_devices
87    if group_0_main_device < group_size:
88      group_0_begin = 0
89      group_1_begin = group_size
90    else:
91      group_0_begin = group_size
92      group_1_begin = 0
93
94    # Aggregate the first group.
95    group_0_device_grads = single_grads[group_0_begin:
96                                        group_0_begin + group_size]
97    with ops.device(avail_devices[group_0_main_device]):
98      group_0_agg_grads, _ = aggregate_single_gradient_using_copy(
99          group_0_device_grads, False, False)
100
101    # Aggregate the second group.
102    group_1_device_grads = single_grads[group_1_begin:
103                                        group_1_begin + group_size]
104    with ops.device(avail_devices[group_1_main_device]):
105      group_1_agg_grads, _ = aggregate_single_gradient_using_copy(
106          group_1_device_grads, False, False)
107
108    # Aggregate between the groups.
109    with ops.device(avail_devices[group_0_main_device]):
110      (agg_total_grads, _), _ = aggregate_single_gradient_using_copy(
111          [group_0_agg_grads, group_1_agg_grads], False, False)
112
113    # Broadcast the result back into the root of each group.
114    with ops.device(avail_devices[group_0_main_device]):
115      group_0_agg_grads_bcast = array_ops.identity(agg_total_grads)
116    with ops.device(avail_devices[group_1_main_device]):
117      group_1_agg_grads_bcast = array_ops.identity(agg_total_grads)
118
119    agg_grads_bcast = []
120    for j in range(len(single_grads)):
121      with ops.device(avail_devices[j]):
122        # Broadcast the result back to each member in the group from the root.
123        if (group_0_main_device < group_size) == (j < group_size):
124          src_device_grad = group_0_agg_grads_bcast
125        else:
126          src_device_grad = group_1_agg_grads_bcast
127        agg_grads_bcast.append(array_ops.identity(src_device_grad))
128
129    agg_grads.append(
130        [(g, v) for g, (_, v) in zip(agg_grads_bcast, single_grads)])
131
132  agg_grads = list(zip(*agg_grads))
133
134  return agg_grads
135
136
137def aggregate_single_gradient_using_copy(grad_and_vars, use_mean,
138                                         check_inf_nan):
139  """Calculate the average gradient for a shared variable across all replicas.
140
141  Note that this function provides a synchronization point across all replicas.
142
143  Args:
144    grad_and_vars: A list or tuple of (gradient, variable) tuples. Each
145      (gradient, variable) pair within the outer list represents the gradient
146      of the variable calculated for a single replica, and the number of pairs
147      equals the number of replicas.
148    use_mean: if True, mean is taken, else sum of gradients is taken.
149    check_inf_nan: check grads for nans and infs.
150
151  Returns:
152    The tuple ([(average_gradient, variable),], has_nan_or_inf) where the
153      gradient has been averaged across all replicas. The variable is chosen
154      from the first replica. The has_nan_or_inf indicates the grads has nan or
155      inf.
156  """
157  grads = [g for g, _ in grad_and_vars]
158  grad = math_ops.add_n(grads)
159
160  if use_mean and len(grads) > 1:
161    grad = array_ops.multiply(grad, 1.0 / len(grads))
162
163  v = grad_and_vars[0][1]
164  if check_inf_nan:
165    has_nan_or_inf = array_ops.logical_not(
166        array_ops.reduce_all(array_ops.is_finite(grads)))
167    return (grad, v), has_nan_or_inf
168  else:
169    return (grad, v), None
170
171
172# TODO(yuefengz): use random key starts to avoid reusing keys?
173class CollectiveKeys(object):
174  """Class that manages collective keys.
175
176  We need to manage three different keys for collective:
177
178  *Group key*: an integer key to identify the set of cooperative devices.
179  Collective ops work under the same set of devices must using the same group
180  key.
181
182  *Instance key*: an integer key to identify the set of same counterpart of
183  tensors on different devices in a device group that need to be all-reduced.
184
185  This class is thread safe.
186  """
187
188  def __init__(self, group_key_start=1):
189    """Initializes the object.
190
191    Args:
192      group_key_start: the starting integer of group key.
193    """
194    self._group_key = group_key_start
195    self._group_key_table = {}
196    self._instance_key_table = {}
197    self._lock = threading.Lock()
198
199  def get_group_key(self, devices):
200    """Returns a group key for the set of devices.
201
202    Args:
203      devices: a list of canonical device strings in a collective group.
204
205    Returns:
206      int key uniquely identifying the set of device names.
207    """
208    key_id = hash(tuple(sorted(devices)))
209    with self._lock:
210      if key_id not in self._group_key_table:
211        new_key = self._group_key
212        self._group_key += 1
213        self._group_key_table[key_id] = new_key
214        self._instance_key_table[new_key] = {}
215        for device in devices:
216          self._instance_key_table[new_key][device] = INSTANCE_KEY_START_NUMBER
217      return self._group_key_table[key_id]
218
219  def get_instance_key(self, group_key, device):
220    """Returns a new instance key for use in defining a collective op.
221
222    You should call this once per each collective op of a collective instance.
223
224    Args:
225      group_key: the group key returned by get_group_key(). You should not
226        assign the group key yourself.
227      device: a canonical device string. It should be the device this collective
228        op is on.
229
230    Returns:
231      a new instance key.
232
233    Raises:
234      ValueError: when the group key is invalid or the device is not in the
235      group.
236    """
237    with self._lock:
238      group = self._instance_key_table.get(group_key, None)
239      if group is None:
240        raise ValueError('group {} not found'.format(group_key))
241      if device not in group:
242        raise ValueError('{} not in group {}'.format(device, group_key))
243      v = group[device]
244      group[device] += 1
245      return v
246
247  def __deepcopy__(self, memo):
248    # distribute_coordinator deep-copies the strategy object, so
249    # CollectiveKeys needs to support deep copy as well.
250    copied = CollectiveKeys()
251    copied._group_key = self._group_key
252    copied._group_key_table = copy.deepcopy(self._group_key_table, memo)
253    copied._instance_key_table = copy.deepcopy(self._instance_key_table, memo)
254    return copied
255
256
257class CollectiveReplicaLauncher(object):
258  """Launch collectives on one replica."""
259
260  _prefer_unique_instance_key = True
261  _prefer_ordering_token = True
262
263  def __init__(self,
264               group_key,
265               group_size,
266               collective_keys,
267               device):
268    self._group_key = group_key
269    self._group_size = group_size
270    self._collective_keys = collective_keys
271    self._device = device
272    if self._use_ordering_token():
273      with ops.init_scope(), ops.device(device):
274        self._ordering_token = resource_variable_ops.ResourceVariable(0.)
275    else:
276      self._ordering_token = None
277
278  def _control_input(self, control_input):
279    if control_input is not None and not self._use_ordering_token():
280      return ops.control_dependencies([control_input])
281    return ops.NullContextmanager()
282
283  def _use_unique_instance_key(self):
284    if not ops.executing_eagerly_outside_functions():
285      return False
286    return CollectiveReplicaLauncher._prefer_unique_instance_key
287
288  def _use_ordering_token(self):
289    # We rely on auto control dep to insert control edges between NCCL calls,
290    # but for tf1 graph mode auto control dep is not used.
291    if not ops.executing_eagerly_outside_functions():
292      return False
293    return CollectiveReplicaLauncher._prefer_ordering_token
294
295  def _next_instance_key(self):
296    """Returns the next instance key."""
297    if self._use_unique_instance_key():
298      # Assigning instance keys at function building time have issues since
299      # different workers may retrace the function at different times. With
300      # collective V2 we can use capture_call_time_value to use a placeholder as
301      # the instance key and feed it at function call time. In this way we also
302      # don't reuse instance keys, which allows for per-instance cancellation.
303      graph = ops.get_default_graph()
304      # Control flow ops don't work with capture_call_time_value, so we put the
305      # capture in the function graph of that control flow op.
306      while getattr(graph, 'is_control_flow_graph', False):
307        graph = graph.outer_graph
308      if not context.executing_eagerly() and graph.building_function:
309        with graph.as_default():
310          # Capture self._next_instance_key so that when building a function
311          # that calls another tf.function, the instance key assignment is
312          # further delayed until we actually call the function in eager. Note
313          # that capture_call_time_value doesn't automatically propagate the
314          # deferred capture to the outer function.
315          return graph.capture_call_time_value(
316              self._next_instance_key, tensor_spec.TensorSpec([], dtypes.int32))
317      else:
318        instance_key = self._collective_keys.get_instance_key(
319            self._group_key, self._device)
320        with ops.device('CPU:0'):
321          return ops.convert_to_tensor(instance_key, dtype=dtypes.int32)
322    else:
323      return self._collective_keys.get_instance_key(self._group_key,
324                                                    self._device)
325
326  def _get_ordering_token(self, communication_hint):
327    if self._use_ordering_token() and communication_hint == 'NCCL':
328      return self._ordering_token.handle
329    return None
330
331  def can_order_nccl(self):
332    """Whether this launcher can order NCCL operations."""
333    return self._use_ordering_token()
334
335  def all_reduce(self,
336                 input_tensor,
337                 control_input=None,
338                 communication_hint='AUTO',
339                 timeout=0):
340    """All-reduce a dense tensor.
341
342    Args:
343      input_tensor: a dense tensor. It must have the same shape on all replicas.
344      control_input: if not None, add control edges between control_input and
345        the all-reduce.
346      communication_hint: string providing hint to runtime for choosing
347        collective implementation.
348      timeout: a float. The timeout in seconds.
349
350    Returns:
351      The reduced tensor.
352    """
353    instance_key = self._next_instance_key()
354    ordering_token = self._get_ordering_token(communication_hint)
355    with ops.device(self._device), \
356         self._control_input(control_input):
357      return collective_ops.all_reduce_v2(
358          input_tensor,
359          self._group_size,
360          self._group_key,
361          instance_key,
362          communication_hint=communication_hint,
363          timeout=timeout,
364          ordering_token=ordering_token)
365
366  def _all_gather(self, input_tensor, communication_hint='AUTO', timeout=0):
367    """All-gather a dense tensor.
368
369    Args:
370      input_tensor: a dense tensor. It must have the same shape on all replicas.
371      communication_hint: string providing hint to runtime for choosing
372        collective implementation.
373      timeout: a float. The timeout in seconds.
374
375    Returns:
376      The reduced tensor.
377    """
378    instance_key = self._next_instance_key()
379    ordering_token = self._get_ordering_token(communication_hint)
380    with ops.device(self._device):
381      return collective_ops.all_gather_v2(
382          input_tensor,
383          self._group_size,
384          self._group_key,
385          instance_key,
386          communication_hint=communication_hint,
387          timeout=timeout,
388          ordering_token=ordering_token)
389
390  def batch_all_reduce(self,
391                       input_tensor_packs,
392                       communication_hint='AUTO',
393                       timeout=0):
394    """Batch all-reduce dense tensors.
395
396    This takes a list of batches of tensors. Using multiple batches have the
397    benefit that it doesn't need to wait for all inputs to be ready to start the
398    all-reduce.
399
400    Args:
401      input_tensor_packs: a list of lists of dense tensors.
402      communication_hint: string providing hint to runtime for choosing
403        collective implementation.
404      timeout: a float. The timeout in seconds.
405
406    Returns:
407      A flat list of reduced tensors.
408    """
409    outputs = []
410    for pack in input_tensor_packs:
411      if context.executing_eagerly():
412        # We don't batch in eager as it sometimes makes the performance worse
413        # due the concat/split ops.
414        for input_tensor in pack:
415          outputs.append(
416              self.all_reduce(input_tensor, None, communication_hint, timeout))
417      else:
418        # TODO(b/169168846): inserts a parallel all_gather to verify packings
419        # are the same on each replica.
420        with ops.device(self._device):
421          flat_tensors = [array_ops.reshape(t, [-1]) for t in pack]
422          shapes = [array_ops.shape(t) for t in pack]
423          if communication_hint == 'NCCL' and outputs:
424            control_input = outputs[-1]
425          else:
426            control_input = None
427          reduced = self.all_reduce(
428              array_ops.concat(flat_tensors, axis=0), control_input,
429              communication_hint, timeout)
430          num_elements = [math_ops.reduce_prod(s) for s in shapes]
431          flat_outputs = array_ops.split(reduced, num_elements, axis=0)
432          for shape, flat_output in zip(shapes, flat_outputs):
433            outputs.append(array_ops.reshape(flat_output, shape))
434
435    return outputs
436
437  def all_gather(self,
438                 input_tensor,
439                 axis,
440                 communication_hint='AUTO',
441                 timeout=0):
442    """All-gather a dense tensor.
443
444    This method must be called inside a tf.function.
445
446    Args:
447      input_tensor: a dense tensor. It must have the same rank on all replicas,
448        and dimensions other than `axis` need to be the same as well.
449      axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
450        range [0, rank(value)).
451      communication_hint: string providing hint to runtime for choosing
452        collective implementation. Available options are `AUTO`, `NCCL`, and
453        `RING`.
454      timeout: a float. The timeout in seconds.
455
456    Returns:
457      The gathered Tensor.
458
459    Raises:
460      RuntimeError: if called in eager mode.
461    """
462    if context.executing_eagerly():
463      raise RuntimeError('all_gather in eager mode is not supported')
464
465    with ops.device(self._device), \
466         ops.control_dependencies([array_ops.identity(input_tensor)]):
467      # 1. Transpose
468      # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3,
469      # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which
470      # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to
471      # place it back.
472      perm_pre = array_ops.concat(
473          ([axis], math_ops.range(axis),
474           math_ops.range(axis + 1, array_ops.rank(input_tensor))),
475          axis=0)
476      input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre)
477      # 2. Pad
478      gathered_shape = self._all_gather(
479          array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0),
480          communication_hint,
481          timeout=timeout)
482      first_dims = gathered_shape[:, 0]
483      full_axis_dim = math_ops.reduce_max(first_dims)
484      padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim)
485
486      # 3. Gather
487      gather_padded_out_tensor = self._all_gather(
488          padded_input_tensor, communication_hint, timeout=timeout)
489      # 4. Unpad
490      split_tensors = []
491      for i in range(self._group_size):
492        start_pos = i * full_axis_dim
493        split_tensors.append(gather_padded_out_tensor[start_pos:start_pos +
494                                                      first_dims[i]])
495      out_tensor_t = array_ops.concat(split_tensors, 0)
496
497      # 5. Transpose back
498      perm_after = array_ops.concat(
499          (math_ops.range(1, axis + 1), [0],
500           math_ops.range(axis + 1, array_ops.rank(input_tensor_t))),
501          axis=0)
502      return array_ops.transpose(out_tensor_t, perm=perm_after)
503
504  def all_reduce_indexed_slices(self,
505                                input_slices,
506                                communication_hint='AUTO',
507                                timeout=0):
508    """All-reduce an IndexedSlices.
509
510    This method must be called inside a tf.function.
511
512    Args:
513      input_slices: an IndexedSlices.
514      communication_hint: string providing hint to runtime for choosing
515        collective implementation.
516      timeout: a float. The timeout in seconds.
517
518    Returns:
519      The reduced IndexedSlices.
520
521    Raises:
522      RuntimeError: if called in eager mode.
523    """
524    if context.executing_eagerly():
525      raise RuntimeError(
526          'all_reduce_indexed_slices in eager mode is not supported')
527
528    # Current CollectiveAllGather implementations require input IndexedSlices to
529    # have consistent length across the board, we handle the reduction of
530    # IndexedSlices as follows:
531    #   1. Gather the lengths of IndexedSlices from all participants.
532    #   2. If they have consistent length, apply all_gather.
533    #   3. Otherwise convert IndexedSlices to dense tensors and apply
534    #      all_reduce.
535    with ops.device(self._device):
536
537      def all_gather():
538        """Use all_gather to aggregate `IndexedSlices`."""
539        all_values = self._all_gather(
540            input_slices.values, communication_hint, timeout=timeout)
541        # Add control dependency to order the all-gather.
542        control = [all_values] if communication_hint == 'NCCL' else []
543        with ops.control_dependencies(control):
544          all_indices = self._all_gather(
545              input_slices.indices, communication_hint, timeout=timeout)
546        return ops.IndexedSlices(
547            values=all_values,
548            indices=all_indices,
549            dense_shape=input_slices.dense_shape)
550
551      def densify_and_all_reduce():
552        """Use all_reduce to aggregate `IndexedSlices`."""
553        densified = ops.convert_to_tensor(input_slices)
554        reduced = self.all_reduce(
555            densified, communication_hint=communication_hint, timeout=timeout)
556        # We have to convert dense grad to IndexedSlice because all_reduce()
557        # and all_gather() must have the same return type as required by
558        # control_flow_ops.cond.
559        return ops.IndexedSlices(
560            values=reduced,
561            indices=math_ops.range(array_ops.shape(reduced)[0]),
562            dense_shape=input_slices.dense_shape)
563
564      length = array_ops.shape(input_slices.indices)
565      all_lengths = self._all_gather(
566          length, communication_hint, timeout=timeout)
567      return control_flow_ops.cond(
568          math_ops.equal(
569              math_ops.reduce_max(all_lengths),
570              math_ops.reduce_min(all_lengths)), all_gather,
571          densify_and_all_reduce)
572
573
574def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n):
575  """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat."""
576  if any(isinstance(v, ops.IndexedSlices) for v in values):
577    return backprop.aggregate_indexed_slices_gradients(values)
578  else:
579    return accumulation_fn(values)
580
581
582def divide_by_n_tensors_or_indexed_slices(value, n):
583  if isinstance(value, ops.IndexedSlices):
584    value = backprop.flatten_nested_indexed_slices(value)
585    return ops.IndexedSlices(
586        value.values / n, value.indices, value.dense_shape)
587  else:
588    return value / n
589
590
591def copy_tensor_or_indexed_slices_to_device(value, device):
592  with ops.device(device):
593    if isinstance(value, ops.IndexedSlices):
594      copied_values = array_ops.identity(value.values)
595      copied_indices = array_ops.identity(value.indices)
596      copied_shape = array_ops.identity(value.dense_shape)
597      result = ops.IndexedSlices(copied_values, copied_indices, copied_shape)
598    else:
599      result = array_ops.identity(value)
600  return result
601
602
603def is_indexed_slices(value):
604  if isinstance(value, ops.IndexedSlices):
605    return True
606  assert isinstance(value, value_lib.DistributedValues)
607  return all(isinstance(v, ops.IndexedSlices) for v in value.values)
608
609
610def split_by_sparsity(values):
611  """Split values into dense and sparse values.
612
613  Args:
614    values: a list of tensors or `PerReplica`s.
615
616  Returns:
617    Four lists:
618      a list of dense values, a list of their indices in `values` and
619      a list of sparse values, a list of their indices in `values`.
620  """
621  dense_values = []
622  dense_indices = []
623  sparse_values = []
624  sparse_indices = []
625  for i, v in enumerate(values):
626    if is_indexed_slices(v):
627      sparse_values.append(v)
628      sparse_indices.append(i)
629    else:
630      dense_values.append(v)
631      dense_indices.append(i)
632  return dense_values, dense_indices, sparse_values, sparse_indices
633
634
635def stitch_values(values_and_indices_list):
636  """Stitch values together according to their indices.
637
638  Args:
639    values_and_indices_list: a list of tuples of values and indices indicating
640      the values and positions in the returned list.
641
642  Returns:
643    a stitched list of values.
644  """
645  length = 0
646  for values_and_indices in values_and_indices_list:
647    length += len(values_and_indices[0])
648
649  result = [None] * length
650  for values_and_indices in values_and_indices_list:
651    if values_and_indices and values_and_indices[0]:
652      for v, i in zip(*values_and_indices):
653        assert result[i] is None
654        result[i] = v
655  return result
656
657
658def group_by_size(input_tensors, bytes_per_pack):
659  """Groups `input_tensors` into chunks of `bytes_per_pack`.
660
661  The method preserves the original order of `input_tensors`. The grouping is
662  best effort, each pack could have more or less bytes than `bytes_per_pack`.
663  It only groups values with known shape.
664
665  Args:
666    input_tensors: a list of Tensor.
667    bytes_per_pack: an integer.
668
669  Returns:
670    A list of packs of Tensor. All values are grouped into one pack if
671    `bytes_per_pack` is zero or any of the value has unknown shape.
672  """
673
674  if bytes_per_pack == 0:
675    return [input_tensors]
676  packs = []
677  last_pack_size = 0
678  for value in input_tensors:
679    num_elements = value.shape.num_elements()
680    if num_elements is None:
681      # Can't pack values with unknown shape.
682      logging.warning(
683          'not packing values due to the unknown or inconsistent shape of %s',
684          value)
685      return [input_tensors]
686    size = num_elements * value.dtype.size
687    # Try to keep each pack as close to bytes_per_pack as possible, while each
688    # pack is at least bytes_per_pack large. I.E. we err on the side of having
689    # few but large packs.
690    if not packs or last_pack_size > bytes_per_pack:
691      packs.append([])
692      last_pack_size = 0
693    packs[-1].append(value)
694    last_pack_size += size
695  return packs
696
697
698def _pad_util(input_tensor, full_axis_dim):
699  """Pad the `input_tensor`'s first dimension to be `full_axis_dim`."""
700  missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0]
701  tensor_rank = array_ops.rank(input_tensor)
702  paddings_axis = [[0, missing_axis_dim]]
703  paddings = array_ops.concat([
704      paddings_axis,
705      array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32)
706  ],
707                              axis=0)
708  padded_input_tensor = array_ops.pad(input_tensor, paddings)
709  return padded_input_tensor
710