1# Lint as: python2, python3 2# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# 16# ============================================================================== 17"""TensorFlow API compatibility tests. 18 19This test ensures all changes to the public API of TensorFlow are intended. 20 21If this test fails, it means a change has been made to the public API. Backwards 22incompatible changes are not allowed. You can run the test with 23"--update_goldens" flag set to "True" to update goldens when making changes to 24the public TF python API. 25""" 26 27from __future__ import absolute_import 28from __future__ import division 29from __future__ import print_function 30 31import argparse 32import os 33import re 34import sys 35 36import six 37import tensorflow as tf 38 39from google.protobuf import message 40from google.protobuf import text_format 41 42from tensorflow.python.lib.io import file_io 43from tensorflow.python.platform import resource_loader 44from tensorflow.python.platform import test 45from tensorflow.python.platform import tf_logging as logging 46from tensorflow.tools.api.lib import api_objects_pb2 47from tensorflow.tools.api.lib import python_object_to_proto_visitor 48from tensorflow.tools.common import public_api 49from tensorflow.tools.common import traverse 50 51# pylint: disable=g-import-not-at-top,unused-import 52_TENSORBOARD_AVAILABLE = True 53try: 54 import tensorboard as _tb 55except ImportError: 56 _TENSORBOARD_AVAILABLE = False 57# pylint: enable=g-import-not-at-top,unused-import 58 59# FLAGS defined at the bottom: 60FLAGS = None 61# DEFINE_boolean, update_goldens, default False: 62_UPDATE_GOLDENS_HELP = """ 63 Update stored golden files if API is updated. WARNING: All API changes 64 have to be authorized by TensorFlow leads. 65""" 66 67# DEFINE_boolean, only_test_core_api, default False: 68_ONLY_TEST_CORE_API_HELP = """ 69 Some TF APIs are being moved outside of the tensorflow/ directory. There is 70 no guarantee which versions of these APIs will be present when running this 71 test. Therefore, do not error out on API changes in non-core TF code 72 if this flag is set. 73""" 74 75# DEFINE_boolean, verbose_diffs, default True: 76_VERBOSE_DIFFS_HELP = """ 77 If set to true, print line by line diffs on all libraries. If set to 78 false, only print which libraries have differences. 79""" 80 81# Initialized with _InitPathConstants function below. 82_API_GOLDEN_FOLDER_V1 = None 83_API_GOLDEN_FOLDER_V2 = None 84 85 86def _InitPathConstants(): 87 global _API_GOLDEN_FOLDER_V1 88 global _API_GOLDEN_FOLDER_V2 89 root_golden_path_v2 = os.path.join(resource_loader.get_data_files_path(), 90 '..', 'golden', 'v2', 'tensorflow.pbtxt') 91 92 if FLAGS.update_goldens: 93 root_golden_path_v2 = os.path.realpath(root_golden_path_v2) 94 # Get API directories based on the root golden file. This way 95 # we make sure to resolve symbolic links before creating new files. 96 _API_GOLDEN_FOLDER_V2 = os.path.dirname(root_golden_path_v2) 97 _API_GOLDEN_FOLDER_V1 = os.path.normpath( 98 os.path.join(_API_GOLDEN_FOLDER_V2, '..', 'v1')) 99 100 101_TEST_README_FILE = resource_loader.get_path_to_datafile('README.txt') 102_UPDATE_WARNING_FILE = resource_loader.get_path_to_datafile( 103 'API_UPDATE_WARNING.txt') 104 105_NON_CORE_PACKAGES = ['estimator', 'keras'] 106_V1_APIS_FROM_KERAS = ['layers', 'nn.rnn_cell'] 107_V2_APIS_FROM_KERAS = ['initializers', 'losses', 'metrics', 'optimizers'] 108 109# TODO(annarev): remove this once we test with newer version of 110# estimator that actually has compat v1 version. 111if not hasattr(tf.compat.v1, 'estimator'): 112 tf.compat.v1.estimator = tf.estimator 113 tf.compat.v2.estimator = tf.estimator 114 115 116def _KeyToFilePath(key, api_version): 117 """From a given key, construct a filepath. 118 119 Filepath will be inside golden folder for api_version. 120 121 Args: 122 key: a string used to determine the file path 123 api_version: a number indicating the tensorflow API version, e.g. 1 or 2. 124 125 Returns: 126 A string of file path to the pbtxt file which describes the public API 127 """ 128 129 def _ReplaceCapsWithDash(matchobj): 130 match = matchobj.group(0) 131 return '-%s' % (match.lower()) 132 133 case_insensitive_key = re.sub('([A-Z]{1})', _ReplaceCapsWithDash, 134 six.ensure_str(key)) 135 api_folder = ( 136 _API_GOLDEN_FOLDER_V2 if api_version == 2 else _API_GOLDEN_FOLDER_V1) 137 if key.startswith('tensorflow.experimental.numpy'): 138 # Jumps up one more level in order to let Copybara find the 139 # 'tensorflow/third_party' string to replace 140 api_folder = os.path.join( 141 api_folder, '..', '..', '..', '..', '../third_party', 142 'py', 'numpy', 'tf_numpy_api') 143 api_folder = os.path.normpath(api_folder) 144 return os.path.join(api_folder, '%s.pbtxt' % case_insensitive_key) 145 146 147def _FileNameToKey(filename): 148 """From a given filename, construct a key we use for api objects.""" 149 150 def _ReplaceDashWithCaps(matchobj): 151 match = matchobj.group(0) 152 return match[1].upper() 153 154 base_filename = os.path.basename(filename) 155 base_filename_without_ext = os.path.splitext(base_filename)[0] 156 api_object_key = re.sub('((-[a-z]){1})', _ReplaceDashWithCaps, 157 six.ensure_str(base_filename_without_ext)) 158 return api_object_key 159 160 161def _VerifyNoSubclassOfMessageVisitor(path, parent, unused_children): 162 """A Visitor that crashes on subclasses of generated proto classes.""" 163 # If the traversed object is a proto Message class 164 if not (isinstance(parent, type) and issubclass(parent, message.Message)): 165 return 166 if parent is message.Message: 167 return 168 # Check that it is a direct subclass of Message. 169 if message.Message not in parent.__bases__: 170 raise NotImplementedError( 171 'Object tf.%s is a subclass of a generated proto Message. ' 172 'They are not yet supported by the API tools.' % path) 173 174 175def _FilterNonCoreGoldenFiles(golden_file_list): 176 """Filter out non-core API pbtxt files.""" 177 return _FilterGoldenFilesByPrefix(golden_file_list, _NON_CORE_PACKAGES) 178 179 180def _FilterV1KerasRelatedGoldenFiles(golden_file_list): 181 return _FilterGoldenFilesByPrefix(golden_file_list, _V1_APIS_FROM_KERAS) 182 183 184def _FilterV2KerasRelatedGoldenFiles(golden_file_list): 185 return _FilterGoldenFilesByPrefix(golden_file_list, _V2_APIS_FROM_KERAS) 186 187 188def _FilterGoldenFilesByPrefix(golden_file_list, package_prefixes): 189 filtered_file_list = [] 190 filtered_package_prefixes = ['tensorflow.%s.' % p for p in package_prefixes] 191 for f in golden_file_list: 192 if any( 193 six.ensure_str(f).rsplit('/')[-1].startswith(pre) 194 for pre in filtered_package_prefixes): 195 continue 196 filtered_file_list.append(f) 197 return filtered_file_list 198 199 200def _FilterGoldenProtoDict(golden_proto_dict, omit_golden_symbols_map): 201 """Filter out golden proto dict symbols that should be omitted.""" 202 if not omit_golden_symbols_map: 203 return golden_proto_dict 204 filtered_proto_dict = dict(golden_proto_dict) 205 for key, symbol_list in six.iteritems(omit_golden_symbols_map): 206 api_object = api_objects_pb2.TFAPIObject() 207 api_object.CopyFrom(filtered_proto_dict[key]) 208 filtered_proto_dict[key] = api_object 209 module_or_class = None 210 if api_object.HasField('tf_module'): 211 module_or_class = api_object.tf_module 212 elif api_object.HasField('tf_class'): 213 module_or_class = api_object.tf_class 214 if module_or_class is not None: 215 for members in (module_or_class.member, module_or_class.member_method): 216 filtered_members = [m for m in members if m.name not in symbol_list] 217 # Two steps because protobuf repeated fields disallow slice assignment. 218 del members[:] 219 members.extend(filtered_members) 220 return filtered_proto_dict 221 222 223def _GetTFNumpyGoldenPattern(api_version): 224 return os.path.join(resource_loader.get_root_dir_with_all_resources(), 225 _KeyToFilePath('tensorflow.experimental.numpy*', 226 api_version)) 227 228 229class ApiCompatibilityTest(test.TestCase): 230 231 def __init__(self, *args, **kwargs): 232 super(ApiCompatibilityTest, self).__init__(*args, **kwargs) 233 234 golden_update_warning_filename = os.path.join( 235 resource_loader.get_root_dir_with_all_resources(), _UPDATE_WARNING_FILE) 236 self._update_golden_warning = file_io.read_file_to_string( 237 golden_update_warning_filename) 238 239 test_readme_filename = os.path.join( 240 resource_loader.get_root_dir_with_all_resources(), _TEST_README_FILE) 241 self._test_readme_message = file_io.read_file_to_string( 242 test_readme_filename) 243 244 def _AssertProtoDictEquals(self, 245 expected_dict, 246 actual_dict, 247 verbose=False, 248 update_goldens=False, 249 additional_missing_object_message='', 250 api_version=2): 251 """Diff given dicts of protobufs and report differences a readable way. 252 253 Args: 254 expected_dict: a dict of TFAPIObject protos constructed from golden files. 255 actual_dict: a ict of TFAPIObject protos constructed by reading from the 256 TF package linked to the test. 257 verbose: Whether to log the full diffs, or simply report which files were 258 different. 259 update_goldens: Whether to update goldens when there are diffs found. 260 additional_missing_object_message: Message to print when a symbol is 261 missing. 262 api_version: TensorFlow API version to test. 263 """ 264 diffs = [] 265 verbose_diffs = [] 266 267 expected_keys = set(expected_dict.keys()) 268 actual_keys = set(actual_dict.keys()) 269 only_in_expected = expected_keys - actual_keys 270 only_in_actual = actual_keys - expected_keys 271 all_keys = expected_keys | actual_keys 272 273 # This will be populated below. 274 updated_keys = [] 275 276 for key in all_keys: 277 diff_message = '' 278 verbose_diff_message = '' 279 # First check if the key is not found in one or the other. 280 if key in only_in_expected: 281 diff_message = 'Object %s expected but not found (removed). %s' % ( 282 key, additional_missing_object_message) 283 verbose_diff_message = diff_message 284 elif key in only_in_actual: 285 diff_message = 'New object %s found (added).' % key 286 verbose_diff_message = diff_message 287 else: 288 # Do not truncate diff 289 self.maxDiff = None # pylint: disable=invalid-name 290 # Now we can run an actual proto diff. 291 try: 292 self.assertProtoEquals(expected_dict[key], actual_dict[key]) 293 except AssertionError as e: 294 updated_keys.append(key) 295 diff_message = 'Change detected in python object: %s.' % key 296 verbose_diff_message = str(e) 297 298 # All difference cases covered above. If any difference found, add to the 299 # list. 300 if diff_message: 301 diffs.append(diff_message) 302 verbose_diffs.append(verbose_diff_message) 303 304 # If diffs are found, handle them based on flags. 305 if diffs: 306 diff_count = len(diffs) 307 logging.error(self._test_readme_message) 308 logging.error('%d differences found between API and golden.', diff_count) 309 310 if update_goldens: 311 # Write files if requested. 312 logging.warning(self._update_golden_warning) 313 314 # If the keys are only in expected, some objects are deleted. 315 # Remove files. 316 for key in only_in_expected: 317 filepath = _KeyToFilePath(key, api_version) 318 file_io.delete_file(filepath) 319 320 # If the files are only in actual (current library), these are new 321 # modules. Write them to files. Also record all updates in files. 322 for key in only_in_actual | set(updated_keys): 323 filepath = _KeyToFilePath(key, api_version) 324 file_io.write_string_to_file( 325 filepath, text_format.MessageToString(actual_dict[key])) 326 else: 327 # Include the actual differences to help debugging. 328 for d, verbose_d in zip(diffs, verbose_diffs): 329 logging.error(' %s', d) 330 logging.error(' %s', verbose_d) 331 # Fail if we cannot fix the test by updating goldens. 332 self.fail('%d differences found between API and golden.' % diff_count) 333 334 else: 335 logging.info('No differences found between API and golden.') 336 337 def testNoSubclassOfMessage(self): 338 visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) 339 visitor.do_not_descend_map['tf'].append('contrib') 340 # visitor.do_not_descend_map['tf'].append('keras') 341 # Skip compat.v1 and compat.v2 since they are validated in separate tests. 342 visitor.private_map['tf.compat'] = ['v1', 'v2'] 343 traverse.traverse(tf, visitor) 344 345 def testNoSubclassOfMessageV1(self): 346 if not hasattr(tf.compat, 'v1'): 347 return 348 visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) 349 visitor.do_not_descend_map['tf'].append('contrib') 350 if FLAGS.only_test_core_api: 351 visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES) 352 visitor.private_map['tf.compat'] = ['v1', 'v2'] 353 traverse.traverse(tf.compat.v1, visitor) 354 355 def testNoSubclassOfMessageV2(self): 356 if not hasattr(tf.compat, 'v2'): 357 return 358 visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) 359 visitor.do_not_descend_map['tf'].append('contrib') 360 if FLAGS.only_test_core_api: 361 visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES) 362 visitor.private_map['tf.compat'] = ['v1', 'v2'] 363 traverse.traverse(tf.compat.v2, visitor) 364 365 def _checkBackwardsCompatibility(self, 366 root, 367 golden_file_patterns, 368 api_version, 369 additional_private_map=None, 370 omit_golden_symbols_map=None): 371 # Extract all API stuff. 372 visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() 373 374 public_api_visitor = public_api.PublicAPIVisitor(visitor) 375 public_api_visitor.private_map['tf'].append('contrib') 376 if api_version == 2: 377 public_api_visitor.private_map['tf'].append('enable_v2_behavior') 378 379 public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental'] 380 # Do not descend into these numpy classes because their signatures may be 381 # different between internal and OSS. 382 public_api_visitor.do_not_descend_map['tf.experimental.numpy'] = [ 383 'bool_', 'complex_', 'complex128', 'complex64', 'float_', 'float16', 384 'float32', 'float64', 'inexact', 'int_', 'int16', 'int32', 'int64', 385 'int8', 'object_', 'string_', 'uint16', 'uint32', 'uint64', 'uint8', 386 'unicode_', 'iinfo'] 387 public_api_visitor.do_not_descend_map['tf'].append('keras') 388 if FLAGS.only_test_core_api: 389 public_api_visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES) 390 if api_version == 2: 391 public_api_visitor.do_not_descend_map['tf'].extend(_V2_APIS_FROM_KERAS) 392 else: 393 public_api_visitor.do_not_descend_map['tf'].extend(['layers']) 394 public_api_visitor.do_not_descend_map['tf.nn'] = ['rnn_cell'] 395 if additional_private_map: 396 public_api_visitor.private_map.update(additional_private_map) 397 398 traverse.traverse(root, public_api_visitor) 399 proto_dict = visitor.GetProtos() 400 401 # Read all golden files. 402 golden_file_list = file_io.get_matching_files(golden_file_patterns) 403 if FLAGS.only_test_core_api: 404 golden_file_list = _FilterNonCoreGoldenFiles(golden_file_list) 405 if api_version == 2: 406 golden_file_list = _FilterV2KerasRelatedGoldenFiles(golden_file_list) 407 else: 408 golden_file_list = _FilterV1KerasRelatedGoldenFiles(golden_file_list) 409 410 def _ReadFileToProto(filename): 411 """Read a filename, create a protobuf from its contents.""" 412 ret_val = api_objects_pb2.TFAPIObject() 413 text_format.Merge(file_io.read_file_to_string(filename), ret_val) 414 return ret_val 415 416 golden_proto_dict = { 417 _FileNameToKey(filename): _ReadFileToProto(filename) 418 for filename in golden_file_list 419 } 420 golden_proto_dict = _FilterGoldenProtoDict(golden_proto_dict, 421 omit_golden_symbols_map) 422 423 # Diff them. Do not fail if called with update. 424 # If the test is run to update goldens, only report diffs but do not fail. 425 self._AssertProtoDictEquals( 426 golden_proto_dict, 427 proto_dict, 428 verbose=FLAGS.verbose_diffs, 429 update_goldens=FLAGS.update_goldens, 430 api_version=api_version) 431 432 def testAPIBackwardsCompatibility(self): 433 api_version = 1 434 if hasattr(tf, '_major_api_version') and tf._major_api_version == 2: 435 api_version = 2 436 golden_file_patterns = [ 437 os.path.join(resource_loader.get_root_dir_with_all_resources(), 438 _KeyToFilePath('*', api_version)), 439 _GetTFNumpyGoldenPattern(api_version)] 440 omit_golden_symbols_map = {} 441 if (api_version == 2 and FLAGS.only_test_core_api and 442 not _TENSORBOARD_AVAILABLE): 443 # In TF 2.0 these summary symbols are imported from TensorBoard. 444 omit_golden_symbols_map['tensorflow.summary'] = [ 445 'audio', 'histogram', 'image', 'scalar', 'text' 446 ] 447 448 self._checkBackwardsCompatibility( 449 tf, 450 golden_file_patterns, 451 api_version, 452 # Skip compat.v1 and compat.v2 since they are validated 453 # in separate tests. 454 additional_private_map={'tf.compat': ['v1', 'v2']}, 455 omit_golden_symbols_map=omit_golden_symbols_map) 456 457 # Check that V2 API does not have contrib 458 self.assertTrue(api_version == 1 or not hasattr(tf, 'contrib')) 459 460 def testAPIBackwardsCompatibilityV1(self): 461 api_version = 1 462 golden_file_patterns = os.path.join( 463 resource_loader.get_root_dir_with_all_resources(), 464 _KeyToFilePath('*', api_version)) 465 self._checkBackwardsCompatibility( 466 tf.compat.v1, 467 golden_file_patterns, 468 api_version, 469 additional_private_map={ 470 'tf': ['pywrap_tensorflow'], 471 'tf.compat': ['v1', 'v2'], 472 }, 473 omit_golden_symbols_map={'tensorflow': ['pywrap_tensorflow']}) 474 475 def testAPIBackwardsCompatibilityV2(self): 476 api_version = 2 477 golden_file_patterns = [ 478 os.path.join(resource_loader.get_root_dir_with_all_resources(), 479 _KeyToFilePath('*', api_version)), 480 _GetTFNumpyGoldenPattern(api_version)] 481 omit_golden_symbols_map = {} 482 if FLAGS.only_test_core_api and not _TENSORBOARD_AVAILABLE: 483 # In TF 2.0 these summary symbols are imported from TensorBoard. 484 omit_golden_symbols_map['tensorflow.summary'] = [ 485 'audio', 'histogram', 'image', 'scalar', 'text' 486 ] 487 self._checkBackwardsCompatibility( 488 tf.compat.v2, 489 golden_file_patterns, 490 api_version, 491 additional_private_map={'tf.compat': ['v1', 'v2']}, 492 omit_golden_symbols_map=omit_golden_symbols_map) 493 494 495if __name__ == '__main__': 496 parser = argparse.ArgumentParser() 497 parser.add_argument( 498 '--update_goldens', type=bool, default=False, help=_UPDATE_GOLDENS_HELP) 499 # TODO(mikecase): Create Estimator's own API compatibility test or 500 # a more general API compatibility test for use for TF components. 501 parser.add_argument( 502 '--only_test_core_api', 503 type=bool, 504 default=True, # only_test_core_api default value 505 help=_ONLY_TEST_CORE_API_HELP) 506 parser.add_argument( 507 '--verbose_diffs', type=bool, default=True, help=_VERBOSE_DIFFS_HELP) 508 FLAGS, unparsed = parser.parse_known_args() 509 _InitPathConstants() 510 511 # Now update argv, so that unittest library does not get confused. 512 sys.argv = [sys.argv[0]] + unparsed 513 test.main() 514