• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module that encodes (decodes) nested structures into (from) protos.
16
17The intended use is to serialize everything needed to restore a `Function` that
18was saved into a SavedModel. This may include concrete function inputs and
19outputs, signatures, function specs, etc.
20
21Example use:
22coder = nested_structure_coder.StructureCoder()
23# Encode into proto.
24signature_proto = coder.encode_structure(function.input_signature)
25# Decode into a Python object.
26restored_signature = coder.decode_proto(signature_proto)
27"""
28
29from __future__ import absolute_import
30from __future__ import division
31from __future__ import print_function
32
33import collections
34import functools
35
36import six
37
38from tensorflow.core.protobuf import struct_pb2
39from tensorflow.python.data.ops import dataset_ops
40from tensorflow.python.data.ops import iterator_ops
41from tensorflow.python.data.ops import optional_ops
42from tensorflow.python.distribute import values
43from tensorflow.python.framework import dtypes
44from tensorflow.python.framework import indexed_slices
45from tensorflow.python.framework import sparse_tensor
46from tensorflow.python.framework import tensor_shape
47from tensorflow.python.framework import tensor_spec
48from tensorflow.python.framework import tensor_util
49from tensorflow.python.ops import resource_variable_ops
50from tensorflow.python.ops import tensor_array_ops
51from tensorflow.python.ops.ragged import ragged_tensor
52from tensorflow.python.ops.ragged import row_partition
53from tensorflow.python.util import compat
54from tensorflow.python.util.compat import collections_abc
55
56
57class NotEncodableError(Exception):
58  """Error raised when a coder cannot encode an object."""
59
60
61class StructureCoder(object):
62  """Encoder and decoder for nested structures into protos."""
63
64  _codecs = []
65
66  @classmethod
67  def register_codec(cls, x):
68    cls._codecs.append(x)
69
70  @classmethod
71  def _get_encoders(cls):
72    return [(c.can_encode, c.do_encode) for c in cls._codecs]
73
74  @classmethod
75  def _get_decoders(cls):
76    return [(c.can_decode, c.do_decode) for c in cls._codecs]
77
78  def _map_structure(self, pyobj, coders):
79    for can, do in coders:
80      if can(pyobj):
81        recursion_fn = functools.partial(self._map_structure, coders=coders)
82        return do(pyobj, recursion_fn)
83    raise NotEncodableError(
84        "No encoder for object [%s] of type [%s]." % (str(pyobj), type(pyobj)))
85
86  def encode_structure(self, nested_structure):
87    """Encodes nested structures composed of encodable types into a proto.
88
89    Args:
90      nested_structure: Structure to encode.
91
92    Returns:
93      Encoded proto.
94
95    Raises:
96      NotEncodableError: For values for which there are no encoders.
97    """
98    return self._map_structure(nested_structure, self._get_encoders())
99
100  def can_encode(self, nested_structure):
101    """Determines whether a nested structure can be encoded into a proto.
102
103    Args:
104      nested_structure: Structure to encode.
105
106    Returns:
107      True if the nested structured can be encoded.
108    """
109    try:
110      self.encode_structure(nested_structure)
111    except NotEncodableError:
112      return False
113    return True
114
115  def decode_proto(self, proto):
116    """Decodes proto representing a nested structure.
117
118    Args:
119      proto: Proto to decode.
120
121    Returns:
122      Decoded structure.
123
124    Raises:
125      NotEncodableError: For values for which there are no encoders.
126    """
127    return self._map_structure(proto, self._get_decoders())
128
129
130class _ListCodec(object):
131  """Codec for lists."""
132
133  def can_encode(self, pyobj):
134    return isinstance(pyobj, list)
135
136  def do_encode(self, list_value, encode_fn):
137    encoded_list = struct_pb2.StructuredValue()
138    encoded_list.list_value.CopyFrom(struct_pb2.ListValue())
139    for element in list_value:
140      encoded_list.list_value.values.add().CopyFrom(encode_fn(element))
141    return encoded_list
142
143  def can_decode(self, value):
144    return value.HasField("list_value")
145
146  def do_decode(self, value, decode_fn):
147    return [decode_fn(element) for element in value.list_value.values]
148
149
150StructureCoder.register_codec(_ListCodec())
151
152
153def _is_tuple(obj):
154  return not _is_named_tuple(obj) and isinstance(obj, tuple)
155
156
157def _is_named_tuple(instance):
158  """Returns True iff `instance` is a `namedtuple`.
159
160  Args:
161    instance: An instance of a Python object.
162
163  Returns:
164    True if `instance` is a `namedtuple`.
165  """
166  if not isinstance(instance, tuple):
167    return False
168  return (hasattr(instance, "_fields") and
169          isinstance(instance._fields, collections_abc.Sequence) and
170          all(isinstance(f, six.string_types) for f in instance._fields))
171
172
173class _TupleCodec(object):
174  """Codec for tuples."""
175
176  def can_encode(self, pyobj):
177    return _is_tuple(pyobj)
178
179  def do_encode(self, tuple_value, encode_fn):
180    encoded_tuple = struct_pb2.StructuredValue()
181    encoded_tuple.tuple_value.CopyFrom(struct_pb2.TupleValue())
182    for element in tuple_value:
183      encoded_tuple.tuple_value.values.add().CopyFrom(encode_fn(element))
184    return encoded_tuple
185
186  def can_decode(self, value):
187    return value.HasField("tuple_value")
188
189  def do_decode(self, value, decode_fn):
190    return tuple(decode_fn(element) for element in value.tuple_value.values)
191
192
193StructureCoder.register_codec(_TupleCodec())
194
195
196class _DictCodec(object):
197  """Codec for dicts."""
198
199  def can_encode(self, pyobj):
200    return isinstance(pyobj, dict)
201
202  def do_encode(self, dict_value, encode_fn):
203    encoded_dict = struct_pb2.StructuredValue()
204    encoded_dict.dict_value.CopyFrom(struct_pb2.DictValue())
205    for key, value in dict_value.items():
206      encoded_dict.dict_value.fields[key].CopyFrom(encode_fn(value))
207    return encoded_dict
208
209  def can_decode(self, value):
210    return value.HasField("dict_value")
211
212  def do_decode(self, value, decode_fn):
213    return {key: decode_fn(val) for key, val in value.dict_value.fields.items()}
214
215
216StructureCoder.register_codec(_DictCodec())
217
218
219class _NamedTupleCodec(object):
220  """Codec for namedtuples.
221
222  Encoding and decoding a namedtuple reconstructs a namedtuple with a different
223  actual Python type, but with the same `typename` and `fields`.
224  """
225
226  def can_encode(self, pyobj):
227    return _is_named_tuple(pyobj)
228
229  def do_encode(self, named_tuple_value, encode_fn):
230    encoded_named_tuple = struct_pb2.StructuredValue()
231    encoded_named_tuple.named_tuple_value.CopyFrom(struct_pb2.NamedTupleValue())
232    encoded_named_tuple.named_tuple_value.name = \
233      named_tuple_value.__class__.__name__
234    for key in named_tuple_value._fields:
235      pair = encoded_named_tuple.named_tuple_value.values.add()
236      pair.key = key
237      pair.value.CopyFrom(encode_fn(named_tuple_value._asdict()[key]))
238    return encoded_named_tuple
239
240  def can_decode(self, value):
241    return value.HasField("named_tuple_value")
242
243  def do_decode(self, value, decode_fn):
244    key_value_pairs = value.named_tuple_value.values
245    items = [(pair.key, decode_fn(pair.value)) for pair in key_value_pairs]
246    named_tuple_type = collections.namedtuple(value.named_tuple_value.name,
247                                              [item[0] for item in items])
248    return named_tuple_type(**dict(items))
249
250
251StructureCoder.register_codec(_NamedTupleCodec())
252
253
254class _Float64Codec(object):
255  """Codec for floats."""
256
257  def can_encode(self, pyobj):
258    return isinstance(pyobj, float)
259
260  def do_encode(self, float64_value, encode_fn):
261    del encode_fn
262    value = struct_pb2.StructuredValue()
263    value.float64_value = float64_value
264    return value
265
266  def can_decode(self, value):
267    return value.HasField("float64_value")
268
269  def do_decode(self, value, decode_fn):
270    del decode_fn
271    return value.float64_value
272
273
274StructureCoder.register_codec(_Float64Codec())
275
276
277class _Int64Codec(object):
278  """Codec for Python integers (limited to 64 bit values)."""
279
280  def can_encode(self, pyobj):
281    return not isinstance(pyobj, bool) and isinstance(pyobj, int)
282
283  def do_encode(self, int_value, encode_fn):
284    del encode_fn
285    value = struct_pb2.StructuredValue()
286    value.int64_value = int_value
287    return value
288
289  def can_decode(self, value):
290    return value.HasField("int64_value")
291
292  def do_decode(self, value, decode_fn):
293    del decode_fn
294    return int(value.int64_value)
295
296
297StructureCoder.register_codec(_Int64Codec())
298
299
300class _StringCodec(object):
301  """Codec for strings.
302
303  See StructuredValue.string_value in proto/struct.proto for more detailed
304  explanation.
305  """
306
307  def can_encode(self, pyobj):
308    return isinstance(pyobj, str)
309
310  def do_encode(self, string_value, encode_fn):
311    del encode_fn
312    value = struct_pb2.StructuredValue()
313    value.string_value = string_value
314    return value
315
316  def can_decode(self, value):
317    return value.HasField("string_value")
318
319  def do_decode(self, value, decode_fn):
320    del decode_fn
321    return compat.as_str(value.string_value)
322
323
324StructureCoder.register_codec(_StringCodec())
325
326
327class _NoneCodec(object):
328  """Codec for None."""
329
330  def can_encode(self, pyobj):
331    return pyobj is None
332
333  def do_encode(self, none_value, encode_fn):
334    del encode_fn, none_value
335    value = struct_pb2.StructuredValue()
336    value.none_value.CopyFrom(struct_pb2.NoneValue())
337    return value
338
339  def can_decode(self, value):
340    return value.HasField("none_value")
341
342  def do_decode(self, value, decode_fn):
343    del decode_fn, value
344    return None
345
346
347StructureCoder.register_codec(_NoneCodec())
348
349
350class _BoolCodec(object):
351  """Codec for booleans."""
352
353  def can_encode(self, pyobj):
354    return isinstance(pyobj, bool)
355
356  def do_encode(self, bool_value, encode_fn):
357    del encode_fn
358    value = struct_pb2.StructuredValue()
359    value.bool_value = bool_value
360    return value
361
362  def can_decode(self, value):
363    return value.HasField("bool_value")
364
365  def do_decode(self, value, decode_fn):
366    del decode_fn
367    return value.bool_value
368
369
370StructureCoder.register_codec(_BoolCodec())
371
372
373class _TensorShapeCodec(object):
374  """Codec for `TensorShape`."""
375
376  def can_encode(self, pyobj):
377    return isinstance(pyobj, tensor_shape.TensorShape)
378
379  def do_encode(self, tensor_shape_value, encode_fn):
380    del encode_fn
381    encoded_tensor_shape = struct_pb2.StructuredValue()
382    encoded_tensor_shape.tensor_shape_value.CopyFrom(
383        tensor_shape_value.as_proto())
384    return encoded_tensor_shape
385
386  def can_decode(self, value):
387    return value.HasField("tensor_shape_value")
388
389  def do_decode(self, value, decode_fn):
390    del decode_fn
391    return tensor_shape.TensorShape(value.tensor_shape_value)
392
393
394StructureCoder.register_codec(_TensorShapeCodec())
395
396
397class _TensorTypeCodec(object):
398  """Codec for `TensorType`."""
399
400  def can_encode(self, pyobj):
401    return isinstance(pyobj, dtypes.DType)
402
403  def do_encode(self, tensor_dtype_value, encode_fn):
404    del encode_fn
405    encoded_tensor_type = struct_pb2.StructuredValue()
406    encoded_tensor_type.tensor_dtype_value = tensor_dtype_value.as_datatype_enum
407    return encoded_tensor_type
408
409  def can_decode(self, value):
410    return value.HasField("tensor_dtype_value")
411
412  def do_decode(self, value, decode_fn):
413    del decode_fn
414    return dtypes.DType(value.tensor_dtype_value)
415
416
417StructureCoder.register_codec(_TensorTypeCodec())
418
419
420class _TensorSpecCodec(object):
421  """Codec for `TensorSpec`."""
422
423  def can_encode(self, pyobj):
424    # BoundedTensorSpec has its own decoder.
425    return (isinstance(pyobj, tensor_spec.TensorSpec) and
426            not isinstance(pyobj, tensor_spec.BoundedTensorSpec))
427
428  def do_encode(self, tensor_spec_value, encode_fn):
429    encoded_tensor_spec = struct_pb2.StructuredValue()
430    encoded_tensor_spec.tensor_spec_value.CopyFrom(
431        struct_pb2.TensorSpecProto(
432            shape=encode_fn(tensor_spec_value.shape).tensor_shape_value,
433            dtype=encode_fn(tensor_spec_value.dtype).tensor_dtype_value,
434            name=tensor_spec_value.name))
435    return encoded_tensor_spec
436
437  def can_decode(self, value):
438    return value.HasField("tensor_spec_value")
439
440  def do_decode(self, value, decode_fn):
441    name = value.tensor_spec_value.name
442    return tensor_spec.TensorSpec(
443        shape=decode_fn(
444            struct_pb2.StructuredValue(
445                tensor_shape_value=value.tensor_spec_value.shape)),
446        dtype=decode_fn(
447            struct_pb2.StructuredValue(
448                tensor_dtype_value=value.tensor_spec_value.dtype)),
449        name=(name if name else None))
450
451
452StructureCoder.register_codec(_TensorSpecCodec())
453
454
455class _BoundedTensorSpecCodec(object):
456  """Codec for `BoundedTensorSpec`."""
457
458  def can_encode(self, pyobj):
459    return isinstance(pyobj, tensor_spec.BoundedTensorSpec)
460
461  def do_encode(self, bounded_tensor_spec_value, encode_fn):
462    """Returns an encoded proto for the given `tf.BoundedTensorSpec`."""
463    encoded_bounded_tensor_spec = struct_pb2.StructuredValue()
464    encoded_bounded_tensor_spec.bounded_tensor_spec_value.CopyFrom(
465        struct_pb2.BoundedTensorSpecProto(
466            shape=encode_fn(bounded_tensor_spec_value.shape).tensor_shape_value,
467            dtype=encode_fn(bounded_tensor_spec_value.dtype).tensor_dtype_value,
468            name=bounded_tensor_spec_value.name,
469            minimum=tensor_util.make_tensor_proto(
470                bounded_tensor_spec_value.minimum),
471            maximum=tensor_util.make_tensor_proto(
472                bounded_tensor_spec_value.maximum)))
473    return encoded_bounded_tensor_spec
474
475  def can_decode(self, value):
476    return value.HasField("bounded_tensor_spec_value")
477
478  def do_decode(self, value, decode_fn):
479    btsv = value.bounded_tensor_spec_value
480    name = btsv.name
481    return tensor_spec.BoundedTensorSpec(
482        shape=decode_fn(
483            struct_pb2.StructuredValue(tensor_shape_value=btsv.shape)),
484        dtype=decode_fn(
485            struct_pb2.StructuredValue(tensor_dtype_value=btsv.dtype)),
486        minimum=tensor_util.MakeNdarray(btsv.minimum),
487        maximum=tensor_util.MakeNdarray(btsv.maximum),
488        name=(name if name else None))
489
490
491StructureCoder.register_codec(_BoundedTensorSpecCodec())
492
493
494class _TypeSpecCodec(object):
495  """Codec for `tf.TypeSpec`."""
496
497  # Mapping from enum value to type (TypeSpec subclass).
498  TYPE_SPEC_CLASS_FROM_PROTO = {
499      struct_pb2.TypeSpecProto.SPARSE_TENSOR_SPEC:
500          sparse_tensor.SparseTensorSpec,
501      struct_pb2.TypeSpecProto.INDEXED_SLICES_SPEC:
502          indexed_slices.IndexedSlicesSpec,
503      struct_pb2.TypeSpecProto.RAGGED_TENSOR_SPEC:
504          ragged_tensor.RaggedTensorSpec,
505      struct_pb2.TypeSpecProto.TENSOR_ARRAY_SPEC:
506          tensor_array_ops.TensorArraySpec,
507      struct_pb2.TypeSpecProto.DATA_DATASET_SPEC:
508          dataset_ops.DatasetSpec,
509      struct_pb2.TypeSpecProto.DATA_ITERATOR_SPEC:
510          iterator_ops.IteratorSpec,
511      struct_pb2.TypeSpecProto.OPTIONAL_SPEC:
512          optional_ops.OptionalSpec,
513      struct_pb2.TypeSpecProto.PER_REPLICA_SPEC:
514          values.PerReplicaSpec,
515      struct_pb2.TypeSpecProto.VARIABLE_SPEC:
516          resource_variable_ops.VariableSpec,
517      struct_pb2.TypeSpecProto.ROW_PARTITION_SPEC:
518          row_partition.RowPartitionSpec,
519  }
520
521  # Mapping from type (TypeSpec subclass) to enum value.
522  TYPE_SPEC_CLASS_TO_PROTO = dict(
523      (cls, enum) for (enum, cls) in TYPE_SPEC_CLASS_FROM_PROTO.items())
524
525  def can_encode(self, pyobj):
526    # pylint: disable=unidiomatic-typecheck
527    return type(pyobj) in self.TYPE_SPEC_CLASS_TO_PROTO
528
529  def do_encode(self, type_spec_value, encode_fn):
530    """Returns an encoded proto for the given `tf.TypeSpec`."""
531    type_spec_class = self.TYPE_SPEC_CLASS_TO_PROTO[type(type_spec_value)]
532    type_state = type_spec_value._serialize()  # pylint: disable=protected-access
533    encoded_type_spec = struct_pb2.StructuredValue()
534    encoded_type_spec.type_spec_value.CopyFrom(
535        struct_pb2.TypeSpecProto(
536            type_spec_class=type_spec_class,
537            type_state=encode_fn(type_state),
538            type_spec_class_name=type(type_spec_value).__name__))
539    return encoded_type_spec
540
541  def can_decode(self, value):
542    return value.HasField("type_spec_value")
543
544  def do_decode(self, value, decode_fn):
545    """Returns the `tf.TypeSpec` encoded by the proto `value`."""
546    type_spec_proto = value.type_spec_value
547    type_spec_class_enum = type_spec_proto.type_spec_class
548    if type_spec_class_enum not in self.TYPE_SPEC_CLASS_FROM_PROTO:
549      raise ValueError(
550          "The type '%s' is not supported by this version of TensorFlow. "
551          "(The object you are loading must have been created with a newer "
552          "version of TensorFlow.)" % type_spec_proto.type_spec_class_name)
553
554    type_spec_class = self.TYPE_SPEC_CLASS_FROM_PROTO[type_spec_class_enum]
555    # pylint: disable=protected-access
556    return type_spec_class._deserialize(decode_fn(type_spec_proto.type_state))
557
558
559StructureCoder.register_codec(_TypeSpecCodec())
560