• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) MediaTek Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import os
9
10import torch
11from executorch.backends.mediatek import Precision
12from executorch.examples.mediatek.aot_utils.oss_utils.utils import (
13    build_executorch_binary,
14)
15from executorch.examples.models.inception_v4 import InceptionV4Model
16
17
18class NhwcWrappedModel(torch.nn.Module):
19    def __init__(self):
20        super(NhwcWrappedModel, self).__init__()
21        self.inception = InceptionV4Model().get_eager_model()
22
23    def forward(self, input1):
24        nchw_input1 = input1.permute(0, 3, 1, 2)
25        output = self.inception(nchw_input1)
26        return output
27
28
29def get_dataset(dataset_path, data_size):
30    from torchvision import datasets, transforms
31
32    def get_data_loader():
33        preprocess = transforms.Compose(
34            [
35                transforms.Resize((299, 299)),
36                transforms.ToTensor(),
37                transforms.Normalize(
38                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
39                ),
40            ]
41        )
42        imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess)
43        return torch.utils.data.DataLoader(
44            imagenet_data,
45            shuffle=True,
46        )
47
48    # prepare input data
49    inputs, targets, input_list = [], [], ""
50    data_loader = get_data_loader()
51    for index, data in enumerate(data_loader):
52        if index >= data_size:
53            break
54        feature, target = data
55        feature = feature.permute(0, 2, 3, 1)  # NHWC
56        inputs.append((feature,))
57        targets.append(target)
58        input_list += f"input_{index}_0.bin\n"
59
60    return inputs, targets, input_list
61
62
63if __name__ == "__main__":
64    parser = argparse.ArgumentParser()
65
66    parser.add_argument(
67        "-d",
68        "--dataset",
69        help=(
70            "path to the validation folder of ImageNet dataset. "
71            "e.g. --dataset imagenet-mini/val "
72            "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
73        ),
74        type=str,
75        required=True,
76    )
77
78    parser.add_argument(
79        "-a",
80        "--artifact",
81        help="path for storing generated artifacts by this example. "
82        "Default ./inceptionV4",
83        default="./inceptionV4",
84        type=str,
85    )
86
87    args = parser.parse_args()
88
89    # ensure the working directory exist.
90    os.makedirs(args.artifact, exist_ok=True)
91
92    data_num = 100
93    inputs, targets, input_list = get_dataset(
94        dataset_path=f"{args.dataset}",
95        data_size=data_num,
96    )
97
98    # save data to inference on device
99    input_list_file = f"{args.artifact}/input_list.txt"
100    with open(input_list_file, "w") as f:
101        f.write(input_list)
102        f.flush()
103    for idx, data in enumerate(inputs):
104        for i, d in enumerate(data):
105            file_name = f"{args.artifact}/input_{idx}_{i}.bin"
106            d.detach().numpy().tofile(file_name)
107    for idx, data in enumerate(targets):
108        file_name = f"{args.artifact}/golden_{idx}_0.bin"
109        data.detach().numpy().tofile(file_name)
110
111    # build pte
112    pte_filename = "inceptionV4_mtk"
113    instance = NhwcWrappedModel()
114    build_executorch_binary(
115        instance.eval(),
116        (torch.randn(1, 299, 299, 3),),
117        f"{args.artifact}/{pte_filename}",
118        inputs,
119        quant_dtype=Precision.A8W8,
120    )
121