• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""A Python wrapper that loads _pywrap_tensorflow_internal.so."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import ctypes
22import sys
23import traceback
24
25from tensorflow.python.platform import self_check
26
27# Perform pre-load sanity checks in order to produce a more actionable error.
28self_check.preload_check()
29
30# pylint: disable=wildcard-import,g-import-not-at-top,unused-import,line-too-long
31
32try:
33  # This import is expected to fail if there is an explicit shared object
34  # dependency (with_framework_lib=true), since we do not need RTLD_GLOBAL.
35  from tensorflow.python import pywrap_dlopen_global_flags
36  _use_dlopen_global_flags = True
37except ImportError:
38  _use_dlopen_global_flags = False
39
40# On UNIX-based platforms, pywrap_tensorflow is a python library that
41# dynamically loads _pywrap_tensorflow.so.
42_can_set_rtld_local = (
43    hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags'))
44if _can_set_rtld_local:
45  _default_dlopen_flags = sys.getdlopenflags()
46
47try:
48  if _use_dlopen_global_flags:
49    pywrap_dlopen_global_flags.set_dlopen_flags()
50  elif _can_set_rtld_local:
51    # Ensure RTLD_LOCAL behavior for platforms where it isn't the default
52    # (macOS). On Linux RTLD_LOCAL is 0, so this does nothing (and would not
53    # override an RTLD_GLOBAL in _default_dlopen_flags).
54    sys.setdlopenflags(_default_dlopen_flags | ctypes.RTLD_LOCAL)
55
56  # Python2.7 does not have a ModuleNotFoundError.
57  try:
58    ModuleNotFoundError
59  except NameError:
60    ModuleNotFoundError = ImportError
61
62  # pylint: disable=wildcard-import,g-import-not-at-top,line-too-long,undefined-variable
63  try:
64    from tensorflow.python._pywrap_tensorflow_internal import *
65  # This try catch logic is because there is no bazel equivalent for py_extension.
66  # Externally in opensource we must enable exceptions to load the shared object
67  # by exposing the PyInit symbols with pybind. This error will only be
68  # caught internally or if someone changes the name of the target _pywrap_tensorflow_internal.
69
70  # This logic is used in other internal projects using py_extension.
71  except ModuleNotFoundError:
72    pass
73
74  if _use_dlopen_global_flags:
75    pywrap_dlopen_global_flags.reset_dlopen_flags()
76  elif _can_set_rtld_local:
77    sys.setdlopenflags(_default_dlopen_flags)
78except ImportError:
79  msg = """%s\n\nFailed to load the native TensorFlow runtime.\n
80See https://www.tensorflow.org/install/errors\n
81for some common reasons and solutions.  Include the entire stack trace
82above this error message when asking for help.""" % traceback.format_exc()
83  raise ImportError(msg)
84
85# pylint: enable=wildcard-import,g-import-not-at-top,unused-import,line-too-long
86