• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# ==============================================================================
17"""TensorFlow API compatibility tests.
18
19This test ensures all changes to the public API of TensorFlow are intended.
20
21If this test fails, it means a change has been made to the public API. Backwards
22incompatible changes are not allowed. You can run the test with
23"--update_goldens" flag set to "True" to update goldens when making changes to
24the public TF python API.
25"""
26
27from __future__ import absolute_import
28from __future__ import division
29from __future__ import print_function
30
31import argparse
32import os
33import re
34import sys
35
36import six
37import tensorflow as tf
38
39from google.protobuf import message
40from google.protobuf import text_format
41
42from tensorflow.python.lib.io import file_io
43from tensorflow.python.platform import resource_loader
44from tensorflow.python.platform import test
45from tensorflow.python.platform import tf_logging as logging
46from tensorflow.tools.api.lib import api_objects_pb2
47from tensorflow.tools.api.lib import python_object_to_proto_visitor
48from tensorflow.tools.common import public_api
49from tensorflow.tools.common import traverse
50
51# pylint: disable=g-import-not-at-top,unused-import
52_TENSORBOARD_AVAILABLE = True
53try:
54  import tensorboard as _tb
55except ImportError:
56  _TENSORBOARD_AVAILABLE = False
57# pylint: enable=g-import-not-at-top,unused-import
58
59# FLAGS defined at the bottom:
60FLAGS = None
61# DEFINE_boolean, update_goldens, default False:
62_UPDATE_GOLDENS_HELP = """
63     Update stored golden files if API is updated. WARNING: All API changes
64     have to be authorized by TensorFlow leads.
65"""
66
67# DEFINE_boolean, only_test_core_api, default False:
68_ONLY_TEST_CORE_API_HELP = """
69    Some TF APIs are being moved outside of the tensorflow/ directory. There is
70    no guarantee which versions of these APIs will be present when running this
71    test. Therefore, do not error out on API changes in non-core TF code
72    if this flag is set.
73"""
74
75# DEFINE_boolean, verbose_diffs, default True:
76_VERBOSE_DIFFS_HELP = """
77     If set to true, print line by line diffs on all libraries. If set to
78     false, only print which libraries have differences.
79"""
80
81# Initialized with _InitPathConstants function below.
82_API_GOLDEN_FOLDER_V1 = None
83_API_GOLDEN_FOLDER_V2 = None
84
85
86def _InitPathConstants():
87  global _API_GOLDEN_FOLDER_V1
88  global _API_GOLDEN_FOLDER_V2
89  root_golden_path_v2 = os.path.join(resource_loader.get_data_files_path(),
90                                     '..', 'golden', 'v2', 'tensorflow.pbtxt')
91
92  if FLAGS.update_goldens:
93    root_golden_path_v2 = os.path.realpath(root_golden_path_v2)
94  # Get API directories based on the root golden file. This way
95  # we make sure to resolve symbolic links before creating new files.
96  _API_GOLDEN_FOLDER_V2 = os.path.dirname(root_golden_path_v2)
97  _API_GOLDEN_FOLDER_V1 = os.path.normpath(
98      os.path.join(_API_GOLDEN_FOLDER_V2, '..', 'v1'))
99
100
101_TEST_README_FILE = resource_loader.get_path_to_datafile('README.txt')
102_UPDATE_WARNING_FILE = resource_loader.get_path_to_datafile(
103    'API_UPDATE_WARNING.txt')
104
105_NON_CORE_PACKAGES = ['estimator', 'keras']
106_V1_APIS_FROM_KERAS = ['layers', 'nn.rnn_cell']
107_V2_APIS_FROM_KERAS = ['initializers', 'losses', 'metrics', 'optimizers']
108
109# TODO(annarev): remove this once we test with newer version of
110# estimator that actually has compat v1 version.
111if not hasattr(tf.compat.v1, 'estimator'):
112  tf.compat.v1.estimator = tf.estimator
113  tf.compat.v2.estimator = tf.estimator
114
115
116def _KeyToFilePath(key, api_version):
117  """From a given key, construct a filepath.
118
119  Filepath will be inside golden folder for api_version.
120
121  Args:
122    key: a string used to determine the file path
123    api_version: a number indicating the tensorflow API version, e.g. 1 or 2.
124
125  Returns:
126    A string of file path to the pbtxt file which describes the public API
127  """
128
129  def _ReplaceCapsWithDash(matchobj):
130    match = matchobj.group(0)
131    return '-%s' % (match.lower())
132
133  case_insensitive_key = re.sub('([A-Z]{1})', _ReplaceCapsWithDash,
134                                six.ensure_str(key))
135  api_folder = (
136      _API_GOLDEN_FOLDER_V2 if api_version == 2 else _API_GOLDEN_FOLDER_V1)
137  if key.startswith('tensorflow.experimental.numpy'):
138    # Jumps up one more level in order to let Copybara find the
139    # 'tensorflow/third_party' string to replace
140    api_folder = os.path.join(
141        api_folder, '..', '..', '..', '..', '../third_party',
142        'py', 'numpy', 'tf_numpy_api')
143    api_folder = os.path.normpath(api_folder)
144  return os.path.join(api_folder, '%s.pbtxt' % case_insensitive_key)
145
146
147def _FileNameToKey(filename):
148  """From a given filename, construct a key we use for api objects."""
149
150  def _ReplaceDashWithCaps(matchobj):
151    match = matchobj.group(0)
152    return match[1].upper()
153
154  base_filename = os.path.basename(filename)
155  base_filename_without_ext = os.path.splitext(base_filename)[0]
156  api_object_key = re.sub('((-[a-z]){1})', _ReplaceDashWithCaps,
157                          six.ensure_str(base_filename_without_ext))
158  return api_object_key
159
160
161def _VerifyNoSubclassOfMessageVisitor(path, parent, unused_children):
162  """A Visitor that crashes on subclasses of generated proto classes."""
163  # If the traversed object is a proto Message class
164  if not (isinstance(parent, type) and issubclass(parent, message.Message)):
165    return
166  if parent is message.Message:
167    return
168  # Check that it is a direct subclass of Message.
169  if message.Message not in parent.__bases__:
170    raise NotImplementedError(
171        'Object tf.%s is a subclass of a generated proto Message. '
172        'They are not yet supported by the API tools.' % path)
173
174
175def _FilterNonCoreGoldenFiles(golden_file_list):
176  """Filter out non-core API pbtxt files."""
177  return _FilterGoldenFilesByPrefix(golden_file_list, _NON_CORE_PACKAGES)
178
179
180def _FilterV1KerasRelatedGoldenFiles(golden_file_list):
181  return _FilterGoldenFilesByPrefix(golden_file_list, _V1_APIS_FROM_KERAS)
182
183
184def _FilterV2KerasRelatedGoldenFiles(golden_file_list):
185  return _FilterGoldenFilesByPrefix(golden_file_list, _V2_APIS_FROM_KERAS)
186
187
188def _FilterGoldenFilesByPrefix(golden_file_list, package_prefixes):
189  filtered_file_list = []
190  filtered_package_prefixes = ['tensorflow.%s.' % p for p in package_prefixes]
191  for f in golden_file_list:
192    if any(
193        six.ensure_str(f).rsplit('/')[-1].startswith(pre)
194        for pre in filtered_package_prefixes):
195      continue
196    filtered_file_list.append(f)
197  return filtered_file_list
198
199
200def _FilterGoldenProtoDict(golden_proto_dict, omit_golden_symbols_map):
201  """Filter out golden proto dict symbols that should be omitted."""
202  if not omit_golden_symbols_map:
203    return golden_proto_dict
204  filtered_proto_dict = dict(golden_proto_dict)
205  for key, symbol_list in six.iteritems(omit_golden_symbols_map):
206    api_object = api_objects_pb2.TFAPIObject()
207    api_object.CopyFrom(filtered_proto_dict[key])
208    filtered_proto_dict[key] = api_object
209    module_or_class = None
210    if api_object.HasField('tf_module'):
211      module_or_class = api_object.tf_module
212    elif api_object.HasField('tf_class'):
213      module_or_class = api_object.tf_class
214    if module_or_class is not None:
215      for members in (module_or_class.member, module_or_class.member_method):
216        filtered_members = [m for m in members if m.name not in symbol_list]
217        # Two steps because protobuf repeated fields disallow slice assignment.
218        del members[:]
219        members.extend(filtered_members)
220  return filtered_proto_dict
221
222
223def _GetTFNumpyGoldenPattern(api_version):
224  return os.path.join(resource_loader.get_root_dir_with_all_resources(),
225                      _KeyToFilePath('tensorflow.experimental.numpy*',
226                                     api_version))
227
228
229class ApiCompatibilityTest(test.TestCase):
230
231  def __init__(self, *args, **kwargs):
232    super(ApiCompatibilityTest, self).__init__(*args, **kwargs)
233
234    golden_update_warning_filename = os.path.join(
235        resource_loader.get_root_dir_with_all_resources(), _UPDATE_WARNING_FILE)
236    self._update_golden_warning = file_io.read_file_to_string(
237        golden_update_warning_filename)
238
239    test_readme_filename = os.path.join(
240        resource_loader.get_root_dir_with_all_resources(), _TEST_README_FILE)
241    self._test_readme_message = file_io.read_file_to_string(
242        test_readme_filename)
243
244  def _AssertProtoDictEquals(self,
245                             expected_dict,
246                             actual_dict,
247                             verbose=False,
248                             update_goldens=False,
249                             additional_missing_object_message='',
250                             api_version=2):
251    """Diff given dicts of protobufs and report differences a readable way.
252
253    Args:
254      expected_dict: a dict of TFAPIObject protos constructed from golden files.
255      actual_dict: a ict of TFAPIObject protos constructed by reading from the
256        TF package linked to the test.
257      verbose: Whether to log the full diffs, or simply report which files were
258        different.
259      update_goldens: Whether to update goldens when there are diffs found.
260      additional_missing_object_message: Message to print when a symbol is
261        missing.
262      api_version: TensorFlow API version to test.
263    """
264    diffs = []
265    verbose_diffs = []
266
267    expected_keys = set(expected_dict.keys())
268    actual_keys = set(actual_dict.keys())
269    only_in_expected = expected_keys - actual_keys
270    only_in_actual = actual_keys - expected_keys
271    all_keys = expected_keys | actual_keys
272
273    # This will be populated below.
274    updated_keys = []
275
276    for key in all_keys:
277      diff_message = ''
278      verbose_diff_message = ''
279      # First check if the key is not found in one or the other.
280      if key in only_in_expected:
281        diff_message = 'Object %s expected but not found (removed). %s' % (
282            key, additional_missing_object_message)
283        verbose_diff_message = diff_message
284      elif key in only_in_actual:
285        diff_message = 'New object %s found (added).' % key
286        verbose_diff_message = diff_message
287      else:
288        # Do not truncate diff
289        self.maxDiff = None  # pylint: disable=invalid-name
290        # Now we can run an actual proto diff.
291        try:
292          self.assertProtoEquals(expected_dict[key], actual_dict[key])
293        except AssertionError as e:
294          updated_keys.append(key)
295          diff_message = 'Change detected in python object: %s.' % key
296          verbose_diff_message = str(e)
297
298      # All difference cases covered above. If any difference found, add to the
299      # list.
300      if diff_message:
301        diffs.append(diff_message)
302        verbose_diffs.append(verbose_diff_message)
303
304    # If diffs are found, handle them based on flags.
305    if diffs:
306      diff_count = len(diffs)
307      logging.error(self._test_readme_message)
308      logging.error('%d differences found between API and golden.', diff_count)
309
310      if update_goldens:
311        # Write files if requested.
312        logging.warning(self._update_golden_warning)
313
314        # If the keys are only in expected, some objects are deleted.
315        # Remove files.
316        for key in only_in_expected:
317          filepath = _KeyToFilePath(key, api_version)
318          file_io.delete_file(filepath)
319
320        # If the files are only in actual (current library), these are new
321        # modules. Write them to files. Also record all updates in files.
322        for key in only_in_actual | set(updated_keys):
323          filepath = _KeyToFilePath(key, api_version)
324          file_io.write_string_to_file(
325              filepath, text_format.MessageToString(actual_dict[key]))
326      else:
327        # Include the actual differences to help debugging.
328        for d, verbose_d in zip(diffs, verbose_diffs):
329          logging.error('    %s', d)
330          logging.error('    %s', verbose_d)
331        # Fail if we cannot fix the test by updating goldens.
332        self.fail('%d differences found between API and golden.' % diff_count)
333
334    else:
335      logging.info('No differences found between API and golden.')
336
337  def testNoSubclassOfMessage(self):
338    visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
339    visitor.do_not_descend_map['tf'].append('contrib')
340    # visitor.do_not_descend_map['tf'].append('keras')
341    # Skip compat.v1 and compat.v2 since they are validated in separate tests.
342    visitor.private_map['tf.compat'] = ['v1', 'v2']
343    traverse.traverse(tf, visitor)
344
345  def testNoSubclassOfMessageV1(self):
346    if not hasattr(tf.compat, 'v1'):
347      return
348    visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
349    visitor.do_not_descend_map['tf'].append('contrib')
350    if FLAGS.only_test_core_api:
351      visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
352    visitor.private_map['tf.compat'] = ['v1', 'v2']
353    traverse.traverse(tf.compat.v1, visitor)
354
355  def testNoSubclassOfMessageV2(self):
356    if not hasattr(tf.compat, 'v2'):
357      return
358    visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
359    visitor.do_not_descend_map['tf'].append('contrib')
360    if FLAGS.only_test_core_api:
361      visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
362    visitor.private_map['tf.compat'] = ['v1', 'v2']
363    traverse.traverse(tf.compat.v2, visitor)
364
365  def _checkBackwardsCompatibility(self,
366                                   root,
367                                   golden_file_patterns,
368                                   api_version,
369                                   additional_private_map=None,
370                                   omit_golden_symbols_map=None):
371    # Extract all API stuff.
372    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
373
374    public_api_visitor = public_api.PublicAPIVisitor(visitor)
375    public_api_visitor.private_map['tf'].append('contrib')
376    if api_version == 2:
377      public_api_visitor.private_map['tf'].append('enable_v2_behavior')
378
379    public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
380    # Do not descend into these numpy classes because their signatures may be
381    # different between internal and OSS.
382    public_api_visitor.do_not_descend_map['tf.experimental.numpy'] = [
383        'bool_', 'complex_', 'complex128', 'complex64', 'float_', 'float16',
384        'float32', 'float64', 'inexact', 'int_', 'int16', 'int32', 'int64',
385        'int8', 'object_', 'string_', 'uint16', 'uint32', 'uint64', 'uint8',
386        'unicode_', 'iinfo']
387    public_api_visitor.do_not_descend_map['tf'].append('keras')
388    if FLAGS.only_test_core_api:
389      public_api_visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
390      if api_version == 2:
391        public_api_visitor.do_not_descend_map['tf'].extend(_V2_APIS_FROM_KERAS)
392      else:
393        public_api_visitor.do_not_descend_map['tf'].extend(['layers'])
394        public_api_visitor.do_not_descend_map['tf.nn'] = ['rnn_cell']
395    if additional_private_map:
396      public_api_visitor.private_map.update(additional_private_map)
397
398    traverse.traverse(root, public_api_visitor)
399    proto_dict = visitor.GetProtos()
400
401    # Read all golden files.
402    golden_file_list = file_io.get_matching_files(golden_file_patterns)
403    if FLAGS.only_test_core_api:
404      golden_file_list = _FilterNonCoreGoldenFiles(golden_file_list)
405      if api_version == 2:
406        golden_file_list = _FilterV2KerasRelatedGoldenFiles(golden_file_list)
407      else:
408        golden_file_list = _FilterV1KerasRelatedGoldenFiles(golden_file_list)
409
410    def _ReadFileToProto(filename):
411      """Read a filename, create a protobuf from its contents."""
412      ret_val = api_objects_pb2.TFAPIObject()
413      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
414      return ret_val
415
416    golden_proto_dict = {
417        _FileNameToKey(filename): _ReadFileToProto(filename)
418        for filename in golden_file_list
419    }
420    golden_proto_dict = _FilterGoldenProtoDict(golden_proto_dict,
421                                               omit_golden_symbols_map)
422
423    # Diff them. Do not fail if called with update.
424    # If the test is run to update goldens, only report diffs but do not fail.
425    self._AssertProtoDictEquals(
426        golden_proto_dict,
427        proto_dict,
428        verbose=FLAGS.verbose_diffs,
429        update_goldens=FLAGS.update_goldens,
430        api_version=api_version)
431
432  def testAPIBackwardsCompatibility(self):
433    api_version = 1
434    if hasattr(tf, '_major_api_version') and tf._major_api_version == 2:
435      api_version = 2
436    golden_file_patterns = [
437        os.path.join(resource_loader.get_root_dir_with_all_resources(),
438                     _KeyToFilePath('*', api_version)),
439        _GetTFNumpyGoldenPattern(api_version)]
440    omit_golden_symbols_map = {}
441    if (api_version == 2 and FLAGS.only_test_core_api and
442        not _TENSORBOARD_AVAILABLE):
443      # In TF 2.0 these summary symbols are imported from TensorBoard.
444      omit_golden_symbols_map['tensorflow.summary'] = [
445          'audio', 'histogram', 'image', 'scalar', 'text'
446      ]
447
448    self._checkBackwardsCompatibility(
449        tf,
450        golden_file_patterns,
451        api_version,
452        # Skip compat.v1 and compat.v2 since they are validated
453        # in separate tests.
454        additional_private_map={'tf.compat': ['v1', 'v2']},
455        omit_golden_symbols_map=omit_golden_symbols_map)
456
457    # Check that V2 API does not have contrib
458    self.assertTrue(api_version == 1 or not hasattr(tf, 'contrib'))
459
460  def testAPIBackwardsCompatibilityV1(self):
461    api_version = 1
462    golden_file_patterns = os.path.join(
463        resource_loader.get_root_dir_with_all_resources(),
464        _KeyToFilePath('*', api_version))
465    self._checkBackwardsCompatibility(
466        tf.compat.v1,
467        golden_file_patterns,
468        api_version,
469        additional_private_map={
470            'tf': ['pywrap_tensorflow'],
471            'tf.compat': ['v1', 'v2'],
472        },
473        omit_golden_symbols_map={'tensorflow': ['pywrap_tensorflow']})
474
475  def testAPIBackwardsCompatibilityV2(self):
476    api_version = 2
477    golden_file_patterns = [
478        os.path.join(resource_loader.get_root_dir_with_all_resources(),
479                     _KeyToFilePath('*', api_version)),
480        _GetTFNumpyGoldenPattern(api_version)]
481    omit_golden_symbols_map = {}
482    if FLAGS.only_test_core_api and not _TENSORBOARD_AVAILABLE:
483      # In TF 2.0 these summary symbols are imported from TensorBoard.
484      omit_golden_symbols_map['tensorflow.summary'] = [
485          'audio', 'histogram', 'image', 'scalar', 'text'
486      ]
487    self._checkBackwardsCompatibility(
488        tf.compat.v2,
489        golden_file_patterns,
490        api_version,
491        additional_private_map={'tf.compat': ['v1', 'v2']},
492        omit_golden_symbols_map=omit_golden_symbols_map)
493
494
495if __name__ == '__main__':
496  parser = argparse.ArgumentParser()
497  parser.add_argument(
498      '--update_goldens', type=bool, default=False, help=_UPDATE_GOLDENS_HELP)
499  # TODO(mikecase): Create Estimator's own API compatibility test or
500  # a more general API compatibility test for use for TF components.
501  parser.add_argument(
502      '--only_test_core_api',
503      type=bool,
504      default=True,  # only_test_core_api default value
505      help=_ONLY_TEST_CORE_API_HELP)
506  parser.add_argument(
507      '--verbose_diffs', type=bool, default=True, help=_VERBOSE_DIFFS_HELP)
508  FLAGS, unparsed = parser.parse_known_args()
509  _InitPathConstants()
510
511  # Now update argv, so that unittest library does not get confused.
512  sys.argv = [sys.argv[0]] + unparsed
513  test.main()
514