• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15# ==============================================================================
16"""TensorFlow API compatibility tests.
17
18This test ensures all changes to the public API of TensorFlow are intended.
19
20If this test fails, it means a change has been made to the public API. Backwards
21incompatible changes are not allowed. You can run the test with
22"--update_goldens" flag set to "True" to update goldens when making changes to
23the public TF python API.
24"""
25
26import argparse
27import os
28import re
29import sys
30
31import tensorflow as tf
32
33from google.protobuf import message
34from google.protobuf import text_format
35
36from tensorflow.python.lib.io import file_io
37from tensorflow.python.platform import resource_loader
38from tensorflow.python.platform import test
39from tensorflow.python.platform import tf_logging as logging
40from tensorflow.tools.api.lib import api_objects_pb2
41from tensorflow.tools.api.lib import python_object_to_proto_visitor
42from tensorflow.tools.common import public_api
43from tensorflow.tools.common import traverse
44
45# pylint: disable=g-import-not-at-top,unused-import
46_TENSORBOARD_AVAILABLE = True
47try:
48  import tensorboard as _tb
49except ImportError:
50  _TENSORBOARD_AVAILABLE = False
51# pylint: enable=g-import-not-at-top,unused-import
52
53# FLAGS defined at the bottom:
54FLAGS = None
55# DEFINE_boolean, update_goldens, default False:
56_UPDATE_GOLDENS_HELP = """
57     Update stored golden files if API is updated. WARNING: All API changes
58     have to be authorized by TensorFlow leads.
59"""
60
61# DEFINE_boolean, only_test_core_api, default False:
62_ONLY_TEST_CORE_API_HELP = """
63    Some TF APIs are being moved outside of the tensorflow/ directory. There is
64    no guarantee which versions of these APIs will be present when running this
65    test. Therefore, do not error out on API changes in non-core TF code
66    if this flag is set.
67"""
68
69# DEFINE_boolean, verbose_diffs, default True:
70_VERBOSE_DIFFS_HELP = """
71     If set to true, print line by line diffs on all libraries. If set to
72     false, only print which libraries have differences.
73"""
74
75# Initialized with _InitPathConstants function below.
76_API_GOLDEN_FOLDER_V1 = None
77_API_GOLDEN_FOLDER_V2 = None
78
79
80def _InitPathConstants():
81  global _API_GOLDEN_FOLDER_V1
82  global _API_GOLDEN_FOLDER_V2
83  root_golden_path_v2 = os.path.join(resource_loader.get_data_files_path(),
84                                     '..', 'golden', 'v2', 'tensorflow.pbtxt')
85
86  if FLAGS.update_goldens:
87    root_golden_path_v2 = os.path.realpath(root_golden_path_v2)
88  # Get API directories based on the root golden file. This way
89  # we make sure to resolve symbolic links before creating new files.
90  _API_GOLDEN_FOLDER_V2 = os.path.dirname(root_golden_path_v2)
91  _API_GOLDEN_FOLDER_V1 = os.path.normpath(
92      os.path.join(_API_GOLDEN_FOLDER_V2, '..', 'v1'))
93
94
95_TEST_README_FILE = resource_loader.get_path_to_datafile('README.txt')
96_UPDATE_WARNING_FILE = resource_loader.get_path_to_datafile(
97    'API_UPDATE_WARNING.txt')
98
99_NON_CORE_PACKAGES = ['estimator', 'keras']
100_V1_APIS_FROM_KERAS = ['layers', 'nn.rnn_cell']
101_V2_APIS_FROM_KERAS = ['initializers', 'losses', 'metrics', 'optimizers']
102
103# TODO(annarev): remove this once we test with newer version of
104# estimator that actually has compat v1 version.
105if not hasattr(tf.compat.v1, 'estimator'):
106  tf.compat.v1.estimator = tf.estimator
107  tf.compat.v2.estimator = tf.estimator
108
109
110def _KeyToFilePath(key, api_version):
111  """From a given key, construct a filepath.
112
113  Filepath will be inside golden folder for api_version.
114
115  Args:
116    key: a string used to determine the file path
117    api_version: a number indicating the tensorflow API version, e.g. 1 or 2.
118
119  Returns:
120    A string of file path to the pbtxt file which describes the public API
121  """
122
123  def _ReplaceCapsWithDash(matchobj):
124    match = matchobj.group(0)
125    return '-%s' % (match.lower())
126
127  case_insensitive_key = re.sub('([A-Z]{1})', _ReplaceCapsWithDash, key)
128  api_folder = (
129      _API_GOLDEN_FOLDER_V2 if api_version == 2 else _API_GOLDEN_FOLDER_V1)
130  if key.startswith('tensorflow.experimental.numpy'):
131    # Jumps up one more level in order to let Copybara find the
132    # 'tensorflow/third_party' string to replace
133    api_folder = os.path.join(
134        api_folder, '..', '..', '..', '..', '../third_party',
135        'py', 'numpy', 'tf_numpy_api')
136    api_folder = os.path.normpath(api_folder)
137  return os.path.join(api_folder, '%s.pbtxt' % case_insensitive_key)
138
139
140def _FileNameToKey(filename):
141  """From a given filename, construct a key we use for api objects."""
142
143  def _ReplaceDashWithCaps(matchobj):
144    match = matchobj.group(0)
145    return match[1].upper()
146
147  base_filename = os.path.basename(filename)
148  base_filename_without_ext = os.path.splitext(base_filename)[0]
149  api_object_key = re.sub('((-[a-z]){1})', _ReplaceDashWithCaps,
150                          base_filename_without_ext)
151  return api_object_key
152
153
154def _VerifyNoSubclassOfMessageVisitor(path, parent, unused_children):
155  """A Visitor that crashes on subclasses of generated proto classes."""
156  # If the traversed object is a proto Message class
157  if not (isinstance(parent, type) and issubclass(parent, message.Message)):
158    return
159  if parent is message.Message:
160    return
161  # Check that it is a direct subclass of Message.
162  if message.Message not in parent.__bases__:
163    raise NotImplementedError(
164        'Object tf.%s is a subclass of a generated proto Message. '
165        'They are not yet supported by the API tools.' % path)
166
167
168def _FilterNonCoreGoldenFiles(golden_file_list):
169  """Filter out non-core API pbtxt files."""
170  return _FilterGoldenFilesByPrefix(golden_file_list, _NON_CORE_PACKAGES)
171
172
173def _FilterV1KerasRelatedGoldenFiles(golden_file_list):
174  return _FilterGoldenFilesByPrefix(golden_file_list, _V1_APIS_FROM_KERAS)
175
176
177def _FilterV2KerasRelatedGoldenFiles(golden_file_list):
178  return _FilterGoldenFilesByPrefix(golden_file_list, _V2_APIS_FROM_KERAS)
179
180
181def _FilterGoldenFilesByPrefix(golden_file_list, package_prefixes):
182  filtered_file_list = []
183  filtered_package_prefixes = ['tensorflow.%s.' % p for p in package_prefixes]
184  for f in golden_file_list:
185    if any(
186        f.rsplit('/')[-1].startswith(pre) for pre in filtered_package_prefixes):
187      continue
188    filtered_file_list.append(f)
189  return filtered_file_list
190
191
192def _FilterGoldenProtoDict(golden_proto_dict, omit_golden_symbols_map):
193  """Filter out golden proto dict symbols that should be omitted."""
194  if not omit_golden_symbols_map:
195    return golden_proto_dict
196  filtered_proto_dict = dict(golden_proto_dict)
197  for key, symbol_list in omit_golden_symbols_map.items():
198    api_object = api_objects_pb2.TFAPIObject()
199    api_object.CopyFrom(filtered_proto_dict[key])
200    filtered_proto_dict[key] = api_object
201    module_or_class = None
202    if api_object.HasField('tf_module'):
203      module_or_class = api_object.tf_module
204    elif api_object.HasField('tf_class'):
205      module_or_class = api_object.tf_class
206    if module_or_class is not None:
207      for members in (module_or_class.member, module_or_class.member_method):
208        filtered_members = [m for m in members if m.name not in symbol_list]
209        # Two steps because protobuf repeated fields disallow slice assignment.
210        del members[:]
211        members.extend(filtered_members)
212  return filtered_proto_dict
213
214
215def _GetTFNumpyGoldenPattern(api_version):
216  return os.path.join(resource_loader.get_root_dir_with_all_resources(),
217                      _KeyToFilePath('tensorflow.experimental.numpy*',
218                                     api_version))
219
220
221class ApiCompatibilityTest(test.TestCase):
222
223  def __init__(self, *args, **kwargs):
224    super(ApiCompatibilityTest, self).__init__(*args, **kwargs)
225
226    golden_update_warning_filename = os.path.join(
227        resource_loader.get_root_dir_with_all_resources(), _UPDATE_WARNING_FILE)
228    self._update_golden_warning = file_io.read_file_to_string(
229        golden_update_warning_filename)
230
231    test_readme_filename = os.path.join(
232        resource_loader.get_root_dir_with_all_resources(), _TEST_README_FILE)
233    self._test_readme_message = file_io.read_file_to_string(
234        test_readme_filename)
235
236  def _AssertProtoDictEquals(self,
237                             expected_dict,
238                             actual_dict,
239                             verbose=False,
240                             update_goldens=False,
241                             additional_missing_object_message='',
242                             api_version=2):
243    """Diff given dicts of protobufs and report differences a readable way.
244
245    Args:
246      expected_dict: a dict of TFAPIObject protos constructed from golden files.
247      actual_dict: a ict of TFAPIObject protos constructed by reading from the
248        TF package linked to the test.
249      verbose: Whether to log the full diffs, or simply report which files were
250        different.
251      update_goldens: Whether to update goldens when there are diffs found.
252      additional_missing_object_message: Message to print when a symbol is
253        missing.
254      api_version: TensorFlow API version to test.
255    """
256    diffs = []
257    verbose_diffs = []
258
259    expected_keys = set(expected_dict.keys())
260    actual_keys = set(actual_dict.keys())
261    only_in_expected = expected_keys - actual_keys
262    only_in_actual = actual_keys - expected_keys
263    all_keys = expected_keys | actual_keys
264
265    # This will be populated below.
266    updated_keys = []
267
268    for key in all_keys:
269      diff_message = ''
270      verbose_diff_message = ''
271      # First check if the key is not found in one or the other.
272      if key in only_in_expected:
273        diff_message = 'Object %s expected but not found (removed). %s' % (
274            key, additional_missing_object_message)
275        verbose_diff_message = diff_message
276      elif key in only_in_actual:
277        diff_message = 'New object %s found (added).' % key
278        verbose_diff_message = diff_message
279      else:
280        # Do not truncate diff
281        self.maxDiff = None  # pylint: disable=invalid-name
282        # Now we can run an actual proto diff.
283        try:
284          self.assertProtoEquals(expected_dict[key], actual_dict[key])
285        except AssertionError as e:
286          updated_keys.append(key)
287          diff_message = 'Change detected in python object: %s.' % key
288          verbose_diff_message = str(e)
289
290      # All difference cases covered above. If any difference found, add to the
291      # list.
292      if diff_message:
293        diffs.append(diff_message)
294        verbose_diffs.append(verbose_diff_message)
295
296    # If diffs are found, handle them based on flags.
297    if diffs:
298      diff_count = len(diffs)
299      logging.error(self._test_readme_message)
300      logging.error('%d differences found between API and golden.', diff_count)
301
302      if update_goldens:
303        # Write files if requested.
304        logging.warning(self._update_golden_warning)
305
306        # If the keys are only in expected, some objects are deleted.
307        # Remove files.
308        for key in only_in_expected:
309          filepath = _KeyToFilePath(key, api_version)
310          file_io.delete_file(filepath)
311
312        # If the files are only in actual (current library), these are new
313        # modules. Write them to files. Also record all updates in files.
314        for key in only_in_actual | set(updated_keys):
315          filepath = _KeyToFilePath(key, api_version)
316          file_io.write_string_to_file(
317              filepath, text_format.MessageToString(actual_dict[key]))
318      else:
319        # Include the actual differences to help debugging.
320        for d, verbose_d in zip(diffs, verbose_diffs):
321          logging.error('    %s', d)
322          logging.error('    %s', verbose_d)
323        # Fail if we cannot fix the test by updating goldens.
324        self.fail('%d differences found between API and golden.' % diff_count)
325
326    else:
327      logging.info('No differences found between API and golden.')
328
329  def testNoSubclassOfMessage(self):
330    visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
331    visitor.do_not_descend_map['tf'].append('contrib')
332    # visitor.do_not_descend_map['tf'].append('keras')
333    # Skip compat.v1 and compat.v2 since they are validated in separate tests.
334    visitor.private_map['tf.compat'] = ['v1', 'v2']
335    traverse.traverse(tf, visitor)
336
337  def testNoSubclassOfMessageV1(self):
338    if not hasattr(tf.compat, 'v1'):
339      return
340    visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
341    visitor.do_not_descend_map['tf'].append('contrib')
342    if FLAGS.only_test_core_api:
343      visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
344    visitor.private_map['tf.compat'] = ['v1', 'v2']
345    traverse.traverse(tf.compat.v1, visitor)
346
347  def testNoSubclassOfMessageV2(self):
348    if not hasattr(tf.compat, 'v2'):
349      return
350    visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
351    visitor.do_not_descend_map['tf'].append('contrib')
352    if FLAGS.only_test_core_api:
353      visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
354    visitor.private_map['tf.compat'] = ['v1', 'v2']
355    traverse.traverse(tf.compat.v2, visitor)
356
357  def _checkBackwardsCompatibility(self,
358                                   root,
359                                   golden_file_patterns,
360                                   api_version,
361                                   additional_private_map=None,
362                                   omit_golden_symbols_map=None):
363    # Extract all API stuff.
364    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
365
366    public_api_visitor = public_api.PublicAPIVisitor(visitor)
367    public_api_visitor.private_map['tf'].append('contrib')
368    if api_version == 2:
369      public_api_visitor.private_map['tf'].append('enable_v2_behavior')
370
371    public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
372    # Do not descend into these numpy classes because their signatures may be
373    # different between internal and OSS.
374    public_api_visitor.do_not_descend_map['tf.experimental.numpy'] = [
375        'bool_', 'complex_', 'complex128', 'complex64', 'float_', 'float16',
376        'float32', 'float64', 'inexact', 'int_', 'int16', 'int32', 'int64',
377        'int8', 'object_', 'string_', 'uint16', 'uint32', 'uint64', 'uint8',
378        'unicode_', 'iinfo']
379    public_api_visitor.do_not_descend_map['tf'].append('keras')
380    if FLAGS.only_test_core_api:
381      public_api_visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
382      if api_version == 2:
383        public_api_visitor.do_not_descend_map['tf'].extend(_V2_APIS_FROM_KERAS)
384      else:
385        public_api_visitor.do_not_descend_map['tf'].extend(['layers'])
386        public_api_visitor.do_not_descend_map['tf.nn'] = ['rnn_cell']
387    if additional_private_map:
388      public_api_visitor.private_map.update(additional_private_map)
389
390    traverse.traverse(root, public_api_visitor)
391    proto_dict = visitor.GetProtos()
392
393    # Read all golden files.
394    golden_file_list = file_io.get_matching_files(golden_file_patterns)
395    if FLAGS.only_test_core_api:
396      golden_file_list = _FilterNonCoreGoldenFiles(golden_file_list)
397      if api_version == 2:
398        golden_file_list = _FilterV2KerasRelatedGoldenFiles(golden_file_list)
399      else:
400        golden_file_list = _FilterV1KerasRelatedGoldenFiles(golden_file_list)
401
402    def _ReadFileToProto(filename):
403      """Read a filename, create a protobuf from its contents."""
404      ret_val = api_objects_pb2.TFAPIObject()
405      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
406      return ret_val
407
408    golden_proto_dict = {
409        _FileNameToKey(filename): _ReadFileToProto(filename)
410        for filename in golden_file_list
411    }
412    golden_proto_dict = _FilterGoldenProtoDict(golden_proto_dict,
413                                               omit_golden_symbols_map)
414
415    # Diff them. Do not fail if called with update.
416    # If the test is run to update goldens, only report diffs but do not fail.
417    self._AssertProtoDictEquals(
418        golden_proto_dict,
419        proto_dict,
420        verbose=FLAGS.verbose_diffs,
421        update_goldens=FLAGS.update_goldens,
422        api_version=api_version)
423
424  def testAPIBackwardsCompatibility(self):
425    api_version = 1
426    if hasattr(tf, '_major_api_version') and tf._major_api_version == 2:
427      api_version = 2
428    golden_file_patterns = [
429        os.path.join(resource_loader.get_root_dir_with_all_resources(),
430                     _KeyToFilePath('*', api_version)),
431        _GetTFNumpyGoldenPattern(api_version)]
432    omit_golden_symbols_map = {}
433    if (api_version == 2 and FLAGS.only_test_core_api and
434        not _TENSORBOARD_AVAILABLE):
435      # In TF 2.0 these summary symbols are imported from TensorBoard.
436      omit_golden_symbols_map['tensorflow.summary'] = [
437          'audio', 'histogram', 'image', 'scalar', 'text'
438      ]
439
440    self._checkBackwardsCompatibility(
441        tf,
442        golden_file_patterns,
443        api_version,
444        # Skip compat.v1 and compat.v2 since they are validated
445        # in separate tests.
446        additional_private_map={'tf.compat': ['v1', 'v2']},
447        omit_golden_symbols_map=omit_golden_symbols_map)
448
449    # Check that V2 API does not have contrib
450    self.assertTrue(api_version == 1 or not hasattr(tf, 'contrib'))
451
452  def testAPIBackwardsCompatibilityV1(self):
453    api_version = 1
454    golden_file_patterns = os.path.join(
455        resource_loader.get_root_dir_with_all_resources(),
456        _KeyToFilePath('*', api_version))
457    self._checkBackwardsCompatibility(
458        tf.compat.v1,
459        golden_file_patterns,
460        api_version,
461        additional_private_map={
462            'tf': ['pywrap_tensorflow'],
463            'tf.compat': ['v1', 'v2'],
464        },
465        omit_golden_symbols_map={'tensorflow': ['pywrap_tensorflow']})
466
467  def testAPIBackwardsCompatibilityV2(self):
468    api_version = 2
469    golden_file_patterns = [
470        os.path.join(resource_loader.get_root_dir_with_all_resources(),
471                     _KeyToFilePath('*', api_version)),
472        _GetTFNumpyGoldenPattern(api_version)]
473    omit_golden_symbols_map = {}
474    if FLAGS.only_test_core_api and not _TENSORBOARD_AVAILABLE:
475      # In TF 2.0 these summary symbols are imported from TensorBoard.
476      omit_golden_symbols_map['tensorflow.summary'] = [
477          'audio', 'histogram', 'image', 'scalar', 'text'
478      ]
479    self._checkBackwardsCompatibility(
480        tf.compat.v2,
481        golden_file_patterns,
482        api_version,
483        additional_private_map={'tf.compat': ['v1', 'v2']},
484        omit_golden_symbols_map=omit_golden_symbols_map)
485
486
487if __name__ == '__main__':
488  parser = argparse.ArgumentParser()
489  parser.add_argument(
490      '--update_goldens', type=bool, default=False, help=_UPDATE_GOLDENS_HELP)
491  # TODO(mikecase): Create Estimator's own API compatibility test or
492  # a more general API compatibility test for use for TF components.
493  parser.add_argument(
494      '--only_test_core_api',
495      type=bool,
496      default=True,  # only_test_core_api default value
497      help=_ONLY_TEST_CORE_API_HELP)
498  parser.add_argument(
499      '--verbose_diffs', type=bool, default=True, help=_VERBOSE_DIFFS_HELP)
500  FLAGS, unparsed = parser.parse_known_args()
501  _InitPathConstants()
502
503  # Now update argv, so that unittest library does not get confused.
504  sys.argv = [sys.argv[0]] + unparsed
505  test.main()
506