1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Tests for tpu_function helpers.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22 23from tensorflow.python.framework import tensor_shape 24from tensorflow.python.platform import test 25from tensorflow.python.tpu import tpu_sharding 26 27 28class ShardingTest(test.TestCase): 29 30 def testFreeze(self): 31 """Tests that freezing a policy applies default values.""" 32 p1 = tpu_sharding.ShardingPolicy() 33 p1.freeze() 34 self.assertEqual(p1.number_of_shards, 35 tpu_sharding._DEFAULT_NUMBER_OF_SHARDS) 36 self.assertEqual(p1.shard_dimension, tpu_sharding._DEFAULT_SHARD_DIMENSION) 37 p2 = tpu_sharding.ShardingPolicy() 38 p2.set_number_of_shards(17) 39 p2.set_shard_dimension(23) 40 p2.freeze() 41 self.assertEqual(p2.number_of_shards, 17) 42 self.assertEqual(p2.shard_dimension, 23) 43 44 def testFrozen(self): 45 """Tests that frozen policies can't be changed.""" 46 p1 = tpu_sharding.ShardingPolicy() 47 p1.freeze() 48 with self.assertRaises(ValueError): 49 p1.set_number_of_shards(17) 50 with self.assertRaises(ValueError): 51 p1.set_shard_dimension(22) 52 53 def testStr(self): 54 """Tests the string representation.""" 55 p1 = tpu_sharding.ShardingPolicy() 56 self.assertEqual(str(p1), "ShardingPolicy(unset)") 57 p1.set_number_of_shards(17) 58 self.assertEqual(str(p1), "ShardingPolicy(unset)") 59 p1.set_shard_dimension(8) 60 self.assertEqual(str(p1), "ShardingPolicy(17 shards dimension 8)") 61 62 def testMerge(self): 63 """Tests that merging works.""" 64 p1 = tpu_sharding.ShardingPolicy() 65 p1.set_number_of_shards(17) 66 p1.set_shard_dimension(23) 67 p2 = tpu_sharding.ShardingPolicy() 68 p2.merge(p1) 69 self.assertEqual(p2.number_of_shards, 17) 70 self.assertEqual(p2.shard_dimension, 23) 71 p1 = tpu_sharding.ShardingPolicy() 72 p1.set_shard_dimension(12) 73 p2.merge(p1) 74 self.assertEqual(p2.number_of_shards, 17) 75 self.assertEqual(p2.shard_dimension, 12) 76 p2.freeze() 77 p2.merge(p1) 78 self.assertEqual(p2.number_of_shards, 17) 79 self.assertEqual(p2.shard_dimension, 12) 80 p1.set_number_of_shards(1) 81 with self.assertRaises(ValueError): 82 p2.merge(p1) 83 p1 = tpu_sharding.ShardingPolicy() 84 p1.set_number_of_shards(17) 85 p2.merge(p1) 86 p1.set_shard_dimension(2) 87 with self.assertRaises(ValueError): 88 p2.merge(p1) 89 90 def testGetShardedShape(self): 91 """Tests getting a sharded shape.""" 92 p = tpu_sharding.ShardingPolicy() 93 p.set_number_of_shards(3) 94 p.set_shard_dimension(1) 95 self.assertEqual(p.get_sharded_shape([4, 9]), [4, 3]) 96 p.freeze() 97 with self.assertRaises(ValueError): 98 p.set_shard_dimension(0) 99 with self.assertRaises(ValueError): 100 _ = p.get_sharded_shape([4, 9], shard_index=4) 101 with self.assertRaises(ValueError): 102 _ = p.get_sharded_shape([4, 9], shard_index=-1) 103 with self.assertRaises(TypeError): 104 _ = p.get_sharded_shape("not_a_shape") 105 with self.assertRaises(ValueError): 106 _ = p.get_sharded_shape(tensor_shape.TensorShape(None)) 107 with self.assertRaises(ValueError): 108 _ = p.get_sharded_shape([4, 10], shard_index=-1) 109 110 def testGetUnshardedShape(self): 111 """Tests getting an unsharded shape.""" 112 p = tpu_sharding.ShardingPolicy() 113 p.set_number_of_shards(2) 114 p.set_shard_dimension(1) 115 self.assertEqual(p.get_unsharded_shape([[4, 3], [4, 3]]), [4, 6]) 116 with self.assertRaises(ValueError): 117 _ = p.get_unsharded_shape([[4, 3]]) 118 with self.assertRaises(ValueError): 119 _ = p.get_unsharded_shape([[4, 3], [4, 3], [4, 3]]) 120 with self.assertRaises(ValueError): 121 _ = p.get_unsharded_shape([[4, 3], [4, 2]]) 122 with self.assertRaises(TypeError): 123 _ = p.get_unsharded_shape([[4, 3], "not_a_shape"]) 124 with self.assertRaises(ValueError): 125 _ = p.get_unsharded_shape([None, [4, 3]]) 126 with self.assertRaises(ValueError): 127 _ = p.get_unsharded_shape([[2], [4, 3]]) 128 129 def testScalar(self): 130 """Tests sharding and unsharding scalars.""" 131 p = tpu_sharding.ShardingPolicy() 132 p.freeze() 133 self.assertEqual(p.get_sharded_shape([]), []) 134 self.assertEqual(p.get_unsharded_shape([[]]), []) 135 136 137if __name__ == "__main__": 138 test.main() 139