1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.tools.common.public_api.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.platform import googletest 22from tensorflow.tools.common import public_api 23 24 25class PublicApiTest(googletest.TestCase): 26 27 class TestVisitor(object): 28 29 def __init__(self): 30 self.symbols = set() 31 self.last_parent = None 32 self.last_children = None 33 34 def __call__(self, path, parent, children): 35 self.symbols.add(path) 36 self.last_parent = parent 37 self.last_children = list(children) # Make a copy to preserve state. 38 39 def test_call_forward(self): 40 visitor = self.TestVisitor() 41 children = [('name1', 'thing1'), ('name2', 'thing2')] 42 public_api.PublicAPIVisitor(visitor)('test', 'dummy', children) 43 self.assertEqual(set(['test']), visitor.symbols) 44 self.assertEqual('dummy', visitor.last_parent) 45 self.assertEqual([('name1', 'thing1'), ('name2', 'thing2')], 46 visitor.last_children) 47 48 def test_private_child_removal(self): 49 visitor = self.TestVisitor() 50 children = [('name1', 'thing1'), ('_name2', 'thing2')] 51 public_api.PublicAPIVisitor(visitor)('test', 'dummy', children) 52 # Make sure the private symbols are removed before the visitor is called. 53 self.assertEqual([('name1', 'thing1')], visitor.last_children) 54 self.assertEqual([('name1', 'thing1')], children) 55 56 def test_no_descent_child_removal(self): 57 visitor = self.TestVisitor() 58 children = [('name1', 'thing1'), ('mock', 'thing2')] 59 public_api.PublicAPIVisitor(visitor)('test', 'dummy', children) 60 # Make sure not-to-be-descended-into symbols are removed after the visitor 61 # is called. 62 self.assertEqual([('name1', 'thing1'), ('mock', 'thing2')], 63 visitor.last_children) 64 self.assertEqual([('name1', 'thing1')], children) 65 66 67if __name__ == '__main__': 68 googletest.main() 69