• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import collections
3import os
4import shutil
5import subprocess
6
7try:
8    # no type stub for conda command line interface
9    import conda.cli.python_api  # type: ignore[import]
10    from conda.cli.python_api import Commands as conda_commands
11except ImportError:
12    # blas_compare.py will fail to import these when it's inside a conda env,
13    # but that's fine as it only wants the constants.
14    pass
15
16
17WORKING_ROOT = "/tmp/pytorch_blas_compare_environments"
18MKL_2020_3 = "mkl_2020_3"
19MKL_2020_0 = "mkl_2020_0"
20OPEN_BLAS = "open_blas"
21EIGEN = "eigen"
22
23
24GENERIC_ENV_VARS = ("USE_CUDA=0", "USE_ROCM=0")
25BASE_PKG_DEPS = (
26    "cmake",
27    "hypothesis",
28    "ninja",
29    "numpy",
30    "pyyaml",
31    "setuptools",
32    "typing_extensions",
33)
34
35
36SubEnvSpec = collections.namedtuple(
37    "SubEnvSpec", (
38        "generic_installs",
39        "special_installs",
40        "environment_variables",
41
42        # Validate install.
43        "expected_blas_symbols",
44        "expected_mkl_version",
45    ))
46
47
48SUB_ENVS = {
49    MKL_2020_3: SubEnvSpec(
50        generic_installs=(),
51        special_installs=("intel", ("mkl=2020.3", "mkl-include=2020.3")),
52        environment_variables=("BLAS=MKL",) + GENERIC_ENV_VARS,
53        expected_blas_symbols=("mkl_blas_sgemm",),
54        expected_mkl_version="2020.0.3",
55    ),
56
57    MKL_2020_0: SubEnvSpec(
58        generic_installs=(),
59        special_installs=("intel", ("mkl=2020.0", "mkl-include=2020.0")),
60        environment_variables=("BLAS=MKL",) + GENERIC_ENV_VARS,
61        expected_blas_symbols=("mkl_blas_sgemm",),
62        expected_mkl_version="2020.0.0",
63    ),
64
65    OPEN_BLAS: SubEnvSpec(
66        generic_installs=("openblas",),
67        special_installs=(),
68        environment_variables=("BLAS=OpenBLAS",) + GENERIC_ENV_VARS,
69        expected_blas_symbols=("exec_blas",),
70        expected_mkl_version=None,
71    ),
72
73    # EIGEN: SubEnvSpec(
74    #     generic_installs=(),
75    #     special_installs=(),
76    #     environment_variables=("BLAS=Eigen",) + GENERIC_ENV_VARS,
77    #     expected_blas_symbols=(),
78    # ),
79}
80
81
82def conda_run(*args):
83    """Convenience method."""
84    stdout, stderr, retcode = conda.cli.python_api.run_command(*args)
85    if retcode:
86        raise OSError(f"conda error: {str(args)}  retcode: {retcode}\n{stderr}")
87
88    return stdout
89
90
91def main():
92    if os.path.exists(WORKING_ROOT):
93        print("Cleaning: removing old working root.")
94        shutil.rmtree(WORKING_ROOT)
95    os.makedirs(WORKING_ROOT)
96
97    git_root = subprocess.check_output(
98        "git rev-parse --show-toplevel",
99        shell=True,
100        cwd=os.path.dirname(os.path.realpath(__file__))
101    ).decode("utf-8").strip()
102
103    for env_name, env_spec in SUB_ENVS.items():
104        env_path = os.path.join(WORKING_ROOT, env_name)
105        print(f"Creating env: {env_name}: ({env_path})")
106        conda_run(
107            conda_commands.CREATE,
108            "--no-default-packages",
109            "--prefix", env_path,
110            "python=3",
111        )
112
113        print("Testing that env can be activated:")
114        base_source = subprocess.run(
115            f"source activate {env_path}",
116            shell=True,
117            capture_output=True,
118            check=False,
119        )
120        if base_source.returncode:
121            raise OSError(
122                "Failed to source base environment:\n"
123                f"  stdout: {base_source.stdout.decode('utf-8')}\n"
124                f"  stderr: {base_source.stderr.decode('utf-8')}"
125            )
126
127        print("Installing packages:")
128        conda_run(
129            conda_commands.INSTALL,
130            "--prefix", env_path,
131            *(BASE_PKG_DEPS + env_spec.generic_installs)
132        )
133
134        if env_spec.special_installs:
135            channel, channel_deps = env_spec.special_installs
136            print(f"Installing packages from channel: {channel}")
137            conda_run(
138                conda_commands.INSTALL,
139                "--prefix", env_path,
140                "-c", channel, *channel_deps
141            )
142
143        if env_spec.environment_variables:
144            print("Setting environment variables.")
145
146            # This does not appear to be possible using the python API.
147            env_set = subprocess.run(
148                f"source activate {env_path} && "
149                f"conda env config vars set {' '.join(env_spec.environment_variables)}",
150                shell=True,
151                capture_output=True,
152                check=False,
153            )
154            if env_set.returncode:
155                raise OSError(
156                    "Failed to set environment variables:\n"
157                    f"  stdout: {env_set.stdout.decode('utf-8')}\n"
158                    f"  stderr: {env_set.stderr.decode('utf-8')}"
159                )
160
161            # Check that they were actually set correctly.
162            actual_env_vars = subprocess.run(
163                f"source activate {env_path} && env",
164                shell=True,
165                capture_output=True,
166                check=True,
167            ).stdout.decode("utf-8").strip().splitlines()
168            for e in env_spec.environment_variables:
169                assert e in actual_env_vars, f"{e} not in envs"
170
171        print(f"Building PyTorch for env: `{env_name}`")
172        # We have to re-run during each build to pick up the new
173        # build config settings.
174        build_run = subprocess.run(
175            f"source activate {env_path} && "
176            f"cd {git_root} && "
177            "python setup.py install --cmake",
178            shell=True,
179            capture_output=True,
180            check=True,
181        )
182
183        print("Checking configuration:")
184        check_run = subprocess.run(
185            # Shameless abuse of `python -c ...`
186            f"source activate {env_path} && "
187            'python -c "'
188            "import torch;"
189            "from torch.utils.benchmark import Timer;"
190            "print(torch.__config__.show());"
191            "setup = 'x=torch.ones((128, 128));y=torch.ones((128, 128))';"
192            "counts = Timer('torch.mm(x, y)', setup).collect_callgrind(collect_baseline=False);"
193            "stats = counts.as_standardized().stats(inclusive=True);"
194            "print(stats.filter(lambda l: 'blas' in l.lower()))\"",
195            shell=True,
196            capture_output=True,
197            check=False,
198        )
199        if check_run.returncode:
200            raise OSError(
201                "Failed to set environment variables:\n"
202                f"  stdout: {check_run.stdout.decode('utf-8')}\n"
203                f"  stderr: {check_run.stderr.decode('utf-8')}"
204            )
205        check_run_stdout = check_run.stdout.decode('utf-8')
206        print(check_run_stdout)
207
208        for e in env_spec.environment_variables:
209            if "BLAS" in e:
210                assert e in check_run_stdout, f"PyTorch build did not respect `BLAS=...`: {e}"
211
212        for s in env_spec.expected_blas_symbols:
213            assert s in check_run_stdout
214
215        if env_spec.expected_mkl_version is not None:
216            assert f"- Intel(R) Math Kernel Library Version {env_spec.expected_mkl_version}" in check_run_stdout
217
218        print(f"Build complete: {env_name}")
219
220
221if __name__ == "__main__":
222    main()
223