1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tf upgrader.""" 16 17import io 18import os 19import tempfile 20 21from tensorflow.python.framework import test_util 22from tensorflow.python.platform import test as test_lib 23from tensorflow.tools.compatibility import ast_edits 24from tensorflow.tools.compatibility import tf_upgrade 25 26 27class TestUpgrade(test_util.TensorFlowTestCase): 28 """Test various APIs that have been changed in 1.0. 29 30 We also test whether a converted file is executable. test_file_v0_11.py 31 aims to exhaustively test that API changes are convertible and actually 32 work when run with current TensorFlow. 33 """ 34 35 def _upgrade(self, old_file_text): 36 in_file = io.StringIO(old_file_text) 37 out_file = io.StringIO() 38 upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec()) 39 count, report, errors = ( 40 upgrader.process_opened_file("test.py", in_file, 41 "test_out.py", out_file)) 42 return count, report, errors, out_file.getvalue() 43 44 def testParseError(self): 45 _, report, unused_errors, unused_new_text = self._upgrade( 46 "import tensorflow as tf\na + \n") 47 self.assertNotEqual(report.find("Failed to parse"), -1) 48 49 def testReport(self): 50 text = "tf.mul(a, b)\n" 51 _, report, unused_errors, unused_new_text = self._upgrade(text) 52 # This is not a complete test, but it is a sanity test that a report 53 # is generating information. 54 self.assertTrue(report.find("Renamed function `tf.mul` to `tf.multiply`")) 55 56 def testRename(self): 57 text = "tf.mul(a, tf.sub(b, c))\n" 58 _, unused_report, unused_errors, new_text = self._upgrade(text) 59 self.assertEqual(new_text, "tf.multiply(a, tf.subtract(b, c))\n") 60 61 def testRenamePack(self): 62 text = "tf.pack(a)\n" 63 _, unused_report, unused_errors, new_text = self._upgrade(text) 64 self.assertEqual(new_text, "tf.stack(a)\n") 65 text = "tf.unpack(a)\n" 66 _, unused_report, unused_errors, new_text = self._upgrade(text) 67 self.assertEqual(new_text, "tf.unstack(a)\n") 68 69 def testReorder(self): 70 text = "tf.concat(a, b)\ntf.split(a, b, c)\n" 71 _, unused_report, unused_errors, new_text = self._upgrade(text) 72 self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n" 73 "tf.split(axis=a, num_or_size_splits=b, value=c)\n") 74 75 def testConcatReorderWithKeywordArgs(self): 76 text = "tf.concat(concat_dim=a, values=b)\n" 77 _, unused_report, unused_errors, new_text = self._upgrade(text) 78 self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n") 79 text = "tf.concat(values=b, concat_dim=a)\n" 80 _, unused_report, unused_errors, new_text = self._upgrade(text) 81 self.assertEqual(new_text, "tf.concat(values=b, axis=a)\n") 82 text = "tf.concat(a, values=b)\n" 83 _, unused_report, unused_errors, new_text = self._upgrade(text) 84 self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n") 85 86 def testConcatReorderNested(self): 87 text = "tf.concat(a, tf.concat(c, d))\n" 88 _, unused_report, unused_errors, new_text = self._upgrade(text) 89 self.assertEqual( 90 new_text, "tf.concat(axis=a, values=tf.concat(axis=c, values=d))\n") 91 92 def testInitializers(self): 93 text = ("tf.zeros_initializer;tf.zeros_initializer ()\n" 94 "tf.ones_initializer;tf.ones_initializer ()\n") 95 _, unused_report, unused_errors, new_text = self._upgrade(text) 96 self.assertEqual( 97 new_text, "tf.zeros_initializer();tf.zeros_initializer ()\n" 98 "tf.ones_initializer();tf.ones_initializer ()\n") 99 100 def testKeyword(self): 101 text = "tf.reduce_any(a, reduction_indices=[1, 2])\n" 102 _, unused_report, unused_errors, new_text = self._upgrade(text) 103 self.assertEqual(new_text, "tf.reduce_any(a, axis=[1, 2])\n") 104 105 def testComplexExpression(self): 106 text = "(foo + bar)[a].word()" 107 _ = self._upgrade(text) 108 109 def testReverse(self): 110 text = "tf.reverse(a, b)\n" 111 _, unused_report, errors, new_text = self._upgrade(text) 112 self.assertEqual(new_text, new_text) 113 self.assertIn("tf.reverse requires manual check", errors[0]) 114 115 def testListComprehension(self): 116 def _test(input, output): # pylint: disable=redefined-builtin 117 _, unused_report, errors, new_text = self._upgrade(input) 118 self.assertEqual(new_text, output) 119 _test("tf.concat(0, \t[x for x in y])\n", 120 "tf.concat(axis=0, \tvalues=[x for x in y])\n") 121 _test("tf.concat(0,[x for x in y])\n", 122 "tf.concat(axis=0,values=[x for x in y])\n") 123 _test("tf.concat(0,[\nx for x in y])\n", 124 "tf.concat(axis=0,values=[\nx for x in y])\n") 125 _test("tf.concat(0,[\n \tx for x in y])\n", 126 "tf.concat(axis=0,values=[\n \tx for x in y])\n") 127 128 # TODO(aselle): Explicitly not testing command line interface and process_tree 129 # for now, since this is a one off utility. 130 131 132class TestUpgradeFiles(test_util.TensorFlowTestCase): 133 134 def testInplace(self): 135 """Check to make sure we don't have a file system race.""" 136 temp_file = tempfile.NamedTemporaryFile("w", delete=False) 137 original = "tf.mul(a, b)\n" 138 upgraded = "tf.multiply(a, b)\n" 139 temp_file.write(original) 140 temp_file.close() 141 upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec()) 142 upgrader.process_file(temp_file.name, temp_file.name) 143 self.assertAllEqual(open(temp_file.name).read(), upgraded) 144 os.unlink(temp_file.name) 145 146 147if __name__ == "__main__": 148 test_lib.main() 149