1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Keras functions required by TensorFlow Lite. 17 18The functions defined in this library have been copied over from Keras in order 19to remove the dependency from TensorFlow Lite to Keras. The functions which 20could not be copied over are accessed using the dependency inversion principle. 21(for details, refer to tensorflow/python/util/keras_deps.py). 22""" 23 24from __future__ import absolute_import 25from __future__ import division 26from __future__ import print_function 27 28import copy 29 30from tensorflow.python.eager import def_function 31from tensorflow.python.util import keras_deps 32from tensorflow.python.util import nest 33from tensorflow.python.util.compat import collections_abc 34 35 36def _enforce_names_consistency(specs): 37 """Enforces that either all specs have names or none do.""" 38 39 def _has_name(spec): 40 return hasattr(spec, 'name') and spec.name is not None 41 42 def _clear_name(spec): 43 spec = copy.deepcopy(spec) 44 if hasattr(spec, 'name'): 45 spec._name = None # pylint:disable=protected-access 46 return spec 47 48 flat_specs = nest.flatten(specs) 49 name_inconsistency = ( 50 any(_has_name(s) for s in flat_specs) and 51 not all(_has_name(s) for s in flat_specs)) 52 53 if name_inconsistency: 54 specs = nest.map_structure(_clear_name, specs) 55 return specs 56 57 58def model_input_signature(model, keep_original_batch_size=False): 59 """Inspect model to get its input signature. 60 61 The model's input signature is a list with a single (possibly-nested) object. 62 This is due to the Keras-enforced restriction that tensor inputs must be 63 passed in as the first argument. 64 65 For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>} 66 will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}] 67 68 Args: 69 model: Keras Model object. 70 keep_original_batch_size: A boolean indicating whether we want to keep using 71 the original batch size or set it to None. Default is `False`, which means 72 that the batch dim of the returned input signature will always be set to 73 `None`. 74 75 Returns: 76 A list containing either a single TensorSpec or an object with nested 77 TensorSpecs. This list does not contain the `training` argument. 78 """ 79 if hasattr(model, 'save_spec'): 80 input_specs = model.save_spec(dynamic_batch=not keep_original_batch_size) 81 if input_specs is None: 82 return None 83 # The model's save spec returns (args, kwargs). Extract the first input arg 84 # to use as the input spec. 85 # TODO(b/188105669): Add support for multiple tensor arguments. 86 input_specs = input_specs[0][0] 87 else: 88 input_specs = model._get_save_spec( # pylint: disable=protected-access 89 dynamic_batch=not keep_original_batch_size) 90 if input_specs is None: 91 return None 92 input_specs = _enforce_names_consistency(input_specs) 93 # Return a list with a single element as the model's input signature. 94 if isinstance(input_specs, 95 collections_abc.Sequence) and len(input_specs) == 1: 96 # Note that the isinstance check filters out single-element dictionaries, 97 # which should also be wrapped as a single-element list. 98 return input_specs 99 else: 100 return [input_specs] 101 102 103def raise_model_input_error(model): 104 raise ValueError( 105 'Model {} cannot be saved because the input shapes have not been ' 106 'set. Usually, input shapes are automatically determined from calling' 107 ' `.fit()` or `.predict()`. To manually set the shapes, call ' 108 '`model.build(input_shape)`.'.format(model)) 109 110 111def _create_pseudo_names(tensors, prefix): 112 """Creates pseudo {input | output} names for subclassed Models. 113 114 Warning: this function should only be used to define default 115 names for `Metics` and `SavedModel`. No other use cases should 116 rely on a `Model`'s input or output names. 117 118 Example with dict: 119 120 `{'a': [x1, x2], 'b': x3}` becomes: 121 `['a_1', 'a_2', 'b']` 122 123 Example with list: 124 125 `[x, y]` becomes: 126 `['output_1', 'output_2']` 127 128 Args: 129 tensors: `Model`'s outputs or inputs. 130 prefix: 'output_' for outputs, 'input_' for inputs. 131 132 Returns: 133 Flattened list of pseudo names. 134 """ 135 136 def one_index(ele): 137 # Start with "output_1" instead of "output_0". 138 if isinstance(ele, int): 139 return ele + 1 140 return ele 141 142 flat_paths = list(nest.yield_flat_paths(tensors)) 143 flat_paths = nest.map_structure(one_index, flat_paths) 144 names = [] 145 for path in flat_paths: 146 if not path: 147 name = prefix + '1' # Single output. 148 else: 149 name = '_'.join(str(p) for p in path) 150 if isinstance(path[0], int): 151 name = prefix + name 152 names.append(name) 153 return names 154 155 156def create_pseudo_output_names(outputs): 157 """Create pseudo output names for a subclassed Model.""" 158 return _create_pseudo_names(outputs, prefix='output_') 159 160 161def trace_model_call(model, input_signature=None): 162 """Trace the model call to create a tf.function for exporting a Keras model. 163 164 Args: 165 model: A Keras model. 166 input_signature: optional, a list of tf.TensorSpec objects specifying the 167 inputs to the model. 168 169 Returns: 170 A tf.function wrapping the model's call function with input signatures set. 171 172 Raises: 173 ValueError: if input signature cannot be inferred from the model. 174 """ 175 if input_signature is None: 176 if isinstance(model.call, def_function.Function): 177 input_signature = model.call.input_signature 178 179 if input_signature is None: 180 input_signature = model_input_signature(model) 181 182 if input_signature is None: 183 raise_model_input_error(model) 184 185 @def_function.function(input_signature=input_signature, autograph=False) 186 def _wrapped_model(*args): 187 """A concrete tf.function that wraps the model's call function.""" 188 # When given a single input, Keras models will call the model on the tensor 189 # rather than a list consisting of the single tensor. 190 inputs = args[0] if len(input_signature) == 1 else list(args) 191 192 with keras_deps.get_call_context_function()().enter( 193 model, inputs=inputs, build_graph=False, training=False, saving=True): 194 outputs = model(inputs, training=False) 195 196 return outputs 197 198 return _wrapped_model 199