• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utility functions for comparing proto2 messages in Python.
17
18ProtoEq() compares two proto2 messages for equality.
19
20ClearDefaultValuedFields() recursively clears the fields that are set to their
21default values. This is useful for comparing protocol buffers where the
22semantics of unset fields and default valued fields are the same.
23
24assertProtoEqual() is useful for unit tests.  It produces much more helpful
25output than assertEqual() for proto2 messages, e.g. this:
26
27  outer {
28    inner {
29-     strings: "x"
30?               ^
31+     strings: "y"
32?               ^
33    }
34  }
35
36...compared to the default output from assertEqual() that looks like this:
37
38AssertionError: <my.Msg object at 0x9fb353c> != <my.Msg object at 0x9fb35cc>
39
40Call it inside your unit test's googletest.TestCase subclasses like this:
41
42  from tensorflow.python.util.protobuf import compare
43
44  class MyTest(googletest.TestCase):
45    ...
46    def testXXX(self):
47      ...
48      compare.assertProtoEqual(self, a, b)
49
50Alternatively:
51
52  from tensorflow.python.util.protobuf import compare
53
54  class MyTest(compare.ProtoAssertions, googletest.TestCase):
55    ...
56    def testXXX(self):
57      ...
58      self.assertProtoEqual(a, b)
59"""
60
61from __future__ import absolute_import
62from __future__ import division
63from __future__ import print_function
64
65import collections
66import difflib
67
68import six
69
70from google.protobuf import descriptor
71from google.protobuf import descriptor_pool
72from google.protobuf import message
73from google.protobuf import text_format
74
75
76def assertProtoEqual(self, a, b, check_initialized=True,  # pylint: disable=invalid-name
77                     normalize_numbers=False, msg=None):
78  """Fails with a useful error if a and b aren't equal.
79
80  Comparison of repeated fields matches the semantics of
81  unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter.
82
83  Args:
84    self: googletest.TestCase
85    a: proto2 PB instance, or text string representing one.
86    b: proto2 PB instance -- message.Message or subclass thereof.
87    check_initialized: boolean, whether to fail if either a or b isn't
88      initialized.
89    normalize_numbers: boolean, whether to normalize types and precision of
90      numbers before comparison.
91    msg: if specified, is used as the error message on failure.
92  """
93  pool = descriptor_pool.Default()
94  if isinstance(a, six.string_types):
95    a = text_format.Merge(a, b.__class__(), descriptor_pool=pool)
96
97  for pb in a, b:
98    if check_initialized:
99      errors = pb.FindInitializationErrors()
100      if errors:
101        self.fail('Initialization errors: %s\n%s' % (errors, pb))
102    if normalize_numbers:
103      NormalizeNumberFields(pb)
104
105  a_str = text_format.MessageToString(a, descriptor_pool=pool)
106  b_str = text_format.MessageToString(b, descriptor_pool=pool)
107
108  # Some Python versions would perform regular diff instead of multi-line
109  # diff if string is longer than 2**16. We substitute this behavior
110  # with a call to unified_diff instead to have easier-to-read diffs.
111  # For context, see: https://bugs.python.org/issue11763.
112  if len(a_str) < 2**16 and len(b_str) < 2**16:
113    self.assertMultiLineEqual(a_str, b_str, msg=msg)
114  else:
115    diff = '\n' + ''.join(difflib.unified_diff(a_str.splitlines(True),
116                                               b_str.splitlines(True)))
117    self.fail('%s : %s' % (msg, diff))
118
119
120def NormalizeNumberFields(pb):
121  """Normalizes types and precisions of number fields in a protocol buffer.
122
123  Due to subtleties in the python protocol buffer implementation, it is possible
124  for values to have different types and precision depending on whether they
125  were set and retrieved directly or deserialized from a protobuf. This function
126  normalizes integer values to ints and longs based on width, 32-bit floats to
127  five digits of precision to account for python always storing them as 64-bit,
128  and ensures doubles are floating point for when they're set to integers.
129
130  Modifies pb in place. Recurses into nested objects.
131
132  Args:
133    pb: proto2 message.
134
135  Returns:
136    the given pb, modified in place.
137  """
138  for desc, values in pb.ListFields():
139    is_repeated = True
140    if desc.label is not descriptor.FieldDescriptor.LABEL_REPEATED:
141      is_repeated = False
142      values = [values]
143
144    normalized_values = None
145
146    # We force 32-bit values to int and 64-bit values to long to make
147    # alternate implementations where the distinction is more significant
148    # (e.g. the C++ implementation) simpler.
149    if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
150                     descriptor.FieldDescriptor.TYPE_UINT64,
151                     descriptor.FieldDescriptor.TYPE_SINT64):
152      normalized_values = [int(x) for x in values]
153    elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
154                       descriptor.FieldDescriptor.TYPE_UINT32,
155                       descriptor.FieldDescriptor.TYPE_SINT32,
156                       descriptor.FieldDescriptor.TYPE_ENUM):
157      normalized_values = [int(x) for x in values]
158    elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
159      normalized_values = [round(x, 6) for x in values]
160    elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
161      normalized_values = [round(float(x), 7) for x in values]
162
163    if normalized_values is not None:
164      if is_repeated:
165        pb.ClearField(desc.name)
166        getattr(pb, desc.name).extend(normalized_values)
167      else:
168        setattr(pb, desc.name, normalized_values[0])
169
170    if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
171        desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
172      if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
173          desc.message_type.has_options and
174          desc.message_type.GetOptions().map_entry):
175        # This is a map, only recurse if the values have a message type.
176        if (desc.message_type.fields_by_number[2].type ==
177            descriptor.FieldDescriptor.TYPE_MESSAGE):
178          for v in six.itervalues(values):
179            NormalizeNumberFields(v)
180      else:
181        for v in values:
182          # recursive step
183          NormalizeNumberFields(v)
184
185  return pb
186
187
188def _IsMap(value):
189  return isinstance(value, collections.Mapping)
190
191
192def _IsRepeatedContainer(value):
193  if isinstance(value, six.string_types):
194    return False
195  try:
196    iter(value)
197    return True
198  except TypeError:
199    return False
200
201
202def ProtoEq(a, b):
203  """Compares two proto2 objects for equality.
204
205  Recurses into nested messages. Uses list (not set) semantics for comparing
206  repeated fields, ie duplicates and order matter.
207
208  Args:
209    a: A proto2 message or a primitive.
210    b: A proto2 message or a primitive.
211
212  Returns:
213    `True` if the messages are equal.
214  """
215  def Format(pb):
216    """Returns a dictionary or unchanged pb bases on its type.
217
218    Specifically, this function returns a dictionary that maps tag
219    number (for messages) or element index (for repeated fields) to
220    value, or just pb unchanged if it's neither.
221
222    Args:
223      pb: A proto2 message or a primitive.
224    Returns:
225      A dict or unchanged pb.
226    """
227    if isinstance(pb, message.Message):
228      return dict((desc.number, value) for desc, value in pb.ListFields())
229    elif _IsMap(pb):
230      return dict(pb.items())
231    elif _IsRepeatedContainer(pb):
232      return dict(enumerate(list(pb)))
233    else:
234      return pb
235
236  a, b = Format(a), Format(b)
237
238  # Base case
239  if not isinstance(a, dict) or not isinstance(b, dict):
240    return a == b
241
242  # This list performs double duty: it compares two messages by tag value *or*
243  # two repeated fields by element, in order. the magic is in the format()
244  # function, which converts them both to the same easily comparable format.
245  for tag in sorted(set(a.keys()) | set(b.keys())):
246    if tag not in a or tag not in b:
247      return False
248    else:
249      # Recursive step
250      if not ProtoEq(a[tag], b[tag]):
251        return False
252
253  # Didn't find any values that differed, so they're equal!
254  return True
255
256
257class ProtoAssertions(object):
258  """Mix this into a googletest.TestCase class to get proto2 assertions.
259
260  Usage:
261
262  class SomeTestCase(compare.ProtoAssertions, googletest.TestCase):
263    ...
264    def testSomething(self):
265      ...
266      self.assertProtoEqual(a, b)
267
268  See module-level definitions for method documentation.
269  """
270
271  # pylint: disable=invalid-name
272  def assertProtoEqual(self, *args, **kwargs):
273    return assertProtoEqual(self, *args, **kwargs)
274