1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python module for Session ops, vars, and functions exported by pybind11.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21# pylint: disable=invalid-import-order,g-bad-import-order, wildcard-import, unused-import 22from tensorflow.python import pywrap_tensorflow 23from tensorflow.python._pywrap_tf_session import * 24from tensorflow.python._pywrap_tf_session import _TF_SetTarget 25from tensorflow.python._pywrap_tf_session import _TF_SetConfig 26from tensorflow.python._pywrap_tf_session import _TF_NewSessionOptions 27 28# Convert versions to strings for Python2 and keep api_compatibility_test green. 29# We can remove this hack once we remove Python2 presubmits. pybind11 can only 30# return unicode for Python2 even with py::str. 31# https://pybind11.readthedocs.io/en/stable/advanced/cast/strings.html#returning-c-strings-to-python 32# pylint: disable=undefined-variable 33__version__ = str(get_version()) 34__git_version__ = str(get_git_version()) 35__compiler_version__ = str(get_compiler_version()) 36__cxx11_abi_flag__ = get_cxx11_abi_flag() 37__monolithic_build__ = get_monolithic_build() 38 39# User getters to hold attributes rather than pybind11's m.attr due to 40# b/145559202. 41GRAPH_DEF_VERSION = get_graph_def_version() 42GRAPH_DEF_VERSION_MIN_CONSUMER = get_graph_def_version_min_consumer() 43GRAPH_DEF_VERSION_MIN_PRODUCER = get_graph_def_version_min_producer() 44TENSOR_HANDLE_KEY = get_tensor_handle_key() 45 46# pylint: enable=undefined-variable 47 48 49# Disable pylint invalid name warnings for legacy functions. 50# pylint: disable=invalid-name 51def TF_NewSessionOptions(target=None, config=None): 52 # NOTE: target and config are validated in the session constructor. 53 opts = _TF_NewSessionOptions() 54 if target is not None: 55 _TF_SetTarget(opts, target) 56 if config is not None: 57 config_str = config.SerializeToString() 58 _TF_SetConfig(opts, config_str) 59 return opts 60 61 62# Disable pylind undefined-variable as the variable is exported in the shared 63# object via pybind11. 64# pylint: disable=undefined-variable 65def TF_Reset(target, containers=None, config=None): 66 opts = TF_NewSessionOptions(target=target, config=config) 67 try: 68 TF_Reset_wrapper(opts, containers) 69 finally: 70 TF_DeleteSessionOptions(opts) 71