• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.Iterator` using distributed sessions."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.contrib import lookup as lookup_ops
23from tensorflow.core.protobuf import config_pb2
24from tensorflow.python.client import session
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.data.ops import iterator_ops
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import errors
30from tensorflow.python.framework import function
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import test_util
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import functional_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import string_ops
37from tensorflow.python.platform import test
38
39
40class IteratorClusterTest(test.TestCase):
41
42  @test_util.run_v1_only("b/120545219")
43  def testRemoteIteratorWithoutRemoteCallFail(self):
44    worker_config = config_pb2.ConfigProto()
45    worker_config.device_count["CPU"] = 2
46    worker, _ = test_util.create_local_cluster(
47        1, 1, worker_config=worker_config)
48
49    with ops.device("/job:worker/replica:0/task:0/cpu:1"):
50      dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
51      iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
52      iterator_3_handle = iterator_3.string_handle()
53
54    with ops.device("/job:worker/replica:0/task:0/cpu:0"):
55      remote_it = iterator_ops.Iterator.from_string_handle(
56          iterator_3_handle, dataset_ops.get_legacy_output_types(dataset_3),
57          dataset_ops.get_legacy_output_shapes(dataset_3))
58      get_next_op = remote_it.get_next()
59
60    with session.Session(worker[0].target) as sess:
61      with self.assertRaises(errors.InvalidArgumentError):
62        sess.run(get_next_op)
63
64  def _testRemoteIteratorHelper(self, device0, device1, target):
65    with ops.device(device1):
66      dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
67      iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
68      iterator_3_handle = iterator_3.string_handle()
69
70    @function.Defun(dtypes.string)
71    def _remote_fn(h):
72      remote_iterator = iterator_ops.Iterator.from_string_handle(
73          h, dataset_ops.get_legacy_output_types(dataset_3),
74          dataset_ops.get_legacy_output_shapes(dataset_3))
75      return remote_iterator.get_next()
76
77    with ops.device(device0):
78      target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
79      remote_op = functional_ops.remote_call(
80          args=[iterator_3_handle],
81          Tout=[dtypes.int32],
82          f=_remote_fn,
83          target=target_placeholder)
84
85    with session.Session(target) as sess:
86      elem = sess.run(remote_op, feed_dict={target_placeholder: device1})
87      self.assertEqual(elem, [1])
88      # Fails when target is cpu:0 where the resource is not located.
89      with self.assertRaises(errors.InvalidArgumentError):
90        sess.run(remote_op, feed_dict={target_placeholder: device0})
91      elem = sess.run(iterator_3.get_next())
92      self.assertEqual(elem, [2])
93      elem = sess.run(remote_op, feed_dict={target_placeholder: device1})
94      self.assertEqual(elem, [3])
95      with self.assertRaises(errors.OutOfRangeError):
96        sess.run(remote_op, feed_dict={target_placeholder: device1})
97
98  @test_util.run_v1_only("b/120545219")
99  def testRemoteIteratorUsingRemoteCallOp(self):
100    worker_config = config_pb2.ConfigProto()
101    worker_config.device_count["CPU"] = 2
102    worker, _ = test_util.create_local_cluster(
103        1, 1, worker_config=worker_config)
104
105    self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0",
106                                   "/job:worker/replica:0/task:0/cpu:1",
107                                   worker[0].target)
108
109  @test_util.run_v1_only("b/120545219")
110  def testRemoteIteratorUsingRemoteCallOpCrossProcess(self):
111    workers, _ = test_util.create_local_cluster(2, 1)
112
113    self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0",
114                                   "/job:worker/replica:0/task:1/cpu:0",
115                                   workers[0].target)
116
117  @test_util.run_v1_only("b/120545219")
118  def testCaptureHashTableInSharedIterator(self):
119    worker, _ = test_util.create_local_cluster(1, 1)
120
121    # NOTE(mrry): We must use the V2 variants of `HashTable`
122    # etc. because these produce a `tf.resource`-typed output that is
123    # compatible with the in-graph function implementation.
124    default_val = -1
125    keys = constant_op.constant(["brain", "salad", "surgery"])
126    values = constant_op.constant([0, 1, 2], dtypes.int64)
127    table = lookup_ops.HashTable(
128        lookup_ops.KeyValueTensorInitializer(keys, values),
129        default_val,
130        shared_name="shared_table")
131
132    input_sentences = dataset_ops.Dataset.from_tensor_slices(
133        ["brain brain tank salad surgery", "surgery brain"])
134
135    iterator = (
136        input_sentences.map(lambda x: string_ops.string_split([x]).values).map(
137            table.lookup)
138        .make_initializable_iterator(shared_name="shared_iterator"))
139    init_op = iterator.initializer
140    get_next = iterator.get_next()
141
142    with session.Session(worker[0].target) as sess:
143      sess.run(table.initializer)
144      sess.run(init_op)
145      self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next))
146
147    with session.Session(worker[0].target) as sess:
148      self.assertAllEqual([2, 0], sess.run(get_next))
149      with self.assertRaises(errors.OutOfRangeError):
150        sess.run(get_next)
151
152  @test_util.run_v1_only("b/120545219")
153  def testImplicitDisposeParallelMapDataset(self):
154    # Tests whether a parallel map dataset will be cleaned up correctly when
155    # the pipeline does not run it until exhaustion.
156    # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
157    # RepeatDataset(None) -> PrefetchDataset(100).
158    worker, _ = test_util.create_local_cluster(1, 1)
159
160    components = (np.arange(1000),
161                  np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
162                  np.array(37.0) * np.arange(1000))
163
164    def _map_fn(x, y, z):
165      return math_ops.square(x), math_ops.square(y), math_ops.square(z)
166
167    dataset = (
168        dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
169        .repeat(None).prefetch(10000))
170
171    iterator = dataset_ops.make_initializable_iterator(dataset)
172    init_op = iterator.initializer
173    get_next = iterator.get_next()
174
175    with session.Session(worker[0].target) as sess:
176      sess.run(init_op)
177      for _ in range(3):
178        sess.run(get_next)
179
180
181if __name__ == "__main__":
182  test.main()
183