• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Keras functions required by TensorFlow Lite.
17
18The functions defined in this library have been copied over from Keras in order
19to remove the dependency from TensorFlow Lite to Keras. The functions which
20could not be copied over are accessed using the dependency inversion principle.
21(for details, refer to tensorflow/python/util/keras_deps.py).
22"""
23
24from __future__ import absolute_import
25from __future__ import division
26from __future__ import print_function
27
28import copy
29
30from tensorflow.python.eager import def_function
31from tensorflow.python.util import keras_deps
32from tensorflow.python.util import nest
33from tensorflow.python.util.compat import collections_abc
34
35
36def _enforce_names_consistency(specs):
37  """Enforces that either all specs have names or none do."""
38
39  def _has_name(spec):
40    return hasattr(spec, 'name') and spec.name is not None
41
42  def _clear_name(spec):
43    spec = copy.deepcopy(spec)
44    if hasattr(spec, 'name'):
45      spec._name = None  # pylint:disable=protected-access
46    return spec
47
48  flat_specs = nest.flatten(specs)
49  name_inconsistency = (
50      any(_has_name(s) for s in flat_specs) and
51      not all(_has_name(s) for s in flat_specs))
52
53  if name_inconsistency:
54    specs = nest.map_structure(_clear_name, specs)
55  return specs
56
57
58def model_input_signature(model, keep_original_batch_size=False):
59  """Inspect model to get its input signature.
60
61  The model's input signature is a list with a single (possibly-nested) object.
62  This is due to the Keras-enforced restriction that tensor inputs must be
63  passed in as the first argument.
64
65  For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
66  will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}]
67
68  Args:
69    model: Keras Model object.
70    keep_original_batch_size: A boolean indicating whether we want to keep using
71      the original batch size or set it to None. Default is `False`, which means
72      that the batch dim of the returned input signature will always be set to
73      `None`.
74
75  Returns:
76    A list containing either a single TensorSpec or an object with nested
77    TensorSpecs. This list does not contain the `training` argument.
78  """
79  if hasattr(model, 'save_spec'):
80    input_specs = model.save_spec(dynamic_batch=not keep_original_batch_size)
81    if input_specs is None:
82      return None
83    # The model's save spec returns (args, kwargs). Extract the first input arg
84    # to use as the input spec.
85    # TODO(b/188105669): Add support for multiple tensor arguments.
86    input_specs = input_specs[0][0]
87  else:
88    input_specs = model._get_save_spec(  # pylint: disable=protected-access
89        dynamic_batch=not keep_original_batch_size)
90    if input_specs is None:
91      return None
92  input_specs = _enforce_names_consistency(input_specs)
93  # Return a list with a single element as the model's input signature.
94  if isinstance(input_specs,
95                collections_abc.Sequence) and len(input_specs) == 1:
96    # Note that the isinstance check filters out single-element dictionaries,
97    # which should also be wrapped as a single-element list.
98    return input_specs
99  else:
100    return [input_specs]
101
102
103def raise_model_input_error(model):
104  raise ValueError(
105      'Model {} cannot be saved because the input shapes have not been '
106      'set. Usually, input shapes are automatically determined from calling'
107      ' `.fit()` or `.predict()`. To manually set the shapes, call '
108      '`model.build(input_shape)`.'.format(model))
109
110
111def _create_pseudo_names(tensors, prefix):
112  """Creates pseudo {input | output} names for subclassed Models.
113
114  Warning: this function should only be used to define default
115  names for `Metics` and `SavedModel`. No other use cases should
116  rely on a `Model`'s input or output names.
117
118  Example with dict:
119
120  `{'a': [x1, x2], 'b': x3}` becomes:
121  `['a_1', 'a_2', 'b']`
122
123  Example with list:
124
125  `[x, y]` becomes:
126  `['output_1', 'output_2']`
127
128  Args:
129    tensors: `Model`'s outputs or inputs.
130    prefix: 'output_' for outputs, 'input_' for inputs.
131
132  Returns:
133    Flattened list of pseudo names.
134  """
135
136  def one_index(ele):
137    # Start with "output_1" instead of "output_0".
138    if isinstance(ele, int):
139      return ele + 1
140    return ele
141
142  flat_paths = list(nest.yield_flat_paths(tensors))
143  flat_paths = nest.map_structure(one_index, flat_paths)
144  names = []
145  for path in flat_paths:
146    if not path:
147      name = prefix + '1'  # Single output.
148    else:
149      name = '_'.join(str(p) for p in path)
150      if isinstance(path[0], int):
151        name = prefix + name
152    names.append(name)
153  return names
154
155
156def create_pseudo_output_names(outputs):
157  """Create pseudo output names for a subclassed Model."""
158  return _create_pseudo_names(outputs, prefix='output_')
159
160
161def trace_model_call(model, input_signature=None):
162  """Trace the model call to create a tf.function for exporting a Keras model.
163
164  Args:
165    model: A Keras model.
166    input_signature: optional, a list of tf.TensorSpec objects specifying the
167      inputs to the model.
168
169  Returns:
170    A tf.function wrapping the model's call function with input signatures set.
171
172  Raises:
173    ValueError: if input signature cannot be inferred from the model.
174  """
175  if input_signature is None:
176    if isinstance(model.call, def_function.Function):
177      input_signature = model.call.input_signature
178
179  if input_signature is None:
180    input_signature = model_input_signature(model)
181
182  if input_signature is None:
183    raise_model_input_error(model)
184
185  @def_function.function(input_signature=input_signature, autograph=False)
186  def _wrapped_model(*args):
187    """A concrete tf.function that wraps the model's call function."""
188    # When given a single input, Keras models will call the model on the tensor
189    # rather than a list consisting of the single tensor.
190    inputs = args[0] if len(input_signature) == 1 else list(args)
191
192    with keras_deps.get_call_context_function()().enter(
193        model, inputs=inputs, build_graph=False, training=False, saving=True):
194      outputs = model(inputs, training=False)
195
196    return outputs
197
198  return _wrapped_model
199