1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tf upgrader.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20import os 21import tempfile 22import six 23from tensorflow.python.framework import test_util 24from tensorflow.python.platform import test as test_lib 25from tensorflow.tools.compatibility import ast_edits 26from tensorflow.tools.compatibility import tf_upgrade 27 28 29class TestUpgrade(test_util.TensorFlowTestCase): 30 """Test various APIs that have been changed in 1.0. 31 32 We also test whether a converted file is executable. test_file_v0_11.py 33 aims to exhaustively test that API changes are convertible and actually 34 work when run with current TensorFlow. 35 """ 36 37 def _upgrade(self, old_file_text): 38 in_file = six.StringIO(old_file_text) 39 out_file = six.StringIO() 40 upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec()) 41 count, report, errors = ( 42 upgrader.process_opened_file("test.py", in_file, 43 "test_out.py", out_file)) 44 return count, report, errors, out_file.getvalue() 45 46 def testParseError(self): 47 _, report, unused_errors, unused_new_text = self._upgrade( 48 "import tensorflow as tf\na + \n") 49 self.assertTrue(report.find("Failed to parse") != -1) 50 51 def testReport(self): 52 text = "tf.mul(a, b)\n" 53 _, report, unused_errors, unused_new_text = self._upgrade(text) 54 # This is not a complete test, but it is a sanity test that a report 55 # is generating information. 56 self.assertTrue(report.find("Renamed function `tf.mul` to `tf.multiply`")) 57 58 def testRename(self): 59 text = "tf.mul(a, tf.sub(b, c))\n" 60 _, unused_report, unused_errors, new_text = self._upgrade(text) 61 self.assertEqual(new_text, "tf.multiply(a, tf.subtract(b, c))\n") 62 63 def testRenamePack(self): 64 text = "tf.pack(a)\n" 65 _, unused_report, unused_errors, new_text = self._upgrade(text) 66 self.assertEqual(new_text, "tf.stack(a)\n") 67 text = "tf.unpack(a)\n" 68 _, unused_report, unused_errors, new_text = self._upgrade(text) 69 self.assertEqual(new_text, "tf.unstack(a)\n") 70 71 def testReorder(self): 72 text = "tf.concat(a, b)\ntf.split(a, b, c)\n" 73 _, unused_report, unused_errors, new_text = self._upgrade(text) 74 self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n" 75 "tf.split(axis=a, num_or_size_splits=b, value=c)\n") 76 77 def testConcatReorderWithKeywordArgs(self): 78 text = "tf.concat(concat_dim=a, values=b)\n" 79 _, unused_report, unused_errors, new_text = self._upgrade(text) 80 self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n") 81 text = "tf.concat(values=b, concat_dim=a)\n" 82 _, unused_report, unused_errors, new_text = self._upgrade(text) 83 self.assertEqual(new_text, "tf.concat(values=b, axis=a)\n") 84 text = "tf.concat(a, values=b)\n" 85 _, unused_report, unused_errors, new_text = self._upgrade(text) 86 self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n") 87 88 def testConcatReorderNested(self): 89 text = "tf.concat(a, tf.concat(c, d))\n" 90 _, unused_report, unused_errors, new_text = self._upgrade(text) 91 self.assertEqual( 92 new_text, "tf.concat(axis=a, values=tf.concat(axis=c, values=d))\n") 93 94 def testInitializers(self): 95 text = ("tf.zeros_initializer;tf.zeros_initializer ()\n" 96 "tf.ones_initializer;tf.ones_initializer ()\n") 97 _, unused_report, unused_errors, new_text = self._upgrade(text) 98 self.assertEqual( 99 new_text, "tf.zeros_initializer();tf.zeros_initializer ()\n" 100 "tf.ones_initializer();tf.ones_initializer ()\n") 101 102 def testKeyword(self): 103 text = "tf.reduce_any(a, reduction_indices=[1, 2])\n" 104 _, unused_report, unused_errors, new_text = self._upgrade(text) 105 self.assertEqual(new_text, "tf.reduce_any(a, axis=[1, 2])\n") 106 107 def testComplexExpression(self): 108 text = "(foo + bar)[a].word()" 109 _ = self._upgrade(text) 110 111 def testReverse(self): 112 text = "tf.reverse(a, b)\n" 113 _, unused_report, errors, new_text = self._upgrade(text) 114 self.assertEqual(new_text, new_text) 115 self.assertIn("tf.reverse requires manual check", errors[0]) 116 117 def testListComprehension(self): 118 def _test(input, output): # pylint: disable=redefined-builtin 119 _, unused_report, errors, new_text = self._upgrade(input) 120 self.assertEqual(new_text, output) 121 _test("tf.concat(0, \t[x for x in y])\n", 122 "tf.concat(axis=0, \tvalues=[x for x in y])\n") 123 _test("tf.concat(0,[x for x in y])\n", 124 "tf.concat(axis=0,values=[x for x in y])\n") 125 _test("tf.concat(0,[\nx for x in y])\n", 126 "tf.concat(axis=0,values=[\nx for x in y])\n") 127 _test("tf.concat(0,[\n \tx for x in y])\n", 128 "tf.concat(axis=0,values=[\n \tx for x in y])\n") 129 130 # TODO(aselle): Explicitly not testing command line interface and process_tree 131 # for now, since this is a one off utility. 132 133 134class TestUpgradeFiles(test_util.TensorFlowTestCase): 135 136 def testInplace(self): 137 """Check to make sure we don't have a file system race.""" 138 temp_file = tempfile.NamedTemporaryFile("w", delete=False) 139 original = "tf.mul(a, b)\n" 140 upgraded = "tf.multiply(a, b)\n" 141 temp_file.write(original) 142 temp_file.close() 143 upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec()) 144 upgrader.process_file(temp_file.name, temp_file.name) 145 self.assertAllEqual(open(temp_file.name).read(), upgraded) 146 os.unlink(temp_file.name) 147 148 149if __name__ == "__main__": 150 test_lib.main() 151