• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for prefetching_ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import itertools
21import threading
22
23from tensorflow.contrib.data.python.ops import prefetching_ops
24from tensorflow.core.protobuf import config_pb2
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.data.ops import iterator_ops
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import function
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import test_util
32from tensorflow.python.ops import resource_variable_ops
33from tensorflow.python.platform import test
34
35
36class StagingAreaOpsTest(test.TestCase):
37
38  def setUp(self):
39    self._event = threading.Event()
40
41  def _prefetch_fn_helper(self, buffer_name, device0, device1):
42    worker_config = config_pb2.ConfigProto()
43    worker_config.device_count["CPU"] = 2
44
45    def gen():
46      for i in itertools.count(start=1, step=1):
47        yield [i + 0.0]
48        if i == 6:
49          self._event.set()
50
51    with ops.device(device0):
52      dataset_3 = dataset_ops.Dataset.from_generator(gen, (dtypes.float32))
53      iterator_3 = dataset_3.make_one_shot_iterator()
54      iterator_3_handle = iterator_3.string_handle()
55
56    @function.Defun(dtypes.string)
57    def _remote_fn(h):
58      remote_iterator = iterator_ops.Iterator.from_string_handle(
59          h, dataset_3.output_types, dataset_3.output_shapes)
60      return remote_iterator.get_next()
61
62    target = constant_op.constant(device0)
63    with ops.device(device1):
64      buffer_resource_handle = prefetching_ops.function_buffering_resource(
65          f=_remote_fn,
66          target_device=target,
67          string_arg=iterator_3_handle,
68          buffer_size=3,
69          thread_pool_size=2,
70          shared_name=buffer_name)
71
72    with ops.device(device1):
73      prefetch_op = prefetching_ops.function_buffering_resource_get_next(
74          function_buffer_resource=buffer_resource_handle,
75          output_types=[dtypes.float32])
76
77    with self.test_session(config=worker_config) as sess:
78      elem = sess.run(prefetch_op)
79      self.assertEqual(elem, [1.0])
80      elem = sess.run(prefetch_op)
81      self.assertEqual(elem, [2.0])
82      elem = sess.run(prefetch_op)
83      self.assertEqual(elem, [3.0])
84      elem = sess.run(prefetch_op)
85      self.assertEqual(elem, [4.0])
86      self._event.wait()
87      elem = sess.run(prefetch_op)
88      self.assertEqual(elem, [5.0])
89      sess.run(
90          resource_variable_ops.destroy_resource_op(
91              buffer_resource_handle, ignore_lookup_error=True))
92
93  def testSameDeviceCPU(self):
94    self._prefetch_fn_helper("same_device_cpu",
95                             "/job:localhost/replica:0/task:0/cpu:0",
96                             "/job:localhost/replica:0/task:0/cpu:0")
97
98  def testDifferentDeviceCPU(self):
99    self._prefetch_fn_helper("diff_device_cpu",
100                             "/job:localhost/replica:0/task:0/cpu:0",
101                             "/job:localhost/replica:0/task:0/cpu:1")
102
103  def testDifferentDeviceCPUGPU(self):
104    if not test_util.is_gpu_available():
105      self.skipTest("No GPU available")
106
107    self._prefetch_fn_helper("cpu_gpu", "/job:localhost/replica:0/task:0/cpu:0",
108                             "/job:localhost/replica:0/task:0/gpu:0")
109
110
111if __name__ == "__main__":
112  test.main()
113