• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module that encodes (decodes) nested structures into (from) protos.
16
17The intended use is to serialize everything needed to restore a `Function` that
18was saved into a SavedModel. This may include concrete function inputs and
19outputs, signatures, function specs, etc.
20
21Example use:
22coder = nested_structure_coder.StructureCoder()
23# Encode into proto.
24signature_proto = coder.encode_structure(function.input_signature)
25# Decode into a Python object.
26restored_signature = coder.decode_proto(signature_proto)
27"""
28
29from __future__ import absolute_import
30from __future__ import division
31from __future__ import print_function
32
33import collections
34import functools
35
36import warnings
37import six
38
39from tensorflow.core.protobuf import struct_pb2
40from tensorflow.python.data.ops import dataset_ops
41from tensorflow.python.data.ops import iterator_ops
42from tensorflow.python.data.ops import optional_ops
43from tensorflow.python.distribute import values
44from tensorflow.python.framework import dtypes
45from tensorflow.python.framework import extension_type
46from tensorflow.python.framework import indexed_slices
47from tensorflow.python.framework import sparse_tensor
48from tensorflow.python.framework import tensor_shape
49from tensorflow.python.framework import tensor_spec
50from tensorflow.python.framework import tensor_util
51from tensorflow.python.framework import type_spec
52from tensorflow.python.ops import resource_variable_ops
53from tensorflow.python.ops import tensor_array_ops
54from tensorflow.python.ops.ragged import ragged_tensor
55from tensorflow.python.ops.ragged import row_partition
56from tensorflow.python.util import compat
57from tensorflow.python.util.compat import collections_abc
58from tensorflow.python.util.tf_export import tf_export
59
60
61class NotEncodableError(Exception):
62  """Error raised when a coder cannot encode an object."""
63
64
65@tf_export("__internal__.saved_model.StructureCoder", v1=[])
66class StructureCoder(object):
67  """Encoder and decoder for nested structures into protos."""
68
69  _codecs = []
70
71  @classmethod
72  def register_codec(cls, x):
73    cls._codecs.append(x)
74
75  @classmethod
76  def _get_encoders(cls):
77    return [(c.can_encode, c.do_encode) for c in cls._codecs]
78
79  @classmethod
80  def _get_decoders(cls):
81    return [(c.can_decode, c.do_decode) for c in cls._codecs]
82
83  def _map_structure(self, pyobj, coders):
84    for can, do in coders:
85      if can(pyobj):
86        recursion_fn = functools.partial(self._map_structure, coders=coders)
87        return do(pyobj, recursion_fn)
88    raise NotEncodableError(
89        f"No encoder for object {str(pyobj)} of type {type(pyobj)}.")
90
91  def encode_structure(self, nested_structure):
92    """Encodes nested structures composed of encodable types into a proto.
93
94    Args:
95      nested_structure: Structure to encode.
96
97    Returns:
98      Encoded proto.
99
100    Raises:
101      NotEncodableError: For values for which there are no encoders.
102    """
103    return self._map_structure(nested_structure, self._get_encoders())
104
105  def can_encode(self, nested_structure):
106    """Determines whether a nested structure can be encoded into a proto.
107
108    Args:
109      nested_structure: Structure to encode.
110
111    Returns:
112      True if the nested structured can be encoded.
113    """
114    try:
115      self.encode_structure(nested_structure)
116    except NotEncodableError:
117      return False
118    return True
119
120  def decode_proto(self, proto):
121    """Decodes proto representing a nested structure.
122
123    Args:
124      proto: Proto to decode.
125
126    Returns:
127      Decoded structure.
128
129    Raises:
130      NotEncodableError: For values for which there are no encoders.
131    """
132    return self._map_structure(proto, self._get_decoders())
133
134
135class _ListCodec(object):
136  """Codec for lists."""
137
138  def can_encode(self, pyobj):
139    return isinstance(pyobj, list)
140
141  def do_encode(self, list_value, encode_fn):
142    encoded_list = struct_pb2.StructuredValue()
143    encoded_list.list_value.CopyFrom(struct_pb2.ListValue())
144    for element in list_value:
145      encoded_list.list_value.values.add().CopyFrom(encode_fn(element))
146    return encoded_list
147
148  def can_decode(self, value):
149    return value.HasField("list_value")
150
151  def do_decode(self, value, decode_fn):
152    return [decode_fn(element) for element in value.list_value.values]
153
154
155StructureCoder.register_codec(_ListCodec())
156
157
158def _is_tuple(obj):
159  return not _is_named_tuple(obj) and isinstance(obj, tuple)
160
161
162def _is_named_tuple(instance):
163  """Returns True iff `instance` is a `namedtuple`.
164
165  Args:
166    instance: An instance of a Python object.
167
168  Returns:
169    True if `instance` is a `namedtuple`.
170  """
171  if not isinstance(instance, tuple):
172    return False
173  return (hasattr(instance, "_fields") and
174          isinstance(instance._fields, collections_abc.Sequence) and
175          all(isinstance(f, six.string_types) for f in instance._fields))
176
177
178class _TupleCodec(object):
179  """Codec for tuples."""
180
181  def can_encode(self, pyobj):
182    return _is_tuple(pyobj)
183
184  def do_encode(self, tuple_value, encode_fn):
185    encoded_tuple = struct_pb2.StructuredValue()
186    encoded_tuple.tuple_value.CopyFrom(struct_pb2.TupleValue())
187    for element in tuple_value:
188      encoded_tuple.tuple_value.values.add().CopyFrom(encode_fn(element))
189    return encoded_tuple
190
191  def can_decode(self, value):
192    return value.HasField("tuple_value")
193
194  def do_decode(self, value, decode_fn):
195    return tuple(decode_fn(element) for element in value.tuple_value.values)
196
197
198StructureCoder.register_codec(_TupleCodec())
199
200
201class _DictCodec(object):
202  """Codec for dicts."""
203
204  def can_encode(self, pyobj):
205    return isinstance(pyobj, dict)
206
207  def do_encode(self, dict_value, encode_fn):
208    encoded_dict = struct_pb2.StructuredValue()
209    encoded_dict.dict_value.CopyFrom(struct_pb2.DictValue())
210    for key, value in dict_value.items():
211      encoded_dict.dict_value.fields[key].CopyFrom(encode_fn(value))
212    return encoded_dict
213
214  def can_decode(self, value):
215    return value.HasField("dict_value")
216
217  def do_decode(self, value, decode_fn):
218    return {key: decode_fn(val) for key, val in value.dict_value.fields.items()}
219
220
221StructureCoder.register_codec(_DictCodec())
222
223
224class _NamedTupleCodec(object):
225  """Codec for namedtuples.
226
227  Encoding and decoding a namedtuple reconstructs a namedtuple with a different
228  actual Python type, but with the same `typename` and `fields`.
229  """
230
231  def can_encode(self, pyobj):
232    return _is_named_tuple(pyobj)
233
234  def do_encode(self, named_tuple_value, encode_fn):
235    encoded_named_tuple = struct_pb2.StructuredValue()
236    encoded_named_tuple.named_tuple_value.CopyFrom(struct_pb2.NamedTupleValue())
237    encoded_named_tuple.named_tuple_value.name = \
238      named_tuple_value.__class__.__name__
239    for key in named_tuple_value._fields:
240      pair = encoded_named_tuple.named_tuple_value.values.add()
241      pair.key = key
242      pair.value.CopyFrom(encode_fn(named_tuple_value._asdict()[key]))
243    return encoded_named_tuple
244
245  def can_decode(self, value):
246    return value.HasField("named_tuple_value")
247
248  def do_decode(self, value, decode_fn):
249    key_value_pairs = value.named_tuple_value.values
250    items = [(pair.key, decode_fn(pair.value)) for pair in key_value_pairs]
251    named_tuple_type = collections.namedtuple(value.named_tuple_value.name,
252                                              [item[0] for item in items])
253    return named_tuple_type(**dict(items))
254
255
256StructureCoder.register_codec(_NamedTupleCodec())
257
258
259class _Float64Codec(object):
260  """Codec for floats."""
261
262  def can_encode(self, pyobj):
263    return isinstance(pyobj, float)
264
265  def do_encode(self, float64_value, encode_fn):
266    del encode_fn
267    value = struct_pb2.StructuredValue()
268    value.float64_value = float64_value
269    return value
270
271  def can_decode(self, value):
272    return value.HasField("float64_value")
273
274  def do_decode(self, value, decode_fn):
275    del decode_fn
276    return value.float64_value
277
278
279StructureCoder.register_codec(_Float64Codec())
280
281
282class _Int64Codec(object):
283  """Codec for Python integers (limited to 64 bit values)."""
284
285  def can_encode(self, pyobj):
286    return not isinstance(pyobj, bool) and isinstance(pyobj, int)
287
288  def do_encode(self, int_value, encode_fn):
289    del encode_fn
290    value = struct_pb2.StructuredValue()
291    value.int64_value = int_value
292    return value
293
294  def can_decode(self, value):
295    return value.HasField("int64_value")
296
297  def do_decode(self, value, decode_fn):
298    del decode_fn
299    return int(value.int64_value)
300
301
302StructureCoder.register_codec(_Int64Codec())
303
304
305class _StringCodec(object):
306  """Codec for strings.
307
308  See StructuredValue.string_value in proto/struct.proto for more detailed
309  explanation.
310  """
311
312  def can_encode(self, pyobj):
313    return isinstance(pyobj, str)
314
315  def do_encode(self, string_value, encode_fn):
316    del encode_fn
317    value = struct_pb2.StructuredValue()
318    value.string_value = string_value
319    return value
320
321  def can_decode(self, value):
322    return value.HasField("string_value")
323
324  def do_decode(self, value, decode_fn):
325    del decode_fn
326    return compat.as_str(value.string_value)
327
328
329StructureCoder.register_codec(_StringCodec())
330
331
332class _NoneCodec(object):
333  """Codec for None."""
334
335  def can_encode(self, pyobj):
336    return pyobj is None
337
338  def do_encode(self, none_value, encode_fn):
339    del encode_fn, none_value
340    value = struct_pb2.StructuredValue()
341    value.none_value.CopyFrom(struct_pb2.NoneValue())
342    return value
343
344  def can_decode(self, value):
345    return value.HasField("none_value")
346
347  def do_decode(self, value, decode_fn):
348    del decode_fn, value
349    return None
350
351
352StructureCoder.register_codec(_NoneCodec())
353
354
355class _BoolCodec(object):
356  """Codec for booleans."""
357
358  def can_encode(self, pyobj):
359    return isinstance(pyobj, bool)
360
361  def do_encode(self, bool_value, encode_fn):
362    del encode_fn
363    value = struct_pb2.StructuredValue()
364    value.bool_value = bool_value
365    return value
366
367  def can_decode(self, value):
368    return value.HasField("bool_value")
369
370  def do_decode(self, value, decode_fn):
371    del decode_fn
372    return value.bool_value
373
374
375StructureCoder.register_codec(_BoolCodec())
376
377
378class _TensorShapeCodec(object):
379  """Codec for `TensorShape`."""
380
381  def can_encode(self, pyobj):
382    return isinstance(pyobj, tensor_shape.TensorShape)
383
384  def do_encode(self, tensor_shape_value, encode_fn):
385    del encode_fn
386    encoded_tensor_shape = struct_pb2.StructuredValue()
387    encoded_tensor_shape.tensor_shape_value.CopyFrom(
388        tensor_shape_value.as_proto())
389    return encoded_tensor_shape
390
391  def can_decode(self, value):
392    return value.HasField("tensor_shape_value")
393
394  def do_decode(self, value, decode_fn):
395    del decode_fn
396    return tensor_shape.TensorShape(value.tensor_shape_value)
397
398
399StructureCoder.register_codec(_TensorShapeCodec())
400
401
402class _TensorTypeCodec(object):
403  """Codec for `TensorType`."""
404
405  def can_encode(self, pyobj):
406    return isinstance(pyobj, dtypes.DType)
407
408  def do_encode(self, tensor_dtype_value, encode_fn):
409    del encode_fn
410    encoded_tensor_type = struct_pb2.StructuredValue()
411    encoded_tensor_type.tensor_dtype_value = tensor_dtype_value.as_datatype_enum
412    return encoded_tensor_type
413
414  def can_decode(self, value):
415    return value.HasField("tensor_dtype_value")
416
417  def do_decode(self, value, decode_fn):
418    del decode_fn
419    return dtypes.DType(value.tensor_dtype_value)
420
421
422StructureCoder.register_codec(_TensorTypeCodec())
423
424
425class _TensorSpecCodec(object):
426  """Codec for `TensorSpec`."""
427
428  def can_encode(self, pyobj):
429    # BoundedTensorSpec has its own decoder.
430    return (isinstance(pyobj, tensor_spec.TensorSpec) and
431            not isinstance(pyobj, tensor_spec.BoundedTensorSpec))
432
433  def do_encode(self, tensor_spec_value, encode_fn):
434    encoded_tensor_spec = struct_pb2.StructuredValue()
435    encoded_tensor_spec.tensor_spec_value.CopyFrom(
436        struct_pb2.TensorSpecProto(
437            shape=encode_fn(tensor_spec_value.shape).tensor_shape_value,
438            dtype=encode_fn(tensor_spec_value.dtype).tensor_dtype_value,
439            name=tensor_spec_value.name))
440    return encoded_tensor_spec
441
442  def can_decode(self, value):
443    return value.HasField("tensor_spec_value")
444
445  def do_decode(self, value, decode_fn):
446    name = value.tensor_spec_value.name
447    return tensor_spec.TensorSpec(
448        shape=decode_fn(
449            struct_pb2.StructuredValue(
450                tensor_shape_value=value.tensor_spec_value.shape)),
451        dtype=decode_fn(
452            struct_pb2.StructuredValue(
453                tensor_dtype_value=value.tensor_spec_value.dtype)),
454        name=(name if name else None))
455
456
457StructureCoder.register_codec(_TensorSpecCodec())
458
459
460class _BoundedTensorSpecCodec(object):
461  """Codec for `BoundedTensorSpec`."""
462
463  def can_encode(self, pyobj):
464    return isinstance(pyobj, tensor_spec.BoundedTensorSpec)
465
466  def do_encode(self, bounded_tensor_spec_value, encode_fn):
467    """Returns an encoded proto for the given `tf.BoundedTensorSpec`."""
468    encoded_bounded_tensor_spec = struct_pb2.StructuredValue()
469    encoded_bounded_tensor_spec.bounded_tensor_spec_value.CopyFrom(
470        struct_pb2.BoundedTensorSpecProto(
471            shape=encode_fn(bounded_tensor_spec_value.shape).tensor_shape_value,
472            dtype=encode_fn(bounded_tensor_spec_value.dtype).tensor_dtype_value,
473            name=bounded_tensor_spec_value.name,
474            minimum=tensor_util.make_tensor_proto(
475                bounded_tensor_spec_value.minimum),
476            maximum=tensor_util.make_tensor_proto(
477                bounded_tensor_spec_value.maximum)))
478    return encoded_bounded_tensor_spec
479
480  def can_decode(self, value):
481    return value.HasField("bounded_tensor_spec_value")
482
483  def do_decode(self, value, decode_fn):
484    btsv = value.bounded_tensor_spec_value
485    name = btsv.name
486    return tensor_spec.BoundedTensorSpec(
487        shape=decode_fn(
488            struct_pb2.StructuredValue(tensor_shape_value=btsv.shape)),
489        dtype=decode_fn(
490            struct_pb2.StructuredValue(tensor_dtype_value=btsv.dtype)),
491        minimum=tensor_util.MakeNdarray(btsv.minimum),
492        maximum=tensor_util.MakeNdarray(btsv.maximum),
493        name=(name if name else None))
494
495
496StructureCoder.register_codec(_BoundedTensorSpecCodec())
497
498
499class _TypeSpecCodec(object):
500  """Codec for `tf.TypeSpec`."""
501
502  # Mapping from enum value to type (TypeSpec subclass).
503  TYPE_SPEC_CLASS_FROM_PROTO = {
504      struct_pb2.TypeSpecProto.SPARSE_TENSOR_SPEC:
505          sparse_tensor.SparseTensorSpec,
506      struct_pb2.TypeSpecProto.INDEXED_SLICES_SPEC:
507          indexed_slices.IndexedSlicesSpec,
508      struct_pb2.TypeSpecProto.RAGGED_TENSOR_SPEC:
509          ragged_tensor.RaggedTensorSpec,
510      struct_pb2.TypeSpecProto.TENSOR_ARRAY_SPEC:
511          tensor_array_ops.TensorArraySpec,
512      struct_pb2.TypeSpecProto.DATA_DATASET_SPEC:
513          dataset_ops.DatasetSpec,
514      struct_pb2.TypeSpecProto.DATA_ITERATOR_SPEC:
515          iterator_ops.IteratorSpec,
516      struct_pb2.TypeSpecProto.OPTIONAL_SPEC:
517          optional_ops.OptionalSpec,
518      struct_pb2.TypeSpecProto.PER_REPLICA_SPEC:
519          values.PerReplicaSpec,
520      struct_pb2.TypeSpecProto.VARIABLE_SPEC:
521          resource_variable_ops.VariableSpec,
522      struct_pb2.TypeSpecProto.ROW_PARTITION_SPEC:
523          row_partition.RowPartitionSpec,
524  }
525
526  # Mapping from type (TypeSpec subclass) to enum value.
527  TYPE_SPEC_CLASS_TO_PROTO = dict(
528      (cls, enum) for (enum, cls) in TYPE_SPEC_CLASS_FROM_PROTO.items())
529
530  def can_encode(self, pyobj):
531    """Returns true if `pyboj` can be encoded as a TypeSpec."""
532    if type(pyobj) in self.TYPE_SPEC_CLASS_TO_PROTO:  # pylint: disable=unidiomatic-typecheck
533      return True
534
535    # Check if it's a registered type.
536    if isinstance(pyobj, type_spec.TypeSpec):
537      try:
538        type_spec.get_name(type(pyobj))
539        return True
540      except ValueError:
541        return False
542
543    return False
544
545  def do_encode(self, type_spec_value, encode_fn):
546    """Returns an encoded proto for the given `tf.TypeSpec`."""
547    type_spec_class = self.TYPE_SPEC_CLASS_TO_PROTO.get(type(type_spec_value))
548    type_spec_class_name = type(type_spec_value).__name__
549
550    if type_spec_class is None:
551      type_spec_class_name = type_spec.get_name(type(type_spec_value))
552      if isinstance(type_spec_value, extension_type.ExtensionTypeSpec):
553        type_spec_class = struct_pb2.TypeSpecProto.EXTENSION_TYPE_SPEC
554      else:
555        type_spec_class = struct_pb2.TypeSpecProto.REGISTERED_TYPE_SPEC
556        # Support for saving registered TypeSpecs is currently experimental.
557        # Issue a warning to indicate the limitations.
558        warnings.warn("Encoding a StructuredValue with type %s; loading this "
559                      "StructuredValue will require that this type be "
560                      "imported and registered." % type_spec_class_name)
561
562    type_state = type_spec_value._serialize()  # pylint: disable=protected-access
563    encoded_type_spec = struct_pb2.StructuredValue()
564    encoded_type_spec.type_spec_value.CopyFrom(
565        struct_pb2.TypeSpecProto(
566            type_spec_class=type_spec_class,
567            type_state=encode_fn(type_state),
568            type_spec_class_name=type_spec_class_name))
569    return encoded_type_spec
570
571  def can_decode(self, value):
572    return value.HasField("type_spec_value")
573
574  def do_decode(self, value, decode_fn):
575    """Returns the `tf.TypeSpec` encoded by the proto `value`."""
576    type_spec_proto = value.type_spec_value
577    type_spec_class_enum = type_spec_proto.type_spec_class
578    class_name = type_spec_proto.type_spec_class_name
579
580    if type_spec_class_enum == struct_pb2.TypeSpecProto.REGISTERED_TYPE_SPEC:
581      try:
582        type_spec_class = type_spec.lookup(class_name)
583      except ValueError as e:
584        raise ValueError(
585            f"The type '{class_name}' has not been registered.  It must be "
586            "registered before you load this object (typically by importing "
587            "its module).") from e
588    elif type_spec_class_enum == struct_pb2.TypeSpecProto.EXTENSION_TYPE_SPEC:
589      try:
590        type_spec_class = type_spec.lookup(class_name)
591      except ValueError:
592        type_spec_class = extension_type.AnonymousExtensionTypeSpec
593        warnings.warn("The type %r has not been registered.  Falling back to "
594                      "using AnonymousExtensionTypeSpec instead.")
595    else:
596      if type_spec_class_enum not in self.TYPE_SPEC_CLASS_FROM_PROTO:
597        raise ValueError(
598            f"The type '{class_name}' is not supported by this version of "
599            "TensorFlow. (The object you are loading must have been created "
600            "with a newer version of TensorFlow.)")
601      type_spec_class = self.TYPE_SPEC_CLASS_FROM_PROTO[type_spec_class_enum]
602
603    # pylint: disable=protected-access
604    return type_spec_class._deserialize(decode_fn(type_spec_proto.type_state))
605
606
607StructureCoder.register_codec(_TypeSpecCodec())
608