1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Module that encodes (decodes) nested structures into (from) protos. 16 17The intended use is to serialize everything needed to restore a `Function` that 18was saved into a SavedModel. This may include concrete function inputs and 19outputs, signatures, function specs, etc. 20 21Example use: 22coder = nested_structure_coder.StructureCoder() 23# Encode into proto. 24signature_proto = coder.encode_structure(function.input_signature) 25# Decode into a Python object. 26restored_signature = coder.decode_proto(signature_proto) 27""" 28 29from __future__ import absolute_import 30from __future__ import division 31from __future__ import print_function 32 33import collections 34import functools 35 36import warnings 37import six 38 39from tensorflow.core.protobuf import struct_pb2 40from tensorflow.python.data.ops import dataset_ops 41from tensorflow.python.data.ops import iterator_ops 42from tensorflow.python.data.ops import optional_ops 43from tensorflow.python.distribute import values 44from tensorflow.python.framework import dtypes 45from tensorflow.python.framework import extension_type 46from tensorflow.python.framework import indexed_slices 47from tensorflow.python.framework import sparse_tensor 48from tensorflow.python.framework import tensor_shape 49from tensorflow.python.framework import tensor_spec 50from tensorflow.python.framework import tensor_util 51from tensorflow.python.framework import type_spec 52from tensorflow.python.ops import resource_variable_ops 53from tensorflow.python.ops import tensor_array_ops 54from tensorflow.python.ops.ragged import ragged_tensor 55from tensorflow.python.ops.ragged import row_partition 56from tensorflow.python.util import compat 57from tensorflow.python.util.compat import collections_abc 58from tensorflow.python.util.tf_export import tf_export 59 60 61class NotEncodableError(Exception): 62 """Error raised when a coder cannot encode an object.""" 63 64 65@tf_export("__internal__.saved_model.StructureCoder", v1=[]) 66class StructureCoder(object): 67 """Encoder and decoder for nested structures into protos.""" 68 69 _codecs = [] 70 71 @classmethod 72 def register_codec(cls, x): 73 cls._codecs.append(x) 74 75 @classmethod 76 def _get_encoders(cls): 77 return [(c.can_encode, c.do_encode) for c in cls._codecs] 78 79 @classmethod 80 def _get_decoders(cls): 81 return [(c.can_decode, c.do_decode) for c in cls._codecs] 82 83 def _map_structure(self, pyobj, coders): 84 for can, do in coders: 85 if can(pyobj): 86 recursion_fn = functools.partial(self._map_structure, coders=coders) 87 return do(pyobj, recursion_fn) 88 raise NotEncodableError( 89 f"No encoder for object {str(pyobj)} of type {type(pyobj)}.") 90 91 def encode_structure(self, nested_structure): 92 """Encodes nested structures composed of encodable types into a proto. 93 94 Args: 95 nested_structure: Structure to encode. 96 97 Returns: 98 Encoded proto. 99 100 Raises: 101 NotEncodableError: For values for which there are no encoders. 102 """ 103 return self._map_structure(nested_structure, self._get_encoders()) 104 105 def can_encode(self, nested_structure): 106 """Determines whether a nested structure can be encoded into a proto. 107 108 Args: 109 nested_structure: Structure to encode. 110 111 Returns: 112 True if the nested structured can be encoded. 113 """ 114 try: 115 self.encode_structure(nested_structure) 116 except NotEncodableError: 117 return False 118 return True 119 120 def decode_proto(self, proto): 121 """Decodes proto representing a nested structure. 122 123 Args: 124 proto: Proto to decode. 125 126 Returns: 127 Decoded structure. 128 129 Raises: 130 NotEncodableError: For values for which there are no encoders. 131 """ 132 return self._map_structure(proto, self._get_decoders()) 133 134 135class _ListCodec(object): 136 """Codec for lists.""" 137 138 def can_encode(self, pyobj): 139 return isinstance(pyobj, list) 140 141 def do_encode(self, list_value, encode_fn): 142 encoded_list = struct_pb2.StructuredValue() 143 encoded_list.list_value.CopyFrom(struct_pb2.ListValue()) 144 for element in list_value: 145 encoded_list.list_value.values.add().CopyFrom(encode_fn(element)) 146 return encoded_list 147 148 def can_decode(self, value): 149 return value.HasField("list_value") 150 151 def do_decode(self, value, decode_fn): 152 return [decode_fn(element) for element in value.list_value.values] 153 154 155StructureCoder.register_codec(_ListCodec()) 156 157 158def _is_tuple(obj): 159 return not _is_named_tuple(obj) and isinstance(obj, tuple) 160 161 162def _is_named_tuple(instance): 163 """Returns True iff `instance` is a `namedtuple`. 164 165 Args: 166 instance: An instance of a Python object. 167 168 Returns: 169 True if `instance` is a `namedtuple`. 170 """ 171 if not isinstance(instance, tuple): 172 return False 173 return (hasattr(instance, "_fields") and 174 isinstance(instance._fields, collections_abc.Sequence) and 175 all(isinstance(f, six.string_types) for f in instance._fields)) 176 177 178class _TupleCodec(object): 179 """Codec for tuples.""" 180 181 def can_encode(self, pyobj): 182 return _is_tuple(pyobj) 183 184 def do_encode(self, tuple_value, encode_fn): 185 encoded_tuple = struct_pb2.StructuredValue() 186 encoded_tuple.tuple_value.CopyFrom(struct_pb2.TupleValue()) 187 for element in tuple_value: 188 encoded_tuple.tuple_value.values.add().CopyFrom(encode_fn(element)) 189 return encoded_tuple 190 191 def can_decode(self, value): 192 return value.HasField("tuple_value") 193 194 def do_decode(self, value, decode_fn): 195 return tuple(decode_fn(element) for element in value.tuple_value.values) 196 197 198StructureCoder.register_codec(_TupleCodec()) 199 200 201class _DictCodec(object): 202 """Codec for dicts.""" 203 204 def can_encode(self, pyobj): 205 return isinstance(pyobj, dict) 206 207 def do_encode(self, dict_value, encode_fn): 208 encoded_dict = struct_pb2.StructuredValue() 209 encoded_dict.dict_value.CopyFrom(struct_pb2.DictValue()) 210 for key, value in dict_value.items(): 211 encoded_dict.dict_value.fields[key].CopyFrom(encode_fn(value)) 212 return encoded_dict 213 214 def can_decode(self, value): 215 return value.HasField("dict_value") 216 217 def do_decode(self, value, decode_fn): 218 return {key: decode_fn(val) for key, val in value.dict_value.fields.items()} 219 220 221StructureCoder.register_codec(_DictCodec()) 222 223 224class _NamedTupleCodec(object): 225 """Codec for namedtuples. 226 227 Encoding and decoding a namedtuple reconstructs a namedtuple with a different 228 actual Python type, but with the same `typename` and `fields`. 229 """ 230 231 def can_encode(self, pyobj): 232 return _is_named_tuple(pyobj) 233 234 def do_encode(self, named_tuple_value, encode_fn): 235 encoded_named_tuple = struct_pb2.StructuredValue() 236 encoded_named_tuple.named_tuple_value.CopyFrom(struct_pb2.NamedTupleValue()) 237 encoded_named_tuple.named_tuple_value.name = \ 238 named_tuple_value.__class__.__name__ 239 for key in named_tuple_value._fields: 240 pair = encoded_named_tuple.named_tuple_value.values.add() 241 pair.key = key 242 pair.value.CopyFrom(encode_fn(named_tuple_value._asdict()[key])) 243 return encoded_named_tuple 244 245 def can_decode(self, value): 246 return value.HasField("named_tuple_value") 247 248 def do_decode(self, value, decode_fn): 249 key_value_pairs = value.named_tuple_value.values 250 items = [(pair.key, decode_fn(pair.value)) for pair in key_value_pairs] 251 named_tuple_type = collections.namedtuple(value.named_tuple_value.name, 252 [item[0] for item in items]) 253 return named_tuple_type(**dict(items)) 254 255 256StructureCoder.register_codec(_NamedTupleCodec()) 257 258 259class _Float64Codec(object): 260 """Codec for floats.""" 261 262 def can_encode(self, pyobj): 263 return isinstance(pyobj, float) 264 265 def do_encode(self, float64_value, encode_fn): 266 del encode_fn 267 value = struct_pb2.StructuredValue() 268 value.float64_value = float64_value 269 return value 270 271 def can_decode(self, value): 272 return value.HasField("float64_value") 273 274 def do_decode(self, value, decode_fn): 275 del decode_fn 276 return value.float64_value 277 278 279StructureCoder.register_codec(_Float64Codec()) 280 281 282class _Int64Codec(object): 283 """Codec for Python integers (limited to 64 bit values).""" 284 285 def can_encode(self, pyobj): 286 return not isinstance(pyobj, bool) and isinstance(pyobj, int) 287 288 def do_encode(self, int_value, encode_fn): 289 del encode_fn 290 value = struct_pb2.StructuredValue() 291 value.int64_value = int_value 292 return value 293 294 def can_decode(self, value): 295 return value.HasField("int64_value") 296 297 def do_decode(self, value, decode_fn): 298 del decode_fn 299 return int(value.int64_value) 300 301 302StructureCoder.register_codec(_Int64Codec()) 303 304 305class _StringCodec(object): 306 """Codec for strings. 307 308 See StructuredValue.string_value in proto/struct.proto for more detailed 309 explanation. 310 """ 311 312 def can_encode(self, pyobj): 313 return isinstance(pyobj, str) 314 315 def do_encode(self, string_value, encode_fn): 316 del encode_fn 317 value = struct_pb2.StructuredValue() 318 value.string_value = string_value 319 return value 320 321 def can_decode(self, value): 322 return value.HasField("string_value") 323 324 def do_decode(self, value, decode_fn): 325 del decode_fn 326 return compat.as_str(value.string_value) 327 328 329StructureCoder.register_codec(_StringCodec()) 330 331 332class _NoneCodec(object): 333 """Codec for None.""" 334 335 def can_encode(self, pyobj): 336 return pyobj is None 337 338 def do_encode(self, none_value, encode_fn): 339 del encode_fn, none_value 340 value = struct_pb2.StructuredValue() 341 value.none_value.CopyFrom(struct_pb2.NoneValue()) 342 return value 343 344 def can_decode(self, value): 345 return value.HasField("none_value") 346 347 def do_decode(self, value, decode_fn): 348 del decode_fn, value 349 return None 350 351 352StructureCoder.register_codec(_NoneCodec()) 353 354 355class _BoolCodec(object): 356 """Codec for booleans.""" 357 358 def can_encode(self, pyobj): 359 return isinstance(pyobj, bool) 360 361 def do_encode(self, bool_value, encode_fn): 362 del encode_fn 363 value = struct_pb2.StructuredValue() 364 value.bool_value = bool_value 365 return value 366 367 def can_decode(self, value): 368 return value.HasField("bool_value") 369 370 def do_decode(self, value, decode_fn): 371 del decode_fn 372 return value.bool_value 373 374 375StructureCoder.register_codec(_BoolCodec()) 376 377 378class _TensorShapeCodec(object): 379 """Codec for `TensorShape`.""" 380 381 def can_encode(self, pyobj): 382 return isinstance(pyobj, tensor_shape.TensorShape) 383 384 def do_encode(self, tensor_shape_value, encode_fn): 385 del encode_fn 386 encoded_tensor_shape = struct_pb2.StructuredValue() 387 encoded_tensor_shape.tensor_shape_value.CopyFrom( 388 tensor_shape_value.as_proto()) 389 return encoded_tensor_shape 390 391 def can_decode(self, value): 392 return value.HasField("tensor_shape_value") 393 394 def do_decode(self, value, decode_fn): 395 del decode_fn 396 return tensor_shape.TensorShape(value.tensor_shape_value) 397 398 399StructureCoder.register_codec(_TensorShapeCodec()) 400 401 402class _TensorTypeCodec(object): 403 """Codec for `TensorType`.""" 404 405 def can_encode(self, pyobj): 406 return isinstance(pyobj, dtypes.DType) 407 408 def do_encode(self, tensor_dtype_value, encode_fn): 409 del encode_fn 410 encoded_tensor_type = struct_pb2.StructuredValue() 411 encoded_tensor_type.tensor_dtype_value = tensor_dtype_value.as_datatype_enum 412 return encoded_tensor_type 413 414 def can_decode(self, value): 415 return value.HasField("tensor_dtype_value") 416 417 def do_decode(self, value, decode_fn): 418 del decode_fn 419 return dtypes.DType(value.tensor_dtype_value) 420 421 422StructureCoder.register_codec(_TensorTypeCodec()) 423 424 425class _TensorSpecCodec(object): 426 """Codec for `TensorSpec`.""" 427 428 def can_encode(self, pyobj): 429 # BoundedTensorSpec has its own decoder. 430 return (isinstance(pyobj, tensor_spec.TensorSpec) and 431 not isinstance(pyobj, tensor_spec.BoundedTensorSpec)) 432 433 def do_encode(self, tensor_spec_value, encode_fn): 434 encoded_tensor_spec = struct_pb2.StructuredValue() 435 encoded_tensor_spec.tensor_spec_value.CopyFrom( 436 struct_pb2.TensorSpecProto( 437 shape=encode_fn(tensor_spec_value.shape).tensor_shape_value, 438 dtype=encode_fn(tensor_spec_value.dtype).tensor_dtype_value, 439 name=tensor_spec_value.name)) 440 return encoded_tensor_spec 441 442 def can_decode(self, value): 443 return value.HasField("tensor_spec_value") 444 445 def do_decode(self, value, decode_fn): 446 name = value.tensor_spec_value.name 447 return tensor_spec.TensorSpec( 448 shape=decode_fn( 449 struct_pb2.StructuredValue( 450 tensor_shape_value=value.tensor_spec_value.shape)), 451 dtype=decode_fn( 452 struct_pb2.StructuredValue( 453 tensor_dtype_value=value.tensor_spec_value.dtype)), 454 name=(name if name else None)) 455 456 457StructureCoder.register_codec(_TensorSpecCodec()) 458 459 460class _BoundedTensorSpecCodec(object): 461 """Codec for `BoundedTensorSpec`.""" 462 463 def can_encode(self, pyobj): 464 return isinstance(pyobj, tensor_spec.BoundedTensorSpec) 465 466 def do_encode(self, bounded_tensor_spec_value, encode_fn): 467 """Returns an encoded proto for the given `tf.BoundedTensorSpec`.""" 468 encoded_bounded_tensor_spec = struct_pb2.StructuredValue() 469 encoded_bounded_tensor_spec.bounded_tensor_spec_value.CopyFrom( 470 struct_pb2.BoundedTensorSpecProto( 471 shape=encode_fn(bounded_tensor_spec_value.shape).tensor_shape_value, 472 dtype=encode_fn(bounded_tensor_spec_value.dtype).tensor_dtype_value, 473 name=bounded_tensor_spec_value.name, 474 minimum=tensor_util.make_tensor_proto( 475 bounded_tensor_spec_value.minimum), 476 maximum=tensor_util.make_tensor_proto( 477 bounded_tensor_spec_value.maximum))) 478 return encoded_bounded_tensor_spec 479 480 def can_decode(self, value): 481 return value.HasField("bounded_tensor_spec_value") 482 483 def do_decode(self, value, decode_fn): 484 btsv = value.bounded_tensor_spec_value 485 name = btsv.name 486 return tensor_spec.BoundedTensorSpec( 487 shape=decode_fn( 488 struct_pb2.StructuredValue(tensor_shape_value=btsv.shape)), 489 dtype=decode_fn( 490 struct_pb2.StructuredValue(tensor_dtype_value=btsv.dtype)), 491 minimum=tensor_util.MakeNdarray(btsv.minimum), 492 maximum=tensor_util.MakeNdarray(btsv.maximum), 493 name=(name if name else None)) 494 495 496StructureCoder.register_codec(_BoundedTensorSpecCodec()) 497 498 499class _TypeSpecCodec(object): 500 """Codec for `tf.TypeSpec`.""" 501 502 # Mapping from enum value to type (TypeSpec subclass). 503 TYPE_SPEC_CLASS_FROM_PROTO = { 504 struct_pb2.TypeSpecProto.SPARSE_TENSOR_SPEC: 505 sparse_tensor.SparseTensorSpec, 506 struct_pb2.TypeSpecProto.INDEXED_SLICES_SPEC: 507 indexed_slices.IndexedSlicesSpec, 508 struct_pb2.TypeSpecProto.RAGGED_TENSOR_SPEC: 509 ragged_tensor.RaggedTensorSpec, 510 struct_pb2.TypeSpecProto.TENSOR_ARRAY_SPEC: 511 tensor_array_ops.TensorArraySpec, 512 struct_pb2.TypeSpecProto.DATA_DATASET_SPEC: 513 dataset_ops.DatasetSpec, 514 struct_pb2.TypeSpecProto.DATA_ITERATOR_SPEC: 515 iterator_ops.IteratorSpec, 516 struct_pb2.TypeSpecProto.OPTIONAL_SPEC: 517 optional_ops.OptionalSpec, 518 struct_pb2.TypeSpecProto.PER_REPLICA_SPEC: 519 values.PerReplicaSpec, 520 struct_pb2.TypeSpecProto.VARIABLE_SPEC: 521 resource_variable_ops.VariableSpec, 522 struct_pb2.TypeSpecProto.ROW_PARTITION_SPEC: 523 row_partition.RowPartitionSpec, 524 } 525 526 # Mapping from type (TypeSpec subclass) to enum value. 527 TYPE_SPEC_CLASS_TO_PROTO = dict( 528 (cls, enum) for (enum, cls) in TYPE_SPEC_CLASS_FROM_PROTO.items()) 529 530 def can_encode(self, pyobj): 531 """Returns true if `pyboj` can be encoded as a TypeSpec.""" 532 if type(pyobj) in self.TYPE_SPEC_CLASS_TO_PROTO: # pylint: disable=unidiomatic-typecheck 533 return True 534 535 # Check if it's a registered type. 536 if isinstance(pyobj, type_spec.TypeSpec): 537 try: 538 type_spec.get_name(type(pyobj)) 539 return True 540 except ValueError: 541 return False 542 543 return False 544 545 def do_encode(self, type_spec_value, encode_fn): 546 """Returns an encoded proto for the given `tf.TypeSpec`.""" 547 type_spec_class = self.TYPE_SPEC_CLASS_TO_PROTO.get(type(type_spec_value)) 548 type_spec_class_name = type(type_spec_value).__name__ 549 550 if type_spec_class is None: 551 type_spec_class_name = type_spec.get_name(type(type_spec_value)) 552 if isinstance(type_spec_value, extension_type.ExtensionTypeSpec): 553 type_spec_class = struct_pb2.TypeSpecProto.EXTENSION_TYPE_SPEC 554 else: 555 type_spec_class = struct_pb2.TypeSpecProto.REGISTERED_TYPE_SPEC 556 # Support for saving registered TypeSpecs is currently experimental. 557 # Issue a warning to indicate the limitations. 558 warnings.warn("Encoding a StructuredValue with type %s; loading this " 559 "StructuredValue will require that this type be " 560 "imported and registered." % type_spec_class_name) 561 562 type_state = type_spec_value._serialize() # pylint: disable=protected-access 563 encoded_type_spec = struct_pb2.StructuredValue() 564 encoded_type_spec.type_spec_value.CopyFrom( 565 struct_pb2.TypeSpecProto( 566 type_spec_class=type_spec_class, 567 type_state=encode_fn(type_state), 568 type_spec_class_name=type_spec_class_name)) 569 return encoded_type_spec 570 571 def can_decode(self, value): 572 return value.HasField("type_spec_value") 573 574 def do_decode(self, value, decode_fn): 575 """Returns the `tf.TypeSpec` encoded by the proto `value`.""" 576 type_spec_proto = value.type_spec_value 577 type_spec_class_enum = type_spec_proto.type_spec_class 578 class_name = type_spec_proto.type_spec_class_name 579 580 if type_spec_class_enum == struct_pb2.TypeSpecProto.REGISTERED_TYPE_SPEC: 581 try: 582 type_spec_class = type_spec.lookup(class_name) 583 except ValueError as e: 584 raise ValueError( 585 f"The type '{class_name}' has not been registered. It must be " 586 "registered before you load this object (typically by importing " 587 "its module).") from e 588 elif type_spec_class_enum == struct_pb2.TypeSpecProto.EXTENSION_TYPE_SPEC: 589 try: 590 type_spec_class = type_spec.lookup(class_name) 591 except ValueError: 592 type_spec_class = extension_type.AnonymousExtensionTypeSpec 593 warnings.warn("The type %r has not been registered. Falling back to " 594 "using AnonymousExtensionTypeSpec instead.") 595 else: 596 if type_spec_class_enum not in self.TYPE_SPEC_CLASS_FROM_PROTO: 597 raise ValueError( 598 f"The type '{class_name}' is not supported by this version of " 599 "TensorFlow. (The object you are loading must have been created " 600 "with a newer version of TensorFlow.)") 601 type_spec_class = self.TYPE_SPEC_CLASS_FROM_PROTO[type_spec_class_enum] 602 603 # pylint: disable=protected-access 604 return type_spec_class._deserialize(decode_fn(type_spec_proto.type_state)) 605 606 607StructureCoder.register_codec(_TypeSpecCodec()) 608