• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for conversion module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import imp
22import sys
23import types
24import weakref
25
26import six
27
28from tensorflow.python.autograph import utils
29from tensorflow.python.autograph.core import config
30from tensorflow.python.autograph.core import converter
31from tensorflow.python.autograph.impl import api
32from tensorflow.python.autograph.impl import conversion
33from tensorflow.python.autograph.impl.testing import pybind_for_testing
34from tensorflow.python.eager import function
35from tensorflow.python.framework import constant_op
36from tensorflow.python.platform import test
37
38
39class ConversionTest(test.TestCase):
40
41  def _simple_program_ctx(self):
42    return converter.ProgramContext(
43        options=converter.ConversionOptions(recursive=True),
44        autograph_module=api)
45
46  def test_is_allowlisted(self):
47
48    def test_fn():
49      return constant_op.constant(1)
50
51    self.assertFalse(conversion.is_allowlisted(test_fn))
52    self.assertTrue(conversion.is_allowlisted(utils))
53    self.assertTrue(conversion.is_allowlisted(constant_op.constant))
54
55  def test_is_allowlisted_tensorflow_like(self):
56
57    tf_like = imp.new_module('tensorflow_foo')
58    def test_fn():
59      pass
60    tf_like.test_fn = test_fn
61    test_fn.__module__ = tf_like
62
63    self.assertFalse(conversion.is_allowlisted(tf_like.test_fn))
64
65  def test_is_allowlisted_callable_allowlisted_call(self):
66
67    allowlisted_mod = imp.new_module('test_allowlisted_call')
68    sys.modules['test_allowlisted_call'] = allowlisted_mod
69    config.CONVERSION_RULES = ((config.DoNotConvert('test_allowlisted_call'),) +
70                               config.CONVERSION_RULES)
71
72    class TestClass(object):
73
74      def __call__(self):
75        pass
76
77      def allowlisted_method(self):
78        pass
79
80    TestClass.__module__ = 'test_allowlisted_call'
81    if six.PY2:
82      TestClass.__call__.__func__.__module__ = 'test_allowlisted_call'
83    else:
84      TestClass.__call__.__module__ = 'test_allowlisted_call'
85
86    class Subclass(TestClass):
87
88      def converted_method(self):
89        pass
90
91    tc = Subclass()
92
93    self.assertTrue(conversion.is_allowlisted(TestClass.__call__))
94    self.assertTrue(conversion.is_allowlisted(tc))
95    self.assertTrue(conversion.is_allowlisted(tc.__call__))
96    self.assertTrue(conversion.is_allowlisted(tc.allowlisted_method))
97    self.assertFalse(conversion.is_allowlisted(Subclass))
98    self.assertFalse(conversion.is_allowlisted(tc.converted_method))
99
100  def test_is_allowlisted_tfmethodwrapper(self):
101
102    class TestClass(object):
103
104      def member_function(self):
105        pass
106
107    TestClass.__module__ = 'test_allowlisted_call'
108    test_obj = TestClass()
109
110    def test_fn(self):
111      del self
112
113    bound_method = types.MethodType(
114        test_fn,
115        function.TfMethodTarget(
116            weakref.ref(test_obj), test_obj.member_function))
117
118    self.assertTrue(conversion.is_allowlisted(bound_method))
119
120  def test_is_allowlisted_pybind(self):
121    test_object = pybind_for_testing.TestClassDef()
122    with test.mock.patch.object(config, 'CONVERSION_RULES', ()):
123      # TODO(mdan): This should return True for functions and methods.
124      # Note: currently, native bindings are allowlisted by a separate check.
125      self.assertFalse(conversion.is_allowlisted(test_object.method))
126
127
128if __name__ == '__main__':
129  test.main()
130