• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Utility classes for testing checkpointing."""
16
17from tensorflow.python.eager import context
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops as ops_lib
20from tensorflow.python.ops import gen_lookup_ops
21from tensorflow.python.training import saver as saver_module
22
23
24class CheckpointedOp:
25  """Op with a custom checkpointing implementation.
26
27  Defined as part of the test because the MutableHashTable Python code is
28  currently in contrib.
29  """
30
31  # pylint: disable=protected-access
32  def __init__(self, name, table_ref=None):
33    if table_ref is None:
34      self.table_ref = gen_lookup_ops.mutable_hash_table_v2(
35          key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name)
36    else:
37      self.table_ref = table_ref
38    self._name = name
39    if not context.executing_eagerly():
40      self._saveable = CheckpointedOp.CustomSaveable(self, name)
41      ops_lib.add_to_collection(ops_lib.GraphKeys.SAVEABLE_OBJECTS,
42                                self._saveable)
43
44  @property
45  def name(self):
46    return self._name
47
48  @property
49  def saveable(self):
50    if context.executing_eagerly():
51      return CheckpointedOp.CustomSaveable(self, self.name)
52    else:
53      return self._saveable
54
55  def insert(self, keys, values):
56    return gen_lookup_ops.lookup_table_insert_v2(self.table_ref, keys, values)
57
58  def lookup(self, keys, default):
59    return gen_lookup_ops.lookup_table_find_v2(self.table_ref, keys, default)
60
61  def keys(self):
62    return self._export()[0]
63
64  def values(self):
65    return self._export()[1]
66
67  def _export(self):
68    return gen_lookup_ops.lookup_table_export_v2(self.table_ref, dtypes.string,
69                                                 dtypes.float32)
70
71  class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject):
72    """A custom saveable for CheckpointedOp."""
73
74    def __init__(self, table, name):
75      tensors = table._export()
76      specs = [
77          saver_module.BaseSaverBuilder.SaveSpec(tensors[0], "",
78                                                 name + "-keys"),
79          saver_module.BaseSaverBuilder.SaveSpec(tensors[1], "",
80                                                 name + "-values")
81      ]
82      super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name)
83
84    def restore(self, restore_tensors, shapes):
85      return gen_lookup_ops.lookup_table_import_v2(
86          self.op.table_ref, restore_tensors[0], restore_tensors[1])
87  # pylint: enable=protected-access
88