1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import collections 5import dataclasses 6import itertools 7import pprint 8from typing import Any, Dict, Iterable, List, Optional, Protocol 9 10import sympy 11 12import torch 13 14from .. import config, ir 15from ..utils import _align, align, cache_on_self, CachedMethod, IndentedBuffer 16from ..virtualized import V 17from .wrapper import ( 18 AllocateLine, 19 FreeIfNotReusedLine, 20 MemoryPlanningLine, 21 NullLine, 22 ReuseLine, 23) 24 25 26@dataclasses.dataclass 27class LiveRange: 28 """ 29 A range where a given tensor is live. Begin and end are both counters 30 representing points in the program of grouped memory operations. 31 Begin is inclusive, end is exclusive. 32 33 Invariant: begin <= end 34 """ 35 36 begin: float # int | +/-inf 37 end: float # int | +/-inf 38 39 def contains(self, other: LiveRange): 40 """Is other entirely within self""" 41 return self.begin <= other.begin and other.end <= self.end 42 43 def join(self, other: LiveRange): 44 """Combine two ranges using a union operation""" 45 return LiveRange(min(self.begin, other.begin), max(self.end, other.end)) 46 47 def __len__(self): 48 return self.end - self.begin 49 50 51class LiveRanges: 52 """ 53 A collection of LiveRange regions, allowing for non-contiguous 54 live regions. 55 56 Invariant: LiveRanges.ranges is in sorted order and non-overlapping 57 """ 58 59 def __init__(self, ranges: Iterable[LiveRange]): 60 ranges = [*sorted(ranges, key=lambda x: x.begin)] 61 self.ranges = ranges[:1] 62 for r in ranges[1:]: 63 assert self.ranges[-1].begin <= r.begin 64 if self.ranges[-1].end >= r.begin: 65 self.ranges[-1] = LiveRange.join(self.ranges[-1], r) 66 else: 67 self.ranges.append(r) 68 69 def overlaps(self, other: LiveRanges): 70 """Check if any pair of ranges in self and other overlap""" 71 left = collections.deque(self.ranges) 72 right = collections.deque(other.ranges) 73 while left and right: 74 if left[0].begin > right[0].begin: 75 left, right = right, left 76 assert left[0].begin <= right[0].begin 77 if left[0].end > right[0].begin: 78 return True 79 left.popleft() 80 return False 81 82 @property 83 def begin(self): 84 return self.ranges[0].begin 85 86 @property 87 def end(self): 88 return self.ranges[-1].end 89 90 def __repr__(self): 91 return f"{self.__class__.__name__}([{', '.join(map(repr, self.ranges))}])" 92 93 94class AllocationTreeNode: 95 """ 96 Abstract base class for nodes in allocation pool. 97 """ 98 99 def allocate(self, block: Allocation, is_last: bool) -> bool: 100 """ 101 Try to assign block to a memory location in this bool. Return True if 102 an assignment was made. 103 """ 104 return False 105 106 def get_live_ranges(self) -> LiveRanges: 107 """Aggregate LiveRanges for all objects below this in tree""" 108 raise NotImplementedError 109 110 def get_size_hint(self) -> int: 111 """Number of bytes used for example inputs""" 112 raise NotImplementedError 113 114 def get_symbolic_size(self) -> sympy.Expr: 115 """Number of bytes needed at runtime""" 116 raise NotImplementedError 117 118 def finalize(self, pool, offset) -> AllocationTreeNode: 119 """Called after all allocations have been made""" 120 return self 121 122 def is_empty(self): 123 return False 124 125 126@dataclasses.dataclass 127class Allocation(AllocationTreeNode): 128 """ 129 Represents memory allocated to a given node in the allocation pool. 130 """ 131 132 node: ir.Buffer 133 live_range: LiveRange 134 size_hint: int 135 symbolic_size: sympy.Expr 136 allocated: bool = False 137 pool: Optional[AllocationPool] = None 138 offset: Optional[sympy.Expr] = None 139 140 @property 141 def device(self): 142 return self.node.get_device() 143 144 def get_live_ranges(self): 145 return LiveRanges([self.live_range]) 146 147 def get_size_hint(self): 148 return self.size_hint 149 150 def get_symbolic_size(self): 151 return self.symbolic_size 152 153 def mark_allocated(self): 154 assert not self.allocated 155 self.allocated = True 156 157 def finalize(self, pool, offset): 158 assert self.pool is None and self.offset is None 159 self.pool = pool 160 self.offset = offset 161 return self 162 163 def codegen_alloc_from_pool(self, wrapper): 164 assert self.pool 165 node = self.node 166 shape = tuple(node.get_size()) 167 stride = tuple(node.get_stride()) 168 return wrapper.codegen_alloc_from_pool( 169 self.pool.name, self.offset, node.get_dtype(), shape, stride 170 ) 171 172 def __repr__(self): 173 return ( 174 f"{self.__class__.__name__}(" 175 f"node={self.node.get_name()}, " 176 f"live_range={self.live_range}, " 177 f"size_hint={self.size_hint}, " 178 f"symbolic_size={self.symbolic_size}, " 179 f"pool={self.pool.name if self.pool else None}, " 180 f"offset={self.offset})" 181 ) 182 183 184@dataclasses.dataclass 185class Empty(AllocationTreeNode): 186 """ 187 Placeholder to represent empty space in the allocation pool. 188 Only exists to get the size_hint correct in parent nodes. 189 """ 190 191 size_hint: int 192 193 def get_live_ranges(self): 194 return LiveRanges([]) 195 196 def get_size_hint(self): 197 return self.size_hint 198 199 def get_symbolic_size(self): 200 return 0 201 202 def is_empty(self): 203 return True 204 205 206class MemorySplitProtocol(Protocol): 207 get_live_ranges: CachedMethod[[], LiveRanges] 208 get_size_hint: CachedMethod[[], int] 209 get_symbolic_size: CachedMethod[[], sympy.Expr] 210 211 def _allocate(self, block: Allocation, is_last: bool) -> bool: 212 ... 213 214 215class ClearCacheOnAllocateMixin(MemorySplitProtocol): 216 """ 217 Helper to assist in caching get_live_ranges, get_size_hint, and 218 get_symbolic_size. 219 """ 220 221 def allocate(self, block: Allocation, is_last: bool): 222 is_allocated = self._allocate(block, is_last) 223 if is_allocated: 224 self.clear_cache() 225 return is_allocated 226 227 def clear_cache(self): 228 self.get_live_ranges.clear_cache(self) 229 self.get_size_hint.clear_cache(self) 230 self.get_symbolic_size.clear_cache(self) 231 232 233@dataclasses.dataclass 234class TemporalSplit(ClearCacheOnAllocateMixin, AllocationTreeNode): 235 """ 236 Contains a list of allocations not overlapping in LiveRanges. 237 238 Invariant: no pair (a,b) in self.allocations will have: 239 a.get_live_ranges().overlaps(b.get_live_ranges()) 240 """ 241 242 allocations: List[AllocationTreeNode] 243 244 def _allocate(self, block: Allocation, is_last: bool): 245 slot_size = self.get_size_hint() 246 block_size = block.get_size_hint() 247 if not is_last and block_size > slot_size: 248 return False # doesn't fit 249 250 block_live = block.get_live_ranges() 251 overlapping = [ 252 s for s in self.allocations if s.get_live_ranges().overlaps(block_live) 253 ] 254 if len(overlapping) > 1: 255 # TODO(jansel): we could try harder here by merging overlapping in space 256 return False 257 elif len(overlapping) == 1: 258 return overlapping[0].allocate(block, is_last) 259 else: 260 block.mark_allocated() 261 262 if len(self.allocations) == 1 and isinstance(self.allocations[-1], Empty): 263 self.allocations.pop() 264 265 if slot_size == block_size: 266 # perfect fit 267 self.allocations.append(block) 268 elif slot_size > block_size: 269 self.allocations.append( 270 SpatialSplit.create(block, slot_size - block_size) 271 ) 272 else: # grow this allocation 273 assert is_last 274 self.allocations = [ 275 *( 276 SpatialSplit.create(a, block_size - slot_size) 277 for a in self.allocations 278 ), 279 block, 280 ] 281 return True 282 283 @cache_on_self 284 def get_live_ranges(self) -> LiveRanges: 285 return LiveRanges( 286 itertools.chain.from_iterable( 287 x.get_live_ranges().ranges for x in self.allocations 288 ) 289 ) 290 291 @cache_on_self 292 def get_size_hint(self) -> int: 293 if not self.allocations: 294 return 0 295 return max(x.get_size_hint() for x in self.allocations) 296 297 @cache_on_self 298 def get_symbolic_size(self) -> sympy.Expr: 299 if not self.allocations: 300 return 0 # type: ignore[return-value] 301 return sympy.Max(*[x.get_symbolic_size() for x in self.allocations]) 302 303 def is_empty(self): 304 return len(self.allocations) == 1 and self.allocations[0].is_empty() 305 306 def finalize(self, pool, offset): 307 self.allocations = [block.finalize(pool, offset) for block in self.allocations] 308 self.clear_cache() 309 if len(self.allocations) == 1: 310 return self.allocations[0] 311 return self 312 313 314@dataclasses.dataclass 315class SpatialSplit(ClearCacheOnAllocateMixin, AllocationTreeNode): 316 """ 317 Contains two allocations, left and right, that do not overlap in space. 318 Right will be allocated immediately after left in memory. 319 """ 320 321 left: TemporalSplit 322 right: TemporalSplit 323 324 @staticmethod 325 def create(left, extra_space): 326 assert isinstance(left, AllocationTreeNode) 327 assert isinstance(extra_space, int) and extra_space >= 1 328 return SpatialSplit(TemporalSplit([left]), TemporalSplit([Empty(extra_space)])) 329 330 def _allocate(self, block: Allocation, is_last: bool): 331 return self.left.allocate(block, False) or self.right.allocate(block, is_last) 332 333 @cache_on_self 334 def get_live_ranges(self): 335 return LiveRanges( 336 itertools.chain( 337 self.left.get_live_ranges().ranges, self.right.get_live_ranges().ranges 338 ) 339 ) 340 341 @cache_on_self 342 def get_size_hint(self) -> int: 343 return _align(self.left.get_size_hint()) + self.right.get_size_hint() 344 345 @cache_on_self 346 def get_symbolic_size(self) -> sympy.Expr: 347 return align(self.left.get_symbolic_size()) + self.right.get_symbolic_size() 348 349 def finalize(self, pool, offset): 350 self.left = self.left.finalize(pool, offset) 351 self.right = self.right.finalize( 352 pool, offset + align(self.left.get_symbolic_size()) 353 ) 354 self.clear_cache() 355 if self.right.is_empty(): 356 return self.left 357 return self 358 359 360@dataclasses.dataclass 361class AllocationPool: 362 """ 363 Represents a pool of allocations that will be generated by a single 364 call to torch.empty. 365 """ 366 367 device: torch.device 368 root: TemporalSplit 369 can_expand: bool = True 370 restrict_live_range: Optional[LiveRange] = None 371 name: Optional[str] = None 372 names_to_del: List[str] = dataclasses.field(default_factory=list) 373 creation_cache: Dict[str, str] = dataclasses.field(default_factory=dict) 374 375 def allocate(self, block: Allocation, is_last: bool): 376 if self.restrict_live_range and not self.restrict_live_range.contains( 377 block.live_range 378 ): 379 return False 380 381 is_last = self.can_expand and is_last 382 if self.root.allocate(block, is_last): 383 return True 384 385 if is_last: 386 return self.allocate_at_end(block) 387 388 return False 389 390 def allocate_at_end(self, block): 391 block.mark_allocated() 392 self.root = TemporalSplit([SpatialSplit(self.root, TemporalSplit([block]))]) 393 return True 394 395 def finalize(self, name): 396 assert not self.name 397 self.name = name 398 self.names_to_del.append(name) 399 self.root.finalize(self, 0) 400 401 def codegen_create(self, wrapper, code: IndentedBuffer): 402 assert self.name 403 nbytes = self.root.get_symbolic_size() 404 for block in self.root.allocations: 405 if isinstance(block, Allocation) and nbytes == block.get_symbolic_size(): 406 # optimization: fuse first allocation and pool creation 407 node = block.node 408 code.writeline( 409 wrapper.make_allocation( 410 self.name, 411 device=self.device, 412 dtype=node.get_dtype(), 413 shape=tuple(node.get_size()), 414 stride=tuple(node.get_stride()), 415 ) 416 ) 417 self.creation_cache[block.codegen_alloc_from_pool(wrapper)] = self.name 418 return 419 else: 420 code.writeline( 421 wrapper.make_allocation( 422 self.name, 423 device=self.device, 424 dtype=torch.uint8, 425 shape=(nbytes,), 426 stride=(1,), 427 ) 428 ) 429 430 def codegen_destroy(self, wrapper, code: IndentedBuffer): 431 code.writeline(wrapper.make_free_by_names(self.names_to_del)) 432 433 def __eq__(self, other): 434 return self is other 435 436 def __hash__(self): 437 return id(self) 438 439 440@dataclasses.dataclass 441class AllocationPools: 442 """ 443 Collection of many AllocationPool objects grouped by device. 444 """ 445 446 device_to_pools: Dict[torch.device, List[AllocationPool]] = dataclasses.field( 447 default_factory=dict 448 ) 449 450 def get_pools(self, block): 451 if block.device not in self.device_to_pools: 452 self.device_to_pools[block.device] = [] 453 return self.device_to_pools[block.device] 454 455 def allocate(self, block: Allocation): 456 pools = self.get_pools(block) 457 458 for pool in pools: 459 if pool.allocate(block, is_last=pool is pools[-1]): 460 return 461 462 # everything is full, make a new pool 463 pools.append( 464 AllocationPool( 465 block.device, 466 TemporalSplit([block]), 467 can_expand=config.memory_pool != "none", 468 ) 469 ) 470 block.mark_allocated() 471 472 def allocate_output(self, block: Allocation): 473 """Outputs get different pools so memory gets freed properly""" 474 pools = self.get_pools(block) 475 if pools and config.memory_pool in ("outputs", "combined"): 476 pools[-1].allocate_at_end(block) 477 else: 478 # create a new pool 479 block.mark_allocated() 480 pools.append( 481 AllocationPool( 482 block.device, 483 TemporalSplit([block]), 484 can_expand=config.memory_pool == "combined", 485 ) 486 ) 487 488 def finalize(self): 489 """Called at the end of allocation process""" 490 for i, pool in enumerate( 491 itertools.chain.from_iterable(self.device_to_pools.values()) 492 ): 493 pool.finalize(f"pool{i}") 494 495 def pprint(self): 496 for pool in itertools.chain.from_iterable(self.device_to_pools.values()): 497 print() 498 print(pool.name) 499 print(pool.root.get_live_ranges()) 500 pprint.pprint(pool.root) 501 502 503class BufferGroup: 504 """ 505 Due to inplace reuse an allocated buffer can have many names. 506 This tracks these collections of buffers sharing underlying memory. 507 """ 508 509 def __init__(self, node: ir.Buffer): 510 self.node = node 511 self.names = [node.get_name()] 512 self.is_output = False 513 self.allocation: Optional[Allocation] = None 514 self.live_range = LiveRange(float("inf"), -float("inf")) 515 516 def update_usage(self, timestep: int): 517 """Expand self.live_range to include timestep""" 518 self.live_range = LiveRange( 519 min(timestep, self.live_range.begin), 520 max(timestep, self.live_range.end), 521 ) 522 523 def sym_nbytes(self): 524 return self.node.get_layout().storage_size() * self.node.get_dtype().itemsize 525 526 def make_allocation(self): 527 assert not self.allocation, "multiple allocations" 528 assert isinstance(self.live_range.begin, int), "live ranges not computed" 529 nbytes = self.sym_nbytes() 530 # For now, fallback value will be used if we encounter an unbacked SymInt. The longer-term plan is to have 531 # size_hint() use better heuristics for unbackeds, at which point the fallback value will be ignored. 532 size_hint = V.graph.sizevars.size_hint(nbytes, fallback=64) 533 self.allocation = Allocation( 534 self.node, 535 self.live_range, 536 size_hint=size_hint, 537 symbolic_size=nbytes, 538 ) 539 540 def __repr__(self): 541 return ( 542 f"{self.__class__.__name__}({self.names!r}, is_output={self.is_output}, " 543 f"live_range={self.live_range}" 544 ) 545 546 547@dataclasses.dataclass 548class PoolMemoryPlanningLine(MemoryPlanningLine): 549 """Abstract base class for {Alloc,Dealloc}FromPoolLine""" 550 551 group: BufferGroup 552 timestep: Optional[int] = None 553 554 @property 555 def node(self): 556 return self.group.node 557 558 559@dataclasses.dataclass 560class AllocFromPoolLine(PoolMemoryPlanningLine): 561 """Similar to AllocationLine, but takes memory from a pool""" 562 563 is_first_pool_usage: bool = False 564 565 def codegen(self, code: IndentedBuffer): 566 allocation = self.group.allocation 567 assert allocation and allocation.pool 568 pool = allocation.pool 569 name = self.node.get_name() 570 571 if self.is_first_pool_usage: 572 pool.codegen_create(self.wrapper, code) 573 574 pool.names_to_del.extend(self.group.names) 575 alloc_from_pool = allocation.codegen_alloc_from_pool(self.wrapper) 576 if alloc_from_pool in pool.creation_cache: 577 code.writeline( 578 self.wrapper.make_tensor_alias( 579 name, pool.creation_cache[alloc_from_pool], "alloc" 580 ) 581 ) 582 else: 583 pool.creation_cache[alloc_from_pool] = name 584 code.writeline( 585 f"{self.wrapper.declare}{name} = {alloc_from_pool}{self.wrapper.ending}" 586 ) 587 588 589@dataclasses.dataclass 590class DeallocFromPoolLine(PoolMemoryPlanningLine): 591 """Similar to FreeIfNotReusedLine, but takes memory from a pool""" 592 593 is_last_pool_usage: bool = False 594 595 def codegen(self, code: IndentedBuffer): 596 if self.is_last_pool_usage: 597 assert self.group.allocation and self.group.allocation.pool 598 self.group.allocation.pool.codegen_destroy(self.wrapper, code) 599 600 601@dataclasses.dataclass 602class MemoryPlanner: 603 """ 604 Coordination object to run memory planning passes during wrapper 605 codegen. 606 """ 607 608 wrapper: Any 609 pools: AllocationPools = dataclasses.field(default_factory=AllocationPools) 610 buffer_groups: Optional[List[BufferGroup]] = None 611 612 def plan(self, lines: List[Any]) -> List[Any]: 613 """Call all the memory planning passes in sequence""" 614 lines = [*lines] 615 self.drop_removed_buffers(lines) 616 self.convert_to_pool_lines(lines) 617 self.compute_live_ranges(lines) 618 self.allocate_groups() 619 self.mark_first_last_usage(lines) 620 return lines 621 622 def drop_removed_buffers(self, lines): 623 """ 624 Replace any memory planning lines in V.graph.removed_buffers with NullLine 625 """ 626 # drop any removed buffers 627 for i, line in enumerate(lines): 628 if isinstance(line, (AllocateLine, FreeIfNotReusedLine, ReuseLine)): 629 if line.node.get_name() in V.graph.removed_buffers: 630 lines[i] = NullLine(self.wrapper) 631 632 def compute_buffer_groups(self, lines): 633 """ 634 Populates self.buffer_groups with BufferGroup objects that join 635 allocations with common storage (due to inplace reuse) into a 636 single object. 637 """ 638 name_to_group = {} 639 for line in lines: 640 if isinstance(line, AllocateLine): 641 name = line.node.get_name() 642 assert name not in name_to_group 643 name_to_group[name] = BufferGroup(line.node) 644 elif isinstance(line, ReuseLine): 645 old_name = line.node.get_name() 646 new_name = line.reused_as.get_name() 647 assert new_name not in name_to_group 648 # TODO(jansel): we should support reusing buffers created via ExternKernelAlloc 649 if old_name in name_to_group: 650 name_to_group[old_name].names.append(new_name) 651 name_to_group[new_name] = name_to_group[old_name] 652 653 outputs = set(V.graph.get_output_names()) 654 unique_groups = [*{id(g): g for g in name_to_group.values()}.values()] 655 for group in unique_groups: 656 group.is_output = any(x in outputs for x in group.names) 657 658 assert self.buffer_groups is None 659 self.buffer_groups = unique_groups 660 return name_to_group 661 662 def convert_to_pool_lines(self, lines): 663 """ 664 Convert AllocateLine/FreeIfNotReusedLine/ReuseLine into their 665 pool-based counterparts. 666 """ 667 name_to_group = self.compute_buffer_groups(lines) 668 for i, line in enumerate(lines): 669 if isinstance(line, AllocateLine): 670 if line.node.get_name() in name_to_group: 671 lines[i] = AllocFromPoolLine( 672 self.wrapper, name_to_group[line.node.get_name()] 673 ) 674 elif isinstance(line, FreeIfNotReusedLine): 675 assert not line.is_reused 676 if line.node.get_name() in name_to_group: 677 lines[i] = DeallocFromPoolLine( 678 self.wrapper, name_to_group[line.node.get_name()] 679 ) 680 elif isinstance(line, ReuseLine): 681 if line.node.get_name() in name_to_group: 682 line.delete_old = False 683 684 def compute_live_ranges(self, lines): 685 """Populate every BufferGroup.live_ranges field based on first/last usage""" 686 timestep = 0 687 worklist = collections.deque(lines) 688 while worklist: 689 if isinstance(worklist[0], MemoryPlanningLine): 690 timestep += 1 691 while worklist and isinstance(worklist[0], MemoryPlanningLine): 692 line = worklist.popleft() 693 if isinstance(line, PoolMemoryPlanningLine): 694 line.group.update_usage(timestep) 695 line.timestep = timestep 696 else: 697 worklist.popleft() 698 699 timestep += 1 700 assert self.buffer_groups is not None 701 for group in self.buffer_groups: 702 if group.is_output: 703 group.update_usage(timestep) 704 705 def allocate_groups(self): 706 """ 707 Assign every allocation to a specific location in a specific AllocationPool. 708 """ 709 assert config.memory_pool in ("none", "intermediates", "outputs", "combined") 710 assert self.buffer_groups is not None 711 712 for group in self.buffer_groups: 713 group.make_allocation() 714 715 outputs: List[Allocation] = [] 716 intermediates: List[Allocation] = [] 717 for group in self.buffer_groups: 718 assert group.allocation 719 if group.is_output and config.memory_pool != "combined": 720 outputs.append(group.allocation) 721 else: 722 intermediates.append(group.allocation) 723 724 for block in sorted( 725 outputs, 726 key=lambda x: ( 727 x.size_hint, 728 -len(x.live_range), 729 ), 730 ): 731 self.pools.allocate_output(block) 732 733 for block in sorted( 734 intermediates, 735 key=lambda x: ( 736 -x.size_hint, 737 -len(x.live_range), 738 ), 739 ): 740 self.pools.allocate(block) 741 742 self.pools.finalize() 743 744 def mark_first_last_usage(self, lines): 745 """ 746 Populate the AllocFromPoolLine.is_first_pool_usage and 747 DeallocFromPoolLine.is_last_pool_usage fields so that pools 748 are created/destroyed. 749 """ 750 seen = set() 751 for line in lines: 752 if isinstance(line, AllocFromPoolLine): 753 assert line.group.allocation 754 pool = line.group.allocation.pool 755 assert pool is not None 756 if pool not in seen: 757 line.is_first_pool_usage = True 758 seen.add(pool) 759 760 seen = set() 761 for line in reversed(lines): 762 if isinstance(line, DeallocFromPoolLine): 763 assert line.group.allocation 764 pool = line.group.allocation.pool 765 assert pool is not None 766 if pool not in seen: 767 line.is_last_pool_usage = ( 768 pool.root.get_live_ranges().end <= line.timestep 769 ) 770 seen.add(pool) 771