1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# 15# ============================================================================== 16"""TensorFlow API compatibility tests. 17 18This test ensures all changes to the public API of TensorFlow are intended. 19 20If this test fails, it means a change has been made to the public API. Backwards 21incompatible changes are not allowed. You can run the test with 22"--update_goldens" flag set to "True" to update goldens when making changes to 23the public TF python API. 24""" 25 26import argparse 27import os 28import re 29import sys 30 31import tensorflow as tf 32 33from google.protobuf import message 34from google.protobuf import text_format 35 36from tensorflow.python.lib.io import file_io 37from tensorflow.python.platform import resource_loader 38from tensorflow.python.platform import test 39from tensorflow.python.platform import tf_logging as logging 40from tensorflow.tools.api.lib import api_objects_pb2 41from tensorflow.tools.api.lib import python_object_to_proto_visitor 42from tensorflow.tools.common import public_api 43from tensorflow.tools.common import traverse 44 45# pylint: disable=g-import-not-at-top,unused-import 46_TENSORBOARD_AVAILABLE = True 47try: 48 import tensorboard as _tb 49except ImportError: 50 _TENSORBOARD_AVAILABLE = False 51# pylint: enable=g-import-not-at-top,unused-import 52 53# FLAGS defined at the bottom: 54FLAGS = None 55# DEFINE_boolean, update_goldens, default False: 56_UPDATE_GOLDENS_HELP = """ 57 Update stored golden files if API is updated. WARNING: All API changes 58 have to be authorized by TensorFlow leads. 59""" 60 61# DEFINE_boolean, only_test_core_api, default False: 62_ONLY_TEST_CORE_API_HELP = """ 63 Some TF APIs are being moved outside of the tensorflow/ directory. There is 64 no guarantee which versions of these APIs will be present when running this 65 test. Therefore, do not error out on API changes in non-core TF code 66 if this flag is set. 67""" 68 69# DEFINE_boolean, verbose_diffs, default True: 70_VERBOSE_DIFFS_HELP = """ 71 If set to true, print line by line diffs on all libraries. If set to 72 false, only print which libraries have differences. 73""" 74 75# Initialized with _InitPathConstants function below. 76_API_GOLDEN_FOLDER_V1 = None 77_API_GOLDEN_FOLDER_V2 = None 78 79 80def _InitPathConstants(): 81 global _API_GOLDEN_FOLDER_V1 82 global _API_GOLDEN_FOLDER_V2 83 root_golden_path_v2 = os.path.join(resource_loader.get_data_files_path(), 84 '..', 'golden', 'v2', 'tensorflow.pbtxt') 85 86 if FLAGS.update_goldens: 87 root_golden_path_v2 = os.path.realpath(root_golden_path_v2) 88 # Get API directories based on the root golden file. This way 89 # we make sure to resolve symbolic links before creating new files. 90 _API_GOLDEN_FOLDER_V2 = os.path.dirname(root_golden_path_v2) 91 _API_GOLDEN_FOLDER_V1 = os.path.normpath( 92 os.path.join(_API_GOLDEN_FOLDER_V2, '..', 'v1')) 93 94 95_TEST_README_FILE = resource_loader.get_path_to_datafile('README.txt') 96_UPDATE_WARNING_FILE = resource_loader.get_path_to_datafile( 97 'API_UPDATE_WARNING.txt') 98 99_NON_CORE_PACKAGES = ['estimator', 'keras'] 100_V1_APIS_FROM_KERAS = ['layers', 'nn.rnn_cell'] 101_V2_APIS_FROM_KERAS = ['initializers', 'losses', 'metrics', 'optimizers'] 102 103# TODO(annarev): remove this once we test with newer version of 104# estimator that actually has compat v1 version. 105if not hasattr(tf.compat.v1, 'estimator'): 106 tf.compat.v1.estimator = tf.estimator 107 tf.compat.v2.estimator = tf.estimator 108 109 110def _KeyToFilePath(key, api_version): 111 """From a given key, construct a filepath. 112 113 Filepath will be inside golden folder for api_version. 114 115 Args: 116 key: a string used to determine the file path 117 api_version: a number indicating the tensorflow API version, e.g. 1 or 2. 118 119 Returns: 120 A string of file path to the pbtxt file which describes the public API 121 """ 122 123 def _ReplaceCapsWithDash(matchobj): 124 match = matchobj.group(0) 125 return '-%s' % (match.lower()) 126 127 case_insensitive_key = re.sub('([A-Z]{1})', _ReplaceCapsWithDash, key) 128 api_folder = ( 129 _API_GOLDEN_FOLDER_V2 if api_version == 2 else _API_GOLDEN_FOLDER_V1) 130 if key.startswith('tensorflow.experimental.numpy'): 131 # Jumps up one more level in order to let Copybara find the 132 # 'tensorflow/third_party' string to replace 133 api_folder = os.path.join( 134 api_folder, '..', '..', '..', '..', '../third_party', 135 'py', 'numpy', 'tf_numpy_api') 136 api_folder = os.path.normpath(api_folder) 137 return os.path.join(api_folder, '%s.pbtxt' % case_insensitive_key) 138 139 140def _FileNameToKey(filename): 141 """From a given filename, construct a key we use for api objects.""" 142 143 def _ReplaceDashWithCaps(matchobj): 144 match = matchobj.group(0) 145 return match[1].upper() 146 147 base_filename = os.path.basename(filename) 148 base_filename_without_ext = os.path.splitext(base_filename)[0] 149 api_object_key = re.sub('((-[a-z]){1})', _ReplaceDashWithCaps, 150 base_filename_without_ext) 151 return api_object_key 152 153 154def _VerifyNoSubclassOfMessageVisitor(path, parent, unused_children): 155 """A Visitor that crashes on subclasses of generated proto classes.""" 156 # If the traversed object is a proto Message class 157 if not (isinstance(parent, type) and issubclass(parent, message.Message)): 158 return 159 if parent is message.Message: 160 return 161 # Check that it is a direct subclass of Message. 162 if message.Message not in parent.__bases__: 163 raise NotImplementedError( 164 'Object tf.%s is a subclass of a generated proto Message. ' 165 'They are not yet supported by the API tools.' % path) 166 167 168def _FilterNonCoreGoldenFiles(golden_file_list): 169 """Filter out non-core API pbtxt files.""" 170 return _FilterGoldenFilesByPrefix(golden_file_list, _NON_CORE_PACKAGES) 171 172 173def _FilterV1KerasRelatedGoldenFiles(golden_file_list): 174 return _FilterGoldenFilesByPrefix(golden_file_list, _V1_APIS_FROM_KERAS) 175 176 177def _FilterV2KerasRelatedGoldenFiles(golden_file_list): 178 return _FilterGoldenFilesByPrefix(golden_file_list, _V2_APIS_FROM_KERAS) 179 180 181def _FilterGoldenFilesByPrefix(golden_file_list, package_prefixes): 182 filtered_file_list = [] 183 filtered_package_prefixes = ['tensorflow.%s.' % p for p in package_prefixes] 184 for f in golden_file_list: 185 if any( 186 f.rsplit('/')[-1].startswith(pre) for pre in filtered_package_prefixes): 187 continue 188 filtered_file_list.append(f) 189 return filtered_file_list 190 191 192def _FilterGoldenProtoDict(golden_proto_dict, omit_golden_symbols_map): 193 """Filter out golden proto dict symbols that should be omitted.""" 194 if not omit_golden_symbols_map: 195 return golden_proto_dict 196 filtered_proto_dict = dict(golden_proto_dict) 197 for key, symbol_list in omit_golden_symbols_map.items(): 198 api_object = api_objects_pb2.TFAPIObject() 199 api_object.CopyFrom(filtered_proto_dict[key]) 200 filtered_proto_dict[key] = api_object 201 module_or_class = None 202 if api_object.HasField('tf_module'): 203 module_or_class = api_object.tf_module 204 elif api_object.HasField('tf_class'): 205 module_or_class = api_object.tf_class 206 if module_or_class is not None: 207 for members in (module_or_class.member, module_or_class.member_method): 208 filtered_members = [m for m in members if m.name not in symbol_list] 209 # Two steps because protobuf repeated fields disallow slice assignment. 210 del members[:] 211 members.extend(filtered_members) 212 return filtered_proto_dict 213 214 215def _GetTFNumpyGoldenPattern(api_version): 216 return os.path.join(resource_loader.get_root_dir_with_all_resources(), 217 _KeyToFilePath('tensorflow.experimental.numpy*', 218 api_version)) 219 220 221class ApiCompatibilityTest(test.TestCase): 222 223 def __init__(self, *args, **kwargs): 224 super(ApiCompatibilityTest, self).__init__(*args, **kwargs) 225 226 golden_update_warning_filename = os.path.join( 227 resource_loader.get_root_dir_with_all_resources(), _UPDATE_WARNING_FILE) 228 self._update_golden_warning = file_io.read_file_to_string( 229 golden_update_warning_filename) 230 231 test_readme_filename = os.path.join( 232 resource_loader.get_root_dir_with_all_resources(), _TEST_README_FILE) 233 self._test_readme_message = file_io.read_file_to_string( 234 test_readme_filename) 235 236 def _AssertProtoDictEquals(self, 237 expected_dict, 238 actual_dict, 239 verbose=False, 240 update_goldens=False, 241 additional_missing_object_message='', 242 api_version=2): 243 """Diff given dicts of protobufs and report differences a readable way. 244 245 Args: 246 expected_dict: a dict of TFAPIObject protos constructed from golden files. 247 actual_dict: a ict of TFAPIObject protos constructed by reading from the 248 TF package linked to the test. 249 verbose: Whether to log the full diffs, or simply report which files were 250 different. 251 update_goldens: Whether to update goldens when there are diffs found. 252 additional_missing_object_message: Message to print when a symbol is 253 missing. 254 api_version: TensorFlow API version to test. 255 """ 256 diffs = [] 257 verbose_diffs = [] 258 259 expected_keys = set(expected_dict.keys()) 260 actual_keys = set(actual_dict.keys()) 261 only_in_expected = expected_keys - actual_keys 262 only_in_actual = actual_keys - expected_keys 263 all_keys = expected_keys | actual_keys 264 265 # This will be populated below. 266 updated_keys = [] 267 268 for key in all_keys: 269 diff_message = '' 270 verbose_diff_message = '' 271 # First check if the key is not found in one or the other. 272 if key in only_in_expected: 273 diff_message = 'Object %s expected but not found (removed). %s' % ( 274 key, additional_missing_object_message) 275 verbose_diff_message = diff_message 276 elif key in only_in_actual: 277 diff_message = 'New object %s found (added).' % key 278 verbose_diff_message = diff_message 279 else: 280 # Do not truncate diff 281 self.maxDiff = None # pylint: disable=invalid-name 282 # Now we can run an actual proto diff. 283 try: 284 self.assertProtoEquals(expected_dict[key], actual_dict[key]) 285 except AssertionError as e: 286 updated_keys.append(key) 287 diff_message = 'Change detected in python object: %s.' % key 288 verbose_diff_message = str(e) 289 290 # All difference cases covered above. If any difference found, add to the 291 # list. 292 if diff_message: 293 diffs.append(diff_message) 294 verbose_diffs.append(verbose_diff_message) 295 296 # If diffs are found, handle them based on flags. 297 if diffs: 298 diff_count = len(diffs) 299 logging.error(self._test_readme_message) 300 logging.error('%d differences found between API and golden.', diff_count) 301 302 if update_goldens: 303 # Write files if requested. 304 logging.warning(self._update_golden_warning) 305 306 # If the keys are only in expected, some objects are deleted. 307 # Remove files. 308 for key in only_in_expected: 309 filepath = _KeyToFilePath(key, api_version) 310 file_io.delete_file(filepath) 311 312 # If the files are only in actual (current library), these are new 313 # modules. Write them to files. Also record all updates in files. 314 for key in only_in_actual | set(updated_keys): 315 filepath = _KeyToFilePath(key, api_version) 316 file_io.write_string_to_file( 317 filepath, text_format.MessageToString(actual_dict[key])) 318 else: 319 # Include the actual differences to help debugging. 320 for d, verbose_d in zip(diffs, verbose_diffs): 321 logging.error(' %s', d) 322 logging.error(' %s', verbose_d) 323 # Fail if we cannot fix the test by updating goldens. 324 self.fail('%d differences found between API and golden.' % diff_count) 325 326 else: 327 logging.info('No differences found between API and golden.') 328 329 def testNoSubclassOfMessage(self): 330 visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) 331 visitor.do_not_descend_map['tf'].append('contrib') 332 # visitor.do_not_descend_map['tf'].append('keras') 333 # Skip compat.v1 and compat.v2 since they are validated in separate tests. 334 visitor.private_map['tf.compat'] = ['v1', 'v2'] 335 traverse.traverse(tf, visitor) 336 337 def testNoSubclassOfMessageV1(self): 338 if not hasattr(tf.compat, 'v1'): 339 return 340 visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) 341 visitor.do_not_descend_map['tf'].append('contrib') 342 if FLAGS.only_test_core_api: 343 visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES) 344 visitor.private_map['tf.compat'] = ['v1', 'v2'] 345 traverse.traverse(tf.compat.v1, visitor) 346 347 def testNoSubclassOfMessageV2(self): 348 if not hasattr(tf.compat, 'v2'): 349 return 350 visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) 351 visitor.do_not_descend_map['tf'].append('contrib') 352 if FLAGS.only_test_core_api: 353 visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES) 354 visitor.private_map['tf.compat'] = ['v1', 'v2'] 355 traverse.traverse(tf.compat.v2, visitor) 356 357 def _checkBackwardsCompatibility(self, 358 root, 359 golden_file_patterns, 360 api_version, 361 additional_private_map=None, 362 omit_golden_symbols_map=None): 363 # Extract all API stuff. 364 visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() 365 366 public_api_visitor = public_api.PublicAPIVisitor(visitor) 367 public_api_visitor.private_map['tf'].append('contrib') 368 if api_version == 2: 369 public_api_visitor.private_map['tf'].append('enable_v2_behavior') 370 371 public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental'] 372 # Do not descend into these numpy classes because their signatures may be 373 # different between internal and OSS. 374 public_api_visitor.do_not_descend_map['tf.experimental.numpy'] = [ 375 'bool_', 'complex_', 'complex128', 'complex64', 'float_', 'float16', 376 'float32', 'float64', 'inexact', 'int_', 'int16', 'int32', 'int64', 377 'int8', 'object_', 'string_', 'uint16', 'uint32', 'uint64', 'uint8', 378 'unicode_', 'iinfo'] 379 public_api_visitor.do_not_descend_map['tf'].append('keras') 380 if FLAGS.only_test_core_api: 381 public_api_visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES) 382 if api_version == 2: 383 public_api_visitor.do_not_descend_map['tf'].extend(_V2_APIS_FROM_KERAS) 384 else: 385 public_api_visitor.do_not_descend_map['tf'].extend(['layers']) 386 public_api_visitor.do_not_descend_map['tf.nn'] = ['rnn_cell'] 387 if additional_private_map: 388 public_api_visitor.private_map.update(additional_private_map) 389 390 traverse.traverse(root, public_api_visitor) 391 proto_dict = visitor.GetProtos() 392 393 # Read all golden files. 394 golden_file_list = file_io.get_matching_files(golden_file_patterns) 395 if FLAGS.only_test_core_api: 396 golden_file_list = _FilterNonCoreGoldenFiles(golden_file_list) 397 if api_version == 2: 398 golden_file_list = _FilterV2KerasRelatedGoldenFiles(golden_file_list) 399 else: 400 golden_file_list = _FilterV1KerasRelatedGoldenFiles(golden_file_list) 401 402 def _ReadFileToProto(filename): 403 """Read a filename, create a protobuf from its contents.""" 404 ret_val = api_objects_pb2.TFAPIObject() 405 text_format.Merge(file_io.read_file_to_string(filename), ret_val) 406 return ret_val 407 408 golden_proto_dict = { 409 _FileNameToKey(filename): _ReadFileToProto(filename) 410 for filename in golden_file_list 411 } 412 golden_proto_dict = _FilterGoldenProtoDict(golden_proto_dict, 413 omit_golden_symbols_map) 414 415 # Diff them. Do not fail if called with update. 416 # If the test is run to update goldens, only report diffs but do not fail. 417 self._AssertProtoDictEquals( 418 golden_proto_dict, 419 proto_dict, 420 verbose=FLAGS.verbose_diffs, 421 update_goldens=FLAGS.update_goldens, 422 api_version=api_version) 423 424 def testAPIBackwardsCompatibility(self): 425 api_version = 1 426 if hasattr(tf, '_major_api_version') and tf._major_api_version == 2: 427 api_version = 2 428 golden_file_patterns = [ 429 os.path.join(resource_loader.get_root_dir_with_all_resources(), 430 _KeyToFilePath('*', api_version)), 431 _GetTFNumpyGoldenPattern(api_version)] 432 omit_golden_symbols_map = {} 433 if (api_version == 2 and FLAGS.only_test_core_api and 434 not _TENSORBOARD_AVAILABLE): 435 # In TF 2.0 these summary symbols are imported from TensorBoard. 436 omit_golden_symbols_map['tensorflow.summary'] = [ 437 'audio', 'histogram', 'image', 'scalar', 'text' 438 ] 439 440 self._checkBackwardsCompatibility( 441 tf, 442 golden_file_patterns, 443 api_version, 444 # Skip compat.v1 and compat.v2 since they are validated 445 # in separate tests. 446 additional_private_map={'tf.compat': ['v1', 'v2']}, 447 omit_golden_symbols_map=omit_golden_symbols_map) 448 449 # Check that V2 API does not have contrib 450 self.assertTrue(api_version == 1 or not hasattr(tf, 'contrib')) 451 452 def testAPIBackwardsCompatibilityV1(self): 453 api_version = 1 454 golden_file_patterns = os.path.join( 455 resource_loader.get_root_dir_with_all_resources(), 456 _KeyToFilePath('*', api_version)) 457 self._checkBackwardsCompatibility( 458 tf.compat.v1, 459 golden_file_patterns, 460 api_version, 461 additional_private_map={ 462 'tf': ['pywrap_tensorflow'], 463 'tf.compat': ['v1', 'v2'], 464 }, 465 omit_golden_symbols_map={'tensorflow': ['pywrap_tensorflow']}) 466 467 def testAPIBackwardsCompatibilityV2(self): 468 api_version = 2 469 golden_file_patterns = [ 470 os.path.join(resource_loader.get_root_dir_with_all_resources(), 471 _KeyToFilePath('*', api_version)), 472 _GetTFNumpyGoldenPattern(api_version)] 473 omit_golden_symbols_map = {} 474 if FLAGS.only_test_core_api and not _TENSORBOARD_AVAILABLE: 475 # In TF 2.0 these summary symbols are imported from TensorBoard. 476 omit_golden_symbols_map['tensorflow.summary'] = [ 477 'audio', 'histogram', 'image', 'scalar', 'text' 478 ] 479 self._checkBackwardsCompatibility( 480 tf.compat.v2, 481 golden_file_patterns, 482 api_version, 483 additional_private_map={'tf.compat': ['v1', 'v2']}, 484 omit_golden_symbols_map=omit_golden_symbols_map) 485 486 487if __name__ == '__main__': 488 parser = argparse.ArgumentParser() 489 parser.add_argument( 490 '--update_goldens', type=bool, default=False, help=_UPDATE_GOLDENS_HELP) 491 # TODO(mikecase): Create Estimator's own API compatibility test or 492 # a more general API compatibility test for use for TF components. 493 parser.add_argument( 494 '--only_test_core_api', 495 type=bool, 496 default=True, # only_test_core_api default value 497 help=_ONLY_TEST_CORE_API_HELP) 498 parser.add_argument( 499 '--verbose_diffs', type=bool, default=True, help=_VERBOSE_DIFFS_HELP) 500 FLAGS, unparsed = parser.parse_known_args() 501 _InitPathConstants() 502 503 # Now update argv, so that unittest library does not get confused. 504 sys.argv = [sys.argv[0]] + unparsed 505 test.main() 506