1# mypy: allow-untyped-defs 2import collections 3import os 4import shutil 5import subprocess 6 7try: 8 # no type stub for conda command line interface 9 import conda.cli.python_api # type: ignore[import] 10 from conda.cli.python_api import Commands as conda_commands 11except ImportError: 12 # blas_compare.py will fail to import these when it's inside a conda env, 13 # but that's fine as it only wants the constants. 14 pass 15 16 17WORKING_ROOT = "/tmp/pytorch_blas_compare_environments" 18MKL_2020_3 = "mkl_2020_3" 19MKL_2020_0 = "mkl_2020_0" 20OPEN_BLAS = "open_blas" 21EIGEN = "eigen" 22 23 24GENERIC_ENV_VARS = ("USE_CUDA=0", "USE_ROCM=0") 25BASE_PKG_DEPS = ( 26 "cmake", 27 "hypothesis", 28 "ninja", 29 "numpy", 30 "pyyaml", 31 "setuptools", 32 "typing_extensions", 33) 34 35 36SubEnvSpec = collections.namedtuple( 37 "SubEnvSpec", ( 38 "generic_installs", 39 "special_installs", 40 "environment_variables", 41 42 # Validate install. 43 "expected_blas_symbols", 44 "expected_mkl_version", 45 )) 46 47 48SUB_ENVS = { 49 MKL_2020_3: SubEnvSpec( 50 generic_installs=(), 51 special_installs=("intel", ("mkl=2020.3", "mkl-include=2020.3")), 52 environment_variables=("BLAS=MKL",) + GENERIC_ENV_VARS, 53 expected_blas_symbols=("mkl_blas_sgemm",), 54 expected_mkl_version="2020.0.3", 55 ), 56 57 MKL_2020_0: SubEnvSpec( 58 generic_installs=(), 59 special_installs=("intel", ("mkl=2020.0", "mkl-include=2020.0")), 60 environment_variables=("BLAS=MKL",) + GENERIC_ENV_VARS, 61 expected_blas_symbols=("mkl_blas_sgemm",), 62 expected_mkl_version="2020.0.0", 63 ), 64 65 OPEN_BLAS: SubEnvSpec( 66 generic_installs=("openblas",), 67 special_installs=(), 68 environment_variables=("BLAS=OpenBLAS",) + GENERIC_ENV_VARS, 69 expected_blas_symbols=("exec_blas",), 70 expected_mkl_version=None, 71 ), 72 73 # EIGEN: SubEnvSpec( 74 # generic_installs=(), 75 # special_installs=(), 76 # environment_variables=("BLAS=Eigen",) + GENERIC_ENV_VARS, 77 # expected_blas_symbols=(), 78 # ), 79} 80 81 82def conda_run(*args): 83 """Convenience method.""" 84 stdout, stderr, retcode = conda.cli.python_api.run_command(*args) 85 if retcode: 86 raise OSError(f"conda error: {str(args)} retcode: {retcode}\n{stderr}") 87 88 return stdout 89 90 91def main(): 92 if os.path.exists(WORKING_ROOT): 93 print("Cleaning: removing old working root.") 94 shutil.rmtree(WORKING_ROOT) 95 os.makedirs(WORKING_ROOT) 96 97 git_root = subprocess.check_output( 98 "git rev-parse --show-toplevel", 99 shell=True, 100 cwd=os.path.dirname(os.path.realpath(__file__)) 101 ).decode("utf-8").strip() 102 103 for env_name, env_spec in SUB_ENVS.items(): 104 env_path = os.path.join(WORKING_ROOT, env_name) 105 print(f"Creating env: {env_name}: ({env_path})") 106 conda_run( 107 conda_commands.CREATE, 108 "--no-default-packages", 109 "--prefix", env_path, 110 "python=3", 111 ) 112 113 print("Testing that env can be activated:") 114 base_source = subprocess.run( 115 f"source activate {env_path}", 116 shell=True, 117 capture_output=True, 118 check=False, 119 ) 120 if base_source.returncode: 121 raise OSError( 122 "Failed to source base environment:\n" 123 f" stdout: {base_source.stdout.decode('utf-8')}\n" 124 f" stderr: {base_source.stderr.decode('utf-8')}" 125 ) 126 127 print("Installing packages:") 128 conda_run( 129 conda_commands.INSTALL, 130 "--prefix", env_path, 131 *(BASE_PKG_DEPS + env_spec.generic_installs) 132 ) 133 134 if env_spec.special_installs: 135 channel, channel_deps = env_spec.special_installs 136 print(f"Installing packages from channel: {channel}") 137 conda_run( 138 conda_commands.INSTALL, 139 "--prefix", env_path, 140 "-c", channel, *channel_deps 141 ) 142 143 if env_spec.environment_variables: 144 print("Setting environment variables.") 145 146 # This does not appear to be possible using the python API. 147 env_set = subprocess.run( 148 f"source activate {env_path} && " 149 f"conda env config vars set {' '.join(env_spec.environment_variables)}", 150 shell=True, 151 capture_output=True, 152 check=False, 153 ) 154 if env_set.returncode: 155 raise OSError( 156 "Failed to set environment variables:\n" 157 f" stdout: {env_set.stdout.decode('utf-8')}\n" 158 f" stderr: {env_set.stderr.decode('utf-8')}" 159 ) 160 161 # Check that they were actually set correctly. 162 actual_env_vars = subprocess.run( 163 f"source activate {env_path} && env", 164 shell=True, 165 capture_output=True, 166 check=True, 167 ).stdout.decode("utf-8").strip().splitlines() 168 for e in env_spec.environment_variables: 169 assert e in actual_env_vars, f"{e} not in envs" 170 171 print(f"Building PyTorch for env: `{env_name}`") 172 # We have to re-run during each build to pick up the new 173 # build config settings. 174 build_run = subprocess.run( 175 f"source activate {env_path} && " 176 f"cd {git_root} && " 177 "python setup.py install --cmake", 178 shell=True, 179 capture_output=True, 180 check=True, 181 ) 182 183 print("Checking configuration:") 184 check_run = subprocess.run( 185 # Shameless abuse of `python -c ...` 186 f"source activate {env_path} && " 187 'python -c "' 188 "import torch;" 189 "from torch.utils.benchmark import Timer;" 190 "print(torch.__config__.show());" 191 "setup = 'x=torch.ones((128, 128));y=torch.ones((128, 128))';" 192 "counts = Timer('torch.mm(x, y)', setup).collect_callgrind(collect_baseline=False);" 193 "stats = counts.as_standardized().stats(inclusive=True);" 194 "print(stats.filter(lambda l: 'blas' in l.lower()))\"", 195 shell=True, 196 capture_output=True, 197 check=False, 198 ) 199 if check_run.returncode: 200 raise OSError( 201 "Failed to set environment variables:\n" 202 f" stdout: {check_run.stdout.decode('utf-8')}\n" 203 f" stderr: {check_run.stderr.decode('utf-8')}" 204 ) 205 check_run_stdout = check_run.stdout.decode('utf-8') 206 print(check_run_stdout) 207 208 for e in env_spec.environment_variables: 209 if "BLAS" in e: 210 assert e in check_run_stdout, f"PyTorch build did not respect `BLAS=...`: {e}" 211 212 for s in env_spec.expected_blas_symbols: 213 assert s in check_run_stdout 214 215 if env_spec.expected_mkl_version is not None: 216 assert f"- Intel(R) Math Kernel Library Version {env_spec.expected_mkl_version}" in check_run_stdout 217 218 print(f"Build complete: {env_name}") 219 220 221if __name__ == "__main__": 222 main() 223