• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2#
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9import confu
10from confu import arm, x86
11
12
13parser = confu.standard_parser()
14
15
16def main(args):
17    options = parser.parse_args(args)
18    build = confu.Build.from_options(options)
19
20    build.export_cpath("include", ["q8gemm.h"])
21
22    with build.options(
23        source_dir="src",
24        deps=[
25            build.deps.cpuinfo,
26            build.deps.clog,
27            build.deps.psimd,
28            build.deps.fxdiv,
29            build.deps.pthreadpool,
30            build.deps.FP16,
31        ],
32        extra_include_dirs="src",
33    ):
34        requantization_objects = [
35            build.cc("requantization/precise-scalar.c"),
36            build.cc("requantization/fp32-scalar.c"),
37            build.cc("requantization/q31-scalar.c"),
38            build.cc("requantization/gemmlowp-scalar.c"),
39        ]
40        with build.options(isa=arm.neon if build.target.is_arm else None):
41            requantization_objects += [
42                build.cc("requantization/precise-psimd.c"),
43                build.cc("requantization/fp32-psimd.c"),
44            ]
45        if build.target.is_x86 or build.target.is_x86_64:
46            with build.options(isa=x86.sse2):
47                requantization_objects += [
48                    build.cc("requantization/precise-sse2.c"),
49                    build.cc("requantization/fp32-sse2.c"),
50                    build.cc("requantization/q31-sse2.c"),
51                    build.cc("requantization/gemmlowp-sse2.c"),
52                ]
53            with build.options(isa=x86.ssse3):
54                requantization_objects += [
55                    build.cc("requantization/precise-ssse3.c"),
56                    build.cc("requantization/q31-ssse3.c"),
57                    build.cc("requantization/gemmlowp-ssse3.c"),
58                ]
59            with build.options(isa=x86.sse4_1):
60                requantization_objects += [
61                    build.cc("requantization/precise-sse4.c"),
62                    build.cc("requantization/q31-sse4.c"),
63                    build.cc("requantization/gemmlowp-sse4.c"),
64                ]
65        if build.target.is_arm or build.target.is_arm64:
66            with build.options(isa=arm.neon if build.target.is_arm else None):
67                requantization_objects += [
68                    build.cc("requantization/precise-neon.c"),
69                    build.cc("requantization/fp32-neon.c"),
70                    build.cc("requantization/q31-neon.c"),
71                    build.cc("requantization/gemmlowp-neon.c"),
72                ]
73
74        qnnpytorch_pack_objects = [
75            # Common parts
76            build.cc("init.c"),
77            build.cc("operator-delete.c"),
78            build.cc("operator-run.c"),
79            # Operators
80            build.cc("add.c"),
81            build.cc("average-pooling.c"),
82            build.cc("channel-shuffle.c"),
83            build.cc("clamp.c"),
84            build.cc("convolution.c"),
85            build.cc("indirection.c"),
86            build.cc("deconvolution.c"),
87            build.cc("fully-connected.c"),
88            build.cc("global-average-pooling.c"),
89            build.cc("hardsigmoid.c"),
90            build.cc("hardswish.c"),
91            build.cc("leaky-relu.c"),
92            build.cc("max-pooling.c"),
93            build.cc("sigmoid.c"),
94            build.cc("softargmax.c"),
95            build.cc("tanh.c"),
96            # Scalar micro-kernels
97            build.cc("u8lut32norm/scalar.c"),
98            build.cc("x8lut/scalar.c"),
99        ]
100
101        with build.options(isa=arm.neon if build.target.is_arm else None):
102            qnnpytorch_pack_objects += [
103                build.cc("sconv/6x8-psimd.c"),
104                build.cc("sdwconv/up4x9-psimd.c"),
105                build.cc("sgemm/6x8-psimd.c"),
106            ]
107
108        with build.options(isa=arm.neon if build.target.is_arm else None):
109            if build.target.is_arm or build.target.is_arm64:
110                qnnpytorch_pack_objects += [
111                    build.cc("q8avgpool/mp8x9p8q-neon.c"),
112                    build.cc("q8avgpool/up8x9-neon.c"),
113                    build.cc("q8avgpool/up8xm-neon.c"),
114                    build.cc("q8conv/4x8-neon.c"),
115                    build.cc("q8conv/8x8-neon.c"),
116                    build.cc("q8dwconv/mp8x25-neon.c"),
117                    build.cc("q8dwconv/mp8x27-neon.c"),
118                    build.cc("q8dwconv/up8x9-neon.c"),
119                    build.cc("q8gavgpool/mp8x7p7q-neon.c"),
120                    build.cc("q8gavgpool/up8x7-neon.c"),
121                    build.cc("q8gavgpool/up8xm-neon.c"),
122                    build.cc("q8gemm/4x-sumrows-neon.c"),
123                    build.cc("q8gemm/4x8-neon.c"),
124                    build.cc("q8gemm/4x8c2-xzp-neon.c"),
125                    build.cc("q8gemm/6x4-neon.c"),
126                    build.cc("q8gemm/8x8-neon.c"),
127                    build.cc("q8vadd/neon.c"),
128                    build.cc("sgemm/5x8-neon.c"),
129                    build.cc("sgemm/6x8-neon.c"),
130                    build.cc("u8clamp/neon.c"),
131                    build.cc("u8maxpool/16x9p8q-neon.c"),
132                    build.cc("u8maxpool/sub16-neon.c"),
133                    build.cc("u8rmax/neon.c"),
134                    build.cc("x8zip/x2-neon.c"),
135                    build.cc("x8zip/x3-neon.c"),
136                    build.cc("x8zip/x4-neon.c"),
137                    build.cc("x8zip/xm-neon.c"),
138                ]
139            if build.target.is_arm:
140                qnnpytorch_pack_objects += [
141                    build.cc("hgemm/8x8-aarch32-neonfp16arith.S"),
142                    build.cc("q8conv/4x8-aarch32-neon.S"),
143                    build.cc("q8dwconv/up8x9-aarch32-neon.S"),
144                    build.cc("q8gemm/4x8-aarch32-neon.S"),
145                    build.cc("q8gemm/4x8c2-xzp-aarch32-neon.S"),
146                ]
147            if build.target.is_arm64:
148                qnnpytorch_pack_objects += [
149                    build.cc("q8gemm/8x8-aarch64-neon.S"),
150                    build.cc("q8conv/8x8-aarch64-neon.S"),
151                ]
152            if build.target.is_x86 or build.target.is_x86_64:
153                with build.options(isa=x86.sse2):
154                    qnnpytorch_pack_objects += [
155                        build.cc("q8avgpool/mp8x9p8q-sse2.c"),
156                        build.cc("q8avgpool/up8x9-sse2.c"),
157                        build.cc("q8avgpool/up8xm-sse2.c"),
158                        build.cc("q8conv/4x4c2-sse2.c"),
159                        build.cc("q8dwconv/mp8x25-sse2.c"),
160                        build.cc("q8dwconv/mp8x27-sse2.c"),
161                        build.cc("q8dwconv/up8x9-sse2.c"),
162                        build.cc("q8gavgpool/mp8x7p7q-sse2.c"),
163                        build.cc("q8gavgpool/up8x7-sse2.c"),
164                        build.cc("q8gavgpool/up8xm-sse2.c"),
165                        build.cc("q8gemm/2x4c8-sse2.c"),
166                        build.cc("q8gemm/4x4c2-sse2.c"),
167                        build.cc("q8vadd/sse2.c"),
168                        build.cc("u8clamp/sse2.c"),
169                        build.cc("u8maxpool/16x9p8q-sse2.c"),
170                        build.cc("u8maxpool/sub16-sse2.c"),
171                        build.cc("u8rmax/sse2.c"),
172                        build.cc("x8zip/x2-sse2.c"),
173                        build.cc("x8zip/x3-sse2.c"),
174                        build.cc("x8zip/x4-sse2.c"),
175                        build.cc("x8zip/xm-sse2.c"),
176                    ]
177            build.static_library("qnnpack", qnnpytorch_pack_objects)
178
179    with build.options(
180        source_dir="test",
181        deps={
182            (
183                build,
184                build.deps.cpuinfo,
185                build.deps.clog,
186                build.deps.pthreadpool,
187                build.deps.FP16,
188                build.deps.googletest,
189            ): any,
190            "log": build.target.is_android,
191        },
192        extra_include_dirs=["src", "test"],
193    ):
194        build.unittest("hgemm-test", build.cxx("hgemm.cc"))
195        build.unittest("q8avgpool-test", build.cxx("q8avgpool.cc"))
196        build.unittest("q8conv-test", build.cxx("q8conv.cc"))
197        build.unittest("q8dwconv-test", build.cxx("q8dwconv.cc"))
198        build.unittest("q8gavgpool-test", build.cxx("q8gavgpool.cc"))
199        build.unittest("q8gemm-test", build.cxx("q8gemm.cc"))
200        build.unittest("q8vadd-test", build.cxx("q8vadd.cc"))
201        build.unittest("sconv-test", build.cxx("sconv.cc"))
202        build.unittest("sgemm-test", build.cxx("sgemm.cc"))
203        build.unittest("u8clamp-test", build.cxx("u8clamp.cc"))
204        build.unittest("u8lut32norm-test", build.cxx("u8lut32norm.cc"))
205        build.unittest("u8maxpool-test", build.cxx("u8maxpool.cc"))
206        build.unittest("u8rmax-test", build.cxx("u8rmax.cc"))
207        build.unittest("x8lut-test", build.cxx("x8lut.cc"))
208        build.unittest("x8zip-test", build.cxx("x8zip.cc"))
209
210        build.unittest("add-test", build.cxx("add.cc"))
211        build.unittest("average-pooling-test", build.cxx("average-pooling.cc"))
212        build.unittest("channel-shuffle-test", build.cxx("channel-shuffle.cc"))
213        build.unittest("clamp-test", build.cxx("clamp.cc"))
214        build.unittest("convolution-test", build.cxx("convolution.cc"))
215        build.unittest("deconvolution-test", build.cxx("deconvolution.cc"))
216        build.unittest("fully-connected-test", build.cxx("fully-connected.cc"))
217        build.unittest(
218            "global-average-pooling-test", build.cxx("global-average-pooling.cc")
219        )
220        build.unittest("leaky-relu-test", build.cxx("leaky-relu.cc"))
221        build.unittest("max-pooling-test", build.cxx("max-pooling.cc"))
222        build.unittest("sigmoid-test", build.cxx("sigmoid.cc"))
223        build.unittest("softargmax-test", build.cxx("softargmax.cc"))
224        build.unittest("tanh-test", build.cxx("tanh.cc"))
225        build.unittest("hardsigmoid-test", build.cxx("hardsigmoid.cc"))
226        build.unittest("hardswish-test", build.cxx("hardswish.cc"))
227        build.unittest(
228            "requantization-test",
229            [build.cxx("requantization.cc")] + requantization_objects,
230        )
231
232    benchmark_isa = None
233    if build.target.is_arm:
234        benchmark_isa = arm.neon
235    elif build.target.is_x86:
236        benchmark_isa = x86.sse4_1
237    with build.options(
238        source_dir="bench",
239        deps={
240            (
241                build,
242                build.deps.cpuinfo,
243                build.deps.clog,
244                build.deps.pthreadpool,
245                build.deps.FP16,
246                build.deps.googlebenchmark,
247            ): any,
248            "log": build.target.is_android,
249        },
250        isa=benchmark_isa,
251        extra_include_dirs="src",
252    ):
253        build.benchmark("add-bench", build.cxx("add.cc"))
254        build.benchmark("average-pooling-bench", build.cxx("average-pooling.cc"))
255        build.benchmark("channel-shuffle-bench", build.cxx("channel-shuffle.cc"))
256        build.benchmark("convolution-bench", build.cxx("convolution.cc"))
257        build.benchmark(
258            "global-average-pooling-bench", build.cxx("global-average-pooling.cc")
259        )
260        build.benchmark("max-pooling-bench", build.cxx("max-pooling.cc"))
261        build.benchmark("sigmoid-bench", build.cxx("sigmoid.cc"))
262        build.benchmark("softargmax-bench", build.cxx("softargmax.cc"))
263        build.benchmark("tanh-bench", build.cxx("tanh.cc"))
264        build.benchmark("hardsigmoid-bench", build.cxx("hardsigmoid.cc"))
265        build.benchmark("hardswish-bench", build.cxx("hardswish.cc"))
266
267        build.benchmark("q8gemm-bench", build.cxx("q8gemm.cc"))
268        build.benchmark("hgemm-bench", build.cxx("hgemm.cc"))
269        build.benchmark("sgemm-bench", build.cxx("sgemm.cc"))
270        build.benchmark(
271            "requantization-bench",
272            [build.cxx("requantization.cc")] + requantization_objects,
273        )
274
275    return build
276
277
278if __name__ == "__main__":
279    import sys
280
281    main(sys.argv[1:]).generate()
282