1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Module that encodes (decodes) nested structures into (from) protos. 16 17The intended use is to serialize everything needed to restore a `Function` that 18was saved into a SavedModel. This may include concrete function inputs and 19outputs, signatures, function specs, etc. 20 21Example use: 22coder = nested_structure_coder.StructureCoder() 23# Encode into proto. 24signature_proto = coder.encode_structure(function.input_signature) 25# Decode into a Python object. 26restored_signature = coder.decode_proto(signature_proto) 27""" 28 29from __future__ import absolute_import 30from __future__ import division 31from __future__ import print_function 32 33import collections 34import functools 35 36import six 37 38from tensorflow.core.protobuf import struct_pb2 39from tensorflow.python.data.ops import dataset_ops 40from tensorflow.python.data.ops import iterator_ops 41from tensorflow.python.data.ops import optional_ops 42from tensorflow.python.distribute import values 43from tensorflow.python.framework import dtypes 44from tensorflow.python.framework import indexed_slices 45from tensorflow.python.framework import sparse_tensor 46from tensorflow.python.framework import tensor_shape 47from tensorflow.python.framework import tensor_spec 48from tensorflow.python.framework import tensor_util 49from tensorflow.python.ops import resource_variable_ops 50from tensorflow.python.ops import tensor_array_ops 51from tensorflow.python.ops.ragged import ragged_tensor 52from tensorflow.python.ops.ragged import row_partition 53from tensorflow.python.util import compat 54from tensorflow.python.util.compat import collections_abc 55 56 57class NotEncodableError(Exception): 58 """Error raised when a coder cannot encode an object.""" 59 60 61class StructureCoder(object): 62 """Encoder and decoder for nested structures into protos.""" 63 64 _codecs = [] 65 66 @classmethod 67 def register_codec(cls, x): 68 cls._codecs.append(x) 69 70 @classmethod 71 def _get_encoders(cls): 72 return [(c.can_encode, c.do_encode) for c in cls._codecs] 73 74 @classmethod 75 def _get_decoders(cls): 76 return [(c.can_decode, c.do_decode) for c in cls._codecs] 77 78 def _map_structure(self, pyobj, coders): 79 for can, do in coders: 80 if can(pyobj): 81 recursion_fn = functools.partial(self._map_structure, coders=coders) 82 return do(pyobj, recursion_fn) 83 raise NotEncodableError( 84 "No encoder for object [%s] of type [%s]." % (str(pyobj), type(pyobj))) 85 86 def encode_structure(self, nested_structure): 87 """Encodes nested structures composed of encodable types into a proto. 88 89 Args: 90 nested_structure: Structure to encode. 91 92 Returns: 93 Encoded proto. 94 95 Raises: 96 NotEncodableError: For values for which there are no encoders. 97 """ 98 return self._map_structure(nested_structure, self._get_encoders()) 99 100 def can_encode(self, nested_structure): 101 """Determines whether a nested structure can be encoded into a proto. 102 103 Args: 104 nested_structure: Structure to encode. 105 106 Returns: 107 True if the nested structured can be encoded. 108 """ 109 try: 110 self.encode_structure(nested_structure) 111 except NotEncodableError: 112 return False 113 return True 114 115 def decode_proto(self, proto): 116 """Decodes proto representing a nested structure. 117 118 Args: 119 proto: Proto to decode. 120 121 Returns: 122 Decoded structure. 123 124 Raises: 125 NotEncodableError: For values for which there are no encoders. 126 """ 127 return self._map_structure(proto, self._get_decoders()) 128 129 130class _ListCodec(object): 131 """Codec for lists.""" 132 133 def can_encode(self, pyobj): 134 return isinstance(pyobj, list) 135 136 def do_encode(self, list_value, encode_fn): 137 encoded_list = struct_pb2.StructuredValue() 138 encoded_list.list_value.CopyFrom(struct_pb2.ListValue()) 139 for element in list_value: 140 encoded_list.list_value.values.add().CopyFrom(encode_fn(element)) 141 return encoded_list 142 143 def can_decode(self, value): 144 return value.HasField("list_value") 145 146 def do_decode(self, value, decode_fn): 147 return [decode_fn(element) for element in value.list_value.values] 148 149 150StructureCoder.register_codec(_ListCodec()) 151 152 153def _is_tuple(obj): 154 return not _is_named_tuple(obj) and isinstance(obj, tuple) 155 156 157def _is_named_tuple(instance): 158 """Returns True iff `instance` is a `namedtuple`. 159 160 Args: 161 instance: An instance of a Python object. 162 163 Returns: 164 True if `instance` is a `namedtuple`. 165 """ 166 if not isinstance(instance, tuple): 167 return False 168 return (hasattr(instance, "_fields") and 169 isinstance(instance._fields, collections_abc.Sequence) and 170 all(isinstance(f, six.string_types) for f in instance._fields)) 171 172 173class _TupleCodec(object): 174 """Codec for tuples.""" 175 176 def can_encode(self, pyobj): 177 return _is_tuple(pyobj) 178 179 def do_encode(self, tuple_value, encode_fn): 180 encoded_tuple = struct_pb2.StructuredValue() 181 encoded_tuple.tuple_value.CopyFrom(struct_pb2.TupleValue()) 182 for element in tuple_value: 183 encoded_tuple.tuple_value.values.add().CopyFrom(encode_fn(element)) 184 return encoded_tuple 185 186 def can_decode(self, value): 187 return value.HasField("tuple_value") 188 189 def do_decode(self, value, decode_fn): 190 return tuple(decode_fn(element) for element in value.tuple_value.values) 191 192 193StructureCoder.register_codec(_TupleCodec()) 194 195 196class _DictCodec(object): 197 """Codec for dicts.""" 198 199 def can_encode(self, pyobj): 200 return isinstance(pyobj, dict) 201 202 def do_encode(self, dict_value, encode_fn): 203 encoded_dict = struct_pb2.StructuredValue() 204 encoded_dict.dict_value.CopyFrom(struct_pb2.DictValue()) 205 for key, value in dict_value.items(): 206 encoded_dict.dict_value.fields[key].CopyFrom(encode_fn(value)) 207 return encoded_dict 208 209 def can_decode(self, value): 210 return value.HasField("dict_value") 211 212 def do_decode(self, value, decode_fn): 213 return {key: decode_fn(val) for key, val in value.dict_value.fields.items()} 214 215 216StructureCoder.register_codec(_DictCodec()) 217 218 219class _NamedTupleCodec(object): 220 """Codec for namedtuples. 221 222 Encoding and decoding a namedtuple reconstructs a namedtuple with a different 223 actual Python type, but with the same `typename` and `fields`. 224 """ 225 226 def can_encode(self, pyobj): 227 return _is_named_tuple(pyobj) 228 229 def do_encode(self, named_tuple_value, encode_fn): 230 encoded_named_tuple = struct_pb2.StructuredValue() 231 encoded_named_tuple.named_tuple_value.CopyFrom(struct_pb2.NamedTupleValue()) 232 encoded_named_tuple.named_tuple_value.name = \ 233 named_tuple_value.__class__.__name__ 234 for key in named_tuple_value._fields: 235 pair = encoded_named_tuple.named_tuple_value.values.add() 236 pair.key = key 237 pair.value.CopyFrom(encode_fn(named_tuple_value._asdict()[key])) 238 return encoded_named_tuple 239 240 def can_decode(self, value): 241 return value.HasField("named_tuple_value") 242 243 def do_decode(self, value, decode_fn): 244 key_value_pairs = value.named_tuple_value.values 245 items = [(pair.key, decode_fn(pair.value)) for pair in key_value_pairs] 246 named_tuple_type = collections.namedtuple(value.named_tuple_value.name, 247 [item[0] for item in items]) 248 return named_tuple_type(**dict(items)) 249 250 251StructureCoder.register_codec(_NamedTupleCodec()) 252 253 254class _Float64Codec(object): 255 """Codec for floats.""" 256 257 def can_encode(self, pyobj): 258 return isinstance(pyobj, float) 259 260 def do_encode(self, float64_value, encode_fn): 261 del encode_fn 262 value = struct_pb2.StructuredValue() 263 value.float64_value = float64_value 264 return value 265 266 def can_decode(self, value): 267 return value.HasField("float64_value") 268 269 def do_decode(self, value, decode_fn): 270 del decode_fn 271 return value.float64_value 272 273 274StructureCoder.register_codec(_Float64Codec()) 275 276 277class _Int64Codec(object): 278 """Codec for Python integers (limited to 64 bit values).""" 279 280 def can_encode(self, pyobj): 281 return not isinstance(pyobj, bool) and isinstance(pyobj, int) 282 283 def do_encode(self, int_value, encode_fn): 284 del encode_fn 285 value = struct_pb2.StructuredValue() 286 value.int64_value = int_value 287 return value 288 289 def can_decode(self, value): 290 return value.HasField("int64_value") 291 292 def do_decode(self, value, decode_fn): 293 del decode_fn 294 return int(value.int64_value) 295 296 297StructureCoder.register_codec(_Int64Codec()) 298 299 300class _StringCodec(object): 301 """Codec for strings. 302 303 See StructuredValue.string_value in proto/struct.proto for more detailed 304 explanation. 305 """ 306 307 def can_encode(self, pyobj): 308 return isinstance(pyobj, str) 309 310 def do_encode(self, string_value, encode_fn): 311 del encode_fn 312 value = struct_pb2.StructuredValue() 313 value.string_value = string_value 314 return value 315 316 def can_decode(self, value): 317 return value.HasField("string_value") 318 319 def do_decode(self, value, decode_fn): 320 del decode_fn 321 return compat.as_str(value.string_value) 322 323 324StructureCoder.register_codec(_StringCodec()) 325 326 327class _NoneCodec(object): 328 """Codec for None.""" 329 330 def can_encode(self, pyobj): 331 return pyobj is None 332 333 def do_encode(self, none_value, encode_fn): 334 del encode_fn, none_value 335 value = struct_pb2.StructuredValue() 336 value.none_value.CopyFrom(struct_pb2.NoneValue()) 337 return value 338 339 def can_decode(self, value): 340 return value.HasField("none_value") 341 342 def do_decode(self, value, decode_fn): 343 del decode_fn, value 344 return None 345 346 347StructureCoder.register_codec(_NoneCodec()) 348 349 350class _BoolCodec(object): 351 """Codec for booleans.""" 352 353 def can_encode(self, pyobj): 354 return isinstance(pyobj, bool) 355 356 def do_encode(self, bool_value, encode_fn): 357 del encode_fn 358 value = struct_pb2.StructuredValue() 359 value.bool_value = bool_value 360 return value 361 362 def can_decode(self, value): 363 return value.HasField("bool_value") 364 365 def do_decode(self, value, decode_fn): 366 del decode_fn 367 return value.bool_value 368 369 370StructureCoder.register_codec(_BoolCodec()) 371 372 373class _TensorShapeCodec(object): 374 """Codec for `TensorShape`.""" 375 376 def can_encode(self, pyobj): 377 return isinstance(pyobj, tensor_shape.TensorShape) 378 379 def do_encode(self, tensor_shape_value, encode_fn): 380 del encode_fn 381 encoded_tensor_shape = struct_pb2.StructuredValue() 382 encoded_tensor_shape.tensor_shape_value.CopyFrom( 383 tensor_shape_value.as_proto()) 384 return encoded_tensor_shape 385 386 def can_decode(self, value): 387 return value.HasField("tensor_shape_value") 388 389 def do_decode(self, value, decode_fn): 390 del decode_fn 391 return tensor_shape.TensorShape(value.tensor_shape_value) 392 393 394StructureCoder.register_codec(_TensorShapeCodec()) 395 396 397class _TensorTypeCodec(object): 398 """Codec for `TensorType`.""" 399 400 def can_encode(self, pyobj): 401 return isinstance(pyobj, dtypes.DType) 402 403 def do_encode(self, tensor_dtype_value, encode_fn): 404 del encode_fn 405 encoded_tensor_type = struct_pb2.StructuredValue() 406 encoded_tensor_type.tensor_dtype_value = tensor_dtype_value.as_datatype_enum 407 return encoded_tensor_type 408 409 def can_decode(self, value): 410 return value.HasField("tensor_dtype_value") 411 412 def do_decode(self, value, decode_fn): 413 del decode_fn 414 return dtypes.DType(value.tensor_dtype_value) 415 416 417StructureCoder.register_codec(_TensorTypeCodec()) 418 419 420class _TensorSpecCodec(object): 421 """Codec for `TensorSpec`.""" 422 423 def can_encode(self, pyobj): 424 # BoundedTensorSpec has its own decoder. 425 return (isinstance(pyobj, tensor_spec.TensorSpec) and 426 not isinstance(pyobj, tensor_spec.BoundedTensorSpec)) 427 428 def do_encode(self, tensor_spec_value, encode_fn): 429 encoded_tensor_spec = struct_pb2.StructuredValue() 430 encoded_tensor_spec.tensor_spec_value.CopyFrom( 431 struct_pb2.TensorSpecProto( 432 shape=encode_fn(tensor_spec_value.shape).tensor_shape_value, 433 dtype=encode_fn(tensor_spec_value.dtype).tensor_dtype_value, 434 name=tensor_spec_value.name)) 435 return encoded_tensor_spec 436 437 def can_decode(self, value): 438 return value.HasField("tensor_spec_value") 439 440 def do_decode(self, value, decode_fn): 441 name = value.tensor_spec_value.name 442 return tensor_spec.TensorSpec( 443 shape=decode_fn( 444 struct_pb2.StructuredValue( 445 tensor_shape_value=value.tensor_spec_value.shape)), 446 dtype=decode_fn( 447 struct_pb2.StructuredValue( 448 tensor_dtype_value=value.tensor_spec_value.dtype)), 449 name=(name if name else None)) 450 451 452StructureCoder.register_codec(_TensorSpecCodec()) 453 454 455class _BoundedTensorSpecCodec(object): 456 """Codec for `BoundedTensorSpec`.""" 457 458 def can_encode(self, pyobj): 459 return isinstance(pyobj, tensor_spec.BoundedTensorSpec) 460 461 def do_encode(self, bounded_tensor_spec_value, encode_fn): 462 """Returns an encoded proto for the given `tf.BoundedTensorSpec`.""" 463 encoded_bounded_tensor_spec = struct_pb2.StructuredValue() 464 encoded_bounded_tensor_spec.bounded_tensor_spec_value.CopyFrom( 465 struct_pb2.BoundedTensorSpecProto( 466 shape=encode_fn(bounded_tensor_spec_value.shape).tensor_shape_value, 467 dtype=encode_fn(bounded_tensor_spec_value.dtype).tensor_dtype_value, 468 name=bounded_tensor_spec_value.name, 469 minimum=tensor_util.make_tensor_proto( 470 bounded_tensor_spec_value.minimum), 471 maximum=tensor_util.make_tensor_proto( 472 bounded_tensor_spec_value.maximum))) 473 return encoded_bounded_tensor_spec 474 475 def can_decode(self, value): 476 return value.HasField("bounded_tensor_spec_value") 477 478 def do_decode(self, value, decode_fn): 479 btsv = value.bounded_tensor_spec_value 480 name = btsv.name 481 return tensor_spec.BoundedTensorSpec( 482 shape=decode_fn( 483 struct_pb2.StructuredValue(tensor_shape_value=btsv.shape)), 484 dtype=decode_fn( 485 struct_pb2.StructuredValue(tensor_dtype_value=btsv.dtype)), 486 minimum=tensor_util.MakeNdarray(btsv.minimum), 487 maximum=tensor_util.MakeNdarray(btsv.maximum), 488 name=(name if name else None)) 489 490 491StructureCoder.register_codec(_BoundedTensorSpecCodec()) 492 493 494class _TypeSpecCodec(object): 495 """Codec for `tf.TypeSpec`.""" 496 497 # Mapping from enum value to type (TypeSpec subclass). 498 TYPE_SPEC_CLASS_FROM_PROTO = { 499 struct_pb2.TypeSpecProto.SPARSE_TENSOR_SPEC: 500 sparse_tensor.SparseTensorSpec, 501 struct_pb2.TypeSpecProto.INDEXED_SLICES_SPEC: 502 indexed_slices.IndexedSlicesSpec, 503 struct_pb2.TypeSpecProto.RAGGED_TENSOR_SPEC: 504 ragged_tensor.RaggedTensorSpec, 505 struct_pb2.TypeSpecProto.TENSOR_ARRAY_SPEC: 506 tensor_array_ops.TensorArraySpec, 507 struct_pb2.TypeSpecProto.DATA_DATASET_SPEC: 508 dataset_ops.DatasetSpec, 509 struct_pb2.TypeSpecProto.DATA_ITERATOR_SPEC: 510 iterator_ops.IteratorSpec, 511 struct_pb2.TypeSpecProto.OPTIONAL_SPEC: 512 optional_ops.OptionalSpec, 513 struct_pb2.TypeSpecProto.PER_REPLICA_SPEC: 514 values.PerReplicaSpec, 515 struct_pb2.TypeSpecProto.VARIABLE_SPEC: 516 resource_variable_ops.VariableSpec, 517 struct_pb2.TypeSpecProto.ROW_PARTITION_SPEC: 518 row_partition.RowPartitionSpec, 519 } 520 521 # Mapping from type (TypeSpec subclass) to enum value. 522 TYPE_SPEC_CLASS_TO_PROTO = dict( 523 (cls, enum) for (enum, cls) in TYPE_SPEC_CLASS_FROM_PROTO.items()) 524 525 def can_encode(self, pyobj): 526 # pylint: disable=unidiomatic-typecheck 527 return type(pyobj) in self.TYPE_SPEC_CLASS_TO_PROTO 528 529 def do_encode(self, type_spec_value, encode_fn): 530 """Returns an encoded proto for the given `tf.TypeSpec`.""" 531 type_spec_class = self.TYPE_SPEC_CLASS_TO_PROTO[type(type_spec_value)] 532 type_state = type_spec_value._serialize() # pylint: disable=protected-access 533 encoded_type_spec = struct_pb2.StructuredValue() 534 encoded_type_spec.type_spec_value.CopyFrom( 535 struct_pb2.TypeSpecProto( 536 type_spec_class=type_spec_class, 537 type_state=encode_fn(type_state), 538 type_spec_class_name=type(type_spec_value).__name__)) 539 return encoded_type_spec 540 541 def can_decode(self, value): 542 return value.HasField("type_spec_value") 543 544 def do_decode(self, value, decode_fn): 545 """Returns the `tf.TypeSpec` encoded by the proto `value`.""" 546 type_spec_proto = value.type_spec_value 547 type_spec_class_enum = type_spec_proto.type_spec_class 548 if type_spec_class_enum not in self.TYPE_SPEC_CLASS_FROM_PROTO: 549 raise ValueError( 550 "The type '%s' is not supported by this version of TensorFlow. " 551 "(The object you are loading must have been created with a newer " 552 "version of TensorFlow.)" % type_spec_proto.type_spec_class_name) 553 554 type_spec_class = self.TYPE_SPEC_CLASS_FROM_PROTO[type_spec_class_enum] 555 # pylint: disable=protected-access 556 return type_spec_class._deserialize(decode_fn(type_spec_proto.type_state)) 557 558 559StructureCoder.register_codec(_TypeSpecCodec()) 560