1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Utility classes for testing checkpointing.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.eager import context 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops as ops_lib 24from tensorflow.python.ops import gen_lookup_ops 25from tensorflow.python.training import saver as saver_module 26 27 28class CheckpointedOp(object): 29 """Op with a custom checkpointing implementation. 30 31 Defined as part of the test because the MutableHashTable Python code is 32 currently in contrib. 33 """ 34 35 # pylint: disable=protected-access 36 def __init__(self, name, table_ref=None): 37 if table_ref is None: 38 self.table_ref = gen_lookup_ops.mutable_hash_table_v2( 39 key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name) 40 else: 41 self.table_ref = table_ref 42 self._name = name 43 if not context.executing_eagerly(): 44 self._saveable = CheckpointedOp.CustomSaveable(self, name) 45 ops_lib.add_to_collection(ops_lib.GraphKeys.SAVEABLE_OBJECTS, 46 self._saveable) 47 48 @property 49 def name(self): 50 return self._name 51 52 @property 53 def saveable(self): 54 if context.executing_eagerly(): 55 return CheckpointedOp.CustomSaveable(self, self.name) 56 else: 57 return self._saveable 58 59 def insert(self, keys, values): 60 return gen_lookup_ops.lookup_table_insert_v2(self.table_ref, keys, values) 61 62 def lookup(self, keys, default): 63 return gen_lookup_ops.lookup_table_find_v2(self.table_ref, keys, default) 64 65 def keys(self): 66 return self._export()[0] 67 68 def values(self): 69 return self._export()[1] 70 71 def _export(self): 72 return gen_lookup_ops.lookup_table_export_v2(self.table_ref, dtypes.string, 73 dtypes.float32) 74 75 class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject): 76 """A custom saveable for CheckpointedOp.""" 77 78 def __init__(self, table, name): 79 tensors = table._export() 80 specs = [ 81 saver_module.BaseSaverBuilder.SaveSpec(tensors[0], "", 82 name + "-keys"), 83 saver_module.BaseSaverBuilder.SaveSpec(tensors[1], "", 84 name + "-values") 85 ] 86 super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name) 87 88 def restore(self, restore_tensors, shapes): 89 return gen_lookup_ops.lookup_table_import_v2( 90 self.op.table_ref, restore_tensors[0], restore_tensors[1]) 91 # pylint: enable=protected-access 92