# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # Example script for exporting simple models to flatbuffer import argparse import logging import torch from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig from executorch.extension.export_util.utils import ( export_to_edge, export_to_exec_prog, save_pte_program, ) from ...models import MODEL_NAME_TO_MODEL from ...models.model_factory import EagerModelFactory FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( "-m", "--model_name", required=True, help=f"provide a model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}", ) parser.add_argument( "-s", "--strict", action=argparse.BooleanOptionalAction, help="whether to export with strict mode. Default is True", ) parser.add_argument( "-a", "--segment_alignment", required=False, help="specify segment alignment in hex. Default is 0x1000. Use 0x4000 for iOS", ) parser.add_argument("-o", "--output_dir", default=".", help="output directory") args = parser.parse_args() if args.model_name not in MODEL_NAME_TO_MODEL: raise RuntimeError( f"Model {args.model_name} is not a valid name. " f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}." ) model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL[args.model_name] ) backend_config = ExecutorchBackendConfig() if args.segment_alignment is not None: backend_config.segment_alignment = int(args.segment_alignment, 16) if dynamic_shapes is not None: edge_manager = export_to_edge( model, example_inputs, dynamic_shapes=dynamic_shapes, edge_compile_config=EdgeCompileConfig( _check_ir_validity=False, ), strict=args.strict, ) prog = edge_manager.to_executorch(config=backend_config) else: prog = export_to_exec_prog( model, example_inputs, dynamic_shapes=dynamic_shapes, backend_config=backend_config, strict=args.strict, ) save_pte_program(prog, args.model_name, args.output_dir) if __name__ == "__main__": with torch.no_grad(): main() # pragma: no cover