1# Copyright (c) MediaTek Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import os 9 10import torch 11from executorch.backends.mediatek import Precision 12from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( 13 build_executorch_binary, 14) 15from executorch.examples.models.inception_v4 import InceptionV4Model 16 17 18class NhwcWrappedModel(torch.nn.Module): 19 def __init__(self): 20 super(NhwcWrappedModel, self).__init__() 21 self.inception = InceptionV4Model().get_eager_model() 22 23 def forward(self, input1): 24 nchw_input1 = input1.permute(0, 3, 1, 2) 25 output = self.inception(nchw_input1) 26 return output 27 28 29def get_dataset(dataset_path, data_size): 30 from torchvision import datasets, transforms 31 32 def get_data_loader(): 33 preprocess = transforms.Compose( 34 [ 35 transforms.Resize((299, 299)), 36 transforms.ToTensor(), 37 transforms.Normalize( 38 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 39 ), 40 ] 41 ) 42 imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) 43 return torch.utils.data.DataLoader( 44 imagenet_data, 45 shuffle=True, 46 ) 47 48 # prepare input data 49 inputs, targets, input_list = [], [], "" 50 data_loader = get_data_loader() 51 for index, data in enumerate(data_loader): 52 if index >= data_size: 53 break 54 feature, target = data 55 feature = feature.permute(0, 2, 3, 1) # NHWC 56 inputs.append((feature,)) 57 targets.append(target) 58 input_list += f"input_{index}_0.bin\n" 59 60 return inputs, targets, input_list 61 62 63if __name__ == "__main__": 64 parser = argparse.ArgumentParser() 65 66 parser.add_argument( 67 "-d", 68 "--dataset", 69 help=( 70 "path to the validation folder of ImageNet dataset. " 71 "e.g. --dataset imagenet-mini/val " 72 "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" 73 ), 74 type=str, 75 required=True, 76 ) 77 78 parser.add_argument( 79 "-a", 80 "--artifact", 81 help="path for storing generated artifacts by this example. " 82 "Default ./inceptionV4", 83 default="./inceptionV4", 84 type=str, 85 ) 86 87 args = parser.parse_args() 88 89 # ensure the working directory exist. 90 os.makedirs(args.artifact, exist_ok=True) 91 92 data_num = 100 93 inputs, targets, input_list = get_dataset( 94 dataset_path=f"{args.dataset}", 95 data_size=data_num, 96 ) 97 98 # save data to inference on device 99 input_list_file = f"{args.artifact}/input_list.txt" 100 with open(input_list_file, "w") as f: 101 f.write(input_list) 102 f.flush() 103 for idx, data in enumerate(inputs): 104 for i, d in enumerate(data): 105 file_name = f"{args.artifact}/input_{idx}_{i}.bin" 106 d.detach().numpy().tofile(file_name) 107 for idx, data in enumerate(targets): 108 file_name = f"{args.artifact}/golden_{idx}_0.bin" 109 data.detach().numpy().tofile(file_name) 110 111 # build pte 112 pte_filename = "inceptionV4_mtk" 113 instance = NhwcWrappedModel() 114 build_executorch_binary( 115 instance.eval(), 116 (torch.randn(1, 299, 299, 3),), 117 f"{args.artifact}/{pte_filename}", 118 inputs, 119 quant_dtype=Precision.A8W8, 120 ) 121