1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for conversion module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import imp 22import sys 23import types 24import weakref 25 26import six 27 28from tensorflow.python.autograph import utils 29from tensorflow.python.autograph.core import config 30from tensorflow.python.autograph.core import converter 31from tensorflow.python.autograph.impl import api 32from tensorflow.python.autograph.impl import conversion 33from tensorflow.python.autograph.impl.testing import pybind_for_testing 34from tensorflow.python.eager import function 35from tensorflow.python.framework import constant_op 36from tensorflow.python.platform import test 37 38 39class ConversionTest(test.TestCase): 40 41 def _simple_program_ctx(self): 42 return converter.ProgramContext( 43 options=converter.ConversionOptions(recursive=True), 44 autograph_module=api) 45 46 def test_is_allowlisted(self): 47 48 def test_fn(): 49 return constant_op.constant(1) 50 51 self.assertFalse(conversion.is_allowlisted(test_fn)) 52 self.assertTrue(conversion.is_allowlisted(utils)) 53 self.assertTrue(conversion.is_allowlisted(constant_op.constant)) 54 55 def test_is_allowlisted_tensorflow_like(self): 56 57 tf_like = imp.new_module('tensorflow_foo') 58 def test_fn(): 59 pass 60 tf_like.test_fn = test_fn 61 test_fn.__module__ = tf_like 62 63 self.assertFalse(conversion.is_allowlisted(tf_like.test_fn)) 64 65 def test_is_allowlisted_callable_allowlisted_call(self): 66 67 allowlisted_mod = imp.new_module('test_allowlisted_call') 68 sys.modules['test_allowlisted_call'] = allowlisted_mod 69 config.CONVERSION_RULES = ((config.DoNotConvert('test_allowlisted_call'),) + 70 config.CONVERSION_RULES) 71 72 class TestClass(object): 73 74 def __call__(self): 75 pass 76 77 def allowlisted_method(self): 78 pass 79 80 TestClass.__module__ = 'test_allowlisted_call' 81 if six.PY2: 82 TestClass.__call__.__func__.__module__ = 'test_allowlisted_call' 83 else: 84 TestClass.__call__.__module__ = 'test_allowlisted_call' 85 86 class Subclass(TestClass): 87 88 def converted_method(self): 89 pass 90 91 tc = Subclass() 92 93 self.assertTrue(conversion.is_allowlisted(TestClass.__call__)) 94 self.assertTrue(conversion.is_allowlisted(tc)) 95 self.assertTrue(conversion.is_allowlisted(tc.__call__)) 96 self.assertTrue(conversion.is_allowlisted(tc.allowlisted_method)) 97 self.assertFalse(conversion.is_allowlisted(Subclass)) 98 self.assertFalse(conversion.is_allowlisted(tc.converted_method)) 99 100 def test_is_allowlisted_tfmethodwrapper(self): 101 102 class TestClass(object): 103 104 def member_function(self): 105 pass 106 107 TestClass.__module__ = 'test_allowlisted_call' 108 test_obj = TestClass() 109 110 def test_fn(self): 111 del self 112 113 bound_method = types.MethodType( 114 test_fn, 115 function.TfMethodTarget( 116 weakref.ref(test_obj), test_obj.member_function)) 117 118 self.assertTrue(conversion.is_allowlisted(bound_method)) 119 120 def test_is_allowlisted_pybind(self): 121 test_object = pybind_for_testing.TestClassDef() 122 with test.mock.patch.object(config, 'CONVERSION_RULES', ()): 123 # TODO(mdan): This should return True for functions and methods. 124 # Note: currently, native bindings are allowlisted by a separate check. 125 self.assertFalse(conversion.is_allowlisted(test_object.method)) 126 127 128if __name__ == '__main__': 129 test.main() 130