• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import json
9from typing import Any, Callable, Dict, List, Optional, Tuple
10
11import torch
12from executorch.devtools import parse_etrecord
13
14from executorch.exir import ExportedProgram
15from executorch.exir.backend.backend_api import LoweredBackendModule
16
17
18def _get_tensor_data(node: torch.fx.Node, tensor: torch.Tensor) -> Dict[str, Any]:
19    return {
20        "name": node.name,
21        "numel": tensor.numel(),
22        "dtype": str(tensor.dtype)[6:],  # Remove "torch." prefix
23        "element_size": tensor.element_size(),
24        "shape": list(tensor.shape),
25        "num_bytes": tensor.element_size() * tensor.numel(),
26        "nn_module_stack": (
27            str(node.meta["nn_module_stack"])
28            if "nn_module_stack" in node.meta
29            else None
30        ),
31    }
32
33
34def _get_delegate_blob_data(
35    node: torch.fx.Node,
36    lowered_backend_module: LoweredBackendModule,
37    delegate_deserializers: Optional[
38        Dict[str, Callable[[bytes], Dict[str, Any]]]
39    ] = None,
40) -> Dict[str, Any]:
41    delegate_blob_data = {
42        "name": node.name,
43        "backend_id": lowered_backend_module.backend_id,
44        "num_bytes": len(lowered_backend_module.processed_bytes),
45    }
46    if (
47        delegate_deserializers is not None
48        and lowered_backend_module.backend_id in delegate_deserializers
49    ):
50        delegate_blob_data.update(
51            delegate_deserializers[lowered_backend_module.backend_id](
52                lowered_backend_module.processed_bytes
53            )
54        )
55
56    return delegate_blob_data
57
58
59def _get_nested_model_data(
60    graph_module: torch.fx.GraphModule,
61    delegate_deserializers: Optional[
62        Dict[str, Callable[[bytes], Dict[str, Any]]]
63    ] = None,
64    tensor_data: Optional[List[Dict[str, Any]]] = None,
65    delegate_blob_data: Optional[List[Dict[str, Any]]] = None,
66) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
67    if tensor_data is None:
68        tensor_data = []
69
70    if delegate_blob_data is None:
71        delegate_blob_data = []
72
73    for node in graph_module.graph.nodes:
74        if node.op == "get_attr":
75            node_attr = getattr(node.graph.owning_module, node.target)
76            if isinstance(node_attr, torch.Tensor):
77                tensor_data.append(_get_tensor_data(node, node_attr))
78            elif isinstance(node_attr, torch.fx.GraphModule):
79                _get_nested_model_data(
80                    node_attr, delegate_deserializers, tensor_data, delegate_blob_data
81                )
82            elif isinstance(node_attr, LoweredBackendModule):
83                delegate_blob_data.append(
84                    _get_delegate_blob_data(node, node_attr, delegate_deserializers)
85                )
86
87    return (tensor_data, delegate_blob_data)
88
89
90def generate_model_size_information(
91    model: ExportedProgram,
92    delegate_deserializers: Optional[
93        Dict[str, Callable[[bytes], Dict[str, Any]]]
94    ] = None,
95    flatbuffer: Optional[bytes] = None,
96) -> Dict[str, Any]:
97    """
98    Generate a json-serializable Dict containing information about a model's
99    size. This includes data about individual tensors and delegate blobs.
100    Optionally:
101    - delegate_deserializers can be provided to manually specify additional
102      information to include for delegate blobs for specific backends.
103    - flatbuffer can be provided to include a comparison of total tensor data
104      size to overall model size
105    """
106
107    tensor_and_delegate_blob_data = _get_nested_model_data(
108        model.graph_module, delegate_deserializers
109    )
110
111    for data_list in tensor_and_delegate_blob_data:
112        data_list.sort(key=lambda data: data["num_bytes"], reverse=True)
113
114    (tensor_data, delegate_blob_data) = tensor_and_delegate_blob_data
115
116    total_tensor_data_size = sum(data["num_bytes"] for data in tensor_data)
117    total_delegate_blob_data_size = sum(
118        data["num_bytes"] for data in delegate_blob_data
119    )
120    overview = {
121        "total_tensor_data_size": total_tensor_data_size,
122        "total_delegate_blob_data_size": total_delegate_blob_data_size,
123    }
124    if flatbuffer is not None:
125        model_size = len(flatbuffer)
126        overview.update(
127            {
128                "serialization_metadata_size": (
129                    model_size - total_tensor_data_size - total_delegate_blob_data_size
130                ),
131                "model_size": model_size,
132            }
133        )
134
135    return {
136        "tensor_data": tensor_data,
137        "delegate_blob_data": delegate_blob_data,
138        "overview": overview,
139    }
140
141
142def parse_args():
143    parser = argparse.ArgumentParser()
144
145    parser.add_argument(
146        "--etrecord_path",
147        required=True,
148        help="The path to the ETRecord for the model to generate size information for",
149    )
150
151    parser.add_argument(
152        "--output_path",
153        default="model_size_information.json",
154        help="The output path for the model size information as a json file",
155    )
156
157    args = parser.parse_args()
158    return args
159
160
161def main():
162    args = parse_args()
163
164    etrecord = parse_etrecord(args.etrecord_path)
165
166    all_model_size_information = [
167        generate_model_size_information(
168            model=exported_program,
169            delegate_deserializers=None,
170            flatbuffer=None,
171        )
172        for (name, exported_program) in etrecord.graph_map.items()
173    ]
174
175    with open(args.output_path, "w") as f:
176        f.write(json.dumps(all_model_size_information))
177
178
179if __name__ == "__main__":
180    main()
181