• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from __future__ import annotations
2
3import os
4import subprocess
5
6from ..util.setting import CompilerType, TestType, TOOLS_FOLDER
7from ..util.utils import print_error, remove_file
8
9
10def get_oss_binary_folder(test_type: TestType) -> str:
11    assert test_type in {TestType.CPP, TestType.PY}
12    # TODO: change the way we get binary file -- binary may not in build/bin ?
13    return os.path.join(
14        get_pytorch_folder(), "build/bin" if test_type == TestType.CPP else "test"
15    )
16
17
18def get_oss_shared_library() -> list[str]:
19    lib_dir = os.path.join(get_pytorch_folder(), "build", "lib")
20    return [
21        os.path.join(lib_dir, lib)
22        for lib in os.listdir(lib_dir)
23        if lib.endswith(".dylib")
24    ]
25
26
27def get_oss_binary_file(test_name: str, test_type: TestType) -> str:
28    assert test_type in {TestType.CPP, TestType.PY}
29    binary_folder = get_oss_binary_folder(test_type)
30    binary_file = os.path.join(binary_folder, test_name)
31    if test_type == TestType.PY:
32        # add python to the command so we can directly run the script by using binary_file variable
33        binary_file = "python " + binary_file
34    return binary_file
35
36
37def get_llvm_tool_path() -> str:
38    return os.environ.get(
39        "LLVM_TOOL_PATH", "/usr/local/opt/llvm/bin"
40    )  # set default as llvm path in dev server, on mac the default may be /usr/local/opt/llvm/bin
41
42
43def get_pytorch_folder() -> str:
44    # TOOLS_FOLDER in oss: pytorch/tools/code_coverage
45    return os.path.abspath(
46        os.environ.get(
47            "PYTORCH_FOLDER", os.path.join(TOOLS_FOLDER, os.path.pardir, os.path.pardir)
48        )
49    )
50
51
52def detect_compiler_type() -> CompilerType | None:
53    # check if user specifies the compiler type
54    user_specify = os.environ.get("CXX", None)
55    if user_specify:
56        if user_specify in ["clang", "clang++"]:
57            return CompilerType.CLANG
58        elif user_specify in ["gcc", "g++"]:
59            return CompilerType.GCC
60
61        raise RuntimeError(f"User specified compiler is not valid {user_specify}")
62
63    # auto detect
64    auto_detect_result = subprocess.check_output(
65        ["cc", "-v"], stderr=subprocess.STDOUT
66    ).decode("utf-8")
67    if "clang" in auto_detect_result:
68        return CompilerType.CLANG
69    elif "gcc" in auto_detect_result:
70        return CompilerType.GCC
71    raise RuntimeError(f"Auto detected compiler is not valid {auto_detect_result}")
72
73
74def clean_up_gcda() -> None:
75    gcda_files = get_gcda_files()
76    for item in gcda_files:
77        remove_file(item)
78
79
80def get_gcda_files() -> list[str]:
81    folder_has_gcda = os.path.join(get_pytorch_folder(), "build")
82    if os.path.isdir(folder_has_gcda):
83        # TODO use glob
84        # output = glob.glob(f"{folder_has_gcda}/**/*.gcda")
85        output = subprocess.check_output(["find", folder_has_gcda, "-iname", "*.gcda"])
86        return output.decode("utf-8").split("\n")
87    else:
88        return []
89
90
91def run_oss_python_test(binary_file: str) -> None:
92    # python test script
93    try:
94        subprocess.check_call(
95            binary_file, shell=True, cwd=get_oss_binary_folder(TestType.PY)
96        )
97    except subprocess.CalledProcessError:
98        print_error(f"Binary failed to run: {binary_file}")
99