• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for exporter.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os.path
22
23from tensorflow.contrib.session_bundle import constants
24from tensorflow.contrib.session_bundle import exporter
25from tensorflow.contrib.session_bundle import gc
26from tensorflow.contrib.session_bundle import manifest_pb2
27from tensorflow.core.framework import graph_pb2
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.core.protobuf import saver_pb2
30from tensorflow.python.client import session
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import ops
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import state_ops
36from tensorflow.python.ops import variables
37from tensorflow.python.platform import flags
38from tensorflow.python.platform import gfile
39from tensorflow.python.platform import test
40from tensorflow.python.training import saver
41
42FLAGS = flags.FLAGS
43
44GLOBAL_STEP = 222
45
46
47def tearDownModule():
48  gfile.DeleteRecursively(test.get_temp_dir())
49
50
51class SaveRestoreShardedTest(test.TestCase):
52
53  def doBasicsOneExportPath(self,
54                            export_path,
55                            clear_devices=False,
56                            global_step=GLOBAL_STEP,
57                            sharded=True,
58                            export_count=1):
59    # Build a graph with 2 parameter nodes on different devices.
60    ops.reset_default_graph()
61    with session.Session(
62        target="",
63        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
64      # v2 is an unsaved variable derived from v0 and v1.  It is used to
65      # exercise the ability to run an init op when restoring a graph.
66      with sess.graph.device("/cpu:0"):
67        v0 = variables.VariableV1(10, name="v0")
68      with sess.graph.device("/cpu:1"):
69        v1 = variables.VariableV1(20, name="v1")
70      v2 = variables.VariableV1(1, name="v2", trainable=False, collections=[])
71      assign_v2 = state_ops.assign(v2, math_ops.add(v0, v1))
72      init_op = control_flow_ops.group(assign_v2, name="init_op")
73
74      ops.add_to_collection("v", v0)
75      ops.add_to_collection("v", v1)
76      ops.add_to_collection("v", v2)
77
78      named_tensor_bindings = {"logical_input_A": v0, "logical_input_B": v1}
79      signatures = {
80          "foo":
81              exporter.regression_signature(
82                  input_tensor=v0, output_tensor=v1),
83          "generic":
84              exporter.generic_signature(named_tensor_bindings)
85      }
86
87      asset_filepath_orig = os.path.join(test.get_temp_dir(), "hello42.txt")
88      asset_file = constant_op.constant(asset_filepath_orig, name="filename42")
89      ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset_file)
90
91      with gfile.FastGFile(asset_filepath_orig, "w") as f:
92        f.write("your data here")
93      assets_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
94
95      ignored_asset = os.path.join(test.get_temp_dir(), "ignored.txt")
96      with gfile.FastGFile(ignored_asset, "w") as f:
97        f.write("additional data here")
98
99      variables.global_variables_initializer().run()
100
101      # Run an export.
102      save = saver.Saver(
103          {
104              "v0": v0,
105              "v1": v1
106          },
107          restore_sequentially=True,
108          sharded=sharded,
109          write_version=saver_pb2.SaverDef.V1)
110      export = exporter.Exporter(save)
111      compare_def = ops.get_default_graph().as_graph_def()
112      export.init(
113          compare_def,
114          init_op=init_op,
115          clear_devices=clear_devices,
116          default_graph_signature=exporter.classification_signature(
117              input_tensor=v0),
118          named_graph_signatures=signatures,
119          assets_collection=assets_collection)
120
121      for x in range(export_count):
122        export.export(
123            export_path,
124            constant_op.constant(global_step + x),
125            sess,
126            exports_to_keep=gc.largest_export_versions(2))
127      # Set global_step to the last exported version, as the rest of the test
128      # uses it to construct model export path, loads model from it, and does
129      # verifications. We want to make sure to always use the last exported
130      # version, as old ones may have be garbage-collected.
131      global_step += export_count - 1
132
133    # Restore graph.
134    ops.reset_default_graph()
135    with session.Session(
136        target="",
137        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
138      save = saver.import_meta_graph(
139          os.path.join(export_path, constants.VERSION_FORMAT_SPECIFIER %
140                       global_step, constants.META_GRAPH_DEF_FILENAME))
141      self.assertIsNotNone(save)
142      meta_graph_def = save.export_meta_graph()
143      collection_def = meta_graph_def.collection_def
144
145      # Validate custom graph_def.
146      graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
147      self.assertEquals(len(graph_def_any), 1)
148      graph_def = graph_pb2.GraphDef()
149      graph_def_any[0].Unpack(graph_def)
150      if clear_devices:
151        for node in compare_def.node:
152          node.device = ""
153      self.assertProtoEquals(compare_def, graph_def)
154
155      # Validate init_op.
156      init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
157      self.assertEquals(len(init_ops), 1)
158      self.assertEquals(init_ops[0], "init_op")
159
160      # Validate signatures.
161      signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value
162      self.assertEquals(len(signatures_any), 1)
163      signatures = manifest_pb2.Signatures()
164      signatures_any[0].Unpack(signatures)
165      default_signature = signatures.default_signature
166      self.assertEqual(
167          default_signature.classification_signature.input.tensor_name, "v0:0")
168      bindings = signatures.named_signatures["generic"].generic_signature.map
169      self.assertEquals(bindings["logical_input_A"].tensor_name, "v0:0")
170      self.assertEquals(bindings["logical_input_B"].tensor_name, "v1:0")
171      read_foo_signature = (
172          signatures.named_signatures["foo"].regression_signature)
173      self.assertEquals(read_foo_signature.input.tensor_name, "v0:0")
174      self.assertEquals(read_foo_signature.output.tensor_name, "v1:0")
175
176      # Validate the assets.
177      assets_any = collection_def[constants.ASSETS_KEY].any_list.value
178      self.assertEquals(len(assets_any), 1)
179      asset = manifest_pb2.AssetFile()
180      assets_any[0].Unpack(asset)
181      assets_path = os.path.join(export_path,
182                                 constants.VERSION_FORMAT_SPECIFIER %
183                                 global_step, constants.ASSETS_DIRECTORY,
184                                 "hello42.txt")
185      asset_contents = gfile.GFile(assets_path).read()
186      self.assertEqual(asset_contents, "your data here")
187      self.assertEquals("hello42.txt", asset.filename)
188      self.assertEquals("filename42:0", asset.tensor_binding.tensor_name)
189      ignored_asset_path = os.path.join(export_path,
190                                        constants.VERSION_FORMAT_SPECIFIER %
191                                        global_step, constants.ASSETS_DIRECTORY,
192                                        "ignored.txt")
193      self.assertFalse(gfile.Exists(ignored_asset_path))
194
195      # Validate graph restoration.
196      if sharded:
197        save.restore(sess,
198                     os.path.join(export_path,
199                                  constants.VERSION_FORMAT_SPECIFIER %
200                                  global_step,
201                                  constants.VARIABLES_FILENAME_PATTERN))
202      else:
203        save.restore(sess,
204                     os.path.join(export_path,
205                                  constants.VERSION_FORMAT_SPECIFIER %
206                                  global_step, constants.VARIABLES_FILENAME))
207      self.assertEqual(10, ops.get_collection("v")[0].eval())
208      self.assertEqual(20, ops.get_collection("v")[1].eval())
209      ops.get_collection(constants.INIT_OP_KEY)[0].run()
210      self.assertEqual(30, ops.get_collection("v")[2].eval())
211
212  def testDuplicateExportRaisesError(self):
213    export_path = os.path.join(test.get_temp_dir(), "export_duplicates")
214    self.doBasicsOneExportPath(export_path)
215    self.assertRaises(RuntimeError, self.doBasicsOneExportPath, export_path)
216
217  def testBasics(self):
218    export_path = os.path.join(test.get_temp_dir(), "export")
219    self.doBasicsOneExportPath(export_path)
220
221  def testBasicsNoShard(self):
222    export_path = os.path.join(test.get_temp_dir(), "export_no_shard")
223    self.doBasicsOneExportPath(export_path, sharded=False)
224
225  def testClearDevice(self):
226    export_path = os.path.join(test.get_temp_dir(), "export_clear_device")
227    self.doBasicsOneExportPath(export_path, clear_devices=True)
228
229  def testGC(self):
230    export_path = os.path.join(test.get_temp_dir(), "gc")
231    self.doBasicsOneExportPath(export_path, global_step=100)
232    self.assertEquals(gfile.ListDirectory(export_path), ["00000100"])
233    self.doBasicsOneExportPath(export_path, global_step=101)
234    self.assertEquals(
235        sorted(gfile.ListDirectory(export_path)), ["00000100", "00000101"])
236    self.doBasicsOneExportPath(export_path, global_step=102)
237    self.assertEquals(
238        sorted(gfile.ListDirectory(export_path)), ["00000101", "00000102"])
239
240  def testExportMultipleTimes(self):
241    export_path = os.path.join(test.get_temp_dir(), "export_multiple_times")
242    self.doBasicsOneExportPath(export_path, export_count=10)
243
244
245if __name__ == "__main__":
246  test.main()
247