# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= """Tests for tpu_function helpers.""" from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import test from tensorflow.python.tpu import tpu_sharding class ShardingTest(test.TestCase): def testFreeze(self): """Tests that freezing a policy applies default values.""" p1 = tpu_sharding.ShardingPolicy() p1.freeze() self.assertEqual(p1.number_of_shards, tpu_sharding._DEFAULT_NUMBER_OF_SHARDS) self.assertEqual(p1.shard_dimension, tpu_sharding._DEFAULT_SHARD_DIMENSION) p2 = tpu_sharding.ShardingPolicy() p2.set_number_of_shards(17) p2.set_shard_dimension(23) p2.freeze() self.assertEqual(p2.number_of_shards, 17) self.assertEqual(p2.shard_dimension, 23) def testFrozen(self): """Tests that frozen policies can't be changed.""" p1 = tpu_sharding.ShardingPolicy() p1.freeze() with self.assertRaises(ValueError): p1.set_number_of_shards(17) with self.assertRaises(ValueError): p1.set_shard_dimension(22) def testStr(self): """Tests the string representation.""" p1 = tpu_sharding.ShardingPolicy() self.assertEqual(str(p1), "ShardingPolicy(unset)") p1.set_number_of_shards(17) self.assertEqual(str(p1), "ShardingPolicy(unset)") p1.set_shard_dimension(8) self.assertEqual(str(p1), "ShardingPolicy(17 shards dimension 8)") def testMerge(self): """Tests that merging works.""" p1 = tpu_sharding.ShardingPolicy() p1.set_number_of_shards(17) p1.set_shard_dimension(23) p2 = tpu_sharding.ShardingPolicy() p2.merge(p1) self.assertEqual(p2.number_of_shards, 17) self.assertEqual(p2.shard_dimension, 23) p1 = tpu_sharding.ShardingPolicy() p1.set_shard_dimension(12) p2.merge(p1) self.assertEqual(p2.number_of_shards, 17) self.assertEqual(p2.shard_dimension, 12) p2.freeze() p2.merge(p1) self.assertEqual(p2.number_of_shards, 17) self.assertEqual(p2.shard_dimension, 12) p1.set_number_of_shards(1) with self.assertRaises(ValueError): p2.merge(p1) p1 = tpu_sharding.ShardingPolicy() p1.set_number_of_shards(17) p2.merge(p1) p1.set_shard_dimension(2) with self.assertRaises(ValueError): p2.merge(p1) def testGetShardedShape(self): """Tests getting a sharded shape.""" p = tpu_sharding.ShardingPolicy() p.set_number_of_shards(3) p.set_shard_dimension(1) self.assertEqual(p.get_sharded_shape([4, 9]), [4, 3]) p.freeze() with self.assertRaises(ValueError): p.set_shard_dimension(0) with self.assertRaises(ValueError): _ = p.get_sharded_shape([4, 9], shard_index=4) with self.assertRaises(ValueError): _ = p.get_sharded_shape([4, 9], shard_index=-1) with self.assertRaises(TypeError): _ = p.get_sharded_shape("not_a_shape") with self.assertRaises(ValueError): _ = p.get_sharded_shape(tensor_shape.TensorShape(None)) with self.assertRaises(ValueError): _ = p.get_sharded_shape([4, 10], shard_index=-1) def testGetUnpartitionedShape(self): """Tests getting a sharded shape.""" p = tpu_sharding.ShardingPolicy() p.set_number_of_shards(3) p.set_shard_dimension(1) p.set_number_of_partitions(4) self.assertEqual(p.get_unpartitioned_shape([3, 5]), [3, 20]) p.freeze() with self.assertRaises(ValueError): _ = p.get_unpartitioned_shape([3, None]) def testGetUnshardedShape(self): """Tests getting an unsharded shape.""" p = tpu_sharding.ShardingPolicy() p.set_number_of_shards(2) p.set_shard_dimension(1) self.assertEqual(p.get_unsharded_shape([[4, 3], [4, 3]]), [4, 6]) with self.assertRaises(ValueError): _ = p.get_unsharded_shape([[4, 3]]) with self.assertRaises(ValueError): _ = p.get_unsharded_shape([[4, 3], [4, 3], [4, 3]]) with self.assertRaises(ValueError): _ = p.get_unsharded_shape([[4, 3], [4, 2]]) with self.assertRaises(TypeError): _ = p.get_unsharded_shape([[4, 3], "not_a_shape"]) with self.assertRaises(ValueError): _ = p.get_unsharded_shape([None, [4, 3]]) with self.assertRaises(ValueError): _ = p.get_unsharded_shape([[2], [4, 3]]) def testScalar(self): """Tests sharding and unsharding scalars.""" p = tpu_sharding.ShardingPolicy() p.freeze() self.assertEqual(p.get_sharded_shape([]), []) self.assertEqual(p.get_unsharded_shape([[]]), []) if __name__ == "__main__": test.main()