1# Lint as: python2, python3 2# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests for tf upgrader.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21import os 22import tempfile 23 24import six 25from tensorflow.python.framework import test_util 26from tensorflow.python.platform import test as test_lib 27from tensorflow.tools.compatibility import ast_edits 28from tensorflow.tools.compatibility import tf_upgrade 29 30 31class TestUpgrade(test_util.TensorFlowTestCase): 32 """Test various APIs that have been changed in 1.0. 33 34 We also test whether a converted file is executable. test_file_v0_11.py 35 aims to exhaustively test that API changes are convertible and actually 36 work when run with current TensorFlow. 37 """ 38 39 def _upgrade(self, old_file_text): 40 in_file = six.StringIO(old_file_text) 41 out_file = six.StringIO() 42 upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec()) 43 count, report, errors = ( 44 upgrader.process_opened_file("test.py", in_file, 45 "test_out.py", out_file)) 46 return count, report, errors, out_file.getvalue() 47 48 def testParseError(self): 49 _, report, unused_errors, unused_new_text = self._upgrade( 50 "import tensorflow as tf\na + \n") 51 self.assertNotEqual(six.ensure_str(report).find("Failed to parse"), -1) 52 53 def testReport(self): 54 text = "tf.mul(a, b)\n" 55 _, report, unused_errors, unused_new_text = self._upgrade(text) 56 # This is not a complete test, but it is a sanity test that a report 57 # is generating information. 58 self.assertTrue( 59 six.ensure_str(report).find( 60 "Renamed function `tf.mul` to `tf.multiply`")) 61 62 def testRename(self): 63 text = "tf.mul(a, tf.sub(b, c))\n" 64 _, unused_report, unused_errors, new_text = self._upgrade(text) 65 self.assertEqual(new_text, "tf.multiply(a, tf.subtract(b, c))\n") 66 67 def testRenamePack(self): 68 text = "tf.pack(a)\n" 69 _, unused_report, unused_errors, new_text = self._upgrade(text) 70 self.assertEqual(new_text, "tf.stack(a)\n") 71 text = "tf.unpack(a)\n" 72 _, unused_report, unused_errors, new_text = self._upgrade(text) 73 self.assertEqual(new_text, "tf.unstack(a)\n") 74 75 def testReorder(self): 76 text = "tf.concat(a, b)\ntf.split(a, b, c)\n" 77 _, unused_report, unused_errors, new_text = self._upgrade(text) 78 self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n" 79 "tf.split(axis=a, num_or_size_splits=b, value=c)\n") 80 81 def testConcatReorderWithKeywordArgs(self): 82 text = "tf.concat(concat_dim=a, values=b)\n" 83 _, unused_report, unused_errors, new_text = self._upgrade(text) 84 self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n") 85 text = "tf.concat(values=b, concat_dim=a)\n" 86 _, unused_report, unused_errors, new_text = self._upgrade(text) 87 self.assertEqual(new_text, "tf.concat(values=b, axis=a)\n") 88 text = "tf.concat(a, values=b)\n" 89 _, unused_report, unused_errors, new_text = self._upgrade(text) 90 self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n") 91 92 def testConcatReorderNested(self): 93 text = "tf.concat(a, tf.concat(c, d))\n" 94 _, unused_report, unused_errors, new_text = self._upgrade(text) 95 self.assertEqual( 96 new_text, "tf.concat(axis=a, values=tf.concat(axis=c, values=d))\n") 97 98 def testInitializers(self): 99 text = ("tf.zeros_initializer;tf.zeros_initializer ()\n" 100 "tf.ones_initializer;tf.ones_initializer ()\n") 101 _, unused_report, unused_errors, new_text = self._upgrade(text) 102 self.assertEqual( 103 new_text, "tf.zeros_initializer();tf.zeros_initializer ()\n" 104 "tf.ones_initializer();tf.ones_initializer ()\n") 105 106 def testKeyword(self): 107 text = "tf.reduce_any(a, reduction_indices=[1, 2])\n" 108 _, unused_report, unused_errors, new_text = self._upgrade(text) 109 self.assertEqual(new_text, "tf.reduce_any(a, axis=[1, 2])\n") 110 111 def testComplexExpression(self): 112 text = "(foo + bar)[a].word()" 113 _ = self._upgrade(text) 114 115 def testReverse(self): 116 text = "tf.reverse(a, b)\n" 117 _, unused_report, errors, new_text = self._upgrade(text) 118 self.assertEqual(new_text, new_text) 119 self.assertIn("tf.reverse requires manual check", errors[0]) 120 121 def testListComprehension(self): 122 def _test(input, output): # pylint: disable=redefined-builtin 123 _, unused_report, errors, new_text = self._upgrade(input) 124 self.assertEqual(new_text, output) 125 _test("tf.concat(0, \t[x for x in y])\n", 126 "tf.concat(axis=0, \tvalues=[x for x in y])\n") 127 _test("tf.concat(0,[x for x in y])\n", 128 "tf.concat(axis=0,values=[x for x in y])\n") 129 _test("tf.concat(0,[\nx for x in y])\n", 130 "tf.concat(axis=0,values=[\nx for x in y])\n") 131 _test("tf.concat(0,[\n \tx for x in y])\n", 132 "tf.concat(axis=0,values=[\n \tx for x in y])\n") 133 134 # TODO(aselle): Explicitly not testing command line interface and process_tree 135 # for now, since this is a one off utility. 136 137 138class TestUpgradeFiles(test_util.TensorFlowTestCase): 139 140 def testInplace(self): 141 """Check to make sure we don't have a file system race.""" 142 temp_file = tempfile.NamedTemporaryFile("w", delete=False) 143 original = "tf.mul(a, b)\n" 144 upgraded = "tf.multiply(a, b)\n" 145 temp_file.write(original) 146 temp_file.close() 147 upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec()) 148 upgrader.process_file(temp_file.name, temp_file.name) 149 self.assertAllEqual(open(temp_file.name).read(), upgraded) 150 os.unlink(temp_file.name) 151 152 153if __name__ == "__main__": 154 test_lib.main() 155