1#!/usr/bin/env python3 2# Copyright © 2020 Arm Ltd. All rights reserved. 3# Copyright © 2020 NXP and Contributors. All rights reserved. 4# SPDX-License-Identifier: MIT 5"""Python bindings for Arm NN 6 7PyArmNN is a python extension for Arm NN SDK providing an interface similar to Arm NN C++ API. 8""" 9__version__ = None 10__arm_ml_version__ = None 11 12import logging 13import os 14import sys 15import subprocess 16from functools import lru_cache 17from pathlib import Path 18from itertools import chain 19 20from setuptools import setup 21from distutils.core import Extension 22from setuptools.command.build_py import build_py 23from setuptools.command.build_ext import build_ext 24 25logger = logging.Logger(__name__) 26 27DOCLINES = __doc__.split("\n") 28LIB_ENV_NAME = "ARMNN_LIB" 29INCLUDE_ENV_NAME = "ARMNN_INCLUDE" 30 31 32def check_armnn_version(*args): 33 pass 34 35__current_dir = os.path.dirname(os.path.realpath(__file__)) 36 37exec(open(os.path.join(__current_dir, 'src', 'pyarmnn', '_version.py'), encoding="utf-8").read()) 38 39 40class ExtensionPriorityBuilder(build_py): 41 """Runs extension builder before other stages. Otherwise generated files are not included to the distribution. 42 """ 43 44 def run(self): 45 self.run_command('build_ext') 46 return super().run() 47 48 49class ArmnnVersionCheckerExtBuilder(build_ext): 50 """Builds an extension (i.e. wrapper). Additionally checks for version. 51 """ 52 53 def __init__(self, dist): 54 super().__init__(dist) 55 self.failed_ext = [] 56 57 def build_extension(self, ext): 58 if ext.optional: 59 try: 60 super().build_extension(ext) 61 except Exception as err: 62 self.failed_ext.append(ext) 63 logger.warning('Failed to build extension %s. \n %s', ext.name, str(err)) 64 else: 65 super().build_extension(ext) 66 if ext.name == 'pyarmnn._generated._pyarmnn_version': 67 sys.path.append(os.path.abspath(os.path.join(self.build_lib, str(Path(ext._file_name).parent)))) 68 from _pyarmnn_version import GetVersion 69 check_armnn_version(GetVersion(), __arm_ml_version__) 70 71 def copy_extensions_to_source(self): 72 73 for ext in self.failed_ext: 74 self.extensions.remove(ext) 75 super().copy_extensions_to_source() 76 77 78def linux_gcc_name(): 79 """Returns the name of the `gcc` compiler. Might happen that we are cross-compiling and the 80 compiler has a longer name. 81 82 Args: 83 None 84 85 Returns: 86 str: Name of the `gcc` compiler or None 87 """ 88 cc_env = os.getenv('CC') 89 if cc_env is not None: 90 if subprocess.Popen([cc_env, "--version"], stdout=subprocess.DEVNULL): 91 return cc_env 92 return "gcc" if subprocess.Popen(["gcc", "--version"], stdout=subprocess.DEVNULL) else None 93 94 95def linux_gcc_lib_search(gcc_compiler_name: str = linux_gcc_name()): 96 """Calls the `gcc` to get linker default system paths. 97 98 Args: 99 gcc_compiler_name(str): Name of the GCC compiler 100 101 Returns: 102 list: A list of paths. 103 104 Raises: 105 RuntimeError: If unable to find GCC. 106 """ 107 if gcc_compiler_name is None: 108 raise RuntimeError("Unable to find gcc compiler") 109 cmd1 = subprocess.Popen([gcc_compiler_name, "--print-search-dirs"], stdout=subprocess.PIPE) 110 cmd2 = subprocess.Popen(["grep", "libraries"], stdin=cmd1.stdout, 111 stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) 112 cmd1.stdout.close() 113 out, _ = cmd2.communicate() 114 out = out.decode("utf-8").split('=') 115 return tuple(out[1].split(':')) if len(out) > 0 else None 116 117 118def find_includes(armnn_include_env: str = INCLUDE_ENV_NAME): 119 """Searches for ArmNN includes. 120 121 Args: 122 armnn_include_env(str): Environmental variable to use as path. 123 124 Returns: 125 list: A list of paths to include. 126 """ 127 128 # split multiple paths 129 global armnn_include_path 130 armnn_include_path_raw = os.getenv(armnn_include_env) 131 if not armnn_include_path_raw == None: 132 armnn_include_path = armnn_include_path_raw.split(":") 133 134 # validate input paths 135 armnn_include_path_result = [] 136 for path in armnn_include_path: 137 if path is not None and os.path.exists(path): 138 armnn_include_path_result = armnn_include_path_result + [path] 139 140 141 # if none exist revert to default 142 if len(armnn_include_path_result) == 0: 143 armnn_include_path_result = ['/usr/local/include', '/usr/include'] 144 return armnn_include_path_result 145 146 147 148@lru_cache(maxsize=1) 149def find_armnn(lib_name: str, 150 optional: bool = False, 151 armnn_libs_env: str = LIB_ENV_NAME, 152 default_lib_search: tuple = linux_gcc_lib_search()): 153 """Searches for ArmNN installation on the local machine. 154 155 Args: 156 lib_name(str): Lib name to find. 157 optional(bool): Do not fail if optional. Default is False - fail if library was not found. 158 armnn_libs_env(str): Custom environment variable pointing to ArmNN libraries location, default is 'ARMNN_LIBS' 159 default_lib_search(tuple): list of paths to search for ArmNN if not found within path provided by 'ARMNN_LIBS' 160 env variable 161 Returns: 162 tuple: Contains name of the armnn libs, paths to the libs. 163 164 Raises: 165 RuntimeError: If armnn libs are not found. 166 """ 167 armnn_lib_path = os.getenv(armnn_libs_env) 168 lib_search = [armnn_lib_path] if armnn_lib_path is not None else default_lib_search 169 armnn_libs = dict(map(lambda path: (':{}'.format(path.name), path), 170 chain.from_iterable(map(lambda lib_path: Path(lib_path).glob(lib_name), 171 lib_search)))) 172 if not optional and len(armnn_libs) == 0: 173 raise RuntimeError("""ArmNN library {} was not found in {}. Please install ArmNN to one of the standard 174 locations or set correct ARMNN_INCLUDE and ARMNN_LIB env variables.""".format(lib_name, 175 lib_search)) 176 if optional and len(armnn_libs) == 0: 177 logger.warning("""Optional parser library %s was not found in %s and will not be installed.""", lib_name, 178 lib_search) 179 180 # gives back tuple of names of the libs, set of unique libs locations and includes. 181 return list(armnn_libs.keys()), list(set( 182 map(lambda path: str(path.absolute().parent), armnn_libs.values()))) 183 184 185class LazyArmnnFinderExtension(Extension): 186 """Derived from `Extension` this class adds ArmNN libraries search on the user's machine. 187 SWIG options and compilation flags are updated with relevant ArmNN libraries files locations (-L) and headers (-I). 188 189 Search for ArmNN is executed only when attributes include_dirs, library_dirs, runtime_library_dirs, libraries or 190 swig_opts are queried. 191 192 """ 193 194 def __init__(self, name, sources, armnn_libs, include_dirs=None, define_macros=None, undef_macros=None, 195 library_dirs=None, 196 libraries=None, runtime_library_dirs=None, extra_objects=None, extra_compile_args=None, 197 extra_link_args=None, export_symbols=None, language=None, optional=None, **kw): 198 self._include_dirs = None 199 self._library_dirs = None 200 self._runtime_library_dirs = None 201 self._armnn_libs = armnn_libs 202 self._optional = False if optional is None else optional 203 204 super().__init__(name=name, sources=sources, include_dirs=include_dirs, define_macros=define_macros, 205 undef_macros=undef_macros, library_dirs=library_dirs, libraries=libraries, 206 runtime_library_dirs=runtime_library_dirs, extra_objects=extra_objects, 207 extra_compile_args=extra_compile_args, extra_link_args=extra_link_args, 208 export_symbols=export_symbols, language=language, optional=optional, **kw) 209 210 @property 211 def include_dirs(self): 212 return self._include_dirs + find_includes() 213 214 @include_dirs.setter 215 def include_dirs(self, include_dirs): 216 self._include_dirs = include_dirs 217 218 @property 219 def library_dirs(self): 220 library_dirs = self._library_dirs 221 for lib in self._armnn_libs: 222 _, lib_path = find_armnn(lib, self._optional) 223 library_dirs = library_dirs + lib_path 224 225 return library_dirs 226 227 @library_dirs.setter 228 def library_dirs(self, library_dirs): 229 self._library_dirs = library_dirs 230 231 @property 232 def runtime_library_dirs(self): 233 library_dirs = self._runtime_library_dirs 234 for lib in self._armnn_libs: 235 _, lib_path = find_armnn(lib, self._optional) 236 library_dirs = library_dirs + lib_path 237 238 return library_dirs 239 240 @runtime_library_dirs.setter 241 def runtime_library_dirs(self, runtime_library_dirs): 242 self._runtime_library_dirs = runtime_library_dirs 243 244 @property 245 def libraries(self): 246 libraries = self._libraries 247 for lib in self._armnn_libs: 248 lib_names, _ = find_armnn(lib, self._optional) 249 libraries = libraries + lib_names 250 251 return libraries 252 253 @libraries.setter 254 def libraries(self, libraries): 255 self._libraries = libraries 256 257 def __eq__(self, other): 258 return self.__class__ == other.__class__ and self.name == other.name 259 260 def __ne__(self, other): 261 return not self.__eq__(other) 262 263 def __hash__(self): 264 return self.name.__hash__() 265 266 267if __name__ == '__main__': 268 # mandatory extensions 269 pyarmnn_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn', 270 sources=['src/pyarmnn/_generated/armnn_wrap.cpp'], 271 extra_compile_args=['-std=c++14'], 272 language='c++', 273 armnn_libs=['libarmnn.so'], 274 optional=False 275 ) 276 pyarmnn_v_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn_version', 277 sources=['src/pyarmnn/_generated/armnn_version_wrap.cpp'], 278 extra_compile_args=['-std=c++14'], 279 language='c++', 280 armnn_libs=['libarmnn.so'], 281 optional=False 282 ) 283 extensions_to_build = [pyarmnn_v_module, pyarmnn_module] 284 285 286 # optional extensions 287 def add_parsers_ext(name: str, ext_list: list): 288 pyarmnn_optional_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn_{}'.format(name.lower()), 289 sources=['src/pyarmnn/_generated/armnn_{}_wrap.cpp'.format( 290 name.lower())], 291 extra_compile_args=['-std=c++14'], 292 language='c++', 293 armnn_libs=['libarmnn.so', 'libarmnn{}.so'.format(name)], 294 optional=True 295 ) 296 ext_list.append(pyarmnn_optional_module) 297 298 299 add_parsers_ext('OnnxParser', extensions_to_build) 300 add_parsers_ext('TfLiteParser', extensions_to_build) 301 add_parsers_ext('Deserializer', extensions_to_build) 302 303 setup( 304 name='pyarmnn', 305 version=__version__, 306 author='Arm Ltd, NXP Semiconductors', 307 author_email='support@linaro.org', 308 description=DOCLINES[0], 309 long_description="\n".join(DOCLINES[2:]), 310 url='https://mlplatform.org/', 311 license='MIT', 312 keywords='armnn neural network machine learning', 313 classifiers=[ 314 'Development Status :: 3 - Alpha', 315 'Intended Audience :: Developers', 316 'Intended Audience :: Education', 317 'Intended Audience :: Science/Research', 318 'License :: OSI Approved :: MIT License', 319 'Programming Language :: Python :: 3', 320 'Programming Language :: Python :: 3 :: Only', 321 'Programming Language :: Python :: 3.6', 322 'Programming Language :: Python :: 3.7', 323 'Programming Language :: Python :: 3.8', 324 'Topic :: Scientific/Engineering', 325 'Topic :: Scientific/Engineering :: Artificial Intelligence', 326 'Topic :: Software Development', 327 'Topic :: Software Development :: Libraries', 328 'Topic :: Software Development :: Libraries :: Python Modules', 329 ], 330 package_dir={'': 'src'}, 331 packages=[ 332 'pyarmnn', 333 'pyarmnn._generated', 334 'pyarmnn._quantization', 335 'pyarmnn._tensor', 336 'pyarmnn._utilities' 337 ], 338 data_files=[('', ['LICENSE'])], 339 python_requires='>=3.5', 340 install_requires=['numpy'], 341 cmdclass={ 342 'build_py': ExtensionPriorityBuilder, 343 'build_ext': ArmnnVersionCheckerExtBuilder 344 }, 345 ext_modules=extensions_to_build 346 ) 347