1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Integration test for input pipeline serialization.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import os 21 22from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.framework import ops 25from tensorflow.python.platform import test 26from tensorflow.python.training import saver as saver_lib 27 28 29class MultipleInputPipelinesTest(test.TestCase): 30 31 def _build_input_pipeline(self, name, num_outputs): 32 with ops.name_scope(name): 33 ds = dataset_ops.Dataset.range(num_outputs).shuffle( 34 10, reshuffle_each_iteration=False).prefetch(10) 35 iterator = ds.make_initializable_iterator() 36 saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) 37 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 38 return iterator.initializer, iterator.get_next() 39 40 def _build_graph(self, num_pipelines, num_outputs): 41 init_ops = [] 42 get_next_ops = [] 43 for i in range(num_pipelines): 44 name = "input_pipeline_%d" % i 45 init_op, get_next_op = self._build_input_pipeline(name, num_outputs) 46 init_ops.append(init_op) 47 get_next_ops.append(get_next_op) 48 saver = saver_lib.Saver() 49 return init_ops, get_next_ops, saver 50 51 def _ckpt_path(self): 52 return os.path.join(self.get_temp_dir(), "iterator") 53 54 def testConcurrentSaves(self): 55 num_pipelines = 100 56 num_outputs = 100 57 break_point = 10 58 all_outputs = [[] for _ in range(num_pipelines)] 59 with ops.Graph().as_default() as g: 60 init_ops, get_next_ops, saver = self._build_graph(num_pipelines, 61 num_outputs) 62 with self.test_session(graph=g) as sess: 63 sess.run(init_ops) 64 for _ in range(break_point): 65 output = sess.run(get_next_ops) 66 for i in range(num_pipelines): 67 all_outputs[i].append(output[i]) 68 saver.save(sess, self._ckpt_path()) 69 70 with ops.Graph().as_default() as g: 71 init_ops, get_next_ops, saver = self._build_graph(num_pipelines, 72 num_outputs) 73 with self.test_session(graph=g) as sess: 74 saver.restore(sess, self._ckpt_path()) 75 for _ in range(num_outputs - break_point): 76 output = sess.run(get_next_ops) 77 for i in range(num_pipelines): 78 all_outputs[i].append(output[i]) 79 80 for output in all_outputs: 81 self.assertSequenceEqual(sorted(output), range(num_outputs)) 82 83 84if __name__ == "__main__": 85 test.main() 86