• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Integration test for input pipeline serialization."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import os
21
22from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.framework import ops
25from tensorflow.python.platform import test
26from tensorflow.python.training import saver as saver_lib
27
28
29class MultipleInputPipelinesTest(test.TestCase):
30
31  def _build_input_pipeline(self, name, num_outputs):
32    with ops.name_scope(name):
33      ds = dataset_ops.Dataset.range(num_outputs).shuffle(
34          10, reshuffle_each_iteration=False).prefetch(10)
35      iterator = ds.make_initializable_iterator()
36      saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
37      ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
38      return iterator.initializer, iterator.get_next()
39
40  def _build_graph(self, num_pipelines, num_outputs):
41    init_ops = []
42    get_next_ops = []
43    for i in range(num_pipelines):
44      name = "input_pipeline_%d" % i
45      init_op, get_next_op = self._build_input_pipeline(name, num_outputs)
46      init_ops.append(init_op)
47      get_next_ops.append(get_next_op)
48    saver = saver_lib.Saver()
49    return init_ops, get_next_ops, saver
50
51  def _ckpt_path(self):
52    return os.path.join(self.get_temp_dir(), "iterator")
53
54  def testConcurrentSaves(self):
55    num_pipelines = 100
56    num_outputs = 100
57    break_point = 10
58    all_outputs = [[] for _ in range(num_pipelines)]
59    with ops.Graph().as_default() as g:
60      init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
61                                                        num_outputs)
62      with self.test_session(graph=g) as sess:
63        sess.run(init_ops)
64        for _ in range(break_point):
65          output = sess.run(get_next_ops)
66          for i in range(num_pipelines):
67            all_outputs[i].append(output[i])
68        saver.save(sess, self._ckpt_path())
69
70    with ops.Graph().as_default() as g:
71      init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
72                                                        num_outputs)
73      with self.test_session(graph=g) as sess:
74        saver.restore(sess, self._ckpt_path())
75        for _ in range(num_outputs - break_point):
76          output = sess.run(get_next_ops)
77          for i in range(num_pipelines):
78            all_outputs[i].append(output[i])
79
80    for output in all_outputs:
81      self.assertSequenceEqual(sorted(output), range(num_outputs))
82
83
84if __name__ == "__main__":
85  test.main()
86