1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for cross_device_ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import threading 23 24from tensorflow.python.distribute import values as value_lib 25from tensorflow.python.eager import backprop 26from tensorflow.python.eager import context 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_spec 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import collective_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import nccl_ops 35from tensorflow.python.ops import resource_variable_ops 36from tensorflow.python.platform import tf_logging as logging 37 38INSTANCE_KEY_START_NUMBER = 100 39 40 41def aggregate_gradients_using_nccl(replica_grads): 42 """Aggregate gradients using nccl allreduce.""" 43 agg_all_g_and_v = [] 44 for single_g_and_v in zip(*replica_grads): 45 single_grads = [g for g, _ in single_g_and_v] 46 agg_grads = nccl_ops.all_sum(single_grads) 47 agg_all_g_and_v.append( 48 [(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)]) 49 50 agg_all_g_and_v = list(zip(*agg_all_g_and_v)) 51 52 return agg_all_g_and_v 53 54 55def aggregate_gradients_using_hierarchical_copy(avail_devices, replica_grads): 56 """Aggregate gradients using hierarchical copies. 57 58 Args: 59 avail_devices: available GPU devices. 60 replica_grads: List of lists of (gradient, variable) tuples. The outer list 61 is over replicas. The inner list is over individual gradients. 62 63 Returns: 64 The list of (aggregated_gradient, variable), where the gradient has been 65 summed across all replicas and the variable is chosen from the first 66 replica. 67 """ 68 # This only works for DGX-1 type of machine topology 69 # Device peer to peer matrix 70 # DMA: 0 1 2 3 4 5 6 7 71 # 0: Y Y Y Y Y N N N 72 # 1: Y Y Y Y N Y N N 73 # 2: Y Y Y Y N N Y N 74 # 3: Y Y Y Y N N N Y 75 # 4: Y N N N Y Y Y Y 76 # 5: N Y N N Y Y Y Y 77 # 6: N N Y N Y Y Y Y 78 # 7: N N N Y Y Y Y Y 79 agg_grads = [] 80 num_devices = len(avail_devices) 81 # In the special case of DGX-1 machine topology, the two groups have equal 82 # size. 83 group_size = num_devices // 2 84 for i, single_grads in enumerate(zip(*replica_grads)): 85 group_0_main_device = i % num_devices 86 group_1_main_device = (group_0_main_device + group_size) % num_devices 87 if group_0_main_device < group_size: 88 group_0_begin = 0 89 group_1_begin = group_size 90 else: 91 group_0_begin = group_size 92 group_1_begin = 0 93 94 # Aggregate the first group. 95 group_0_device_grads = single_grads[group_0_begin: 96 group_0_begin + group_size] 97 with ops.device(avail_devices[group_0_main_device]): 98 group_0_agg_grads, _ = aggregate_single_gradient_using_copy( 99 group_0_device_grads, False, False) 100 101 # Aggregate the second group. 102 group_1_device_grads = single_grads[group_1_begin: 103 group_1_begin + group_size] 104 with ops.device(avail_devices[group_1_main_device]): 105 group_1_agg_grads, _ = aggregate_single_gradient_using_copy( 106 group_1_device_grads, False, False) 107 108 # Aggregate between the groups. 109 with ops.device(avail_devices[group_0_main_device]): 110 (agg_total_grads, _), _ = aggregate_single_gradient_using_copy( 111 [group_0_agg_grads, group_1_agg_grads], False, False) 112 113 # Broadcast the result back into the root of each group. 114 with ops.device(avail_devices[group_0_main_device]): 115 group_0_agg_grads_bcast = array_ops.identity(agg_total_grads) 116 with ops.device(avail_devices[group_1_main_device]): 117 group_1_agg_grads_bcast = array_ops.identity(agg_total_grads) 118 119 agg_grads_bcast = [] 120 for j in range(len(single_grads)): 121 with ops.device(avail_devices[j]): 122 # Broadcast the result back to each member in the group from the root. 123 if (group_0_main_device < group_size) == (j < group_size): 124 src_device_grad = group_0_agg_grads_bcast 125 else: 126 src_device_grad = group_1_agg_grads_bcast 127 agg_grads_bcast.append(array_ops.identity(src_device_grad)) 128 129 agg_grads.append( 130 [(g, v) for g, (_, v) in zip(agg_grads_bcast, single_grads)]) 131 132 agg_grads = list(zip(*agg_grads)) 133 134 return agg_grads 135 136 137def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, 138 check_inf_nan): 139 """Calculate the average gradient for a shared variable across all replicas. 140 141 Note that this function provides a synchronization point across all replicas. 142 143 Args: 144 grad_and_vars: A list or tuple of (gradient, variable) tuples. Each 145 (gradient, variable) pair within the outer list represents the gradient 146 of the variable calculated for a single replica, and the number of pairs 147 equals the number of replicas. 148 use_mean: if True, mean is taken, else sum of gradients is taken. 149 check_inf_nan: check grads for nans and infs. 150 151 Returns: 152 The tuple ([(average_gradient, variable),], has_nan_or_inf) where the 153 gradient has been averaged across all replicas. The variable is chosen 154 from the first replica. The has_nan_or_inf indicates the grads has nan or 155 inf. 156 """ 157 grads = [g for g, _ in grad_and_vars] 158 grad = math_ops.add_n(grads) 159 160 if use_mean and len(grads) > 1: 161 grad = array_ops.multiply(grad, 1.0 / len(grads)) 162 163 v = grad_and_vars[0][1] 164 if check_inf_nan: 165 has_nan_or_inf = array_ops.logical_not( 166 array_ops.reduce_all(array_ops.is_finite(grads))) 167 return (grad, v), has_nan_or_inf 168 else: 169 return (grad, v), None 170 171 172# TODO(yuefengz): use random key starts to avoid reusing keys? 173class CollectiveKeys(object): 174 """Class that manages collective keys. 175 176 We need to manage three different keys for collective: 177 178 *Group key*: an integer key to identify the set of cooperative devices. 179 Collective ops work under the same set of devices must using the same group 180 key. 181 182 *Instance key*: an integer key to identify the set of same counterpart of 183 tensors on different devices in a device group that need to be all-reduced. 184 185 This class is thread safe. 186 """ 187 188 def __init__(self, group_key_start=1): 189 """Initializes the object. 190 191 Args: 192 group_key_start: the starting integer of group key. 193 """ 194 self._group_key = group_key_start 195 self._group_key_table = {} 196 self._instance_key_table = {} 197 self._lock = threading.Lock() 198 199 def get_group_key(self, devices): 200 """Returns a group key for the set of devices. 201 202 Args: 203 devices: a list of canonical device strings in a collective group. 204 205 Returns: 206 int key uniquely identifying the set of device names. 207 """ 208 key_id = hash(tuple(sorted(devices))) 209 with self._lock: 210 if key_id not in self._group_key_table: 211 new_key = self._group_key 212 self._group_key += 1 213 self._group_key_table[key_id] = new_key 214 self._instance_key_table[new_key] = {} 215 for device in devices: 216 self._instance_key_table[new_key][device] = INSTANCE_KEY_START_NUMBER 217 return self._group_key_table[key_id] 218 219 def get_instance_key(self, group_key, device): 220 """Returns a new instance key for use in defining a collective op. 221 222 You should call this once per each collective op of a collective instance. 223 224 Args: 225 group_key: the group key returned by get_group_key(). You should not 226 assign the group key yourself. 227 device: a canonical device string. It should be the device this collective 228 op is on. 229 230 Returns: 231 a new instance key. 232 233 Raises: 234 ValueError: when the group key is invalid or the device is not in the 235 group. 236 """ 237 with self._lock: 238 group = self._instance_key_table.get(group_key, None) 239 if group is None: 240 raise ValueError('group {} not found'.format(group_key)) 241 if device not in group: 242 raise ValueError('{} not in group {}'.format(device, group_key)) 243 v = group[device] 244 group[device] += 1 245 return v 246 247 def __deepcopy__(self, memo): 248 # distribute_coordinator deep-copies the strategy object, so 249 # CollectiveKeys needs to support deep copy as well. 250 copied = CollectiveKeys() 251 copied._group_key = self._group_key 252 copied._group_key_table = copy.deepcopy(self._group_key_table, memo) 253 copied._instance_key_table = copy.deepcopy(self._instance_key_table, memo) 254 return copied 255 256 257class CollectiveReplicaLauncher(object): 258 """Launch collectives on one replica.""" 259 260 _prefer_unique_instance_key = True 261 _prefer_ordering_token = True 262 263 def __init__(self, 264 group_key, 265 group_size, 266 collective_keys, 267 device): 268 self._group_key = group_key 269 self._group_size = group_size 270 self._collective_keys = collective_keys 271 self._device = device 272 if self._use_ordering_token(): 273 with ops.init_scope(), ops.device(device): 274 self._ordering_token = resource_variable_ops.ResourceVariable(0.) 275 else: 276 self._ordering_token = None 277 278 def _control_input(self, control_input): 279 if control_input is not None and not self._use_ordering_token(): 280 return ops.control_dependencies([control_input]) 281 return ops.NullContextmanager() 282 283 def _use_unique_instance_key(self): 284 if not ops.executing_eagerly_outside_functions(): 285 return False 286 return CollectiveReplicaLauncher._prefer_unique_instance_key 287 288 def _use_ordering_token(self): 289 # We rely on auto control dep to insert control edges between NCCL calls, 290 # but for tf1 graph mode auto control dep is not used. 291 if not ops.executing_eagerly_outside_functions(): 292 return False 293 return CollectiveReplicaLauncher._prefer_ordering_token 294 295 def _next_instance_key(self): 296 """Returns the next instance key.""" 297 if self._use_unique_instance_key(): 298 # Assigning instance keys at function building time have issues since 299 # different workers may retrace the function at different times. With 300 # collective V2 we can use capture_call_time_value to use a placeholder as 301 # the instance key and feed it at function call time. In this way we also 302 # don't reuse instance keys, which allows for per-instance cancellation. 303 graph = ops.get_default_graph() 304 # Control flow ops don't work with capture_call_time_value, so we put the 305 # capture in the function graph of that control flow op. 306 while getattr(graph, 'is_control_flow_graph', False): 307 graph = graph.outer_graph 308 if not context.executing_eagerly() and graph.building_function: 309 with graph.as_default(): 310 # Capture self._next_instance_key so that when building a function 311 # that calls another tf.function, the instance key assignment is 312 # further delayed until we actually call the function in eager. Note 313 # that capture_call_time_value doesn't automatically propagate the 314 # deferred capture to the outer function. 315 return graph.capture_call_time_value( 316 self._next_instance_key, tensor_spec.TensorSpec([], dtypes.int32)) 317 else: 318 instance_key = self._collective_keys.get_instance_key( 319 self._group_key, self._device) 320 with ops.device('CPU:0'): 321 return ops.convert_to_tensor(instance_key, dtype=dtypes.int32) 322 else: 323 return self._collective_keys.get_instance_key(self._group_key, 324 self._device) 325 326 def _get_ordering_token(self, communication_hint): 327 if self._use_ordering_token() and communication_hint == 'NCCL': 328 return self._ordering_token.handle 329 return None 330 331 def can_order_nccl(self): 332 """Whether this launcher can order NCCL operations.""" 333 return self._use_ordering_token() 334 335 def all_reduce(self, 336 input_tensor, 337 control_input=None, 338 communication_hint='AUTO', 339 timeout=0): 340 """All-reduce a dense tensor. 341 342 Args: 343 input_tensor: a dense tensor. It must have the same shape on all replicas. 344 control_input: if not None, add control edges between control_input and 345 the all-reduce. 346 communication_hint: string providing hint to runtime for choosing 347 collective implementation. 348 timeout: a float. The timeout in seconds. 349 350 Returns: 351 The reduced tensor. 352 """ 353 instance_key = self._next_instance_key() 354 ordering_token = self._get_ordering_token(communication_hint) 355 with ops.device(self._device), \ 356 self._control_input(control_input): 357 return collective_ops.all_reduce_v2( 358 input_tensor, 359 self._group_size, 360 self._group_key, 361 instance_key, 362 communication_hint=communication_hint, 363 timeout=timeout, 364 ordering_token=ordering_token) 365 366 def _all_gather(self, input_tensor, communication_hint='AUTO', timeout=0): 367 """All-gather a dense tensor. 368 369 Args: 370 input_tensor: a dense tensor. It must have the same shape on all replicas. 371 communication_hint: string providing hint to runtime for choosing 372 collective implementation. 373 timeout: a float. The timeout in seconds. 374 375 Returns: 376 The reduced tensor. 377 """ 378 instance_key = self._next_instance_key() 379 ordering_token = self._get_ordering_token(communication_hint) 380 with ops.device(self._device): 381 return collective_ops.all_gather_v2( 382 input_tensor, 383 self._group_size, 384 self._group_key, 385 instance_key, 386 communication_hint=communication_hint, 387 timeout=timeout, 388 ordering_token=ordering_token) 389 390 def batch_all_reduce(self, 391 input_tensor_packs, 392 communication_hint='AUTO', 393 timeout=0): 394 """Batch all-reduce dense tensors. 395 396 This takes a list of batches of tensors. Using multiple batches have the 397 benefit that it doesn't need to wait for all inputs to be ready to start the 398 all-reduce. 399 400 Args: 401 input_tensor_packs: a list of lists of dense tensors. 402 communication_hint: string providing hint to runtime for choosing 403 collective implementation. 404 timeout: a float. The timeout in seconds. 405 406 Returns: 407 A flat list of reduced tensors. 408 """ 409 outputs = [] 410 for pack in input_tensor_packs: 411 if context.executing_eagerly(): 412 # We don't batch in eager as it sometimes makes the performance worse 413 # due the concat/split ops. 414 for input_tensor in pack: 415 outputs.append( 416 self.all_reduce(input_tensor, None, communication_hint, timeout)) 417 else: 418 # TODO(b/169168846): inserts a parallel all_gather to verify packings 419 # are the same on each replica. 420 with ops.device(self._device): 421 flat_tensors = [array_ops.reshape(t, [-1]) for t in pack] 422 shapes = [array_ops.shape(t) for t in pack] 423 if communication_hint == 'NCCL' and outputs: 424 control_input = outputs[-1] 425 else: 426 control_input = None 427 reduced = self.all_reduce( 428 array_ops.concat(flat_tensors, axis=0), control_input, 429 communication_hint, timeout) 430 num_elements = [math_ops.reduce_prod(s) for s in shapes] 431 flat_outputs = array_ops.split(reduced, num_elements, axis=0) 432 for shape, flat_output in zip(shapes, flat_outputs): 433 outputs.append(array_ops.reshape(flat_output, shape)) 434 435 return outputs 436 437 def all_gather(self, 438 input_tensor, 439 axis, 440 communication_hint='AUTO', 441 timeout=0): 442 """All-gather a dense tensor. 443 444 This method must be called inside a tf.function. 445 446 Args: 447 input_tensor: a dense tensor. It must have the same rank on all replicas, 448 and dimensions other than `axis` need to be the same as well. 449 axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the 450 range [0, rank(value)). 451 communication_hint: string providing hint to runtime for choosing 452 collective implementation. Available options are `AUTO`, `NCCL`, and 453 `RING`. 454 timeout: a float. The timeout in seconds. 455 456 Returns: 457 The gathered Tensor. 458 459 Raises: 460 RuntimeError: if called in eager mode. 461 """ 462 if context.executing_eagerly(): 463 raise RuntimeError('all_gather in eager mode is not supported') 464 465 with ops.device(self._device), \ 466 ops.control_dependencies([array_ops.identity(input_tensor)]): 467 # 1. Transpose 468 # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3, 469 # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which 470 # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to 471 # place it back. 472 perm_pre = array_ops.concat( 473 ([axis], math_ops.range(axis), 474 math_ops.range(axis + 1, array_ops.rank(input_tensor))), 475 axis=0) 476 input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre) 477 # 2. Pad 478 gathered_shape = self._all_gather( 479 array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0), 480 communication_hint, 481 timeout=timeout) 482 first_dims = gathered_shape[:, 0] 483 full_axis_dim = math_ops.reduce_max(first_dims) 484 padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim) 485 486 # 3. Gather 487 gather_padded_out_tensor = self._all_gather( 488 padded_input_tensor, communication_hint, timeout=timeout) 489 # 4. Unpad 490 split_tensors = [] 491 for i in range(self._group_size): 492 start_pos = i * full_axis_dim 493 split_tensors.append(gather_padded_out_tensor[start_pos:start_pos + 494 first_dims[i]]) 495 out_tensor_t = array_ops.concat(split_tensors, 0) 496 497 # 5. Transpose back 498 perm_after = array_ops.concat( 499 (math_ops.range(1, axis + 1), [0], 500 math_ops.range(axis + 1, array_ops.rank(input_tensor_t))), 501 axis=0) 502 return array_ops.transpose(out_tensor_t, perm=perm_after) 503 504 def all_reduce_indexed_slices(self, 505 input_slices, 506 communication_hint='AUTO', 507 timeout=0): 508 """All-reduce an IndexedSlices. 509 510 This method must be called inside a tf.function. 511 512 Args: 513 input_slices: an IndexedSlices. 514 communication_hint: string providing hint to runtime for choosing 515 collective implementation. 516 timeout: a float. The timeout in seconds. 517 518 Returns: 519 The reduced IndexedSlices. 520 521 Raises: 522 RuntimeError: if called in eager mode. 523 """ 524 if context.executing_eagerly(): 525 raise RuntimeError( 526 'all_reduce_indexed_slices in eager mode is not supported') 527 528 # Current CollectiveAllGather implementations require input IndexedSlices to 529 # have consistent length across the board, we handle the reduction of 530 # IndexedSlices as follows: 531 # 1. Gather the lengths of IndexedSlices from all participants. 532 # 2. If they have consistent length, apply all_gather. 533 # 3. Otherwise convert IndexedSlices to dense tensors and apply 534 # all_reduce. 535 with ops.device(self._device): 536 537 def all_gather(): 538 """Use all_gather to aggregate `IndexedSlices`.""" 539 all_values = self._all_gather( 540 input_slices.values, communication_hint, timeout=timeout) 541 # Add control dependency to order the all-gather. 542 control = [all_values] if communication_hint == 'NCCL' else [] 543 with ops.control_dependencies(control): 544 all_indices = self._all_gather( 545 input_slices.indices, communication_hint, timeout=timeout) 546 return ops.IndexedSlices( 547 values=all_values, 548 indices=all_indices, 549 dense_shape=input_slices.dense_shape) 550 551 def densify_and_all_reduce(): 552 """Use all_reduce to aggregate `IndexedSlices`.""" 553 densified = ops.convert_to_tensor(input_slices) 554 reduced = self.all_reduce( 555 densified, communication_hint=communication_hint, timeout=timeout) 556 # We have to convert dense grad to IndexedSlice because all_reduce() 557 # and all_gather() must have the same return type as required by 558 # control_flow_ops.cond. 559 return ops.IndexedSlices( 560 values=reduced, 561 indices=math_ops.range(array_ops.shape(reduced)[0]), 562 dense_shape=input_slices.dense_shape) 563 564 length = array_ops.shape(input_slices.indices) 565 all_lengths = self._all_gather( 566 length, communication_hint, timeout=timeout) 567 return control_flow_ops.cond( 568 math_ops.equal( 569 math_ops.reduce_max(all_lengths), 570 math_ops.reduce_min(all_lengths)), all_gather, 571 densify_and_all_reduce) 572 573 574def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): 575 """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat.""" 576 if any(isinstance(v, ops.IndexedSlices) for v in values): 577 return backprop.aggregate_indexed_slices_gradients(values) 578 else: 579 return accumulation_fn(values) 580 581 582def divide_by_n_tensors_or_indexed_slices(value, n): 583 if isinstance(value, ops.IndexedSlices): 584 value = backprop.flatten_nested_indexed_slices(value) 585 return ops.IndexedSlices( 586 value.values / n, value.indices, value.dense_shape) 587 else: 588 return value / n 589 590 591def copy_tensor_or_indexed_slices_to_device(value, device): 592 with ops.device(device): 593 if isinstance(value, ops.IndexedSlices): 594 copied_values = array_ops.identity(value.values) 595 copied_indices = array_ops.identity(value.indices) 596 copied_shape = array_ops.identity(value.dense_shape) 597 result = ops.IndexedSlices(copied_values, copied_indices, copied_shape) 598 else: 599 result = array_ops.identity(value) 600 return result 601 602 603def is_indexed_slices(value): 604 if isinstance(value, ops.IndexedSlices): 605 return True 606 assert isinstance(value, value_lib.DistributedValues) 607 return all(isinstance(v, ops.IndexedSlices) for v in value.values) 608 609 610def split_by_sparsity(values): 611 """Split values into dense and sparse values. 612 613 Args: 614 values: a list of tensors or `PerReplica`s. 615 616 Returns: 617 Four lists: 618 a list of dense values, a list of their indices in `values` and 619 a list of sparse values, a list of their indices in `values`. 620 """ 621 dense_values = [] 622 dense_indices = [] 623 sparse_values = [] 624 sparse_indices = [] 625 for i, v in enumerate(values): 626 if is_indexed_slices(v): 627 sparse_values.append(v) 628 sparse_indices.append(i) 629 else: 630 dense_values.append(v) 631 dense_indices.append(i) 632 return dense_values, dense_indices, sparse_values, sparse_indices 633 634 635def stitch_values(values_and_indices_list): 636 """Stitch values together according to their indices. 637 638 Args: 639 values_and_indices_list: a list of tuples of values and indices indicating 640 the values and positions in the returned list. 641 642 Returns: 643 a stitched list of values. 644 """ 645 length = 0 646 for values_and_indices in values_and_indices_list: 647 length += len(values_and_indices[0]) 648 649 result = [None] * length 650 for values_and_indices in values_and_indices_list: 651 if values_and_indices and values_and_indices[0]: 652 for v, i in zip(*values_and_indices): 653 assert result[i] is None 654 result[i] = v 655 return result 656 657 658def group_by_size(input_tensors, bytes_per_pack): 659 """Groups `input_tensors` into chunks of `bytes_per_pack`. 660 661 The method preserves the original order of `input_tensors`. The grouping is 662 best effort, each pack could have more or less bytes than `bytes_per_pack`. 663 It only groups values with known shape. 664 665 Args: 666 input_tensors: a list of Tensor. 667 bytes_per_pack: an integer. 668 669 Returns: 670 A list of packs of Tensor. All values are grouped into one pack if 671 `bytes_per_pack` is zero or any of the value has unknown shape. 672 """ 673 674 if bytes_per_pack == 0: 675 return [input_tensors] 676 packs = [] 677 last_pack_size = 0 678 for value in input_tensors: 679 num_elements = value.shape.num_elements() 680 if num_elements is None: 681 # Can't pack values with unknown shape. 682 logging.warning( 683 'not packing values due to the unknown or inconsistent shape of %s', 684 value) 685 return [input_tensors] 686 size = num_elements * value.dtype.size 687 # Try to keep each pack as close to bytes_per_pack as possible, while each 688 # pack is at least bytes_per_pack large. I.E. we err on the side of having 689 # few but large packs. 690 if not packs or last_pack_size > bytes_per_pack: 691 packs.append([]) 692 last_pack_size = 0 693 packs[-1].append(value) 694 last_pack_size += size 695 return packs 696 697 698def _pad_util(input_tensor, full_axis_dim): 699 """Pad the `input_tensor`'s first dimension to be `full_axis_dim`.""" 700 missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0] 701 tensor_rank = array_ops.rank(input_tensor) 702 paddings_axis = [[0, missing_axis_dim]] 703 paddings = array_ops.concat([ 704 paddings_axis, 705 array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32) 706 ], 707 axis=0) 708 padded_input_tensor = array_ops.pad(input_tensor, paddings) 709 return padded_input_tensor 710