• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Configure build environment for certain Intel platforms."""
16
17import argparse
18import os
19import subprocess
20
21BASIC_BUILD_OPTS = ["--cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0", "--copt=-O3"]
22
23SECURE_BUILD_OPTS = [
24    "--copt=-Wformat", "--copt=-Wformat-security", "--copt=-fstack-protector",
25    "--copt=-fPIC", "--copt=-fpic", "--linkopt=-znoexecstack",
26    "--linkopt=-zrelro", "--linkopt=-znow", "--linkopt=-fstack-protector"
27]
28
29
30class IntelPlatform(object):
31  min_gcc_major_version_ = 0
32  min_gcc_minor_version_ = 0
33  host_gcc_major_version_ = 0
34  host_gcc_minor_version_ = 0
35  BAZEL_PREFIX_ = "--copt="
36  ARCH_PREFIX_ = "-march="
37  FLAG_PREFIX_ = "-m"
38
39  def __init__(self, min_gcc_major_version, min_gcc_minor_version):
40    self.min_gcc_minor_version_ = min_gcc_minor_version
41    self.min_gcc_major_version_ = min_gcc_major_version
42
43  # Return True or False depending on whether
44  # The platform optimization flags can be generated by
45  # the gcc version specified in the parameters
46  def set_host_gcc_version(self, gcc_major_version, gcc_minor_version):
47    # True only if the gcc version in the tuple is >=
48    # min_gcc_major_version_, min_gcc_minor_version_
49    if gcc_major_version < self.min_gcc_major_version_:
50      print("Your MAJOR version of GCC is too old: {}; "
51            "it must be at least {}.{}".format(gcc_major_version,
52                                               self.min_gcc_major_version_,
53                                               self.min_gcc_minor_version_))
54      return False
55    elif gcc_major_version == self.min_gcc_major_version_ and \
56          gcc_minor_version < self.min_gcc_minor_version_:
57      print("Your MINOR version of GCC is too old: {}; "
58            "it must be at least {}.{}".format(gcc_minor_version,
59                                               self.min_gcc_major_version_,
60                                               self.min_gcc_minor_version_))
61      return False
62    print("gcc version OK: {}.{}".format(gcc_major_version, gcc_minor_version))
63    self.host_gcc_major_version_ = gcc_major_version
64    self.host_gcc_minor_version_ = gcc_minor_version
65    return True
66
67  # return a string with all the necessary bazel formatted flags for this
68  # platform in this gcc environment
69  def get_bazel_gcc_flags(self):
70    raise NotImplementedError(self)
71
72  # Returns True if the host gcc version is older than the gcc version in which
73  # the new march flag became available.
74  # Specify the version in which the new name usage began
75  def use_old_arch_names(self, gcc_new_march_major_version,
76                         gcc_new_march_minor_version):
77    if self.host_gcc_major_version_ < gcc_new_march_major_version:
78      return True
79    elif self.host_gcc_major_version_ == gcc_new_march_major_version and \
80       self.host_gcc_minor_version_ < gcc_new_march_minor_version:
81      return True
82    return False
83
84
85class NehalemPlatform(IntelPlatform):
86
87  def __init__(self):
88    IntelPlatform.__init__(self, 4, 8)
89
90  def get_bazel_gcc_flags(self):
91    NEHALEM_ARCH_OLD = "corei7"
92    NEHALEM_ARCH_NEW = "nehalem"
93    if self.use_old_arch_names(4, 9):
94      return self.BAZEL_PREFIX_ + self.ARCH_PREFIX_ + \
95             NEHALEM_ARCH_OLD + " "
96    else:
97      return self.BAZEL_PREFIX_ + self.ARCH_PREFIX_ + \
98             NEHALEM_ARCH_NEW + " "
99
100
101class SandyBridgePlatform(IntelPlatform):
102
103  def __init__(self):
104    IntelPlatform.__init__(self, 4, 8)
105
106  def get_bazel_gcc_flags(self):
107    SANDYBRIDGE_ARCH_OLD = "corei7-avx"
108    SANDYBRIDGE_ARCH_NEW = "sandybridge"
109    if self.use_old_arch_names(4, 9):
110      return self.BAZEL_PREFIX_ + self.ARCH_PREFIX_ + \
111             SANDYBRIDGE_ARCH_OLD + " "
112    else:
113      return self.BAZEL_PREFIX_ + self.ARCH_PREFIX_ + \
114             SANDYBRIDGE_ARCH_NEW + " "
115
116
117class HaswellPlatform(IntelPlatform):
118
119  def __init__(self):
120    IntelPlatform.__init__(self, 4, 8)
121
122  def get_bazel_gcc_flags(self):
123    HASWELL_ARCH_OLD = "core-avx2"  # Only missing the POPCNT instruction
124    HASWELL_ARCH_NEW = "haswell"
125    POPCNT_FLAG = "popcnt"
126    if self.use_old_arch_names(4, 9):
127      ret_val = self.BAZEL_PREFIX_ + self.ARCH_PREFIX_ + \
128                HASWELL_ARCH_OLD + " "
129      return ret_val + self.BAZEL_PREFIX_ + self.FLAG_PREFIX_ + \
130             POPCNT_FLAG + " "
131    else:
132      return self.BAZEL_PREFIX_ + self.ARCH_PREFIX_ + \
133             HASWELL_ARCH_NEW + " "
134
135
136class SkylakePlatform(IntelPlatform):
137
138  def __init__(self):
139    IntelPlatform.__init__(self, 4, 9)
140
141  def get_bazel_gcc_flags(self):
142    SKYLAKE_ARCH_OLD = "broadwell"  # Only missing the POPCNT instruction
143    SKYLAKE_ARCH_NEW = "skylake-avx512"
144    # the flags that broadwell is missing: pku, clflushopt, clwb, avx512vl,
145    # avx512bw, avx512dq. xsavec and xsaves are available in gcc 5.x
146    # but for now, just exclude them.
147    AVX512_FLAGS = ["avx512f", "avx512cd"]
148    if self.use_old_arch_names(6, 1):
149      ret_val = self.BAZEL_PREFIX_ + self.ARCH_PREFIX_ + \
150                SKYLAKE_ARCH_OLD + " "
151      for flag in AVX512_FLAGS:
152        ret_val += self.BAZEL_PREFIX_ + self.FLAG_PREFIX_ + flag + " "
153      return ret_val
154    else:
155      return self.BAZEL_PREFIX_ + self.ARCH_PREFIX_ + \
156             SKYLAKE_ARCH_NEW + " "
157
158
159class CascadelakePlatform(IntelPlatform):
160
161  def __init__(self):
162    IntelPlatform.__init__(self, 8, 3)
163
164  def get_bazel_gcc_flags(self):
165    CASCADELAKE_ARCH_OLD = "skylake-avx512"  # Only missing the POPCNT instruction
166    CASCADELAKE_ARCH_NEW = "cascadelake"
167    # the flags that broadwell is missing: pku, clflushopt, clwb, avx512vl, avx512bw, avx512dq
168    VNNI_FLAG = "avx512vnni"
169    if IntelPlatform.use_old_arch_names(self, 9, 1):
170      ret_val = self.BAZEL_PREFIX_ + self.ARCH_PREFIX_ + \
171        CASCADELAKE_ARCH_OLD + " "
172      return ret_val + self.BAZEL_PREFIX_ + self.FLAG_PREFIX_ + \
173             VNNI_FLAG + " "
174    else:
175      return self.BAZEL_PREFIX_ + self.ARCH_PREFIX_ + \
176             CASCADELAKE_ARCH_NEW + " "
177
178
179class IcelakeClientPlatform(IntelPlatform):
180
181  def __init__(self):
182    IntelPlatform.__init__(self, 8, 4)
183
184  def get_bazel_gcc_flags(self):
185    ICELAKE_ARCH_OLD = "skylake-avx512"
186    ICELAKE_ARCH_NEW = "icelake-client"
187    AVX512_FLAGS = ["avx512f", "avx512cd"]
188    if IntelPlatform.use_old_arch_names(self, 8, 4):
189      ret_val = self.BAZEL_PREFIX_ + self.ARCH_PREFIX_ + \
190        ICELAKE_ARCH_OLD + " "
191      for flag in AVX512_FLAGS:
192        ret_val += self.BAZEL_PREFIX_ + self.FLAG_PREFIX_ + flag + " "
193      return ret_val
194    else:
195      return self.BAZEL_PREFIX_ + self.ARCH_PREFIX_ + \
196             ICELAKE_ARCH_NEW + " "
197
198
199class IcelakeServerPlatform(IntelPlatform):
200
201  def __init__(self):
202    IntelPlatform.__init__(self, 8, 4)
203
204  def get_bazel_gcc_flags(self):
205    ICELAKE_ARCH_OLD = "skylake-avx512"
206    ICELAKE_ARCH_NEW = "icelake-server"
207    AVX512_FLAGS = ["avx512f", "avx512cd"]
208    if IntelPlatform.use_old_arch_names(self, 8, 4):
209      ret_val = self.BAZEL_PREFIX_ + self.ARCH_PREFIX_ + \
210        ICELAKE_ARCH_OLD + " "
211      for flag in AVX512_FLAGS:
212        ret_val += self.BAZEL_PREFIX_ + self.FLAG_PREFIX_ + flag + " "
213      return ret_val
214    else:
215      return self.BAZEL_PREFIX_ + self.ARCH_PREFIX_ + \
216             ICELAKE_ARCH_NEW + " "
217
218
219class BuildEnvSetter(object):
220  """Prepares the proper environment settings for various Intel platforms."""
221  default_platform_ = "haswell"
222
223  PLATFORMS_ = {
224      "nehalem": NehalemPlatform(),
225      "sandybridge": SandyBridgePlatform(),
226      "haswell": HaswellPlatform(),
227      "skylake": SkylakePlatform(),
228      "cascadelake": CascadelakePlatform(),
229      "icelake-client": IcelakeClientPlatform(),
230      "icelake-server": IcelakeServerPlatform(),
231  }
232
233  def __init__(self):
234    self.args = None
235    self.bazel_flags_ = "build "
236    self.target_platform_ = None
237
238  # Return a tuple of the current gcc version
239  def get_gcc_version(self):
240    gcc_major_version = 0
241    gcc_minor_version = 0
242    # check to see if gcc is present
243    gcc_path = ""
244    gcc_path_cmd = "command -v gcc"
245    try:
246      gcc_path = subprocess.check_output(gcc_path_cmd, shell=True,
247                                         stderr=subprocess.STDOUT).\
248        strip()
249      print("gcc located here: {}".format(gcc_path))
250      if not os.access(gcc_path, os.F_OK | os.X_OK):
251        raise ValueError(
252            "{} does not exist or is not executable.".format(gcc_path))
253
254      gcc_output = subprocess.check_output(
255          [gcc_path, "-dumpfullversion", "-dumpversion"],
256          stderr=subprocess.STDOUT).strip()
257      # handle python2 vs 3 (bytes vs str type)
258      if isinstance(gcc_output, bytes):
259        gcc_output = gcc_output.decode("utf-8")
260      print("gcc version: {}".format(gcc_output))
261      gcc_info = gcc_output.split(".")
262      gcc_major_version = int(gcc_info[0])
263      gcc_minor_version = int(gcc_info[1])
264    except subprocess.CalledProcessException as e:
265      print("Problem getting gcc info: {}".format(e))
266      gcc_major_version = 0
267      gcc_minor_version = 0
268    return gcc_major_version, gcc_minor_version
269
270  def parse_args(self):
271    """Set up argument parser, and parse CLI args."""
272    arg_parser = argparse.ArgumentParser(
273        description="Parse the arguments for the "
274        "TensorFlow build environment "
275        " setter")
276    arg_parser.add_argument(
277        "--disable-mkl",
278        dest="disable_mkl",
279        help="Turn off MKL. By default the compiler flag "
280        "--config=mkl is enabled.",
281        action="store_true")
282    arg_parser.add_argument(
283        "--disable-v2",
284        dest="disable_v2",
285        help="Build TensorFlow v1 rather than v2. By default the "
286        " compiler flag --config=v2 is enabled.",
287        action="store_true")
288    arg_parser.add_argument(
289        "--enable-bfloat16",
290        dest="enable_bfloat16",
291        help="Enable bfloat16 build. By default it is "
292        " disabled if no parameter is passed.",
293        action="store_true")
294    arg_parser.add_argument(
295        "--enable-dnnl1",
296        dest="enable_dnnl1",
297        help="Enable dnnl1 build. By default it is "
298        " disabled if no parameter is passed.",
299        action="store_true")
300    arg_parser.add_argument(
301        "-s",
302        "--secure-build",
303        dest="secure_build",
304        help="Enable secure build flags.",
305        action="store_true")
306    arg_parser.add_argument(
307        "-p",
308        "--platform",
309        choices=self.PLATFORMS_.keys(),
310        help="The target platform.",
311        dest="target_platform",
312        default=self.default_platform_)
313    arg_parser.add_argument(
314        "-f",
315        "--bazelrc-file",
316        dest="bazelrc_file",
317        help="The full path to the bazelrc file into which "
318        "the build command will be written. The path "
319        "will be relative to the container "
320        " environment.",
321        required=True)
322
323    self.args = arg_parser.parse_args()
324
325  def validate_args(self):
326    # Check the bazelrc file
327    if os.path.exists(self.args.bazelrc_file):
328      if os.path.isfile(self.args.bazelrc_file):
329        self._debug("The file {} exists and will be deleted.".format(
330            self.args.bazelrc_file))
331      elif os.path.isdir(self.args.bazelrc_file):
332        print("You can't write bazel config to \"{}\" "
333              "because it is a directory".format(self.args.bazelrc_file))
334        return False
335
336    # Validate gcc with the requested platform
337    gcc_major_version, gcc_minor_version = self.get_gcc_version()
338    if gcc_major_version == 0 or \
339       not self.target_platform_.set_host_gcc_version(
340           gcc_major_version, gcc_minor_version):
341      return False
342
343    return True
344
345  def set_build_args(self):
346    """Generate Bazel build flags."""
347    for flag in BASIC_BUILD_OPTS:
348      self.bazel_flags_ += "{} ".format(flag)
349    if self.args.secure_build:
350      for flag in SECURE_BUILD_OPTS:
351        self.bazel_flags_ += "{} ".format(flag)
352    if not self.args.disable_mkl:
353      self.bazel_flags_ += "--config=mkl "
354    if self.args.disable_v2:
355      self.bazel_flags_ += "--config=v1 "
356    if self.args.enable_dnnl1:
357      self.bazel_flags_ += "--define build_with_mkl_dnn_v1_only=true "
358    if self.args.enable_bfloat16:
359      self.bazel_flags_ += "--copt=-DENABLE_INTEL_MKL_BFLOAT16 "
360
361    self.bazel_flags_ += self.target_platform_.get_bazel_gcc_flags()
362
363  def write_build_args(self):
364    self._debug("Writing build flags: {}".format(self.bazel_flags_))
365    with open(self.args.bazelrc_file, "w") as f:
366      f.write(self.bazel_flags_ + "\n")
367
368  def _debug(self, msg):
369    print(msg)
370
371  def go(self):
372    self.parse_args()
373    self.target_platform_ = self.PLATFORMS_.get(self.args.target_platform)
374    if self.validate_args():
375      self.set_build_args()
376      self.write_build_args()
377    else:
378      print("Error.")
379
380env_setter = BuildEnvSetter()
381env_setter.go()
382