• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" configure cropper tool """
16
17from functools import lru_cache
18import glob
19import json
20import os
21import queue
22import re
23import shlex
24import subprocess
25
26from mindspore import log as logger
27
28DEFINE_STR = "-DENABLE_ANDROID -DENABLE_ARM -DENABLE_ARM64 -DENABLE_NEON -DNO_DLIB -DUSE_ANDROID_LOG -DANDROID"
29
30ASSOCIATIONS_FILENAME = 'associations.txt'
31DEPENDENCIES_FILENAME = 'dependencies.txt'
32ERRORS_FILENAME = 'debug.txt'
33OUTPUT_LOCATION = "mindspore/lite/tools/dataset/cropper"
34
35# needed for gcc command for include directories
36MANUAL_HEADERS = [
37    ".",
38    "mindspore",
39    "mindspore/ccsrc",
40    "mindspore/ccsrc/minddata/dataset",
41    "mindspore/ccsrc/minddata/dataset/kernels/image",
42    "mindspore/core",
43    "mindspore/lite",
44]
45
46# To stop gcc command once reaching these external headers
47# (not all of them may be used now in MindData lite)
48EXTERNAL_DEPS = [
49    "graphengine/inc/external",
50    "akg/third_party/fwkacllib/inc",
51    "third_party",
52    "third_party/securec/include",
53    "build/mindspore/_deps/sqlite-src",
54    "build/mindspore/_deps/pybind11-src/include",
55    "build/mindspore/_deps/tinyxml2-src",
56    "build/mindspore/_deps/jpeg_turbo-src",
57    "build/mindspore/_deps/jpeg_turbo-src/_build",
58    "build/mindspore/_deps/icu4c-src/icu4c/source/i18n",
59    "build/mindspore/_deps/icu4c-src/icu4c/source/common",
60    "mindspore/lite/build/_deps/tinyxml2-src",
61    "mindspore/lite/build/_deps/jpeg_turbo-src",
62    "mindspore/lite/build/_deps/jpeg_turbo-src/_build",
63    "mindspore/lite/build/_deps/nlohmann_json-src",
64]
65
66# API files which the corresponding objects and all objects for their dependencies must always be included.
67ESSENTIAL_FILES_1 = [
68    "api/data_helper.cc",
69    "api/datasets.cc",
70    "api/execute.cc",
71    "api/iterator.cc",
72]
73
74# API files which the corresponding objects must always be included.
75# (corresponding IR files will be included according to user ops)
76ESSENTIAL_FILES_2 = [
77    "api/text.cc",
78    "api/transforms.cc",
79    "api/samplers.cc",
80    "api/vision.cc",
81]
82
83DATASET_PATH = "mindspore/ccsrc/minddata/dataset"
84
85OPS_DIRS = [
86    "engine/ir/datasetops",
87    "engine/ir/datasetops/source",
88    "engine/ir/datasetops/source/samplers",
89    "kernels/ir/vision",
90    "kernels/ir/data",
91    "text/ir/kernels",
92]
93
94
95def extract_classname_samplers(header_content):
96    """
97    Use regex to find class names in header files of samplers
98
99    :param header_content: string containing header of a sampler IR file
100    :return: list of sampler classes found
101    """
102    return re.findall(r"(?<=class )[\w\d_]+(?=Obj : )", header_content)
103
104
105def extract_classname_source_node(header_content):
106    """
107    Use regex to find class names in header files of source nodes
108
109    :param header_content: string containing header of a source node IR file
110    :return: list of source node classes found
111    """
112    return re.findall(r"(?<=class )[\w\d_]+(?=Node : )", header_content)
113
114
115def extract_classname_nonsource_node(header_content):
116    """
117    Use regex to find class names in header files of non-source nodes
118
119    :param header_content: string containing header of a non-source IR file
120    :return: list of non-source node classes found
121    """
122    return re.findall(r"(?<=class )[\w\d_]+(?=Node : )", header_content)
123
124
125def extract_classname_vision(header_content):
126    """
127    Use regex to find class names in header files of vision ops
128
129    :param header_content: string containing header of a vision op IR file
130    :return: list of vision ops found
131    """
132    return re.findall(r"(?<=class )[\w\d_]+(?=Operation : )", header_content)
133
134
135def extract_classname_data(header_content):
136    """
137    Use regex to find class names in header files of data ops
138
139    :param header_content: string containing header of a data op IR file
140    :return: list of data ops found
141    """
142    return re.findall(r"(?<=class )[\w\d_]+(?=Operation : )", header_content)
143
144
145def extract_classname_text(header_content):
146    """
147    Use regex to find class names in header files of text ops
148
149    :param header_content: string containing header of a text op IR file
150    :return: list of text ops found
151    """
152    return re.findall(r"(?<=class )[\w\d_]+(?=Operation : )", header_content)
153
154
155# For each op type (directory) store the corresponding function which extracts op name
156registered_functions = {
157    os.path.join(DATASET_PATH, 'engine/ir/datasetops/source/samplers'): extract_classname_samplers,
158    os.path.join(DATASET_PATH, 'engine/ir/datasetops/source'): extract_classname_source_node,
159    os.path.join(DATASET_PATH, 'engine/ir/datasetops'): extract_classname_nonsource_node,
160    os.path.join(DATASET_PATH, 'kernels/ir/vision'): extract_classname_vision,
161    os.path.join(DATASET_PATH, 'kernels/ir/data'): extract_classname_data,
162    os.path.join(DATASET_PATH, 'text/ir/kernels'): extract_classname_text,
163}
164
165
166def get_headers():
167    """
168    Get the headers flag: "-Ixx/yy -Ixx/zz ..."
169
170    :return: a string to be passed to compiler
171    """
172    headers_paths = MANUAL_HEADERS + EXTERNAL_DEPS
173
174    output = "-I{}/".format("/ -I".join(headers_paths))
175
176    return output
177
178
179@lru_cache(maxsize=1024)
180def get_dependencies_of_file(headers_flag, filename):
181    """
182    Create dependency list for a file (file0.cc):
183    file0.cc.o: file1.h, file2.h, ...
184
185    :param headers_flag: string containing headers include paths with -I prepended to them.
186    :param filename: a string containing path of a file.
187    :return: a list of file names [file0.cc, file1.h, file2.h, file3.h] and error string
188    """
189    command = 'gcc -MM -MG {0} {1} {2}'.format(filename, DEFINE_STR, headers_flag)
190    command_split = shlex.split(command)
191    stdout, stderr = subprocess.Popen(command_split, shell=False, stdout=subprocess.PIPE,
192                                      stderr=subprocess.PIPE).communicate()
193    deps = re.split(r'[\s\\]+', stdout.decode('utf-8').strip(), flags=re.MULTILINE)[1:]
194
195    return deps, stderr.decode('utf-8')
196
197
198def needs_processing(dep_cc, processed_cc, queue_cc_set):
199    """
200    Determine if a file's dependencies need to be processed.
201
202    :param dep_cc: the candidate file to be processed by gcc
203    :param processed_cc: set of files that have been already processed.
204    :param queue_cc_set: files currently in the queue (to be processed)
205    :return: boolean, whether the file should be further processed by gcc.
206    """
207    # don't add the file to the queue if already processed
208    if dep_cc in processed_cc:
209        return False
210    # don't add the file to the queue if it is already there
211    if dep_cc in queue_cc_set:
212        return False
213    # if file doesn't exist, don't process as it will cause error (may happen for cache)
214    if not os.path.isfile(dep_cc):
215        return False
216    return True
217
218
219def build_source_file_path(dep_h):
220    """
221    Given the path to a header file, find the path for the associated source file.
222    - if an external dependency, return "EXTERNAL"
223    - if not found, keep the header file's path
224
225    :param dep_h: a string containing path to the header file
226    :return: dep_cc: a string containing path to the source file
227    """
228    for x in EXTERNAL_DEPS:
229        if x in dep_h:
230            dep_cc = "EXTERNAL"
231            return dep_cc
232    if 'include/api/types.h' in dep_h:
233        dep_cc = "mindspore/ccsrc/cxx_api/types.cc"
234        return dep_cc
235    dep_cc = dep_h.replace('.hpp', '.cc').replace('.h', '.cc')
236    if not os.path.isfile(dep_cc):
237        dep_cc = dep_h
238    return dep_cc
239
240
241def get_all_dependencies_of_file(headers_flag, filename):
242    """
243    Create dependency list for a file (incl. all source files needed).
244
245    :param headers_flag: string containing headers include paths with -I prepended to them.
246    :param filename: a string containing path of a file.
247    :return: all dependencies of that file and the error string
248    """
249    errors = []
250    # a queue to process files
251    queue_cc = queue.SimpleQueue()
252    # a set of items that have ever been in queue_cc (faster access time)
253    queue_cc_set = set()
254    # store processed files
255    processed_cc = set()
256
257    # add the source file to the queue
258    queue_cc.put(filename)
259    queue_cc_set.add(filename)
260
261    while not queue_cc.empty():
262        # process the first item in the queue
263        curr_cc = queue_cc.get()
264        deps, error = get_dependencies_of_file(headers_flag, curr_cc)
265        errors.append(error)
266        processed_cc.add(curr_cc)
267        # prepare its dependencies for processing
268        for dep_h in deps:
269            dep_cc = build_source_file_path(dep_h)
270            # ignore if marked as an external dependency
271            if dep_cc == "EXTERNAL":
272                processed_cc.add(dep_h)
273                continue
274            # add to queue if needs processing
275            if needs_processing(dep_cc, processed_cc, queue_cc_set):
276                queue_cc.put(dep_cc)
277                queue_cc_set.add(dep_cc)
278    logger.debug('file: {} | deps: {}'.format(os.path.basename(filename), len(processed_cc)))
279
280    return list(processed_cc), "".join(errors)
281
282
283def get_deps_essential(headers_flag):
284    """
285    Return dependencies required for any run (essential).
286
287    :param headers_flag: string containing headers include paths with -I prepended to them.
288    :return: a list of essential files, and the error string
289    """
290    essentials = []
291    errors = []
292
293    # find dependencies for ESSENTIAL_FILES_1 as we need them too.
294    for filename in [os.path.join(DATASET_PATH, x) for x in ESSENTIAL_FILES_1]:
295        deps, err = get_all_dependencies_of_file(headers_flag, filename)
296        errors.append(err)
297        essentials.extend(deps)
298        essentials.append(filename)
299    # we only need ESSENTIAL_FILES_2 themselves (IR files are split)
300    for filename in [os.path.join(DATASET_PATH, x) for x in ESSENTIAL_FILES_2]:
301        essentials.append(filename)
302    essentials = list(set(essentials))
303
304    return essentials, "".join(errors)
305
306
307def get_deps_non_essential(headers_flag):
308    """
309    Find the entry points (IR Level) for each op and write them in associations dict.
310    Starting from these entry point, recursively find the dependencies for each file and write in a dict.
311
312    :param headers_flag: string containing headers include paths with -I prepended to them.
313    :return: dependencies dict, associations dict, the error string
314    """
315    dependencies = dict()  # what files each file imports
316    associations = dict()  # what file each op is defined in (IR level)
317    errors = []
318    for dirname in [os.path.join(DATASET_PATH, x) for x in OPS_DIRS]:
319        # Get the proper regex function for this directory
320        if dirname not in registered_functions:
321            raise ValueError("Directory has no registered regex function:", dirname)
322        extract_classname = registered_functions[dirname]
323        # iterate over source files in the directory
324        for src_filename in glob.glob("{}/*.cc".format(dirname)):
325            # get the dependencies of source file
326            deps, err = get_all_dependencies_of_file(headers_flag, src_filename)
327            dependencies[src_filename] = deps
328            errors.append(err)
329            # locate the corresponding header file and read it
330            header_filename = src_filename.replace('.cc', '.h')
331            if not os.path.isfile(header_filename):
332                raise ValueError("Header file doesn't exist!")
333            with open(header_filename, 'r') as f:
334                content = f.read().strip()
335            # extract ops from header file
336            ops = extract_classname(content)
337            # add the op to associations table
338            for raw_op in ops:
339                op = raw_op.lower().replace('_', '')
340                associations[op] = src_filename
341    return dependencies, associations, "".join(errors)
342
343
344def main():
345    """
346    Configure the cropper tool by creating  associations.txt and dependencies.txt
347    """
348    errors = ""
349    dependencies = {}
350
351    # convert to a single string with '-I' prepended to each dir name
352    headers_flag = get_headers()
353
354    # get dependencies for essential files
355    all_deps, err = get_deps_essential(headers_flag)
356    dependencies['ESSENTIAL'] = all_deps
357    errors += err
358    logger.debug('len(ESSENTIAL): {}'.format(len(dependencies['ESSENTIAL'])))
359
360    # get dependencies for other files (non-essentials)
361    other_dependencies, all_associations, err = get_deps_non_essential(headers_flag)
362    dependencies.update(other_dependencies)
363    errors += err
364
365    with os.fdopen(os.open(os.path.join(OUTPUT_LOCATION, DEPENDENCIES_FILENAME), os.O_WRONLY | os.O_CREAT, 0o660),
366                   "w+") as f:
367        json.dump(dependencies, f)
368
369    with os.fdopen(os.open(os.path.join(OUTPUT_LOCATION, ASSOCIATIONS_FILENAME), os.O_WRONLY | os.O_CREAT, 0o660),
370                   "w+") as f:
371        json.dump(all_associations, f)
372
373    with os.fdopen(os.open(os.path.join(OUTPUT_LOCATION, ERRORS_FILENAME), os.O_WRONLY | os.O_CREAT, 0o660), "w+") as f:
374        f.write(errors)
375
376
377if __name__ == "__main__":
378
379    logger.info('STARTING: cropper_configure.py ')
380
381    original_path = os.getcwd()
382    script_path = os.path.dirname(os.path.abspath(__file__))
383
384    try:
385        # change directory to mindspore directory
386        os.chdir(os.path.join(script_path, "../../../../.."))
387        main()
388    except (OSError, IndexError, KeyError):
389        logger.error('FAILED: cropper_configure.py!')
390        raise
391    else:
392        logger.info('SUCCESS: cropper_configure.py ')
393    finally:
394        os.chdir(original_path)
395