# Copyright (c) Qualcomm Innovation Center, Inc. # All rights reserved # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-ignore-all-errors import re from typing import List import torch from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.export.exported_program import ExportedProgram from torch.library import impl, Library fallback_op_lib = Library("llama", "DEF") # registering an operator. fallback_op_lib.define("fallback(Tensor input) -> Tensor") @impl(fallback_op_lib, "fallback") def fallback_impl(a: torch.Tensor) -> torch.Tensor: return a # registering the out variant. fallback_op_lib.define("fallback.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)") @impl(fallback_op_lib, "fallback.out") def fallback_out_impl(a: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor: out.copy_(a) return out class SplitGraph(ExportPass): """ Class to split the model to multiple partitions. Because there is limited memory on the device, it could not load all llama model in one pte. """ def __init__(self, shard_layers: List[int]): super().__init__() self.shard_layers = shard_layers def _insert_fallback_op( self, graph_module: torch.fx.GraphModule ) -> torch.fx.GraphModule: """ Insert fallback op before layer that needs to be shard. Example: There is 12 layers llama model and num_sharding is 3. The first partition will contain layers [0, 4) and embedding. The second partition will contain layers [4, 8). The third partition will contain layers [8, 12) and output. """ pattern = r"layers.(\d+)" prev_node = None prev_layer = None for node in graph_module.graph.nodes: if node.op != "call_function" or "nn_module_stack" not in node.meta: continue module_values_list = list(node.meta["nn_module_stack"].values()) full_qualified_name = module_values_list[-1][0] # Search which layer this node belongs to match = re.search(pattern, full_qualified_name) if match is None: continue cur_layer = int(match.group(1)) # Check the current node which is the last node of the layer if cur_layer in self.shard_layers and prev_layer == cur_layer - 1: with graph_module.graph.inserting_after(prev_node): users = list(prev_node.users.keys()) inserted_node = graph_module.graph.create_node( "call_function", exir_ops.edge.llama.fallback.default, (prev_node,), ) inserted_node.meta["val"] = prev_node.meta["val"] if prev_node.meta.get(QCOM_QUANT_ATTRS, None): inserted_node.meta[QCOM_QUANT_ATTRS] = prev_node.meta[ QCOM_QUANT_ATTRS ] for user in users: user.replace_input_with(prev_node, inserted_node) prev_layer = cur_layer prev_node = node def call(self, graph_module: torch.fx.GraphModule): self._insert_fallback_op(graph_module) graph_module.recompile() return PassResult(graph_module, True) def split_graph(edge_program: ExportedProgram, num_layers: int, shares: int): graph_module = edge_program.graph_module shard_layers = list(range(0, num_layers, int(num_layers / shares))) return SplitGraph(shard_layers)(graph_module)