1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utility functions for comparing proto2 messages in Python. 17 18ProtoEq() compares two proto2 messages for equality. 19 20ClearDefaultValuedFields() recursively clears the fields that are set to their 21default values. This is useful for comparing protocol buffers where the 22semantics of unset fields and default valued fields are the same. 23 24assertProtoEqual() is useful for unit tests. It produces much more helpful 25output than assertEqual() for proto2 messages, e.g. this: 26 27 outer { 28 inner { 29- strings: "x" 30? ^ 31+ strings: "y" 32? ^ 33 } 34 } 35 36...compared to the default output from assertEqual() that looks like this: 37 38AssertionError: <my.Msg object at 0x9fb353c> != <my.Msg object at 0x9fb35cc> 39 40Call it inside your unit test's googletest.TestCase subclasses like this: 41 42 from tensorflow.python.util.protobuf import compare 43 44 class MyTest(googletest.TestCase): 45 ... 46 def testXXX(self): 47 ... 48 compare.assertProtoEqual(self, a, b) 49 50Alternatively: 51 52 from tensorflow.python.util.protobuf import compare 53 54 class MyTest(compare.ProtoAssertions, googletest.TestCase): 55 ... 56 def testXXX(self): 57 ... 58 self.assertProtoEqual(a, b) 59""" 60 61from __future__ import absolute_import 62from __future__ import division 63from __future__ import print_function 64 65import collections 66import difflib 67 68import six 69 70from google.protobuf import descriptor 71from google.protobuf import descriptor_pool 72from google.protobuf import message 73from google.protobuf import text_format 74 75 76def assertProtoEqual(self, a, b, check_initialized=True, # pylint: disable=invalid-name 77 normalize_numbers=False, msg=None): 78 """Fails with a useful error if a and b aren't equal. 79 80 Comparison of repeated fields matches the semantics of 81 unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter. 82 83 Args: 84 self: googletest.TestCase 85 a: proto2 PB instance, or text string representing one. 86 b: proto2 PB instance -- message.Message or subclass thereof. 87 check_initialized: boolean, whether to fail if either a or b isn't 88 initialized. 89 normalize_numbers: boolean, whether to normalize types and precision of 90 numbers before comparison. 91 msg: if specified, is used as the error message on failure. 92 """ 93 pool = descriptor_pool.Default() 94 if isinstance(a, six.string_types): 95 a = text_format.Merge(a, b.__class__(), descriptor_pool=pool) 96 97 for pb in a, b: 98 if check_initialized: 99 errors = pb.FindInitializationErrors() 100 if errors: 101 self.fail('Initialization errors: %s\n%s' % (errors, pb)) 102 if normalize_numbers: 103 NormalizeNumberFields(pb) 104 105 a_str = text_format.MessageToString(a, descriptor_pool=pool) 106 b_str = text_format.MessageToString(b, descriptor_pool=pool) 107 108 # Some Python versions would perform regular diff instead of multi-line 109 # diff if string is longer than 2**16. We substitute this behavior 110 # with a call to unified_diff instead to have easier-to-read diffs. 111 # For context, see: https://bugs.python.org/issue11763. 112 if len(a_str) < 2**16 and len(b_str) < 2**16: 113 self.assertMultiLineEqual(a_str, b_str, msg=msg) 114 else: 115 diff = '\n' + ''.join(difflib.unified_diff(a_str.splitlines(True), 116 b_str.splitlines(True))) 117 self.fail('%s : %s' % (msg, diff)) 118 119 120def NormalizeNumberFields(pb): 121 """Normalizes types and precisions of number fields in a protocol buffer. 122 123 Due to subtleties in the python protocol buffer implementation, it is possible 124 for values to have different types and precision depending on whether they 125 were set and retrieved directly or deserialized from a protobuf. This function 126 normalizes integer values to ints and longs based on width, 32-bit floats to 127 five digits of precision to account for python always storing them as 64-bit, 128 and ensures doubles are floating point for when they're set to integers. 129 130 Modifies pb in place. Recurses into nested objects. 131 132 Args: 133 pb: proto2 message. 134 135 Returns: 136 the given pb, modified in place. 137 """ 138 for desc, values in pb.ListFields(): 139 is_repeated = True 140 if desc.label is not descriptor.FieldDescriptor.LABEL_REPEATED: 141 is_repeated = False 142 values = [values] 143 144 normalized_values = None 145 146 # We force 32-bit values to int and 64-bit values to long to make 147 # alternate implementations where the distinction is more significant 148 # (e.g. the C++ implementation) simpler. 149 if desc.type in (descriptor.FieldDescriptor.TYPE_INT64, 150 descriptor.FieldDescriptor.TYPE_UINT64, 151 descriptor.FieldDescriptor.TYPE_SINT64): 152 normalized_values = [int(x) for x in values] 153 elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32, 154 descriptor.FieldDescriptor.TYPE_UINT32, 155 descriptor.FieldDescriptor.TYPE_SINT32, 156 descriptor.FieldDescriptor.TYPE_ENUM): 157 normalized_values = [int(x) for x in values] 158 elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT: 159 normalized_values = [round(x, 6) for x in values] 160 elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE: 161 normalized_values = [round(float(x), 7) for x in values] 162 163 if normalized_values is not None: 164 if is_repeated: 165 pb.ClearField(desc.name) 166 getattr(pb, desc.name).extend(normalized_values) 167 else: 168 setattr(pb, desc.name, normalized_values[0]) 169 170 if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or 171 desc.type == descriptor.FieldDescriptor.TYPE_GROUP): 172 if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and 173 desc.message_type.has_options and 174 desc.message_type.GetOptions().map_entry): 175 # This is a map, only recurse if the values have a message type. 176 if (desc.message_type.fields_by_number[2].type == 177 descriptor.FieldDescriptor.TYPE_MESSAGE): 178 for v in six.itervalues(values): 179 NormalizeNumberFields(v) 180 else: 181 for v in values: 182 # recursive step 183 NormalizeNumberFields(v) 184 185 return pb 186 187 188def _IsMap(value): 189 return isinstance(value, collections.Mapping) 190 191 192def _IsRepeatedContainer(value): 193 if isinstance(value, six.string_types): 194 return False 195 try: 196 iter(value) 197 return True 198 except TypeError: 199 return False 200 201 202def ProtoEq(a, b): 203 """Compares two proto2 objects for equality. 204 205 Recurses into nested messages. Uses list (not set) semantics for comparing 206 repeated fields, ie duplicates and order matter. 207 208 Args: 209 a: A proto2 message or a primitive. 210 b: A proto2 message or a primitive. 211 212 Returns: 213 `True` if the messages are equal. 214 """ 215 def Format(pb): 216 """Returns a dictionary or unchanged pb bases on its type. 217 218 Specifically, this function returns a dictionary that maps tag 219 number (for messages) or element index (for repeated fields) to 220 value, or just pb unchanged if it's neither. 221 222 Args: 223 pb: A proto2 message or a primitive. 224 Returns: 225 A dict or unchanged pb. 226 """ 227 if isinstance(pb, message.Message): 228 return dict((desc.number, value) for desc, value in pb.ListFields()) 229 elif _IsMap(pb): 230 return dict(pb.items()) 231 elif _IsRepeatedContainer(pb): 232 return dict(enumerate(list(pb))) 233 else: 234 return pb 235 236 a, b = Format(a), Format(b) 237 238 # Base case 239 if not isinstance(a, dict) or not isinstance(b, dict): 240 return a == b 241 242 # This list performs double duty: it compares two messages by tag value *or* 243 # two repeated fields by element, in order. the magic is in the format() 244 # function, which converts them both to the same easily comparable format. 245 for tag in sorted(set(a.keys()) | set(b.keys())): 246 if tag not in a or tag not in b: 247 return False 248 else: 249 # Recursive step 250 if not ProtoEq(a[tag], b[tag]): 251 return False 252 253 # Didn't find any values that differed, so they're equal! 254 return True 255 256 257class ProtoAssertions(object): 258 """Mix this into a googletest.TestCase class to get proto2 assertions. 259 260 Usage: 261 262 class SomeTestCase(compare.ProtoAssertions, googletest.TestCase): 263 ... 264 def testSomething(self): 265 ... 266 self.assertProtoEqual(a, b) 267 268 See module-level definitions for method documentation. 269 """ 270 271 # pylint: disable=invalid-name 272 def assertProtoEqual(self, *args, **kwargs): 273 return assertProtoEqual(self, *args, **kwargs) 274