1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for cross_device_ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections as pycoll 22import threading 23 24from tensorflow.python.distribute import all_reduce 25from tensorflow.python.distribute import values as value_lib 26from tensorflow.python.eager import context 27from tensorflow.python.eager import def_function 28from tensorflow.python.framework import device as pydev 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import collective_ops 33from tensorflow.python.ops import gradients_util 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops import nccl_ops 36 37 38def aggregate_gradients_using_nccl(replica_grads): 39 """Aggregate gradients using nccl allreduce.""" 40 agg_all_g_and_v = [] 41 for single_g_and_v in zip(*replica_grads): 42 single_grads = [g for g, _ in single_g_and_v] 43 agg_grads = nccl_ops.all_sum(single_grads) 44 agg_all_g_and_v.append( 45 [(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)]) 46 47 agg_all_g_and_v = list(zip(*agg_all_g_and_v)) 48 49 return agg_all_g_and_v 50 51 52def aggregate_gradients_using_hierarchical_copy(avail_devices, replica_grads): 53 """Aggregate gradients using hierarchical copies. 54 55 Args: 56 avail_devices: available GPU devices. 57 replica_grads: List of lists of (gradient, variable) tuples. The outer list 58 is over replicas. The inner list is over individual gradients. 59 60 Returns: 61 The list of (aggregated_gradient, variable), where the gradient has been 62 summed across all replicas and the variable is chosen from the first 63 replica. 64 """ 65 # This only works for DGX-1 type of machine topology 66 # Device peer to peer matrix 67 # DMA: 0 1 2 3 4 5 6 7 68 # 0: Y Y Y Y Y N N N 69 # 1: Y Y Y Y N Y N N 70 # 2: Y Y Y Y N N Y N 71 # 3: Y Y Y Y N N N Y 72 # 4: Y N N N Y Y Y Y 73 # 5: N Y N N Y Y Y Y 74 # 6: N N Y N Y Y Y Y 75 # 7: N N N Y Y Y Y Y 76 agg_grads = [] 77 num_devices = len(avail_devices) 78 # In the special case of DGX-1 machine topology, the two groups have equal 79 # size. 80 group_size = num_devices // 2 81 for i, single_grads in enumerate(zip(*replica_grads)): 82 group_0_main_device = i % num_devices 83 group_1_main_device = (group_0_main_device + group_size) % num_devices 84 if group_0_main_device < group_size: 85 group_0_begin = 0 86 group_1_begin = group_size 87 else: 88 group_0_begin = group_size 89 group_1_begin = 0 90 91 # Aggregate the first group. 92 group_0_device_grads = single_grads[group_0_begin: 93 group_0_begin + group_size] 94 with ops.device(avail_devices[group_0_main_device]): 95 group_0_agg_grads, _ = aggregate_single_gradient_using_copy( 96 group_0_device_grads, False, False) 97 98 # Aggregate the second group. 99 group_1_device_grads = single_grads[group_1_begin: 100 group_1_begin + group_size] 101 with ops.device(avail_devices[group_1_main_device]): 102 group_1_agg_grads, _ = aggregate_single_gradient_using_copy( 103 group_1_device_grads, False, False) 104 105 # Aggregate between the groups. 106 with ops.device(avail_devices[group_0_main_device]): 107 (agg_total_grads, _), _ = aggregate_single_gradient_using_copy( 108 [group_0_agg_grads, group_1_agg_grads], False, False) 109 110 # Broadcast the result back into the root of each group. 111 with ops.device(avail_devices[group_0_main_device]): 112 group_0_agg_grads_bcast = array_ops.identity(agg_total_grads) 113 with ops.device(avail_devices[group_1_main_device]): 114 group_1_agg_grads_bcast = array_ops.identity(agg_total_grads) 115 116 agg_grads_bcast = [] 117 for j in range(len(single_grads)): 118 with ops.device(avail_devices[j]): 119 # Broadcast the result back to each member in the group from the root. 120 if (group_0_main_device < group_size) == (j < group_size): 121 src_device_grad = group_0_agg_grads_bcast 122 else: 123 src_device_grad = group_1_agg_grads_bcast 124 agg_grads_bcast.append(array_ops.identity(src_device_grad)) 125 126 agg_grads.append( 127 [(g, v) for g, (_, v) in zip(agg_grads_bcast, single_grads)]) 128 129 agg_grads = list(zip(*agg_grads)) 130 131 return agg_grads 132 133 134def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, 135 check_inf_nan): 136 """Calculate the average gradient for a shared variable across all replicas. 137 138 Note that this function provides a synchronization point across all replicas. 139 140 Args: 141 grad_and_vars: A list or tuple of (gradient, variable) tuples. Each 142 (gradient, variable) pair within the outer list represents the gradient 143 of the variable calculated for a single replica, and the number of pairs 144 equals the number of replicas. 145 use_mean: if True, mean is taken, else sum of gradients is taken. 146 check_inf_nan: check grads for nans and infs. 147 148 Returns: 149 The tuple ([(average_gradient, variable),], has_nan_or_inf) where the 150 gradient has been averaged across all replicas. The variable is chosen 151 from the first replica. The has_nan_or_inf indicates the grads has nan or 152 inf. 153 """ 154 grads = [g for g, _ in grad_and_vars] 155 grad = math_ops.add_n(grads) 156 157 if use_mean and len(grads) > 1: 158 grad = array_ops.multiply(grad, 1.0 / len(grads)) 159 160 v = grad_and_vars[0][1] 161 if check_inf_nan: 162 has_nan_or_inf = array_ops.logical_not( 163 array_ops.reduce_all(array_ops.is_finite(grads))) 164 return (grad, v), has_nan_or_inf 165 else: 166 return (grad, v), None 167 168 169def group_device_names(devices, group_size): 170 """Group device names into groups of group_size. 171 172 Args: 173 devices: a list of canonical device strings. 174 group_size: integer which is equal to or greater than 1. 175 176 Returns: 177 list of lists of devices, where each inner list is group_size long, 178 and each device appears at least once in an inner list. If 179 len(devices) % group_size == 0 then each device will appear exactly once. 180 181 Raises: 182 ValueError: if group_size > len(devices) 183 """ 184 num_devices = len(devices) 185 if group_size > num_devices: 186 raise ValueError( 187 'only %d devices, but group_size=%d' % (num_devices, group_size)) 188 num_groups = ( 189 num_devices // group_size + (1 if (num_devices % group_size != 0) else 0)) 190 groups = [[] for i in range(num_groups)] 191 for i in range(num_groups * group_size): 192 groups[i % num_groups].append(devices[i % num_devices]) 193 return groups 194 195 196def split_grads_by_size(threshold_size, device_grads): 197 """Break gradients into two sets according to tensor size. 198 199 Args: 200 threshold_size: int size cutoff for small vs large tensor. 201 device_grads: List of lists of (gradient, variable) tuples. The outer 202 list is over devices. The inner list is over individual gradients. 203 204 Returns: 205 small_grads: Subset of device_grads where shape is <= threshold_size 206 elements. 207 large_grads: Subset of device_grads where shape is > threshold_size 208 elements. 209 """ 210 small_grads = [] 211 large_grads = [] 212 for dl in device_grads: 213 small_dl = [] 214 large_dl = [] 215 for (g, v) in dl: 216 tensor_size = g.get_shape().num_elements() 217 if tensor_size <= threshold_size: 218 small_dl.append([g, v]) 219 else: 220 large_dl.append([g, v]) 221 if small_dl: 222 small_grads.append(small_dl) 223 if large_dl: 224 large_grads.append(large_dl) 225 return small_grads, large_grads 226 227 228# threading.Lock() and threading.local() cannot be pickled and therefore cannot 229# be a field of CollectiveKeys. Right now _thread_local is not necessary to be 230# an instance member of CollectiveKeys since we always create a new thread for 231# each replica. 232_lock = threading.Lock() 233_thread_local = threading.local() 234 235 236# TODO(yuefengz): use random key starts to avoid reusing keys? 237class CollectiveKeys(object): 238 """Class that manages collective keys. 239 240 We need to manage three different keys for collective: 241 242 *Group key*: an integer key to identify the set of cooperative devices. 243 Collective ops work under the same set of devices must using the same group 244 key. 245 246 *Instance key*: an integer key to identify the set of same counterpart of 247 tensors on different devices in a device group that need to be all-reduced. 248 249 "Graph key": an integer key that is unique key graph. This is used to support 250 multiple graphs per client session. It must be non-zero and set in the 251 `config` argument of each call to `session.run`. 252 """ 253 254 def __init__(self, 255 group_key_start=1, 256 instance_key_start=100, 257 instance_key_with_id_start=10000): 258 """Initializes the object. 259 260 Args: 261 group_key_start: the starting integer of group key. 262 instance_key_start: the starting integer of instance key. 263 instance_key_with_id_start: the starting integer of instance key that is 264 recorded with an id. 265 """ 266 self._group_key = group_key_start 267 self._group_key_table = dict() 268 269 # For instance keys with ids 270 self._instance_key_id_to_key_table = dict() 271 self._instance_key_with_id_counter = instance_key_with_id_start 272 273 # For instance keys without ids 274 self._instance_key_start = instance_key_start 275 276 def _get_thread_local_object(self): 277 # We make instance key without key ids thread local so that it will work 278 # with MirroredStrategy and distribute coordinator. 279 if not hasattr(_thread_local, 'instance_key'): 280 _thread_local.instance_key = self._instance_key_start 281 return _thread_local 282 283 def get_group_key(self, devices): 284 """Returns a group key for the set of devices. 285 286 Args: 287 devices: list of strings naming devices in a collective group. 288 289 Returns: 290 int key uniquely identifying the set of device names. 291 """ 292 parsed = [pydev.DeviceSpec.from_string(d) for d in devices] 293 # In the between-graph replicated training, different workers need to get 294 # the same device key. So we remove the task_type and task_id from the 295 # devices. 296 # TODO(yuefengz): in the in-graph replicated training, we need to include 297 # task_type and task_id. 298 names = sorted(['%s:%d' % (d.device_type, d.device_index) for d in parsed]) 299 key_id = ','.join(names) 300 with _lock: 301 if key_id not in self._group_key_table: 302 new_key = self._group_key 303 self._group_key += 1 304 self._group_key_table[key_id] = new_key 305 return self._group_key_table[key_id] 306 307 def get_instance_key(self, key_id=None): 308 """Returns a new instance key for use in defining a collective op. 309 310 Args: 311 key_id: optional string. If set, key will be recorded and the same key 312 will be returned when the same key_id is provided. If not, an increasing 313 instance key will be returned. 314 """ 315 if key_id: 316 with _lock: 317 if key_id not in self._instance_key_id_to_key_table: 318 self._instance_key_with_id_counter += 1 319 self._instance_key_id_to_key_table[key_id] = ( 320 self._instance_key_with_id_counter) 321 return self._instance_key_id_to_key_table[key_id] 322 else: 323 v = self._get_thread_local_object().instance_key 324 self._get_thread_local_object().instance_key += 1 325 return v 326 327 328def build_collective_reduce(input_tensors, 329 num_workers, 330 collective_keys, 331 reduction_op='Add', 332 unary_op='Id'): 333 """Build a subgraph that does one full all-reduce, using the collective Op. 334 335 Args: 336 input_tensors: tensors within a single worker graph that are to be reduced 337 together; must be one per device. 338 num_workers: total number of workers with identical independent graphs that 339 will be doing this same reduction. The reduction will actually include 340 the corresponding tensors at all these workers. 341 collective_keys: a CollectiveKeys object. 342 reduction_op: string naming the reduction op. 343 unary_op: string naming the unary final op. 344 345 Returns: 346 An array of final tensors, one per device, computed by the full reduction. 347 348 Raises: 349 ValueError: There must be at least two tensors over all the workers. 350 """ 351 group_size = len(input_tensors) * num_workers 352 if group_size < 2: 353 return input_tensors 354 devices = [t.device for t in input_tensors] 355 num_devices = len(devices) 356 group_key = collective_keys.get_group_key(devices) 357 instance_key = collective_keys.get_instance_key() 358 subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec 359 360 def collective_all_reduce(): 361 """Call collective allreduce.""" 362 assert not context.executing_eagerly() 363 out_tensors = [] 364 for d in range(num_devices): 365 with ops.device(devices[d]): 366 reduce_op = collective_ops.all_reduce( 367 input_tensors[d], group_size, group_key, instance_key, reduction_op, 368 unary_op, subdiv_offsets) 369 out_tensors.append(reduce_op) 370 return out_tensors 371 372 if context.executing_eagerly(): 373 # Collective ops will block unless they are executed concurrently such as in 374 # a graph or a defun. 375 collective_all_reduce = def_function.function(collective_all_reduce) 376 return collective_all_reduce() 377 378 379def sum_grad_and_var_all_reduce(grad_and_vars, 380 num_workers, 381 alg, 382 gpu_indices, 383 aux_devices=None, 384 num_shards=1): 385 """Apply all-reduce algorithm over specified gradient tensors.""" 386 with ops.name_scope('allreduce'): 387 # Note that each grad_and_vars looks like the following: 388 # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 389 scaled_grads = [g for g, _ in grad_and_vars] 390 if alg == 'nccl': 391 summed_grads = nccl_ops.all_sum(scaled_grads) 392 elif alg == 'xring': 393 summed_grads = all_reduce.build_ring_all_reduce( 394 scaled_grads, num_workers, num_shards, gpu_indices, math_ops.add) 395 elif alg == 'nccl/xring': 396 summed_grads = all_reduce.build_nccl_then_ring(scaled_grads, num_shards, 397 math_ops.add) 398 elif alg == 'nccl/rechd': 399 summed_grads = all_reduce.build_nccl_then_recursive_hd( 400 scaled_grads, math_ops.add) 401 elif alg == 'nccl/pscpu': 402 summed_grads = all_reduce.build_nccl_then_shuffle( 403 scaled_grads, aux_devices, math_ops.add, math_ops.add_n) 404 elif alg == 'pscpu/pscpu': 405 second_gather_devices = aux_devices[:num_shards] 406 summed_grads = all_reduce.build_shuffle_then_shuffle( 407 scaled_grads, aux_devices, second_gather_devices, math_ops.add_n) 408 elif alg in ['pscpu', 'psgpu']: 409 summed_grads = all_reduce.build_shuffle_all_reduce( 410 scaled_grads, aux_devices, math_ops.add_n) 411 else: 412 raise ValueError('unsupported all_reduce alg: ', alg) 413 414 result = [] 415 for (_, v), g in zip(grad_and_vars, summed_grads): 416 result.append([g, v]) 417 return result 418 419 420def sum_gradients_all_reduce(dev_prefixes, replica_grads, num_workers, alg, 421 num_shards, gpu_indices): 422 """Apply all-reduce algorithm over specified gradient tensors. 423 424 Args: 425 dev_prefixes: list of prefix strings to use to generate PS device names. 426 replica_grads: the gradients to reduce. 427 num_workers: number of worker processes across entire job. 428 alg: the all-reduce algorithm to apply. 429 num_shards: alg-specific sharding factor. 430 gpu_indices: indices of local GPUs in order usable for ring-reduce. 431 432 Returns: 433 list of reduced tensors 434 """ 435 alg_contains_shuffle = any(n in alg for n in ['pscpu', 'psgpu']) 436 is_hierarchical = '/' in alg 437 if 'pscpu' in alg: 438 aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes] 439 elif 'psgpu' in alg: 440 aux_devices = [ 441 prefix + '/gpu:%d' % i 442 for i in range(len(gpu_indices)) 443 for prefix in dev_prefixes 444 ] 445 else: 446 aux_devices = ['/job:localhost/cpu:0'] 447 # Auxiliary devices for hierarchical all-reduces. 448 aux_device_groups = group_device_names( 449 aux_devices, num_shards if alg_contains_shuffle else 1) 450 group_index = 0 451 reduced_gv_list = [] 452 for grad_and_vars in zip(*replica_grads): 453 reduced_gv_list.append( 454 sum_grad_and_var_all_reduce( 455 grad_and_vars, num_workers, alg, gpu_indices, aux_devices 456 if is_hierarchical else aux_device_groups[group_index], num_shards)) 457 group_index = (group_index + 1) % len(aux_device_groups) 458 new_replica_grads = [list(x) for x in zip(*reduced_gv_list)] 459 return new_replica_grads 460 461 462def extract_ranges(index_list, range_size_limit=32): 463 """Extract consecutive ranges and singles from index_list. 464 465 Args: 466 index_list: List of monotone increasing non-negative integers. 467 range_size_limit: Largest size range to return. If a larger 468 consecutive range exists, it will be returned as multiple 469 ranges. 470 471 Returns: 472 (ranges, singles) where ranges is a list of [first, last] pairs of 473 consecutive elements in index_list, and singles is all of the 474 other elements, in original order. 475 """ 476 if not index_list: 477 return [], [] 478 first = index_list[0] 479 last = first 480 ranges = [] 481 singles = [] 482 for i in index_list[1:]: 483 if i == last + 1 and (last - first) <= range_size_limit: 484 last = i 485 else: 486 if last > first: 487 ranges.append([first, last]) 488 else: 489 singles.append(first) 490 first = i 491 last = i 492 if last > first: 493 ranges.append([first, last]) 494 else: 495 singles.append(first) 496 return ranges, singles 497 498 499GradPackTuple = pycoll.namedtuple('GradPackTuple', 'indices vars shapes') 500 501 502def pack_range(key, packing, grad_vars, rng): 503 """Form the concatenation of a specified range of gradient tensors. 504 505 Args: 506 key: Value under which to store meta-data in packing that will be used 507 later to restore the grad_var list structure. 508 packing: Dict holding data describing packed ranges of small tensors. 509 grad_vars: List of (grad, var) pairs for one replica. 510 rng: A pair of integers giving the first, last indices of a consecutive 511 range of tensors to be packed. 512 513 Returns: 514 A tensor that is the concatenation of all the specified small tensors. 515 """ 516 to_pack = grad_vars[rng[0]:rng[1] + 1] 517 members = [] 518 variables = [] 519 restore_shapes = [] 520 with ops.name_scope('pack'): 521 for g, v in to_pack: 522 variables.append(v) 523 restore_shapes.append(g.shape) 524 with ops.device(g.device): 525 members.append(array_ops.reshape(g, [-1])) 526 packing[key] = GradPackTuple( 527 indices=range(rng[0], rng[1] + 1), 528 vars=variables, 529 shapes=restore_shapes) 530 with ops.device(members[0].device): 531 return array_ops.concat(members, 0) 532 533 534def unpack_grad_tuple(gv, gpt): 535 """Unpack a previously packed collection of gradient tensors. 536 537 Args: 538 gv: A (grad, var) pair to be unpacked. 539 gpt: A GradPackTuple describing the packing operation that produced gv. 540 541 Returns: 542 A list of (grad, var) pairs corresponding to the values that were 543 originally packed into gv, maybe following subsequent operations like 544 reduction. 545 """ 546 elt_widths = [x.num_elements() for x in gpt.shapes] 547 with ops.device(gv[0][0].device): 548 with ops.name_scope('unpack'): 549 splits = array_ops.split(gv[0], elt_widths) 550 unpacked_gv = [] 551 for idx, s in enumerate(splits): 552 unpacked_gv.append((array_ops.reshape(s, gpt.shapes[idx]), 553 gpt.vars[idx])) 554 return unpacked_gv 555 556 557def pack_small_tensors(replica_grads, max_bytes=0, max_group=0): 558 """Concatenate small gradient tensors together for reduction. 559 560 Args: 561 replica_grads: List of lists of (gradient, variable) tuples. 562 max_bytes: Int giving max number of bytes in a tensor that 563 may be considered small. 564 max_group: Int giving max number of small tensors that may be 565 concatenated into one new tensor. 566 567 Returns: 568 new_replica_grads, packing where new_replica_grads is identical to 569 replica_grads except that all feasible small_tensors have been removed 570 from their places and concatenated into larger tensors that are 571 now in the front of the list for each replica, and packing contains 572 the data necessary to restore the replica_grads structure. 573 574 Look through the first replica for gradients of the same type (float), 575 and small size, that are all sequential. For each such group, 576 replace by a new tensor that is a flattened concatenation. Note 577 that the corresponding variable will be absent, which doesn't matter 578 because it isn't used during all-reduce. 579 580 Requires: 581 Every gv_list in replicas must have isomorphic structure including identical 582 tensor sizes and types. 583 """ 584 small_indices = [] 585 large_indices = [] 586 for idx, (g, _) in enumerate(replica_grads[0]): 587 if g.dtype == dtypes.float32 and (4 * g.shape.num_elements()) <= max_bytes: 588 small_indices.append(idx) 589 else: 590 large_indices.append(idx) 591 small_ranges, small_singles = extract_ranges( 592 small_indices, range_size_limit=max_group) 593 large_indices = sorted(large_indices + small_singles) 594 num_gv = len(replica_grads[0]) 595 packing = {} 596 if small_ranges: 597 new_replica_grads = [] 598 for dev_idx, gv_list in enumerate(replica_grads): 599 assert len(gv_list) == num_gv 600 new_gv_list = [] 601 for r in small_ranges: 602 key = '%d:%d' % (dev_idx, len(new_gv_list)) 603 new_gv_list.append((pack_range(key, packing, gv_list, r), 604 'packing_var_placeholder')) 605 for i in large_indices: 606 new_gv_list.append(gv_list[i]) 607 new_replica_grads.append(new_gv_list) 608 return new_replica_grads, packing 609 else: 610 return replica_grads, None 611 612 613def unpack_small_tensors(replica_grads, packing): 614 """Undo the structure alterations to replica_grads done by pack_small_tensors. 615 616 Args: 617 replica_grads: List of List of (grad, var) tuples. 618 packing: A dict generated by pack_small_tensors describing the changes 619 it made to replica_grads. 620 621 Returns: 622 new_replica_grads: identical to replica_grads except that concatenations 623 of small tensors have been split apart and returned to their original 624 positions, paired with their original variables. 625 """ 626 if not packing: 627 return replica_grads 628 new_replica_grads = [] 629 num_devices = len(replica_grads) 630 num_packed = len(packing.keys()) // num_devices 631 for dev_idx, gv_list in enumerate(replica_grads): 632 gv_list = list(gv_list) 633 new_gv_list = gv_list[num_packed:] 634 for i in range(num_packed): 635 k = '%d:%d' % (dev_idx, i) 636 gpt = packing[k] 637 gv = unpack_grad_tuple(gv_list[i], gpt) 638 for gi, idx in enumerate(gpt.indices): 639 assert idx == gpt.indices[gi] 640 new_gv_list.insert(idx, gv[gi]) 641 new_replica_grads.append(new_gv_list) 642 return new_replica_grads 643 644 645def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): 646 """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat.""" 647 if any(isinstance(v, ops.IndexedSlices) for v in values): 648 return gradients_util._AggregateIndexedSlicesGradients(values) # pylint: disable=protected-access 649 else: 650 return accumulation_fn(values) 651 652 653def divide_by_n_tensors_or_indexed_slices(value, n): 654 if isinstance(value, ops.IndexedSlices): 655 value = gradients_util._HandleNestedIndexedSlices(value) # pylint: disable=protected-access 656 return ops.IndexedSlices( 657 value.values / n, value.indices, value.dense_shape) 658 else: 659 return value / n 660 661 662def copy_tensor_or_indexed_slices_to_device(value, device): 663 with ops.device(device): 664 if isinstance(value, ops.IndexedSlices): 665 copied_values = array_ops.identity(value.values) 666 copied_indices = array_ops.identity(value.indices) 667 copied_shape = array_ops.identity(value.dense_shape) 668 result = ops.IndexedSlices(copied_values, copied_indices, copied_shape) 669 else: 670 result = array_ops.identity(value) 671 return result 672 673 674def contains_indexed_slices(value): 675 """Check whether the value is `IndexedSlices` or contains `IndexedSlices`.""" 676 if isinstance(value, ops.IndexedSlices): 677 return True 678 elif isinstance(value, (list, tuple)) and value: 679 return any(contains_indexed_slices(v) for v in value) 680 elif isinstance(value, value_lib.DistributedValues): 681 return contains_indexed_slices(value.values) 682 else: 683 return False 684 685 686def is_indexed_slices(value): 687 if isinstance(value, ops.IndexedSlices): 688 return True 689 assert isinstance(value, value_lib.DistributedValues) 690 return all([isinstance(v, ops.IndexedSlices) for v in value.values]) 691 692 693def split_by_sparsity(values): 694 """Split values into dense and sparse values. 695 696 Args: 697 values: a list of tensors or `PerReplica`s. 698 699 Returns: 700 Four lists: 701 a list of dense values, a list of their indices in `values` and 702 a list of sparse values, a list of their indices in `values`. 703 """ 704 dense_values = [] 705 dense_indices = [] 706 sparse_values = [] 707 sparse_indices = [] 708 for i, v in enumerate(values): 709 if is_indexed_slices(v): 710 sparse_values.append(v) 711 sparse_indices.append(i) 712 else: 713 dense_values.append(v) 714 dense_indices.append(i) 715 return dense_values, dense_indices, sparse_values, sparse_indices 716 717 718def stitch_values(values_and_indices_list): 719 """Stitch values together according to their indices. 720 721 Args: 722 values_and_indices_list: a list of tuples of values and indices indicating 723 the values and postions in the returned list. 724 725 Returns: 726 a stitched list of values. 727 """ 728 length = 0 729 for values_and_indices in values_and_indices_list: 730 length += len(values_and_indices[0]) 731 732 result = [None] * length 733 for values_and_indices in values_and_indices_list: 734 if values_and_indices and values_and_indices[0]: 735 for v, i in zip(*values_and_indices): 736 assert result[i] is None 737 result[i] = v 738 return result 739