1# ============================================================================= 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================= 16"""Tests for decode_proto op.""" 17 18# Python3 preparedness imports. 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import itertools 24 25from absl.testing import parameterized 26import numpy as np 27 28 29from google.protobuf import text_format 30 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import errors 33from tensorflow.python.kernel_tests.proto import proto_op_test_base as test_base 34from tensorflow.python.kernel_tests.proto import test_example_pb2 35 36 37class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): 38 """Base class for testing proto decoding ops.""" 39 40 def __init__(self, decode_module, methodName='runTest'): # pylint: disable=invalid-name 41 """DecodeProtoOpTestBase initializer. 42 43 Args: 44 decode_module: a module containing the `decode_proto_op` method 45 methodName: the name of the test method (same as for test.TestCase) 46 """ 47 48 super(DecodeProtoOpTestBase, self).__init__(methodName) 49 self._decode_module = decode_module 50 51 def _compareValues(self, fd, vs, evs): 52 """Compare lists/arrays of field values.""" 53 54 if len(vs) != len(evs): 55 self.fail('Field %s decoded %d outputs, expected %d' % 56 (fd.name, len(vs), len(evs))) 57 for i, ev in enumerate(evs): 58 # Special case fuzzy match for float32. TensorFlow seems to mess with 59 # MAX_FLT slightly and the test doesn't work otherwise. 60 # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through. 61 if fd.cpp_type == fd.CPPTYPE_FLOAT: 62 # Numpy isclose() is better than assertIsClose() which uses an absolute 63 # value comparison. 64 self.assertTrue( 65 np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i])) 66 elif fd.cpp_type == fd.CPPTYPE_STRING: 67 # In Python3 string tensor values will be represented as bytes, so we 68 # reencode the proto values to match that. 69 self.assertEqual(vs[i], ev.encode('ascii')) 70 else: 71 # Doubles and other types pass through unscathed. 72 self.assertEqual(vs[i], ev) 73 74 def _compareProtos(self, batch_shape, sizes, fields, field_dict): 75 """Compare protos of type TestValue. 76 77 Args: 78 batch_shape: the shape of the input tensor of serialized messages. 79 sizes: int matrix of repeat counts returned by decode_proto 80 fields: list of test_example_pb2.FieldSpec (types and expected values) 81 field_dict: map from field names to decoded numpy tensors of values 82 """ 83 84 # Check that expected values match. 85 for field in fields: 86 values = field_dict[field.name] 87 self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype) 88 89 if 'ext_value' in field.name: 90 fd = test_example_pb2.PrimitiveValue() 91 else: 92 fd = field.value.DESCRIPTOR.fields_by_name[field.name] 93 94 # Values has the same shape as the input plus an extra 95 # dimension for repeats. 96 self.assertEqual(list(values.shape)[:-1], batch_shape) 97 98 # Nested messages are represented as TF strings, requiring 99 # some special handling. 100 if field.name == 'message_value' or 'ext_value' in field.name: 101 vs = [] 102 for buf in values.flat: 103 msg = test_example_pb2.PrimitiveValue() 104 msg.ParseFromString(buf) 105 vs.append(msg) 106 if 'ext_value' in field.name: 107 evs = field.value.Extensions[test_example_pb2.ext_value] 108 else: 109 evs = getattr(field.value, field.name) 110 if len(vs) != len(evs): 111 self.fail('Field %s decoded %d outputs, expected %d' % 112 (fd.name, len(vs), len(evs))) 113 for v, ev in zip(vs, evs): 114 self.assertEqual(v, ev) 115 continue 116 117 tf_type_to_primitive_value_field = { 118 dtypes.bool: 119 'bool_value', 120 dtypes.float32: 121 'float_value', 122 dtypes.float64: 123 'double_value', 124 dtypes.int8: 125 'int8_value', 126 dtypes.int32: 127 'int32_value', 128 dtypes.int64: 129 'int64_value', 130 dtypes.string: 131 'string_value', 132 dtypes.uint8: 133 'uint8_value', 134 dtypes.uint32: 135 'uint32_value', 136 dtypes.uint64: 137 'uint64_value', 138 } 139 if field.name in ['enum_value', 'enum_value_with_default']: 140 tf_field_name = 'enum_value' 141 else: 142 tf_field_name = tf_type_to_primitive_value_field.get(field.dtype) 143 if tf_field_name is None: 144 self.fail('Unhandled tensorflow type %d' % field.dtype) 145 146 self._compareValues(fd, values.flat, 147 getattr(field.value, tf_field_name)) 148 149 def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch, 150 message_type, message_format, sanitize, 151 force_disordered=False): 152 """Run decode tests on a batch of messages. 153 154 Args: 155 fields: list of test_example_pb2.FieldSpec (types and expected values) 156 case_sizes: expected sizes array 157 batch_shape: the shape of the input tensor of serialized messages 158 batch: list of serialized messages 159 message_type: descriptor name for messages 160 message_format: format of messages, 'text' or 'binary' 161 sanitize: whether to sanitize binary protobuf inputs 162 force_disordered: whether to force fields encoded out of order. 163 """ 164 165 if force_disordered: 166 # Exercise code path that handles out-of-order fields by prepending extra 167 # fields with tag numbers higher than any real field. Note that this won't 168 # work with sanitization because that forces reserialization using a 169 # trusted decoder and encoder. 170 assert not sanitize 171 extra_fields = test_example_pb2.ExtraFields() 172 extra_fields.string_value = 'IGNORE ME' 173 extra_fields.bool_value = False 174 extra_msg = extra_fields.SerializeToString() 175 batch = [extra_msg + msg for msg in batch] 176 177 # Numpy silently truncates the strings if you don't specify dtype=object. 178 batch = np.array(batch, dtype=object) 179 batch = np.reshape(batch, batch_shape) 180 181 field_names = [f.name for f in fields] 182 output_types = [f.dtype for f in fields] 183 184 with self.cached_session() as sess: 185 sizes, vtensor = self._decode_module.decode_proto( 186 batch, 187 message_type=message_type, 188 field_names=field_names, 189 output_types=output_types, 190 message_format=message_format, 191 sanitize=sanitize) 192 193 vlist = sess.run([sizes] + vtensor) 194 sizes = vlist[0] 195 # Values is a list of tensors, one for each field. 196 value_tensors = vlist[1:] 197 198 # Check that the repeat sizes are correct. 199 self.assertTrue( 200 np.all(np.array(sizes.shape) == batch_shape + [len(field_names)])) 201 202 # Check that the decoded sizes match the expected sizes. 203 self.assertEqual(len(sizes.flat), len(case_sizes)) 204 self.assertTrue( 205 np.all(sizes.flat == np.array( 206 case_sizes, dtype=np.int32))) 207 208 field_dict = dict(zip(field_names, value_tensors)) 209 210 self._compareProtos(batch_shape, sizes, fields, field_dict) 211 212 @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) 213 def testBinary(self, case): 214 batch = [value.SerializeToString() for value in case.values] 215 self._runDecodeProtoTests( 216 case.fields, 217 case.sizes, 218 list(case.shapes), 219 batch, 220 'tensorflow.contrib.proto.TestValue', 221 'binary', 222 sanitize=False) 223 224 @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) 225 def testBinaryDisordered(self, case): 226 batch = [value.SerializeToString() for value in case.values] 227 self._runDecodeProtoTests( 228 case.fields, 229 case.sizes, 230 list(case.shapes), 231 batch, 232 'tensorflow.contrib.proto.TestValue', 233 'binary', 234 sanitize=False, 235 force_disordered=True) 236 237 @parameterized.named_parameters( 238 *test_base.ProtoOpTestBase.named_parameters(extension=False)) 239 def testPacked(self, case): 240 # Now try with the packed serialization. 241 # 242 # We test the packed representations by loading the same test case using 243 # PackedTestValue instead of TestValue. To do this we rely on the text 244 # format being the same for packed and unpacked fields, and reparse the 245 # test message using the packed version of the proto. 246 packed_batch = [ 247 # Note: float_format='.17g' is necessary to ensure preservation of 248 # doubles and floats in text format. 249 text_format.Parse( 250 text_format.MessageToString(value, float_format='.17g'), 251 test_example_pb2.PackedTestValue()).SerializeToString() 252 for value in case.values 253 ] 254 255 self._runDecodeProtoTests( 256 case.fields, 257 case.sizes, 258 list(case.shapes), 259 packed_batch, 260 'tensorflow.contrib.proto.PackedTestValue', 261 'binary', 262 sanitize=False) 263 264 @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) 265 def testText(self, case): 266 # Note: float_format='.17g' is necessary to ensure preservation of 267 # doubles and floats in text format. 268 text_batch = [ 269 text_format.MessageToString( 270 value, float_format='.17g') for value in case.values 271 ] 272 273 self._runDecodeProtoTests( 274 case.fields, 275 case.sizes, 276 list(case.shapes), 277 text_batch, 278 'tensorflow.contrib.proto.TestValue', 279 'text', 280 sanitize=False) 281 282 @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) 283 def testSanitizerGood(self, case): 284 batch = [value.SerializeToString() for value in case.values] 285 self._runDecodeProtoTests( 286 case.fields, 287 case.sizes, 288 list(case.shapes), 289 batch, 290 'tensorflow.contrib.proto.TestValue', 291 'binary', 292 sanitize=True) 293 294 @parameterized.parameters((False), (True)) 295 def testCorruptProtobuf(self, sanitize): 296 corrupt_proto = 'This is not a binary protobuf' 297 298 # Numpy silently truncates the strings if you don't specify dtype=object. 299 batch = np.array(corrupt_proto, dtype=object) 300 msg_type = 'tensorflow.contrib.proto.TestCase' 301 field_names = ['sizes'] 302 field_types = [dtypes.int32] 303 304 with self.assertRaisesRegexp( 305 errors.DataLossError, 'Unable to parse binary protobuf' 306 '|Failed to consume entire buffer'): 307 self.evaluate( 308 self._decode_module.decode_proto( 309 batch, 310 message_type=msg_type, 311 field_names=field_names, 312 output_types=field_types, 313 sanitize=sanitize)) 314 315 def testOutOfOrderRepeated(self): 316 fragments = [ 317 test_example_pb2.TestValue(double_value=[1.0]).SerializeToString(), 318 test_example_pb2.TestValue( 319 message_value=[test_example_pb2.PrimitiveValue( 320 string_value='abc')]).SerializeToString(), 321 test_example_pb2.TestValue( 322 message_value=[test_example_pb2.PrimitiveValue( 323 string_value='def')]).SerializeToString() 324 ] 325 all_fields_to_parse = ['double_value', 'message_value'] 326 field_types = { 327 'double_value': dtypes.double, 328 'message_value': dtypes.string, 329 } 330 # Test against all 3! permutations of fragments, and for each permutation 331 # test parsing all possible combination of 2 fields. 332 for indices in itertools.permutations(range(len(fragments))): 333 proto = b''.join(fragments[i] for i in indices) 334 for i in indices: 335 if i == 1: 336 expected_message_values = [ 337 test_example_pb2.PrimitiveValue( 338 string_value='abc').SerializeToString(), 339 test_example_pb2.PrimitiveValue( 340 string_value='def').SerializeToString(), 341 ] 342 break 343 if i == 2: 344 expected_message_values = [ 345 test_example_pb2.PrimitiveValue( 346 string_value='def').SerializeToString(), 347 test_example_pb2.PrimitiveValue( 348 string_value='abc').SerializeToString(), 349 ] 350 break 351 352 expected_field_values = { 353 'double_value': [[1.0]], 354 'message_value': [expected_message_values], 355 } 356 357 for num_fields_to_parse in range(len(all_fields_to_parse)): 358 for comb in itertools.combinations( 359 all_fields_to_parse, num_fields_to_parse): 360 parsed_values = self.evaluate( 361 self._decode_module.decode_proto( 362 [proto], 363 message_type='tensorflow.contrib.proto.TestValue', 364 field_names=comb, 365 output_types=[field_types[f] for f in comb], 366 sanitize=False)).values 367 self.assertLen(parsed_values, len(comb)) 368 for field_name, parsed in zip(comb, parsed_values): 369 self.assertAllEqual(parsed, expected_field_values[field_name], 370 'perm: {}, comb: {}'.format(indices, comb)) 371