• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright © 2020 Arm Ltd. All rights reserved.
3# Copyright © 2020 NXP and Contributors. All rights reserved.
4# SPDX-License-Identifier: MIT
5"""Python bindings for Arm NN
6
7PyArmNN is a python extension for Arm NN SDK providing an interface similar to Arm NN C++ API.
8"""
9__version__ = None
10__arm_ml_version__ = None
11
12import logging
13import os
14import sys
15import subprocess
16from functools import lru_cache
17from pathlib import Path
18from itertools import chain
19
20from setuptools import setup
21from distutils.core import Extension
22from setuptools.command.build_py import build_py
23from setuptools.command.build_ext import build_ext
24
25logger = logging.Logger(__name__)
26
27DOCLINES = __doc__.split("\n")
28LIB_ENV_NAME = "ARMNN_LIB"
29INCLUDE_ENV_NAME = "ARMNN_INCLUDE"
30
31
32def check_armnn_version(*args):
33    pass
34
35__current_dir = os.path.dirname(os.path.realpath(__file__))
36
37exec(open(os.path.join(__current_dir, 'src', 'pyarmnn', '_version.py'), encoding="utf-8").read())
38
39
40class ExtensionPriorityBuilder(build_py):
41    """Runs extension builder before other stages. Otherwise generated files are not included to the distribution.
42    """
43
44    def run(self):
45        self.run_command('build_ext')
46        return super().run()
47
48
49class ArmnnVersionCheckerExtBuilder(build_ext):
50    """Builds an extension (i.e. wrapper). Additionally checks for version.
51    """
52
53    def __init__(self, dist):
54        super().__init__(dist)
55        self.failed_ext = []
56
57    def build_extension(self, ext):
58        if ext.optional:
59            try:
60                super().build_extension(ext)
61            except Exception as err:
62                self.failed_ext.append(ext)
63                logger.warning('Failed to build extension %s. \n %s', ext.name, str(err))
64        else:
65            super().build_extension(ext)
66            if ext.name == 'pyarmnn._generated._pyarmnn_version':
67                sys.path.append(os.path.abspath(os.path.join(self.build_lib, str(Path(ext._file_name).parent))))
68                from _pyarmnn_version import GetVersion
69                check_armnn_version(GetVersion(), __arm_ml_version__)
70
71    def copy_extensions_to_source(self):
72
73        for ext in self.failed_ext:
74            self.extensions.remove(ext)
75        super().copy_extensions_to_source()
76
77
78def linux_gcc_name():
79    """Returns the name of the `gcc` compiler. Might happen that we are cross-compiling and the
80    compiler has a longer name.
81
82    Args:
83        None
84
85    Returns:
86        str: Name of the `gcc` compiler or None
87    """
88    cc_env = os.getenv('CC')
89    if cc_env is not None:
90        if subprocess.Popen([cc_env, "--version"], stdout=subprocess.DEVNULL):
91            return cc_env
92    return "gcc" if subprocess.Popen(["gcc", "--version"], stdout=subprocess.DEVNULL) else None
93
94
95def linux_gcc_lib_search(gcc_compiler_name: str = linux_gcc_name()):
96    """Calls the `gcc` to get linker default system paths.
97
98    Args:
99        gcc_compiler_name(str): Name of the GCC compiler
100
101    Returns:
102        list: A list of paths.
103
104    Raises:
105        RuntimeError: If unable to find GCC.
106    """
107    if gcc_compiler_name is None:
108        raise RuntimeError("Unable to find gcc compiler")
109    cmd1 = subprocess.Popen([gcc_compiler_name, "--print-search-dirs"], stdout=subprocess.PIPE)
110    cmd2 = subprocess.Popen(["grep", "libraries"], stdin=cmd1.stdout,
111                         stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
112    cmd1.stdout.close()
113    out, _ = cmd2.communicate()
114    out = out.decode("utf-8").split('=')
115    return tuple(out[1].split(':')) if len(out) > 0 else None
116
117
118def find_includes(armnn_include_env: str = INCLUDE_ENV_NAME):
119    """Searches for ArmNN includes.
120
121    Args:
122        armnn_include_env(str): Environmental variable to use as path.
123
124    Returns:
125        list: A list of paths to include.
126    """
127
128    # split multiple paths
129    global armnn_include_path
130    armnn_include_path_raw = os.getenv(armnn_include_env)
131    if not armnn_include_path_raw == None:
132        armnn_include_path = armnn_include_path_raw.split(":")
133
134    # validate input paths
135    armnn_include_path_result = []
136    for path in armnn_include_path:
137        if path is not None and os.path.exists(path):
138            armnn_include_path_result = armnn_include_path_result + [path]
139
140
141    # if none exist revert to default
142    if len(armnn_include_path_result) == 0:
143        armnn_include_path_result = ['/usr/local/include', '/usr/include']
144    return armnn_include_path_result
145
146
147
148@lru_cache(maxsize=1)
149def find_armnn(lib_name: str,
150               optional: bool = False,
151               armnn_libs_env: str = LIB_ENV_NAME,
152               default_lib_search: tuple = linux_gcc_lib_search()):
153    """Searches for ArmNN installation on the local machine.
154
155    Args:
156        lib_name(str): Lib name to find.
157        optional(bool): Do not fail if optional. Default is False - fail if library was not found.
158        armnn_libs_env(str): Custom environment variable pointing to ArmNN libraries location, default is 'ARMNN_LIBS'
159        default_lib_search(tuple): list of paths to search for ArmNN if not found within path provided by 'ARMNN_LIBS'
160                            env variable
161    Returns:
162        tuple: Contains name of the armnn libs, paths to the libs.
163
164    Raises:
165        RuntimeError: If armnn libs are not found.
166    """
167    armnn_lib_path = os.getenv(armnn_libs_env)
168    lib_search = [armnn_lib_path] if armnn_lib_path is not None else default_lib_search
169    armnn_libs = dict(map(lambda path: (':{}'.format(path.name), path),
170                          chain.from_iterable(map(lambda lib_path: Path(lib_path).glob(lib_name),
171                                                  lib_search))))
172    if not optional and len(armnn_libs) == 0:
173        raise RuntimeError("""ArmNN library {} was not found in {}. Please install ArmNN to one of the standard
174                           locations or set correct ARMNN_INCLUDE and ARMNN_LIB env variables.""".format(lib_name,
175                                                                                                         lib_search))
176    if optional and len(armnn_libs) == 0:
177        logger.warning("""Optional parser library %s was not found in %s and will not be installed.""", lib_name,
178                                                                                                        lib_search)
179
180    # gives back tuple of names of the libs, set of unique libs locations and includes.
181    return list(armnn_libs.keys()), list(set(
182        map(lambda path: str(path.absolute().parent), armnn_libs.values())))
183
184
185class LazyArmnnFinderExtension(Extension):
186    """Derived from `Extension` this class adds ArmNN libraries search on the user's machine.
187    SWIG options and compilation flags are updated with relevant ArmNN libraries files locations (-L) and headers (-I).
188
189    Search for ArmNN is executed only when attributes include_dirs, library_dirs, runtime_library_dirs, libraries or
190    swig_opts are queried.
191
192    """
193
194    def __init__(self, name, sources, armnn_libs, include_dirs=None, define_macros=None, undef_macros=None,
195                 library_dirs=None,
196                 libraries=None, runtime_library_dirs=None, extra_objects=None, extra_compile_args=None,
197                 extra_link_args=None, export_symbols=None, language=None, optional=None, **kw):
198        self._include_dirs = None
199        self._library_dirs = None
200        self._runtime_library_dirs = None
201        self._armnn_libs = armnn_libs
202        self._optional = False if optional is None else optional
203
204        super().__init__(name=name, sources=sources, include_dirs=include_dirs, define_macros=define_macros,
205                         undef_macros=undef_macros, library_dirs=library_dirs, libraries=libraries,
206                         runtime_library_dirs=runtime_library_dirs, extra_objects=extra_objects,
207                         extra_compile_args=extra_compile_args, extra_link_args=extra_link_args,
208                         export_symbols=export_symbols, language=language, optional=optional, **kw)
209
210    @property
211    def include_dirs(self):
212        return self._include_dirs + find_includes()
213
214    @include_dirs.setter
215    def include_dirs(self, include_dirs):
216        self._include_dirs = include_dirs
217
218    @property
219    def library_dirs(self):
220        library_dirs = self._library_dirs
221        for lib in self._armnn_libs:
222            _, lib_path = find_armnn(lib, self._optional)
223            library_dirs = library_dirs + lib_path
224
225        return library_dirs
226
227    @library_dirs.setter
228    def library_dirs(self, library_dirs):
229        self._library_dirs = library_dirs
230
231    @property
232    def runtime_library_dirs(self):
233        library_dirs = self._runtime_library_dirs
234        for lib in self._armnn_libs:
235            _, lib_path = find_armnn(lib, self._optional)
236            library_dirs = library_dirs + lib_path
237
238        return library_dirs
239
240    @runtime_library_dirs.setter
241    def runtime_library_dirs(self, runtime_library_dirs):
242        self._runtime_library_dirs = runtime_library_dirs
243
244    @property
245    def libraries(self):
246        libraries = self._libraries
247        for lib in self._armnn_libs:
248            lib_names, _ = find_armnn(lib, self._optional)
249            libraries = libraries + lib_names
250
251        return libraries
252
253    @libraries.setter
254    def libraries(self, libraries):
255        self._libraries = libraries
256
257    def __eq__(self, other):
258        return self.__class__ == other.__class__ and self.name == other.name
259
260    def __ne__(self, other):
261        return not self.__eq__(other)
262
263    def __hash__(self):
264        return self.name.__hash__()
265
266
267if __name__ == '__main__':
268    # mandatory extensions
269    pyarmnn_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn',
270                                              sources=['src/pyarmnn/_generated/armnn_wrap.cpp'],
271                                              extra_compile_args=['-std=c++14'],
272                                              language='c++',
273                                              armnn_libs=['libarmnn.so'],
274                                              optional=False
275                                              )
276    pyarmnn_v_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn_version',
277                                                sources=['src/pyarmnn/_generated/armnn_version_wrap.cpp'],
278                                                extra_compile_args=['-std=c++14'],
279                                                language='c++',
280                                                armnn_libs=['libarmnn.so'],
281                                                optional=False
282                                                )
283    extensions_to_build = [pyarmnn_v_module, pyarmnn_module]
284
285
286    # optional extensions
287    def add_parsers_ext(name: str, ext_list: list):
288        pyarmnn_optional_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn_{}'.format(name.lower()),
289                                                           sources=['src/pyarmnn/_generated/armnn_{}_wrap.cpp'.format(
290                                                               name.lower())],
291                                                           extra_compile_args=['-std=c++14'],
292                                                           language='c++',
293                                                           armnn_libs=['libarmnn.so', 'libarmnn{}.so'.format(name)],
294                                                           optional=True
295                                                           )
296        ext_list.append(pyarmnn_optional_module)
297
298
299    add_parsers_ext('OnnxParser', extensions_to_build)
300    add_parsers_ext('TfLiteParser', extensions_to_build)
301    add_parsers_ext('Deserializer', extensions_to_build)
302
303    setup(
304        name='pyarmnn',
305        version=__version__,
306        author='Arm Ltd, NXP Semiconductors',
307        author_email='support@linaro.org',
308        description=DOCLINES[0],
309        long_description="\n".join(DOCLINES[2:]),
310        url='https://mlplatform.org/',
311        license='MIT',
312        keywords='armnn neural network machine learning',
313        classifiers=[
314            'Development Status :: 3 - Alpha',
315            'Intended Audience :: Developers',
316            'Intended Audience :: Education',
317            'Intended Audience :: Science/Research',
318            'License :: OSI Approved :: MIT License',
319            'Programming Language :: Python :: 3',
320            'Programming Language :: Python :: 3 :: Only',
321            'Programming Language :: Python :: 3.6',
322            'Programming Language :: Python :: 3.7',
323            'Programming Language :: Python :: 3.8',
324            'Topic :: Scientific/Engineering',
325            'Topic :: Scientific/Engineering :: Artificial Intelligence',
326            'Topic :: Software Development',
327            'Topic :: Software Development :: Libraries',
328            'Topic :: Software Development :: Libraries :: Python Modules',
329        ],
330        package_dir={'': 'src'},
331        packages=[
332            'pyarmnn',
333            'pyarmnn._generated',
334            'pyarmnn._quantization',
335            'pyarmnn._tensor',
336            'pyarmnn._utilities'
337        ],
338        data_files=[('', ['LICENSE'])],
339        python_requires='>=3.5',
340        install_requires=['numpy'],
341        cmdclass={
342            'build_py': ExtensionPriorityBuilder,
343            'build_ext': ArmnnVersionCheckerExtBuilder
344        },
345        ext_modules=extensions_to_build
346    )
347