1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for prefetching_ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import itertools 21import threading 22 23from tensorflow.contrib.data.python.ops import prefetching_ops 24from tensorflow.core.protobuf import config_pb2 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.data.ops import iterator_ops 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import function 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import test_util 32from tensorflow.python.ops import resource_variable_ops 33from tensorflow.python.platform import test 34 35 36class StagingAreaOpsTest(test.TestCase): 37 38 def setUp(self): 39 self._event = threading.Event() 40 41 def _prefetch_fn_helper(self, buffer_name, device0, device1): 42 worker_config = config_pb2.ConfigProto() 43 worker_config.device_count["CPU"] = 2 44 45 def gen(): 46 for i in itertools.count(start=1, step=1): 47 yield [i + 0.0] 48 if i == 6: 49 self._event.set() 50 51 with ops.device(device0): 52 dataset_3 = dataset_ops.Dataset.from_generator(gen, (dtypes.float32)) 53 iterator_3 = dataset_3.make_one_shot_iterator() 54 iterator_3_handle = iterator_3.string_handle() 55 56 @function.Defun(dtypes.string) 57 def _remote_fn(h): 58 remote_iterator = iterator_ops.Iterator.from_string_handle( 59 h, dataset_3.output_types, dataset_3.output_shapes) 60 return remote_iterator.get_next() 61 62 target = constant_op.constant(device0) 63 with ops.device(device1): 64 buffer_resource_handle = prefetching_ops.function_buffering_resource( 65 f=_remote_fn, 66 target_device=target, 67 string_arg=iterator_3_handle, 68 buffer_size=3, 69 thread_pool_size=2, 70 shared_name=buffer_name) 71 72 with ops.device(device1): 73 prefetch_op = prefetching_ops.function_buffering_resource_get_next( 74 function_buffer_resource=buffer_resource_handle, 75 output_types=[dtypes.float32]) 76 77 with self.test_session(config=worker_config) as sess: 78 elem = sess.run(prefetch_op) 79 self.assertEqual(elem, [1.0]) 80 elem = sess.run(prefetch_op) 81 self.assertEqual(elem, [2.0]) 82 elem = sess.run(prefetch_op) 83 self.assertEqual(elem, [3.0]) 84 elem = sess.run(prefetch_op) 85 self.assertEqual(elem, [4.0]) 86 self._event.wait() 87 elem = sess.run(prefetch_op) 88 self.assertEqual(elem, [5.0]) 89 sess.run( 90 resource_variable_ops.destroy_resource_op( 91 buffer_resource_handle, ignore_lookup_error=True)) 92 93 def testSameDeviceCPU(self): 94 self._prefetch_fn_helper("same_device_cpu", 95 "/job:localhost/replica:0/task:0/cpu:0", 96 "/job:localhost/replica:0/task:0/cpu:0") 97 98 def testDifferentDeviceCPU(self): 99 self._prefetch_fn_helper("diff_device_cpu", 100 "/job:localhost/replica:0/task:0/cpu:0", 101 "/job:localhost/replica:0/task:0/cpu:1") 102 103 def testDifferentDeviceCPUGPU(self): 104 if not test_util.is_gpu_available(): 105 self.skipTest("No GPU available") 106 107 self._prefetch_fn_helper("cpu_gpu", "/job:localhost/replica:0/task:0/cpu:0", 108 "/job:localhost/replica:0/task:0/gpu:0") 109 110 111if __name__ == "__main__": 112 test.main() 113