1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.Iterator` using distributed sessions.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.contrib import lookup as lookup_ops 23from tensorflow.core.protobuf import config_pb2 24from tensorflow.python.client import session 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.data.ops import iterator_ops 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import errors 30from tensorflow.python.framework import function 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import test_util 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import functional_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import string_ops 37from tensorflow.python.platform import test 38 39 40class IteratorClusterTest(test.TestCase): 41 42 @test_util.run_v1_only("b/120545219") 43 def testRemoteIteratorWithoutRemoteCallFail(self): 44 worker_config = config_pb2.ConfigProto() 45 worker_config.device_count["CPU"] = 2 46 worker, _ = test_util.create_local_cluster( 47 1, 1, worker_config=worker_config) 48 49 with ops.device("/job:worker/replica:0/task:0/cpu:1"): 50 dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) 51 iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) 52 iterator_3_handle = iterator_3.string_handle() 53 54 with ops.device("/job:worker/replica:0/task:0/cpu:0"): 55 remote_it = iterator_ops.Iterator.from_string_handle( 56 iterator_3_handle, dataset_ops.get_legacy_output_types(dataset_3), 57 dataset_ops.get_legacy_output_shapes(dataset_3)) 58 get_next_op = remote_it.get_next() 59 60 with session.Session(worker[0].target) as sess: 61 with self.assertRaises(errors.InvalidArgumentError): 62 sess.run(get_next_op) 63 64 def _testRemoteIteratorHelper(self, device0, device1, target): 65 with ops.device(device1): 66 dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) 67 iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) 68 iterator_3_handle = iterator_3.string_handle() 69 70 @function.Defun(dtypes.string) 71 def _remote_fn(h): 72 remote_iterator = iterator_ops.Iterator.from_string_handle( 73 h, dataset_ops.get_legacy_output_types(dataset_3), 74 dataset_ops.get_legacy_output_shapes(dataset_3)) 75 return remote_iterator.get_next() 76 77 with ops.device(device0): 78 target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 79 remote_op = functional_ops.remote_call( 80 args=[iterator_3_handle], 81 Tout=[dtypes.int32], 82 f=_remote_fn, 83 target=target_placeholder) 84 85 with session.Session(target) as sess: 86 elem = sess.run(remote_op, feed_dict={target_placeholder: device1}) 87 self.assertEqual(elem, [1]) 88 # Fails when target is cpu:0 where the resource is not located. 89 with self.assertRaises(errors.InvalidArgumentError): 90 sess.run(remote_op, feed_dict={target_placeholder: device0}) 91 elem = sess.run(iterator_3.get_next()) 92 self.assertEqual(elem, [2]) 93 elem = sess.run(remote_op, feed_dict={target_placeholder: device1}) 94 self.assertEqual(elem, [3]) 95 with self.assertRaises(errors.OutOfRangeError): 96 sess.run(remote_op, feed_dict={target_placeholder: device1}) 97 98 @test_util.run_v1_only("b/120545219") 99 def testRemoteIteratorUsingRemoteCallOp(self): 100 worker_config = config_pb2.ConfigProto() 101 worker_config.device_count["CPU"] = 2 102 worker, _ = test_util.create_local_cluster( 103 1, 1, worker_config=worker_config) 104 105 self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0", 106 "/job:worker/replica:0/task:0/cpu:1", 107 worker[0].target) 108 109 @test_util.run_v1_only("b/120545219") 110 def testRemoteIteratorUsingRemoteCallOpCrossProcess(self): 111 workers, _ = test_util.create_local_cluster(2, 1) 112 113 self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0", 114 "/job:worker/replica:0/task:1/cpu:0", 115 workers[0].target) 116 117 @test_util.run_v1_only("b/120545219") 118 def testCaptureHashTableInSharedIterator(self): 119 worker, _ = test_util.create_local_cluster(1, 1) 120 121 # NOTE(mrry): We must use the V2 variants of `HashTable` 122 # etc. because these produce a `tf.resource`-typed output that is 123 # compatible with the in-graph function implementation. 124 default_val = -1 125 keys = constant_op.constant(["brain", "salad", "surgery"]) 126 values = constant_op.constant([0, 1, 2], dtypes.int64) 127 table = lookup_ops.HashTable( 128 lookup_ops.KeyValueTensorInitializer(keys, values), 129 default_val, 130 shared_name="shared_table") 131 132 input_sentences = dataset_ops.Dataset.from_tensor_slices( 133 ["brain brain tank salad surgery", "surgery brain"]) 134 135 iterator = ( 136 input_sentences.map(lambda x: string_ops.string_split([x]).values).map( 137 table.lookup) 138 .make_initializable_iterator(shared_name="shared_iterator")) 139 init_op = iterator.initializer 140 get_next = iterator.get_next() 141 142 with session.Session(worker[0].target) as sess: 143 sess.run(table.initializer) 144 sess.run(init_op) 145 self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next)) 146 147 with session.Session(worker[0].target) as sess: 148 self.assertAllEqual([2, 0], sess.run(get_next)) 149 with self.assertRaises(errors.OutOfRangeError): 150 sess.run(get_next) 151 152 @test_util.run_v1_only("b/120545219") 153 def testImplicitDisposeParallelMapDataset(self): 154 # Tests whether a parallel map dataset will be cleaned up correctly when 155 # the pipeline does not run it until exhaustion. 156 # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> 157 # RepeatDataset(None) -> PrefetchDataset(100). 158 worker, _ = test_util.create_local_cluster(1, 1) 159 160 components = (np.arange(1000), 161 np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis], 162 np.array(37.0) * np.arange(1000)) 163 164 def _map_fn(x, y, z): 165 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 166 167 dataset = ( 168 dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) 169 .repeat(None).prefetch(10000)) 170 171 iterator = dataset_ops.make_initializable_iterator(dataset) 172 init_op = iterator.initializer 173 get_next = iterator.get_next() 174 175 with session.Session(worker[0].target) as sess: 176 sess.run(init_op) 177 for _ in range(3): 178 sess.run(get_next) 179 180 181if __name__ == "__main__": 182 test.main() 183