1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Tests for TPU InfeedQueue methods.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.platform import test 26from tensorflow.python.tpu import tpu_feed 27 28 29class InfeedTest(test.TestCase): 30 31 def testConstructor(self): 32 """Tests that the constructor can be called with different arguments.""" 33 i = tpu_feed.InfeedQueue(number_of_tuple_elements=2) 34 self.assertEqual(i.number_of_tuple_elements, 2) 35 self.assertEqual(i.tuple_types, None) 36 self.assertEqual(i.tuple_shapes, None) 37 self.assertEqual(i.number_of_shards, None) 38 i = tpu_feed.InfeedQueue( 39 tuple_types=[dtypes.float32, dtypes.int32, dtypes.int32]) 40 self.assertEqual(i.number_of_tuple_elements, 3) 41 self.assertEqual(i.tuple_types, 42 [dtypes.float32, dtypes.int32, dtypes.int32]) 43 self.assertEqual(i.tuple_shapes, None) 44 self.assertEqual(i.number_of_shards, None) 45 i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]]) 46 self.assertEqual(i.number_of_tuple_elements, 2) 47 self.assertEqual(i.tuple_types, None) 48 self.assertEqual(i.tuple_shapes, [[1], [2, 3]]) 49 self.assertEqual(i.number_of_shards, None) 50 i = tpu_feed.InfeedQueue(shard_dimensions=[1, 0, 7]) 51 self.assertEqual(i.number_of_tuple_elements, 3) 52 self.assertEqual(i.tuple_types, None) 53 self.assertEqual(i.tuple_shapes, None) 54 self.assertEqual([p.shard_dimension 55 for p in i.sharding_policies], [1, 0, 7]) 56 with self.assertRaises(ValueError): 57 i = tpu_feed.InfeedQueue() 58 with self.assertRaises(ValueError): 59 i = tpu_feed.InfeedQueue( 60 number_of_tuple_elements=2, tuple_types=[dtypes.float32]) 61 with self.assertRaises(ValueError): 62 i = tpu_feed.InfeedQueue(number_of_tuple_elements=2, tuple_shapes=[[1]]) 63 with self.assertRaises(ValueError): 64 i = tpu_feed.InfeedQueue(number_of_tuple_elements=2, shard_dimensions=[1]) 65 with self.assertRaises(ValueError): 66 i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]], shard_dimensions=[1]) 67 68 def testModification(self): 69 """Tests modification of the queue post-construction.""" 70 i = tpu_feed.InfeedQueue(number_of_tuple_elements=2) 71 i.set_tuple_types([dtypes.float32, dtypes.int32]) 72 self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) 73 i.set_tuple_types([dtypes.float32, dtypes.float32]) 74 self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.float32]) 75 with self.assertRaises(ValueError): 76 i.set_tuple_types([dtypes.float32]) 77 i.set_tuple_shapes([[1], [2, 3]]) 78 self.assertEqual(i.tuple_shapes, [[1], [2, 3]]) 79 i.set_tuple_shapes([[1, 2], [3, 4]]) 80 self.assertEqual(i.tuple_shapes, [[1, 2], [3, 4]]) 81 with self.assertRaises(ValueError): 82 i.set_tuple_shapes([[1, 2]]) 83 i.set_number_of_shards(2) 84 self.assertEqual(i.number_of_shards, 2) 85 i.set_number_of_shards(3) 86 self.assertEqual(i.number_of_shards, 3) 87 t1 = constant_op.constant(1, dtypes.int32, shape=[6]) 88 t2 = constant_op.constant(2.0, dtypes.float32, shape=[3, 18]) 89 i.set_configuration_from_input_tensors([t1, t2]) 90 self.assertEqual(i.tuple_shapes, [[6], [3, 18]]) 91 self.assertEqual(i.tuple_types, [dtypes.int32, dtypes.float32]) 92 i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]]) 93 self.assertEqual(i.number_of_shards, 2) 94 self.assertEqual(i.tuple_shapes, [[6, 18], [12]]) 95 self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) 96 i.set_shard_dimensions([1, 0]) 97 i.set_number_of_shards(3) 98 with self.assertRaises(ValueError): 99 i.set_number_of_shards(4) 100 101 def testFreezing(self): 102 """Tests freezing the queue.""" 103 i = tpu_feed.InfeedQueue(number_of_tuple_elements=2) 104 t1 = constant_op.constant(1, dtypes.int32, shape=[2]) 105 t2 = constant_op.constant(2.0, dtypes.float32, shape=[2, 4]) 106 i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]]) 107 self.assertEqual(i.number_of_shards, 2) 108 self.assertEqual(i.tuple_shapes, [[4, 4], [4]]) 109 self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) 110 self.assertEqual(i.shard_dimensions, [0, 0]) 111 i.freeze() 112 i.set_number_of_shards(2) 113 i.set_tuple_shapes([[4, 4], [4]]) 114 i.set_tuple_types([dtypes.float32, dtypes.int32]) 115 i.set_shard_dimensions([0, 0]) 116 with self.assertRaises(ValueError): 117 i.set_number_of_shards(1) 118 with self.assertRaises(ValueError): 119 i.set_tuple_shapes([[8, 8], [8]]) 120 with self.assertRaises(ValueError): 121 i.set_tuple_types([dtypes.int32, dtypes.float32]) 122 with self.assertRaises(ValueError): 123 i.set_shard_dimensions([1, 0]) 124 self.assertEqual(i.number_of_shards, 2) 125 self.assertEqual(i.tuple_shapes, [[4, 4], [4]]) 126 self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) 127 self.assertEqual(i.shard_dimensions, [0, 0]) 128 129if __name__ == '__main__': 130 test.main() 131