1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tensorflow::FunctionParameterCanonicalizer`.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.platform import test 22from tensorflow.python.util import _function_parameter_canonicalizer_binding_for_test 23 24 25class FunctionParameterCanonicalizerTest(test.TestCase): 26 27 def setUp(self): 28 super(FunctionParameterCanonicalizerTest, self).setUp() 29 self._matmul_func = ( 30 _function_parameter_canonicalizer_binding_for_test 31 .FunctionParameterCanonicalizer([ 32 'a', 'b', 'transpose_a', 'transpose_b', 'adjoint_a', 'adjoint_b', 33 'a_is_sparse', 'b_is_sparse', 'name' 34 ], (False, False, False, False, False, False, None))) 35 36 def testPosOnly(self): 37 self.assertEqual( 38 self._matmul_func.canonicalize(2, 3), 39 [2, 3, False, False, False, False, False, False, None]) 40 41 def testPosOnly2(self): 42 self.assertEqual( 43 self._matmul_func.canonicalize(2, 3, True, False, True), 44 [2, 3, True, False, True, False, False, False, None]) 45 46 def testPosAndKwd(self): 47 self.assertEqual( 48 self._matmul_func.canonicalize( 49 2, 3, transpose_a=True, name='my_matmul'), 50 [2, 3, True, False, False, False, False, False, 'my_matmul']) 51 52 def testPosAndKwd2(self): 53 self.assertEqual( 54 self._matmul_func.canonicalize(2, b=3), 55 [2, 3, False, False, False, False, False, False, None]) 56 57 def testMissingPos(self): 58 with self.assertRaisesRegex(TypeError, 59 'Missing required positional argument'): 60 self._matmul_func.canonicalize(2) 61 62 def testMissingPos2(self): 63 with self.assertRaisesRegex(TypeError, 64 'Missing required positional argument'): 65 self._matmul_func.canonicalize( 66 transpose_a=True, transpose_b=True, adjoint_a=True) 67 68 def testTooManyArgs(self): 69 with self.assertRaisesRegex(TypeError, 'Too many arguments were given'): 70 self._matmul_func.canonicalize(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) 71 72 def testInvalidKwd(self): 73 with self.assertRaisesRegex(TypeError, 74 'Got an unexpected keyword argument'): 75 self._matmul_func.canonicalize(2, 3, hohoho=True) 76 77 def testDuplicatedArg(self): 78 with self.assertRaisesRegex(TypeError, 79 "Got multiple values for argument 'b'"): 80 self._matmul_func.canonicalize(2, 3, False, b=4) 81 82 def testDuplicatedArg2(self): 83 with self.assertRaisesRegex( 84 TypeError, "Got multiple values for argument 'transpose_a'"): 85 self._matmul_func.canonicalize(2, 3, False, transpose_a=True) 86 87 88if __name__ == '__main__': 89 test.main() 90