• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import collections
5import dataclasses
6import itertools
7import pprint
8from typing import Any, Dict, Iterable, List, Optional, Protocol
9
10import sympy
11
12import torch
13
14from .. import config, ir
15from ..utils import _align, align, cache_on_self, CachedMethod, IndentedBuffer
16from ..virtualized import V
17from .wrapper import (
18    AllocateLine,
19    FreeIfNotReusedLine,
20    MemoryPlanningLine,
21    NullLine,
22    ReuseLine,
23)
24
25
26@dataclasses.dataclass
27class LiveRange:
28    """
29    A range where a given tensor is live.  Begin and end are both counters
30    representing points in the program of grouped memory operations.
31    Begin is inclusive, end is exclusive.
32
33    Invariant: begin <= end
34    """
35
36    begin: float  # int | +/-inf
37    end: float  # int | +/-inf
38
39    def contains(self, other: LiveRange):
40        """Is other entirely within self"""
41        return self.begin <= other.begin and other.end <= self.end
42
43    def join(self, other: LiveRange):
44        """Combine two ranges using a union operation"""
45        return LiveRange(min(self.begin, other.begin), max(self.end, other.end))
46
47    def __len__(self):
48        return self.end - self.begin
49
50
51class LiveRanges:
52    """
53    A collection of LiveRange regions, allowing for non-contiguous
54    live regions.
55
56    Invariant: LiveRanges.ranges is in sorted order and non-overlapping
57    """
58
59    def __init__(self, ranges: Iterable[LiveRange]):
60        ranges = [*sorted(ranges, key=lambda x: x.begin)]
61        self.ranges = ranges[:1]
62        for r in ranges[1:]:
63            assert self.ranges[-1].begin <= r.begin
64            if self.ranges[-1].end >= r.begin:
65                self.ranges[-1] = LiveRange.join(self.ranges[-1], r)
66            else:
67                self.ranges.append(r)
68
69    def overlaps(self, other: LiveRanges):
70        """Check if any pair of ranges in self and other overlap"""
71        left = collections.deque(self.ranges)
72        right = collections.deque(other.ranges)
73        while left and right:
74            if left[0].begin > right[0].begin:
75                left, right = right, left
76            assert left[0].begin <= right[0].begin
77            if left[0].end > right[0].begin:
78                return True
79            left.popleft()
80        return False
81
82    @property
83    def begin(self):
84        return self.ranges[0].begin
85
86    @property
87    def end(self):
88        return self.ranges[-1].end
89
90    def __repr__(self):
91        return f"{self.__class__.__name__}([{', '.join(map(repr, self.ranges))}])"
92
93
94class AllocationTreeNode:
95    """
96    Abstract base class for nodes in allocation pool.
97    """
98
99    def allocate(self, block: Allocation, is_last: bool) -> bool:
100        """
101        Try to assign block to a memory location in this bool.  Return True if
102        an assignment was made.
103        """
104        return False
105
106    def get_live_ranges(self) -> LiveRanges:
107        """Aggregate LiveRanges for all objects below this in tree"""
108        raise NotImplementedError
109
110    def get_size_hint(self) -> int:
111        """Number of bytes used for example inputs"""
112        raise NotImplementedError
113
114    def get_symbolic_size(self) -> sympy.Expr:
115        """Number of bytes needed at runtime"""
116        raise NotImplementedError
117
118    def finalize(self, pool, offset) -> AllocationTreeNode:
119        """Called after all allocations have been made"""
120        return self
121
122    def is_empty(self):
123        return False
124
125
126@dataclasses.dataclass
127class Allocation(AllocationTreeNode):
128    """
129    Represents memory allocated to a given node in the allocation pool.
130    """
131
132    node: ir.Buffer
133    live_range: LiveRange
134    size_hint: int
135    symbolic_size: sympy.Expr
136    allocated: bool = False
137    pool: Optional[AllocationPool] = None
138    offset: Optional[sympy.Expr] = None
139
140    @property
141    def device(self):
142        return self.node.get_device()
143
144    def get_live_ranges(self):
145        return LiveRanges([self.live_range])
146
147    def get_size_hint(self):
148        return self.size_hint
149
150    def get_symbolic_size(self):
151        return self.symbolic_size
152
153    def mark_allocated(self):
154        assert not self.allocated
155        self.allocated = True
156
157    def finalize(self, pool, offset):
158        assert self.pool is None and self.offset is None
159        self.pool = pool
160        self.offset = offset
161        return self
162
163    def codegen_alloc_from_pool(self, wrapper):
164        assert self.pool
165        node = self.node
166        shape = tuple(node.get_size())
167        stride = tuple(node.get_stride())
168        return wrapper.codegen_alloc_from_pool(
169            self.pool.name, self.offset, node.get_dtype(), shape, stride
170        )
171
172    def __repr__(self):
173        return (
174            f"{self.__class__.__name__}("
175            f"node={self.node.get_name()}, "
176            f"live_range={self.live_range}, "
177            f"size_hint={self.size_hint}, "
178            f"symbolic_size={self.symbolic_size}, "
179            f"pool={self.pool.name if self.pool else None}, "
180            f"offset={self.offset})"
181        )
182
183
184@dataclasses.dataclass
185class Empty(AllocationTreeNode):
186    """
187    Placeholder to represent empty space in the allocation pool.
188    Only exists to get the size_hint correct in parent nodes.
189    """
190
191    size_hint: int
192
193    def get_live_ranges(self):
194        return LiveRanges([])
195
196    def get_size_hint(self):
197        return self.size_hint
198
199    def get_symbolic_size(self):
200        return 0
201
202    def is_empty(self):
203        return True
204
205
206class MemorySplitProtocol(Protocol):
207    get_live_ranges: CachedMethod[[], LiveRanges]
208    get_size_hint: CachedMethod[[], int]
209    get_symbolic_size: CachedMethod[[], sympy.Expr]
210
211    def _allocate(self, block: Allocation, is_last: bool) -> bool:
212        ...
213
214
215class ClearCacheOnAllocateMixin(MemorySplitProtocol):
216    """
217    Helper to assist in caching get_live_ranges, get_size_hint, and
218    get_symbolic_size.
219    """
220
221    def allocate(self, block: Allocation, is_last: bool):
222        is_allocated = self._allocate(block, is_last)
223        if is_allocated:
224            self.clear_cache()
225        return is_allocated
226
227    def clear_cache(self):
228        self.get_live_ranges.clear_cache(self)
229        self.get_size_hint.clear_cache(self)
230        self.get_symbolic_size.clear_cache(self)
231
232
233@dataclasses.dataclass
234class TemporalSplit(ClearCacheOnAllocateMixin, AllocationTreeNode):
235    """
236    Contains a list of allocations not overlapping in LiveRanges.
237
238    Invariant: no pair (a,b) in self.allocations will have:
239         a.get_live_ranges().overlaps(b.get_live_ranges())
240    """
241
242    allocations: List[AllocationTreeNode]
243
244    def _allocate(self, block: Allocation, is_last: bool):
245        slot_size = self.get_size_hint()
246        block_size = block.get_size_hint()
247        if not is_last and block_size > slot_size:
248            return False  # doesn't fit
249
250        block_live = block.get_live_ranges()
251        overlapping = [
252            s for s in self.allocations if s.get_live_ranges().overlaps(block_live)
253        ]
254        if len(overlapping) > 1:
255            # TODO(jansel): we could try harder here by merging overlapping in space
256            return False
257        elif len(overlapping) == 1:
258            return overlapping[0].allocate(block, is_last)
259        else:
260            block.mark_allocated()
261
262            if len(self.allocations) == 1 and isinstance(self.allocations[-1], Empty):
263                self.allocations.pop()
264
265            if slot_size == block_size:
266                # perfect fit
267                self.allocations.append(block)
268            elif slot_size > block_size:
269                self.allocations.append(
270                    SpatialSplit.create(block, slot_size - block_size)
271                )
272            else:  # grow this allocation
273                assert is_last
274                self.allocations = [
275                    *(
276                        SpatialSplit.create(a, block_size - slot_size)
277                        for a in self.allocations
278                    ),
279                    block,
280                ]
281            return True
282
283    @cache_on_self
284    def get_live_ranges(self) -> LiveRanges:
285        return LiveRanges(
286            itertools.chain.from_iterable(
287                x.get_live_ranges().ranges for x in self.allocations
288            )
289        )
290
291    @cache_on_self
292    def get_size_hint(self) -> int:
293        if not self.allocations:
294            return 0
295        return max(x.get_size_hint() for x in self.allocations)
296
297    @cache_on_self
298    def get_symbolic_size(self) -> sympy.Expr:
299        if not self.allocations:
300            return 0  # type: ignore[return-value]
301        return sympy.Max(*[x.get_symbolic_size() for x in self.allocations])
302
303    def is_empty(self):
304        return len(self.allocations) == 1 and self.allocations[0].is_empty()
305
306    def finalize(self, pool, offset):
307        self.allocations = [block.finalize(pool, offset) for block in self.allocations]
308        self.clear_cache()
309        if len(self.allocations) == 1:
310            return self.allocations[0]
311        return self
312
313
314@dataclasses.dataclass
315class SpatialSplit(ClearCacheOnAllocateMixin, AllocationTreeNode):
316    """
317    Contains two allocations, left and right, that do not overlap in space.
318    Right will be allocated immediately after left in memory.
319    """
320
321    left: TemporalSplit
322    right: TemporalSplit
323
324    @staticmethod
325    def create(left, extra_space):
326        assert isinstance(left, AllocationTreeNode)
327        assert isinstance(extra_space, int) and extra_space >= 1
328        return SpatialSplit(TemporalSplit([left]), TemporalSplit([Empty(extra_space)]))
329
330    def _allocate(self, block: Allocation, is_last: bool):
331        return self.left.allocate(block, False) or self.right.allocate(block, is_last)
332
333    @cache_on_self
334    def get_live_ranges(self):
335        return LiveRanges(
336            itertools.chain(
337                self.left.get_live_ranges().ranges, self.right.get_live_ranges().ranges
338            )
339        )
340
341    @cache_on_self
342    def get_size_hint(self) -> int:
343        return _align(self.left.get_size_hint()) + self.right.get_size_hint()
344
345    @cache_on_self
346    def get_symbolic_size(self) -> sympy.Expr:
347        return align(self.left.get_symbolic_size()) + self.right.get_symbolic_size()
348
349    def finalize(self, pool, offset):
350        self.left = self.left.finalize(pool, offset)
351        self.right = self.right.finalize(
352            pool, offset + align(self.left.get_symbolic_size())
353        )
354        self.clear_cache()
355        if self.right.is_empty():
356            return self.left
357        return self
358
359
360@dataclasses.dataclass
361class AllocationPool:
362    """
363    Represents a pool of allocations that will be generated by a single
364    call to torch.empty.
365    """
366
367    device: torch.device
368    root: TemporalSplit
369    can_expand: bool = True
370    restrict_live_range: Optional[LiveRange] = None
371    name: Optional[str] = None
372    names_to_del: List[str] = dataclasses.field(default_factory=list)
373    creation_cache: Dict[str, str] = dataclasses.field(default_factory=dict)
374
375    def allocate(self, block: Allocation, is_last: bool):
376        if self.restrict_live_range and not self.restrict_live_range.contains(
377            block.live_range
378        ):
379            return False
380
381        is_last = self.can_expand and is_last
382        if self.root.allocate(block, is_last):
383            return True
384
385        if is_last:
386            return self.allocate_at_end(block)
387
388        return False
389
390    def allocate_at_end(self, block):
391        block.mark_allocated()
392        self.root = TemporalSplit([SpatialSplit(self.root, TemporalSplit([block]))])
393        return True
394
395    def finalize(self, name):
396        assert not self.name
397        self.name = name
398        self.names_to_del.append(name)
399        self.root.finalize(self, 0)
400
401    def codegen_create(self, wrapper, code: IndentedBuffer):
402        assert self.name
403        nbytes = self.root.get_symbolic_size()
404        for block in self.root.allocations:
405            if isinstance(block, Allocation) and nbytes == block.get_symbolic_size():
406                # optimization: fuse first allocation and pool creation
407                node = block.node
408                code.writeline(
409                    wrapper.make_allocation(
410                        self.name,
411                        device=self.device,
412                        dtype=node.get_dtype(),
413                        shape=tuple(node.get_size()),
414                        stride=tuple(node.get_stride()),
415                    )
416                )
417                self.creation_cache[block.codegen_alloc_from_pool(wrapper)] = self.name
418                return
419        else:
420            code.writeline(
421                wrapper.make_allocation(
422                    self.name,
423                    device=self.device,
424                    dtype=torch.uint8,
425                    shape=(nbytes,),
426                    stride=(1,),
427                )
428            )
429
430    def codegen_destroy(self, wrapper, code: IndentedBuffer):
431        code.writeline(wrapper.make_free_by_names(self.names_to_del))
432
433    def __eq__(self, other):
434        return self is other
435
436    def __hash__(self):
437        return id(self)
438
439
440@dataclasses.dataclass
441class AllocationPools:
442    """
443    Collection of many AllocationPool objects grouped by device.
444    """
445
446    device_to_pools: Dict[torch.device, List[AllocationPool]] = dataclasses.field(
447        default_factory=dict
448    )
449
450    def get_pools(self, block):
451        if block.device not in self.device_to_pools:
452            self.device_to_pools[block.device] = []
453        return self.device_to_pools[block.device]
454
455    def allocate(self, block: Allocation):
456        pools = self.get_pools(block)
457
458        for pool in pools:
459            if pool.allocate(block, is_last=pool is pools[-1]):
460                return
461
462        # everything is full, make a new pool
463        pools.append(
464            AllocationPool(
465                block.device,
466                TemporalSplit([block]),
467                can_expand=config.memory_pool != "none",
468            )
469        )
470        block.mark_allocated()
471
472    def allocate_output(self, block: Allocation):
473        """Outputs get different pools so memory gets freed properly"""
474        pools = self.get_pools(block)
475        if pools and config.memory_pool in ("outputs", "combined"):
476            pools[-1].allocate_at_end(block)
477        else:
478            # create a new pool
479            block.mark_allocated()
480            pools.append(
481                AllocationPool(
482                    block.device,
483                    TemporalSplit([block]),
484                    can_expand=config.memory_pool == "combined",
485                )
486            )
487
488    def finalize(self):
489        """Called at the end of allocation process"""
490        for i, pool in enumerate(
491            itertools.chain.from_iterable(self.device_to_pools.values())
492        ):
493            pool.finalize(f"pool{i}")
494
495    def pprint(self):
496        for pool in itertools.chain.from_iterable(self.device_to_pools.values()):
497            print()
498            print(pool.name)
499            print(pool.root.get_live_ranges())
500            pprint.pprint(pool.root)
501
502
503class BufferGroup:
504    """
505    Due to inplace reuse an allocated buffer can have many names.
506    This tracks these collections of buffers sharing underlying memory.
507    """
508
509    def __init__(self, node: ir.Buffer):
510        self.node = node
511        self.names = [node.get_name()]
512        self.is_output = False
513        self.allocation: Optional[Allocation] = None
514        self.live_range = LiveRange(float("inf"), -float("inf"))
515
516    def update_usage(self, timestep: int):
517        """Expand self.live_range to include timestep"""
518        self.live_range = LiveRange(
519            min(timestep, self.live_range.begin),
520            max(timestep, self.live_range.end),
521        )
522
523    def sym_nbytes(self):
524        return self.node.get_layout().storage_size() * self.node.get_dtype().itemsize
525
526    def make_allocation(self):
527        assert not self.allocation, "multiple allocations"
528        assert isinstance(self.live_range.begin, int), "live ranges not computed"
529        nbytes = self.sym_nbytes()
530        # For now, fallback value will be used if we encounter an unbacked SymInt. The longer-term plan is to have
531        # size_hint() use better heuristics for unbackeds, at which point the fallback value will be ignored.
532        size_hint = V.graph.sizevars.size_hint(nbytes, fallback=64)
533        self.allocation = Allocation(
534            self.node,
535            self.live_range,
536            size_hint=size_hint,
537            symbolic_size=nbytes,
538        )
539
540    def __repr__(self):
541        return (
542            f"{self.__class__.__name__}({self.names!r}, is_output={self.is_output}, "
543            f"live_range={self.live_range}"
544        )
545
546
547@dataclasses.dataclass
548class PoolMemoryPlanningLine(MemoryPlanningLine):
549    """Abstract base class for {Alloc,Dealloc}FromPoolLine"""
550
551    group: BufferGroup
552    timestep: Optional[int] = None
553
554    @property
555    def node(self):
556        return self.group.node
557
558
559@dataclasses.dataclass
560class AllocFromPoolLine(PoolMemoryPlanningLine):
561    """Similar to AllocationLine, but takes memory from a pool"""
562
563    is_first_pool_usage: bool = False
564
565    def codegen(self, code: IndentedBuffer):
566        allocation = self.group.allocation
567        assert allocation and allocation.pool
568        pool = allocation.pool
569        name = self.node.get_name()
570
571        if self.is_first_pool_usage:
572            pool.codegen_create(self.wrapper, code)
573
574        pool.names_to_del.extend(self.group.names)
575        alloc_from_pool = allocation.codegen_alloc_from_pool(self.wrapper)
576        if alloc_from_pool in pool.creation_cache:
577            code.writeline(
578                self.wrapper.make_tensor_alias(
579                    name, pool.creation_cache[alloc_from_pool], "alloc"
580                )
581            )
582        else:
583            pool.creation_cache[alloc_from_pool] = name
584            code.writeline(
585                f"{self.wrapper.declare}{name} = {alloc_from_pool}{self.wrapper.ending}"
586            )
587
588
589@dataclasses.dataclass
590class DeallocFromPoolLine(PoolMemoryPlanningLine):
591    """Similar to FreeIfNotReusedLine, but takes memory from a pool"""
592
593    is_last_pool_usage: bool = False
594
595    def codegen(self, code: IndentedBuffer):
596        if self.is_last_pool_usage:
597            assert self.group.allocation and self.group.allocation.pool
598            self.group.allocation.pool.codegen_destroy(self.wrapper, code)
599
600
601@dataclasses.dataclass
602class MemoryPlanner:
603    """
604    Coordination object to run memory planning passes during wrapper
605    codegen.
606    """
607
608    wrapper: Any
609    pools: AllocationPools = dataclasses.field(default_factory=AllocationPools)
610    buffer_groups: Optional[List[BufferGroup]] = None
611
612    def plan(self, lines: List[Any]) -> List[Any]:
613        """Call all the memory planning passes in sequence"""
614        lines = [*lines]
615        self.drop_removed_buffers(lines)
616        self.convert_to_pool_lines(lines)
617        self.compute_live_ranges(lines)
618        self.allocate_groups()
619        self.mark_first_last_usage(lines)
620        return lines
621
622    def drop_removed_buffers(self, lines):
623        """
624        Replace any memory planning lines in V.graph.removed_buffers with NullLine
625        """
626        # drop any removed buffers
627        for i, line in enumerate(lines):
628            if isinstance(line, (AllocateLine, FreeIfNotReusedLine, ReuseLine)):
629                if line.node.get_name() in V.graph.removed_buffers:
630                    lines[i] = NullLine(self.wrapper)
631
632    def compute_buffer_groups(self, lines):
633        """
634        Populates self.buffer_groups with BufferGroup objects that join
635        allocations with common storage (due to inplace reuse) into a
636        single object.
637        """
638        name_to_group = {}
639        for line in lines:
640            if isinstance(line, AllocateLine):
641                name = line.node.get_name()
642                assert name not in name_to_group
643                name_to_group[name] = BufferGroup(line.node)
644            elif isinstance(line, ReuseLine):
645                old_name = line.node.get_name()
646                new_name = line.reused_as.get_name()
647                assert new_name not in name_to_group
648                # TODO(jansel): we should support reusing buffers created via ExternKernelAlloc
649                if old_name in name_to_group:
650                    name_to_group[old_name].names.append(new_name)
651                    name_to_group[new_name] = name_to_group[old_name]
652
653        outputs = set(V.graph.get_output_names())
654        unique_groups = [*{id(g): g for g in name_to_group.values()}.values()]
655        for group in unique_groups:
656            group.is_output = any(x in outputs for x in group.names)
657
658        assert self.buffer_groups is None
659        self.buffer_groups = unique_groups
660        return name_to_group
661
662    def convert_to_pool_lines(self, lines):
663        """
664        Convert AllocateLine/FreeIfNotReusedLine/ReuseLine into their
665        pool-based counterparts.
666        """
667        name_to_group = self.compute_buffer_groups(lines)
668        for i, line in enumerate(lines):
669            if isinstance(line, AllocateLine):
670                if line.node.get_name() in name_to_group:
671                    lines[i] = AllocFromPoolLine(
672                        self.wrapper, name_to_group[line.node.get_name()]
673                    )
674            elif isinstance(line, FreeIfNotReusedLine):
675                assert not line.is_reused
676                if line.node.get_name() in name_to_group:
677                    lines[i] = DeallocFromPoolLine(
678                        self.wrapper, name_to_group[line.node.get_name()]
679                    )
680            elif isinstance(line, ReuseLine):
681                if line.node.get_name() in name_to_group:
682                    line.delete_old = False
683
684    def compute_live_ranges(self, lines):
685        """Populate every BufferGroup.live_ranges field based on first/last usage"""
686        timestep = 0
687        worklist = collections.deque(lines)
688        while worklist:
689            if isinstance(worklist[0], MemoryPlanningLine):
690                timestep += 1
691                while worklist and isinstance(worklist[0], MemoryPlanningLine):
692                    line = worklist.popleft()
693                    if isinstance(line, PoolMemoryPlanningLine):
694                        line.group.update_usage(timestep)
695                        line.timestep = timestep
696            else:
697                worklist.popleft()
698
699        timestep += 1
700        assert self.buffer_groups is not None
701        for group in self.buffer_groups:
702            if group.is_output:
703                group.update_usage(timestep)
704
705    def allocate_groups(self):
706        """
707        Assign every allocation to a specific location in a specific AllocationPool.
708        """
709        assert config.memory_pool in ("none", "intermediates", "outputs", "combined")
710        assert self.buffer_groups is not None
711
712        for group in self.buffer_groups:
713            group.make_allocation()
714
715        outputs: List[Allocation] = []
716        intermediates: List[Allocation] = []
717        for group in self.buffer_groups:
718            assert group.allocation
719            if group.is_output and config.memory_pool != "combined":
720                outputs.append(group.allocation)
721            else:
722                intermediates.append(group.allocation)
723
724        for block in sorted(
725            outputs,
726            key=lambda x: (
727                x.size_hint,
728                -len(x.live_range),
729            ),
730        ):
731            self.pools.allocate_output(block)
732
733        for block in sorted(
734            intermediates,
735            key=lambda x: (
736                -x.size_hint,
737                -len(x.live_range),
738            ),
739        ):
740            self.pools.allocate(block)
741
742        self.pools.finalize()
743
744    def mark_first_last_usage(self, lines):
745        """
746        Populate the AllocFromPoolLine.is_first_pool_usage and
747        DeallocFromPoolLine.is_last_pool_usage fields so that pools
748        are created/destroyed.
749        """
750        seen = set()
751        for line in lines:
752            if isinstance(line, AllocFromPoolLine):
753                assert line.group.allocation
754                pool = line.group.allocation.pool
755                assert pool is not None
756                if pool not in seen:
757                    line.is_first_pool_usage = True
758                    seen.add(pool)
759
760        seen = set()
761        for line in reversed(lines):
762            if isinstance(line, DeallocFromPoolLine):
763                assert line.group.allocation
764                pool = line.group.allocation.pool
765                assert pool is not None
766                if pool not in seen:
767                    line.is_last_pool_usage = (
768                        pool.root.get_live_ranges().end <= line.timestep
769                    )
770                    seen.add(pool)
771