# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import collections import itertools import logging from functools import partial from typing import Iterable, List, Optional, Tuple import torch from executorch.backends.cadence.aot.utils import MemoryConfig from executorch.exir import ExecutorchProgramManager from executorch.exir.memory_planning import collect_specs_from_nodes, Verifier from executorch.exir.passes import MemoryPlanningPass from executorch.exir.tensor import TensorSpec from tabulate import tabulate from torch.export.exported_program import ExportGraphSignature from torch.fx.passes.infra.pass_base import PassResult # get num memories indexed from 1..N, compatible with EXIR's spec.mem_id def get_num_memories(memory_config: MemoryConfig) -> int: return len(memory_config.memory_sizes) + 1 # memory_space module provides num_memories indexed 0..num_memories-1. def get_size(memory_config: MemoryConfig, exir_id: int) -> int: return memory_config.memory_sizes[exir_id - 1] def collect_specs_from_graph_module( graph_module: torch.fx.GraphModule, alloc_graph_input: bool, alloc_graph_output: bool, ) -> Iterable[TensorSpec]: """ Return the specs for all the nodes in the graph module in topological order. """ # Collect the specs from all the nodes in the graph module, and return it return collect_specs_from_nodes( graph_module.graph.nodes, ignore_graph_input=not alloc_graph_input, ignore_graph_output=not alloc_graph_output, ) # baseline tensor placement algorithm, that greedily tries to place the tensor in # the fastest memory available def position_based_greedy_with_hierarchy( graph_module: torch.fx.GraphModule, alignment: int, graph_signature: ExportGraphSignature, alloc_graph_input: bool, alloc_graph_output: bool, *, memory_config: MemoryConfig, ) -> List[int]: num_memories = get_num_memories(memory_config) bufsizes = [0] * num_memories allocated_buffers: List[List[TensorSpec]] = [[] for _ in range(num_memories)] def overlap(spec: TensorSpec) -> Optional[TensorSpec]: for allocated_spec in allocated_buffers[spec.mem_id]: if Verifier.lifetime_overlap( spec, allocated_spec ) and Verifier.storage_overlap(spec, allocated_spec): return allocated_spec return None def memory_available(spec: TensorSpec) -> bool: return spec.mem_offset + spec.allocated_memory <= get_size( memory_config, spec.mem_id ) # Iterate over all the specs in sorted order for spec in sorted( collect_specs_from_graph_module( graph_module, alloc_graph_input, alloc_graph_output ), key=lambda spec: spec.allocated_memory, reverse=True, ): for spec.mem_id in range(1, num_memories): spec.mem_offset = 0 while memory_available(spec) and (overlapped := overlap(spec)): spec.mem_offset = overlapped.mem_offset + overlapped.allocated_memory if memory_available(spec): allocated_buffers[spec.mem_id].append(spec) bufsizes[spec.mem_id] = max( spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id] ) break if ( not allocated_buffers[spec.mem_id] or allocated_buffers[spec.mem_id][-1] is not spec ): raise MemoryError(f"Cannot fit {spec} in any memory hierarchy") logging.debug( f"position based greedy algorithm with hierarchy returns bufsizes: {bufsizes}" ) return bufsizes # Greedy tensor placement with the heuristics from arxiv.org/pdf/2001.03288.pdf def greedy_by_size_for_offset_calculation_with_hierarchy( graph_module: torch.fx.GraphModule, alignment: int, graph_signature: ExportGraphSignature, alloc_graph_input: bool, alloc_graph_output: bool, *, memory_config: MemoryConfig, ) -> List[int]: num_memories = get_num_memories(memory_config) bufsizes = [0] * num_memories allocated_buffers = [[] for _ in range(num_memories)] # Iterate over all the specs in sorted order for spec in sorted( collect_specs_from_graph_module( graph_module, alloc_graph_input, alloc_graph_output ), key=lambda spec: spec.allocated_memory, reverse=True, ): for spec.mem_id in range(1, num_memories): prev_offset, smallest_gap = 0, float("inf") for allocated_spec in allocated_buffers[spec.mem_id]: if Verifier.lifetime_overlap(spec, allocated_spec): if ( gap := allocated_spec.mem_offset - prev_offset ) >= spec.allocated_memory and gap < smallest_gap: smallest_gap = gap spec.mem_offset = prev_offset # Note that different from the paper, which updates prev_offset for all # allocated tensors, we only update tensors with overlapping lifetime. # Updating prev_offset outside the if statement will include tensors without # overlapping lifetime, causing unnecessary waste of memory and make the # calculation of gap incorrect. Moving it out will make the algorithm degenerate # to the naive one, reusing 0 tensor. The paper may have a typo here. prev_offset = max( allocated_spec.mem_offset + allocated_spec.allocated_memory, prev_offset, ) if spec.mem_offset is None: if prev_offset + spec.allocated_memory > get_size( memory_config, spec.mem_id ): continue else: spec.mem_offset = prev_offset bufsizes[spec.mem_id] = max( spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id] ) allocated_buffers[spec.mem_id].append(spec) allocated_buffers[spec.mem_id].sort(key=lambda spec: spec.mem_offset) # A data structure used for maintaining the tensor order # by offset, named ordered_allocated_ids in the paper break if spec not in allocated_buffers[spec.mem_id]: raise MemoryError(f"Cannot fit {spec} in any memory hierarchy") logging.debug( f"greedy by size for offset calculation with hierarchy returns bufsizes: {bufsizes}" ) return bufsizes def find_peak_memory_usages_per_memory( graph_module: torch.fx.GraphModule, alloc_graph_input: bool, alloc_graph_output: bool, ) -> List[int]: """ Given a GraphModule with a memory plan, find the peak memory usages for each memory in the memory hierarchy. """ # Create a defaultdict to keep track of memory usages: {mem_id: mem_usage} # Use a defaultdict here because we don't know how many unique memory_id in # the memory hierarchy used in memory planning. usages = collections.defaultdict(int) # go through all nodes in the graph, collect memory usage per spec.mem_id for spec in collect_specs_from_graph_module( graph_module, alloc_graph_input, alloc_graph_output ): usages[spec.mem_id] = max( usages[spec.mem_id], spec.mem_offset + spec.allocated_memory ) # Convert usages dictionary into list of len of max memory id # Ex: {1: 20, 3:30} -> [0, 20, 0, 30]. # ^ ^ ^ ^ # | | | |_ mem_id 3 # | | |_ mem_id 2 # | |_ mem_id 1 # |_ mem_id 0 max_mem_id = max(usages.keys(), default=0) usages = [usages[i] for i in range(1, max_mem_id + 1)] return usages def find_peak_memory_usage( graph_module: torch.fx.GraphModule, alloc_graph_input: bool, alloc_graph_output: bool, ) -> Tuple[int, int]: """ Given a GraphModule with a memory plan, find the peak usage over time across all memories in the memory hierarchy. The resulting peak memory usage should be: 1. >= min(find_peak_memory_usages_per_memory(graph_module)) 2. <= sum(find_peak_memory_usages_per_memory(graph_module)) """ # memory allocations over time (measured in nodex index) byte_allocated = [0] * (len(graph_module.graph.nodes) + 1) # Iterate over all the node specs for spec in collect_specs_from_graph_module( graph_module, alloc_graph_input, alloc_graph_output ): if spec.lifetime[0] is None: continue # lifetime is [start, end], both ends inclusive start, end = spec.lifetime byte_allocated[start] += spec.allocated_memory byte_allocated[end + 1] -= spec.allocated_memory # accumulate the bytes allocated/deallocated to get memory usages memory_usages = list(itertools.accumulate(byte_allocated)) # find the peak memory usage and the index peak_memory_usage = max(memory_usages, default=0) peak_memory_usage_node_idx = ( memory_usages.index(peak_memory_usage) if memory_usages else 0 ) return peak_memory_usage, peak_memory_usage_node_idx # Print two tables with relevant memory planning information # # Per Memory Space Usage Table: # +--------------------------------------+----------------+-----------------------+-----------------------------+ # | Memory Space | Base Address | Memory Size (Bytes) | Peak Memory Usage (Bytes) | # +======================================+================+=======================+=============================+ # | MEMORY SPACE A | 0x57be0000 | 65213 | 64544 | # | MEMORY SPACE B | 0x57bf0000 | 65521 | 36864 | # | MEMORY SPACE ... | ... | ... | ... | # +--------------------------------------+----------------+-----------------------+-----------------------------+ # # Total Memory Space Usage Table: # +-------------------------------------+---------------+---------+ # | Peak memory usage across all spaces | 2380032 bytes | Node 86 | # +-------------------------------------+---------------+---------+ def print_memory_planning_info( # pyre-fixme[11]: Annotation `ExecutorchProgramManager` is not defined as a type. executorch_prog: ExecutorchProgramManager, memory_config: MemoryConfig, alloc_graph_input: bool, alloc_graph_output: bool, ) -> None: # Get the peak memory usages per memory space peak_memory_usages_per_memory = find_peak_memory_usages_per_memory( executorch_prog.exported_program().graph_module, alloc_graph_input, alloc_graph_output, ) # Create a table of memory spaces and their base addresses, total memory sizes, and peak memory usage memory_names, base_addrs = memory_config.memory_names, memory_config.base_addrs memory_usage_table = [ [ f"{(i + 1) if memory_names is None else memory_names[i]}", None if base_addrs is None else hex(base_addrs[i]), memory_config.memory_sizes[i], peak_memory_usages_per_memory[i], ] for i in range(len(peak_memory_usages_per_memory)) ] # Print the memory usage per memory space as a table logging.info( tabulate( memory_usage_table, headers=[ "Memory Space", "Base Address", "Memory Size (Bytes)", "Peak Memory Usage (Bytes)", ], tablefmt="outline", ) ) # Get the total peak memory usage across all memory spaces total_peak_memory_usage = find_peak_memory_usage( executorch_prog.exported_program().graph_module, alloc_graph_input, alloc_graph_output, ) # Create a table with total peak memory usage and node at which this occurs total_memory_usage_table = [ [ "Peak memory usage across all spaces", f"{total_peak_memory_usage[0]} bytes", f"Node {total_peak_memory_usage[1]}", ] ] # Print the total memory usage as a table logging.info( tabulate( total_memory_usage_table, tablefmt="outline", ) ) class CadenceMemoryPlanning: def __init__( self, memory_config: MemoryConfig, mem_algo: int, alloc_graph_input: bool = True, alloc_graph_output: bool = True, ) -> None: self._init_mem_algos() self.memory_config = memory_config self.mem_algo = mem_algo self.alloc_graph_input = alloc_graph_input self.alloc_graph_output = alloc_graph_output def _init_mem_algos(self) -> None: self.available_mem_algos = [ position_based_greedy_with_hierarchy, greedy_by_size_for_offset_calculation_with_hierarchy, ] def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: algo = partial( self.available_mem_algos[self.mem_algo], memory_config=self.memory_config, ) # Create the memory planning pass. We allocate memory for input # (output) tensors if alloc_graph_input (alloc_graph_output) is # True. mem_planning = MemoryPlanningPass( algo, allow_lifetime_and_storage_overlap=False, alloc_graph_input=self.alloc_graph_input, alloc_graph_output=self.alloc_graph_output, ) mem_planning(graph_module) return PassResult(graph_module, True)