1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for exporter.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os.path 22 23from tensorflow.contrib.session_bundle import constants 24from tensorflow.contrib.session_bundle import exporter 25from tensorflow.contrib.session_bundle import gc 26from tensorflow.contrib.session_bundle import manifest_pb2 27from tensorflow.core.framework import graph_pb2 28from tensorflow.core.protobuf import config_pb2 29from tensorflow.core.protobuf import saver_pb2 30from tensorflow.python.client import session 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import ops 33from tensorflow.python.ops import control_flow_ops 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops import state_ops 36from tensorflow.python.ops import variables 37from tensorflow.python.platform import flags 38from tensorflow.python.platform import gfile 39from tensorflow.python.platform import test 40from tensorflow.python.training import saver 41 42FLAGS = flags.FLAGS 43 44GLOBAL_STEP = 222 45 46 47def tearDownModule(): 48 gfile.DeleteRecursively(test.get_temp_dir()) 49 50 51class SaveRestoreShardedTest(test.TestCase): 52 53 def doBasicsOneExportPath(self, 54 export_path, 55 clear_devices=False, 56 global_step=GLOBAL_STEP, 57 sharded=True, 58 export_count=1): 59 # Build a graph with 2 parameter nodes on different devices. 60 ops.reset_default_graph() 61 with session.Session( 62 target="", 63 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 64 # v2 is an unsaved variable derived from v0 and v1. It is used to 65 # exercise the ability to run an init op when restoring a graph. 66 with sess.graph.device("/cpu:0"): 67 v0 = variables.VariableV1(10, name="v0") 68 with sess.graph.device("/cpu:1"): 69 v1 = variables.VariableV1(20, name="v1") 70 v2 = variables.VariableV1(1, name="v2", trainable=False, collections=[]) 71 assign_v2 = state_ops.assign(v2, math_ops.add(v0, v1)) 72 init_op = control_flow_ops.group(assign_v2, name="init_op") 73 74 ops.add_to_collection("v", v0) 75 ops.add_to_collection("v", v1) 76 ops.add_to_collection("v", v2) 77 78 named_tensor_bindings = {"logical_input_A": v0, "logical_input_B": v1} 79 signatures = { 80 "foo": 81 exporter.regression_signature( 82 input_tensor=v0, output_tensor=v1), 83 "generic": 84 exporter.generic_signature(named_tensor_bindings) 85 } 86 87 asset_filepath_orig = os.path.join(test.get_temp_dir(), "hello42.txt") 88 asset_file = constant_op.constant(asset_filepath_orig, name="filename42") 89 ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset_file) 90 91 with gfile.FastGFile(asset_filepath_orig, "w") as f: 92 f.write("your data here") 93 assets_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS) 94 95 ignored_asset = os.path.join(test.get_temp_dir(), "ignored.txt") 96 with gfile.FastGFile(ignored_asset, "w") as f: 97 f.write("additional data here") 98 99 variables.global_variables_initializer().run() 100 101 # Run an export. 102 save = saver.Saver( 103 { 104 "v0": v0, 105 "v1": v1 106 }, 107 restore_sequentially=True, 108 sharded=sharded, 109 write_version=saver_pb2.SaverDef.V1) 110 export = exporter.Exporter(save) 111 compare_def = ops.get_default_graph().as_graph_def() 112 export.init( 113 compare_def, 114 init_op=init_op, 115 clear_devices=clear_devices, 116 default_graph_signature=exporter.classification_signature( 117 input_tensor=v0), 118 named_graph_signatures=signatures, 119 assets_collection=assets_collection) 120 121 for x in range(export_count): 122 export.export( 123 export_path, 124 constant_op.constant(global_step + x), 125 sess, 126 exports_to_keep=gc.largest_export_versions(2)) 127 # Set global_step to the last exported version, as the rest of the test 128 # uses it to construct model export path, loads model from it, and does 129 # verifications. We want to make sure to always use the last exported 130 # version, as old ones may have be garbage-collected. 131 global_step += export_count - 1 132 133 # Restore graph. 134 ops.reset_default_graph() 135 with session.Session( 136 target="", 137 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 138 save = saver.import_meta_graph( 139 os.path.join(export_path, constants.VERSION_FORMAT_SPECIFIER % 140 global_step, constants.META_GRAPH_DEF_FILENAME)) 141 self.assertIsNotNone(save) 142 meta_graph_def = save.export_meta_graph() 143 collection_def = meta_graph_def.collection_def 144 145 # Validate custom graph_def. 146 graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value 147 self.assertEquals(len(graph_def_any), 1) 148 graph_def = graph_pb2.GraphDef() 149 graph_def_any[0].Unpack(graph_def) 150 if clear_devices: 151 for node in compare_def.node: 152 node.device = "" 153 self.assertProtoEquals(compare_def, graph_def) 154 155 # Validate init_op. 156 init_ops = collection_def[constants.INIT_OP_KEY].node_list.value 157 self.assertEquals(len(init_ops), 1) 158 self.assertEquals(init_ops[0], "init_op") 159 160 # Validate signatures. 161 signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value 162 self.assertEquals(len(signatures_any), 1) 163 signatures = manifest_pb2.Signatures() 164 signatures_any[0].Unpack(signatures) 165 default_signature = signatures.default_signature 166 self.assertEqual( 167 default_signature.classification_signature.input.tensor_name, "v0:0") 168 bindings = signatures.named_signatures["generic"].generic_signature.map 169 self.assertEquals(bindings["logical_input_A"].tensor_name, "v0:0") 170 self.assertEquals(bindings["logical_input_B"].tensor_name, "v1:0") 171 read_foo_signature = ( 172 signatures.named_signatures["foo"].regression_signature) 173 self.assertEquals(read_foo_signature.input.tensor_name, "v0:0") 174 self.assertEquals(read_foo_signature.output.tensor_name, "v1:0") 175 176 # Validate the assets. 177 assets_any = collection_def[constants.ASSETS_KEY].any_list.value 178 self.assertEquals(len(assets_any), 1) 179 asset = manifest_pb2.AssetFile() 180 assets_any[0].Unpack(asset) 181 assets_path = os.path.join(export_path, 182 constants.VERSION_FORMAT_SPECIFIER % 183 global_step, constants.ASSETS_DIRECTORY, 184 "hello42.txt") 185 asset_contents = gfile.GFile(assets_path).read() 186 self.assertEqual(asset_contents, "your data here") 187 self.assertEquals("hello42.txt", asset.filename) 188 self.assertEquals("filename42:0", asset.tensor_binding.tensor_name) 189 ignored_asset_path = os.path.join(export_path, 190 constants.VERSION_FORMAT_SPECIFIER % 191 global_step, constants.ASSETS_DIRECTORY, 192 "ignored.txt") 193 self.assertFalse(gfile.Exists(ignored_asset_path)) 194 195 # Validate graph restoration. 196 if sharded: 197 save.restore(sess, 198 os.path.join(export_path, 199 constants.VERSION_FORMAT_SPECIFIER % 200 global_step, 201 constants.VARIABLES_FILENAME_PATTERN)) 202 else: 203 save.restore(sess, 204 os.path.join(export_path, 205 constants.VERSION_FORMAT_SPECIFIER % 206 global_step, constants.VARIABLES_FILENAME)) 207 self.assertEqual(10, ops.get_collection("v")[0].eval()) 208 self.assertEqual(20, ops.get_collection("v")[1].eval()) 209 ops.get_collection(constants.INIT_OP_KEY)[0].run() 210 self.assertEqual(30, ops.get_collection("v")[2].eval()) 211 212 def testDuplicateExportRaisesError(self): 213 export_path = os.path.join(test.get_temp_dir(), "export_duplicates") 214 self.doBasicsOneExportPath(export_path) 215 self.assertRaises(RuntimeError, self.doBasicsOneExportPath, export_path) 216 217 def testBasics(self): 218 export_path = os.path.join(test.get_temp_dir(), "export") 219 self.doBasicsOneExportPath(export_path) 220 221 def testBasicsNoShard(self): 222 export_path = os.path.join(test.get_temp_dir(), "export_no_shard") 223 self.doBasicsOneExportPath(export_path, sharded=False) 224 225 def testClearDevice(self): 226 export_path = os.path.join(test.get_temp_dir(), "export_clear_device") 227 self.doBasicsOneExportPath(export_path, clear_devices=True) 228 229 def testGC(self): 230 export_path = os.path.join(test.get_temp_dir(), "gc") 231 self.doBasicsOneExportPath(export_path, global_step=100) 232 self.assertEquals(gfile.ListDirectory(export_path), ["00000100"]) 233 self.doBasicsOneExportPath(export_path, global_step=101) 234 self.assertEquals( 235 sorted(gfile.ListDirectory(export_path)), ["00000100", "00000101"]) 236 self.doBasicsOneExportPath(export_path, global_step=102) 237 self.assertEquals( 238 sorted(gfile.ListDirectory(export_path)), ["00000101", "00000102"]) 239 240 def testExportMultipleTimes(self): 241 export_path = os.path.join(test.get_temp_dir(), "export_multiple_times") 242 self.doBasicsOneExportPath(export_path, export_count=10) 243 244 245if __name__ == "__main__": 246 test.main() 247