• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# =============================================================================
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# =============================================================================
16"""Tests for decode_proto op."""
17
18# Python3 preparedness imports.
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import itertools
24
25from absl.testing import parameterized
26import numpy as np
27
28
29from google.protobuf import text_format
30
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import errors
33from tensorflow.python.kernel_tests.proto import proto_op_test_base as test_base
34from tensorflow.python.kernel_tests.proto import test_example_pb2
35
36
37class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
38  """Base class for testing proto decoding ops."""
39
40  def __init__(self, decode_module, methodName='runTest'):  # pylint: disable=invalid-name
41    """DecodeProtoOpTestBase initializer.
42
43    Args:
44      decode_module: a module containing the `decode_proto_op` method
45      methodName: the name of the test method (same as for test.TestCase)
46    """
47
48    super(DecodeProtoOpTestBase, self).__init__(methodName)
49    self._decode_module = decode_module
50
51  def _compareValues(self, fd, vs, evs):
52    """Compare lists/arrays of field values."""
53
54    if len(vs) != len(evs):
55      self.fail('Field %s decoded %d outputs, expected %d' %
56                (fd.name, len(vs), len(evs)))
57    for i, ev in enumerate(evs):
58      # Special case fuzzy match for float32. TensorFlow seems to mess with
59      # MAX_FLT slightly and the test doesn't work otherwise.
60      # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through.
61      if fd.cpp_type == fd.CPPTYPE_FLOAT:
62        # Numpy isclose() is better than assertIsClose() which uses an absolute
63        # value comparison.
64        self.assertTrue(
65            np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i]))
66      elif fd.cpp_type == fd.CPPTYPE_STRING:
67        # In Python3 string tensor values will be represented as bytes, so we
68        # reencode the proto values to match that.
69        self.assertEqual(vs[i], ev.encode('ascii'))
70      else:
71        # Doubles and other types pass through unscathed.
72        self.assertEqual(vs[i], ev)
73
74  def _compareProtos(self, batch_shape, sizes, fields, field_dict):
75    """Compare protos of type TestValue.
76
77    Args:
78      batch_shape: the shape of the input tensor of serialized messages.
79      sizes: int matrix of repeat counts returned by decode_proto
80      fields: list of test_example_pb2.FieldSpec (types and expected values)
81      field_dict: map from field names to decoded numpy tensors of values
82    """
83
84    # Check that expected values match.
85    for field in fields:
86      values = field_dict[field.name]
87      self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype)
88
89      if 'ext_value' in field.name:
90        fd = test_example_pb2.PrimitiveValue()
91      else:
92        fd = field.value.DESCRIPTOR.fields_by_name[field.name]
93
94      # Values has the same shape as the input plus an extra
95      # dimension for repeats.
96      self.assertEqual(list(values.shape)[:-1], batch_shape)
97
98      # Nested messages are represented as TF strings, requiring
99      # some special handling.
100      if field.name == 'message_value' or 'ext_value' in field.name:
101        vs = []
102        for buf in values.flat:
103          msg = test_example_pb2.PrimitiveValue()
104          msg.ParseFromString(buf)
105          vs.append(msg)
106        if 'ext_value' in field.name:
107          evs = field.value.Extensions[test_example_pb2.ext_value]
108        else:
109          evs = getattr(field.value, field.name)
110        if len(vs) != len(evs):
111          self.fail('Field %s decoded %d outputs, expected %d' %
112                    (fd.name, len(vs), len(evs)))
113        for v, ev in zip(vs, evs):
114          self.assertEqual(v, ev)
115        continue
116
117      tf_type_to_primitive_value_field = {
118          dtypes.bool:
119              'bool_value',
120          dtypes.float32:
121              'float_value',
122          dtypes.float64:
123              'double_value',
124          dtypes.int8:
125              'int8_value',
126          dtypes.int32:
127              'int32_value',
128          dtypes.int64:
129              'int64_value',
130          dtypes.string:
131              'string_value',
132          dtypes.uint8:
133              'uint8_value',
134          dtypes.uint32:
135              'uint32_value',
136          dtypes.uint64:
137              'uint64_value',
138      }
139      if field.name in ['enum_value', 'enum_value_with_default']:
140        tf_field_name = 'enum_value'
141      else:
142        tf_field_name = tf_type_to_primitive_value_field.get(field.dtype)
143      if tf_field_name is None:
144        self.fail('Unhandled tensorflow type %d' % field.dtype)
145
146      self._compareValues(fd, values.flat,
147                          getattr(field.value, tf_field_name))
148
149  def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch,
150                           message_type, message_format, sanitize,
151                           force_disordered=False):
152    """Run decode tests on a batch of messages.
153
154    Args:
155      fields: list of test_example_pb2.FieldSpec (types and expected values)
156      case_sizes: expected sizes array
157      batch_shape: the shape of the input tensor of serialized messages
158      batch: list of serialized messages
159      message_type: descriptor name for messages
160      message_format: format of messages, 'text' or 'binary'
161      sanitize: whether to sanitize binary protobuf inputs
162      force_disordered: whether to force fields encoded out of order.
163    """
164
165    if force_disordered:
166      # Exercise code path that handles out-of-order fields by prepending extra
167      # fields with tag numbers higher than any real field. Note that this won't
168      # work with sanitization because that forces reserialization using a
169      # trusted decoder and encoder.
170      assert not sanitize
171      extra_fields = test_example_pb2.ExtraFields()
172      extra_fields.string_value = 'IGNORE ME'
173      extra_fields.bool_value = False
174      extra_msg = extra_fields.SerializeToString()
175      batch = [extra_msg + msg for msg in batch]
176
177    # Numpy silently truncates the strings if you don't specify dtype=object.
178    batch = np.array(batch, dtype=object)
179    batch = np.reshape(batch, batch_shape)
180
181    field_names = [f.name for f in fields]
182    output_types = [f.dtype for f in fields]
183
184    with self.cached_session() as sess:
185      sizes, vtensor = self._decode_module.decode_proto(
186          batch,
187          message_type=message_type,
188          field_names=field_names,
189          output_types=output_types,
190          message_format=message_format,
191          sanitize=sanitize)
192
193      vlist = sess.run([sizes] + vtensor)
194      sizes = vlist[0]
195      # Values is a list of tensors, one for each field.
196      value_tensors = vlist[1:]
197
198      # Check that the repeat sizes are correct.
199      self.assertTrue(
200          np.all(np.array(sizes.shape) == batch_shape + [len(field_names)]))
201
202      # Check that the decoded sizes match the expected sizes.
203      self.assertEqual(len(sizes.flat), len(case_sizes))
204      self.assertTrue(
205          np.all(sizes.flat == np.array(
206              case_sizes, dtype=np.int32)))
207
208      field_dict = dict(zip(field_names, value_tensors))
209
210      self._compareProtos(batch_shape, sizes, fields, field_dict)
211
212  @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
213  def testBinary(self, case):
214    batch = [value.SerializeToString() for value in case.values]
215    self._runDecodeProtoTests(
216        case.fields,
217        case.sizes,
218        list(case.shapes),
219        batch,
220        'tensorflow.contrib.proto.TestValue',
221        'binary',
222        sanitize=False)
223
224  @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
225  def testBinaryDisordered(self, case):
226    batch = [value.SerializeToString() for value in case.values]
227    self._runDecodeProtoTests(
228        case.fields,
229        case.sizes,
230        list(case.shapes),
231        batch,
232        'tensorflow.contrib.proto.TestValue',
233        'binary',
234        sanitize=False,
235        force_disordered=True)
236
237  @parameterized.named_parameters(
238      *test_base.ProtoOpTestBase.named_parameters(extension=False))
239  def testPacked(self, case):
240    # Now try with the packed serialization.
241    #
242    # We test the packed representations by loading the same test case using
243    # PackedTestValue instead of TestValue. To do this we rely on the text
244    # format being the same for packed and unpacked fields, and reparse the
245    # test message using the packed version of the proto.
246    packed_batch = [
247        # Note: float_format='.17g' is necessary to ensure preservation of
248        # doubles and floats in text format.
249        text_format.Parse(
250            text_format.MessageToString(value, float_format='.17g'),
251            test_example_pb2.PackedTestValue()).SerializeToString()
252        for value in case.values
253    ]
254
255    self._runDecodeProtoTests(
256        case.fields,
257        case.sizes,
258        list(case.shapes),
259        packed_batch,
260        'tensorflow.contrib.proto.PackedTestValue',
261        'binary',
262        sanitize=False)
263
264  @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
265  def testText(self, case):
266    # Note: float_format='.17g' is necessary to ensure preservation of
267    # doubles and floats in text format.
268    text_batch = [
269        text_format.MessageToString(
270            value, float_format='.17g') for value in case.values
271    ]
272
273    self._runDecodeProtoTests(
274        case.fields,
275        case.sizes,
276        list(case.shapes),
277        text_batch,
278        'tensorflow.contrib.proto.TestValue',
279        'text',
280        sanitize=False)
281
282  @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
283  def testSanitizerGood(self, case):
284    batch = [value.SerializeToString() for value in case.values]
285    self._runDecodeProtoTests(
286        case.fields,
287        case.sizes,
288        list(case.shapes),
289        batch,
290        'tensorflow.contrib.proto.TestValue',
291        'binary',
292        sanitize=True)
293
294  @parameterized.parameters((False), (True))
295  def testCorruptProtobuf(self, sanitize):
296    corrupt_proto = 'This is not a binary protobuf'
297
298    # Numpy silently truncates the strings if you don't specify dtype=object.
299    batch = np.array(corrupt_proto, dtype=object)
300    msg_type = 'tensorflow.contrib.proto.TestCase'
301    field_names = ['sizes']
302    field_types = [dtypes.int32]
303
304    with self.assertRaisesRegexp(
305        errors.DataLossError, 'Unable to parse binary protobuf'
306        '|Failed to consume entire buffer'):
307      self.evaluate(
308          self._decode_module.decode_proto(
309              batch,
310              message_type=msg_type,
311              field_names=field_names,
312              output_types=field_types,
313              sanitize=sanitize))
314
315  def testOutOfOrderRepeated(self):
316    fragments = [
317        test_example_pb2.TestValue(double_value=[1.0]).SerializeToString(),
318        test_example_pb2.TestValue(
319            message_value=[test_example_pb2.PrimitiveValue(
320                string_value='abc')]).SerializeToString(),
321        test_example_pb2.TestValue(
322            message_value=[test_example_pb2.PrimitiveValue(
323                string_value='def')]).SerializeToString()
324    ]
325    all_fields_to_parse = ['double_value', 'message_value']
326    field_types = {
327        'double_value': dtypes.double,
328        'message_value': dtypes.string,
329    }
330    # Test against all 3! permutations of fragments, and for each permutation
331    # test parsing all possible combination of 2 fields.
332    for indices in itertools.permutations(range(len(fragments))):
333      proto = b''.join(fragments[i] for i in indices)
334      for i in indices:
335        if i == 1:
336          expected_message_values = [
337              test_example_pb2.PrimitiveValue(
338                  string_value='abc').SerializeToString(),
339              test_example_pb2.PrimitiveValue(
340                  string_value='def').SerializeToString(),
341          ]
342          break
343        if i == 2:
344          expected_message_values = [
345              test_example_pb2.PrimitiveValue(
346                  string_value='def').SerializeToString(),
347              test_example_pb2.PrimitiveValue(
348                  string_value='abc').SerializeToString(),
349          ]
350          break
351
352      expected_field_values = {
353          'double_value': [[1.0]],
354          'message_value': [expected_message_values],
355      }
356
357      for num_fields_to_parse in range(len(all_fields_to_parse)):
358        for comb in itertools.combinations(
359            all_fields_to_parse, num_fields_to_parse):
360          parsed_values = self.evaluate(
361              self._decode_module.decode_proto(
362                  [proto],
363                  message_type='tensorflow.contrib.proto.TestValue',
364                  field_names=comb,
365                  output_types=[field_types[f] for f in comb],
366                  sanitize=False)).values
367          self.assertLen(parsed_values, len(comb))
368          for field_name, parsed in zip(comb, parsed_values):
369            self.assertAllEqual(parsed, expected_field_values[field_name],
370                                'perm: {}, comb: {}'.format(indices, comb))
371