1"""Utilities for including Python state in TensorFlow checkpoints.""" 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import functools 22 23import six 24 25from tensorflow.python.training.tracking import base 26from tensorflow.python.util.tf_export import tf_export 27 28 29@tf_export("train.experimental.PythonState") 30@six.add_metaclass(abc.ABCMeta) 31class PythonState(base.Trackable): 32 """A mixin for putting Python state in an object-based checkpoint. 33 34 This is an abstract class which allows extensions to TensorFlow's object-based 35 checkpointing (see `tf.train.Checkpoint`). For example a wrapper for NumPy 36 arrays: 37 38 ```python 39 import io 40 import numpy 41 42 class NumpyWrapper(tf.train.experimental.PythonState): 43 44 def __init__(self, array): 45 self.array = array 46 47 def serialize(self): 48 string_file = io.BytesIO() 49 try: 50 numpy.save(string_file, self.array, allow_pickle=False) 51 serialized = string_file.getvalue() 52 finally: 53 string_file.close() 54 return serialized 55 56 def deserialize(self, string_value): 57 string_file = io.BytesIO(string_value) 58 try: 59 self.array = numpy.load(string_file, allow_pickle=False) 60 finally: 61 string_file.close() 62 ``` 63 64 Instances of `NumpyWrapper` are checkpointable objects, and will be saved and 65 restored from checkpoints along with TensorFlow state like variables. 66 67 ```python 68 root = tf.train.Checkpoint(numpy=NumpyWrapper(numpy.array([1.]))) 69 save_path = root.save(prefix) 70 root.numpy.array *= 2. 71 assert [2.] == root.numpy.array 72 root.restore(save_path) 73 assert [1.] == root.numpy.array 74 ``` 75 """ 76 77 @abc.abstractmethod 78 def serialize(self): 79 """Callback to serialize the object. Returns a string.""" 80 81 @abc.abstractmethod 82 def deserialize(self, string_value): 83 """Callback to deserialize the object.""" 84 85 def _gather_saveables_for_checkpoint(self): 86 """Specify callbacks for saving and restoring `array`.""" 87 return { 88 "py_state": functools.partial( 89 base.PythonStringStateSaveable, 90 state_callback=self.serialize, 91 restore_callback=self.deserialize) 92 } 93