1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Interface that provides access to Keras dependencies. 17 18This library is a common interface that contains Keras functions needed by 19TensorFlow and TensorFlow Lite and is required as per the dependency inversion 20principle (https://en.wikipedia.org/wiki/Dependency_inversion_principle). As per 21this principle, high-level modules (eg: TensorFlow and TensorFlow Lite) should 22not depend on low-level modules (eg: Keras) and instead both should depend on a 23common interface such as this file. 24""" 25 26 27from __future__ import absolute_import 28from __future__ import division 29from __future__ import print_function 30 31from tensorflow.python.util.tf_export import tf_export 32 33_KERAS_CALL_CONTEXT_FUNCTION = None 34_KERAS_CLEAR_SESSION_FUNCTION = None 35_KERAS_GET_SESSION_FUNCTION = None 36_KERAS_LOAD_MODEL_FUNCTION = None 37 38# TODO(scottzhu): Disable duplicated inject once keras is moved to 39# third_party/py/keras. 40# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF 41 42 43# Register functions 44@tf_export('__internal__.register_call_context_function', v1=[]) 45def register_call_context_function(func): 46 global _KERAS_CALL_CONTEXT_FUNCTION 47 _KERAS_CALL_CONTEXT_FUNCTION = func 48 49 50@tf_export('__internal__.register_clear_session_function', v1=[]) 51def register_clear_session_function(func): 52 global _KERAS_CLEAR_SESSION_FUNCTION 53 _KERAS_CLEAR_SESSION_FUNCTION = func 54 55 56@tf_export('__internal__.register_get_session_function', v1=[]) 57def register_get_session_function(func): 58 global _KERAS_GET_SESSION_FUNCTION 59 _KERAS_GET_SESSION_FUNCTION = func 60 61 62@tf_export('__internal__.register_load_model_function', v1=[]) 63def register_load_model_function(func): 64 global _KERAS_LOAD_MODEL_FUNCTION 65 _KERAS_LOAD_MODEL_FUNCTION = func 66 67 68# Get functions 69def get_call_context_function(): 70 global _KERAS_CALL_CONTEXT_FUNCTION 71 return _KERAS_CALL_CONTEXT_FUNCTION 72 73 74def get_clear_session_function(): 75 global _KERAS_CLEAR_SESSION_FUNCTION 76 return _KERAS_CLEAR_SESSION_FUNCTION 77 78 79def get_get_session_function(): 80 global _KERAS_GET_SESSION_FUNCTION 81 return _KERAS_GET_SESSION_FUNCTION 82 83 84def get_load_model_function(): 85 global _KERAS_LOAD_MODEL_FUNCTION 86 return _KERAS_LOAD_MODEL_FUNCTION 87