• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# =============================================================================
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# =============================================================================
16"""Test case base for testing proto operations."""
17
18# Python3 preparedness imports.
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import ctypes as ct
24import os
25
26from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
27from tensorflow.core.framework import types_pb2
28from tensorflow.python.platform import test
29
30
31class ProtoOpTestBase(test.TestCase):
32  """Base class for testing proto decoding and encoding ops."""
33
34  def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
35    super(ProtoOpTestBase, self).__init__(methodName)
36    lib = os.path.join(os.path.dirname(__file__), "libtestexample.so")
37    if os.path.isfile(lib):
38      ct.cdll.LoadLibrary(lib)
39
40  @staticmethod
41  def named_parameters(extension=True):
42    parameters = [("defaults", ProtoOpTestBase.defaults_test_case()),
43                  ("minmax", ProtoOpTestBase.minmax_test_case()),
44                  ("nested", ProtoOpTestBase.nested_test_case()),
45                  ("optional", ProtoOpTestBase.optional_test_case()),
46                  ("promote", ProtoOpTestBase.promote_test_case()),
47                  ("ragged", ProtoOpTestBase.ragged_test_case()),
48                  ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()),
49                  ("simple", ProtoOpTestBase.simple_test_case())]
50    if extension:
51      parameters.append(("extension", ProtoOpTestBase.extension_test_case()))
52    return parameters
53
54  @staticmethod
55  def defaults_test_case():
56    test_case = test_example_pb2.TestCase()
57    test_case.values.add()  # No fields specified, so we get all defaults.
58    test_case.shapes.append(1)
59    test_case.sizes.append(0)
60    field = test_case.fields.add()
61    field.name = "double_value_with_default"
62    field.dtype = types_pb2.DT_DOUBLE
63    field.value.double_value.append(1.0)
64    test_case.sizes.append(0)
65    field = test_case.fields.add()
66    field.name = "float_value_with_default"
67    field.dtype = types_pb2.DT_FLOAT
68    field.value.float_value.append(2.0)
69    test_case.sizes.append(0)
70    field = test_case.fields.add()
71    field.name = "int64_value_with_default"
72    field.dtype = types_pb2.DT_INT64
73    field.value.int64_value.append(3)
74    test_case.sizes.append(0)
75    field = test_case.fields.add()
76    field.name = "sfixed64_value_with_default"
77    field.dtype = types_pb2.DT_INT64
78    field.value.int64_value.append(11)
79    test_case.sizes.append(0)
80    field = test_case.fields.add()
81    field.name = "sint64_value_with_default"
82    field.dtype = types_pb2.DT_INT64
83    field.value.int64_value.append(13)
84    test_case.sizes.append(0)
85    field = test_case.fields.add()
86    field.name = "uint64_value_with_default"
87    field.dtype = types_pb2.DT_UINT64
88    field.value.uint64_value.append(4)
89    test_case.sizes.append(0)
90    field = test_case.fields.add()
91    field.name = "fixed64_value_with_default"
92    field.dtype = types_pb2.DT_UINT64
93    field.value.uint64_value.append(6)
94    test_case.sizes.append(0)
95    field = test_case.fields.add()
96    field.name = "int32_value_with_default"
97    field.dtype = types_pb2.DT_INT32
98    field.value.int32_value.append(5)
99    test_case.sizes.append(0)
100    field = test_case.fields.add()
101    field.name = "sfixed32_value_with_default"
102    field.dtype = types_pb2.DT_INT32
103    field.value.int32_value.append(10)
104    test_case.sizes.append(0)
105    field = test_case.fields.add()
106    field.name = "sint32_value_with_default"
107    field.dtype = types_pb2.DT_INT32
108    field.value.int32_value.append(12)
109    test_case.sizes.append(0)
110    field = test_case.fields.add()
111    field.name = "uint32_value_with_default"
112    field.dtype = types_pb2.DT_UINT32
113    field.value.uint32_value.append(9)
114    test_case.sizes.append(0)
115    field = test_case.fields.add()
116    field.name = "fixed32_value_with_default"
117    field.dtype = types_pb2.DT_UINT32
118    field.value.uint32_value.append(7)
119    test_case.sizes.append(0)
120    field = test_case.fields.add()
121    field.name = "bool_value_with_default"
122    field.dtype = types_pb2.DT_BOOL
123    field.value.bool_value.append(True)
124    test_case.sizes.append(0)
125    field = test_case.fields.add()
126    field.name = "string_value_with_default"
127    field.dtype = types_pb2.DT_STRING
128    field.value.string_value.append("a")
129    test_case.sizes.append(0)
130    field = test_case.fields.add()
131    field.name = "bytes_value_with_default"
132    field.dtype = types_pb2.DT_STRING
133    field.value.string_value.append("a longer default string")
134    return test_case
135
136  @staticmethod
137  def minmax_test_case():
138    test_case = test_example_pb2.TestCase()
139    value = test_case.values.add()
140    value.double_value.append(-1.7976931348623158e+308)
141    value.double_value.append(2.2250738585072014e-308)
142    value.double_value.append(1.7976931348623158e+308)
143    value.float_value.append(-3.402823466e+38)
144    value.float_value.append(1.175494351e-38)
145    value.float_value.append(3.402823466e+38)
146    value.int64_value.append(-9223372036854775808)
147    value.int64_value.append(9223372036854775807)
148    value.sfixed64_value.append(-9223372036854775808)
149    value.sfixed64_value.append(9223372036854775807)
150    value.sint64_value.append(-9223372036854775808)
151    value.sint64_value.append(9223372036854775807)
152    value.uint64_value.append(0)
153    value.uint64_value.append(18446744073709551615)
154    value.fixed64_value.append(0)
155    value.fixed64_value.append(18446744073709551615)
156    value.int32_value.append(-2147483648)
157    value.int32_value.append(2147483647)
158    value.sfixed32_value.append(-2147483648)
159    value.sfixed32_value.append(2147483647)
160    value.sint32_value.append(-2147483648)
161    value.sint32_value.append(2147483647)
162    value.uint32_value.append(0)
163    value.uint32_value.append(4294967295)
164    value.fixed32_value.append(0)
165    value.fixed32_value.append(4294967295)
166    value.bool_value.append(False)
167    value.bool_value.append(True)
168    value.string_value.append("")
169    value.string_value.append("I refer to the infinite.")
170    test_case.shapes.append(1)
171    test_case.sizes.append(3)
172    field = test_case.fields.add()
173    field.name = "double_value"
174    field.dtype = types_pb2.DT_DOUBLE
175    field.value.double_value.append(-1.7976931348623158e+308)
176    field.value.double_value.append(2.2250738585072014e-308)
177    field.value.double_value.append(1.7976931348623158e+308)
178    test_case.sizes.append(3)
179    field = test_case.fields.add()
180    field.name = "float_value"
181    field.dtype = types_pb2.DT_FLOAT
182    field.value.float_value.append(-3.402823466e+38)
183    field.value.float_value.append(1.175494351e-38)
184    field.value.float_value.append(3.402823466e+38)
185    test_case.sizes.append(2)
186    field = test_case.fields.add()
187    field.name = "int64_value"
188    field.dtype = types_pb2.DT_INT64
189    field.value.int64_value.append(-9223372036854775808)
190    field.value.int64_value.append(9223372036854775807)
191    test_case.sizes.append(2)
192    field = test_case.fields.add()
193    field.name = "sfixed64_value"
194    field.dtype = types_pb2.DT_INT64
195    field.value.int64_value.append(-9223372036854775808)
196    field.value.int64_value.append(9223372036854775807)
197    test_case.sizes.append(2)
198    field = test_case.fields.add()
199    field.name = "sint64_value"
200    field.dtype = types_pb2.DT_INT64
201    field.value.int64_value.append(-9223372036854775808)
202    field.value.int64_value.append(9223372036854775807)
203    test_case.sizes.append(2)
204    field = test_case.fields.add()
205    field.name = "uint64_value"
206    field.dtype = types_pb2.DT_UINT64
207    field.value.uint64_value.append(0)
208    field.value.uint64_value.append(18446744073709551615)
209    test_case.sizes.append(2)
210    field = test_case.fields.add()
211    field.name = "fixed64_value"
212    field.dtype = types_pb2.DT_UINT64
213    field.value.uint64_value.append(0)
214    field.value.uint64_value.append(18446744073709551615)
215    test_case.sizes.append(2)
216    field = test_case.fields.add()
217    field.name = "int32_value"
218    field.dtype = types_pb2.DT_INT32
219    field.value.int32_value.append(-2147483648)
220    field.value.int32_value.append(2147483647)
221    test_case.sizes.append(2)
222    field = test_case.fields.add()
223    field.name = "sfixed32_value"
224    field.dtype = types_pb2.DT_INT32
225    field.value.int32_value.append(-2147483648)
226    field.value.int32_value.append(2147483647)
227    test_case.sizes.append(2)
228    field = test_case.fields.add()
229    field.name = "sint32_value"
230    field.dtype = types_pb2.DT_INT32
231    field.value.int32_value.append(-2147483648)
232    field.value.int32_value.append(2147483647)
233    test_case.sizes.append(2)
234    field = test_case.fields.add()
235    field.name = "uint32_value"
236    field.dtype = types_pb2.DT_UINT32
237    field.value.uint32_value.append(0)
238    field.value.uint32_value.append(4294967295)
239    test_case.sizes.append(2)
240    field = test_case.fields.add()
241    field.name = "fixed32_value"
242    field.dtype = types_pb2.DT_UINT32
243    field.value.uint32_value.append(0)
244    field.value.uint32_value.append(4294967295)
245    test_case.sizes.append(2)
246    field = test_case.fields.add()
247    field.name = "bool_value"
248    field.dtype = types_pb2.DT_BOOL
249    field.value.bool_value.append(False)
250    field.value.bool_value.append(True)
251    test_case.sizes.append(2)
252    field = test_case.fields.add()
253    field.name = "string_value"
254    field.dtype = types_pb2.DT_STRING
255    field.value.string_value.append("")
256    field.value.string_value.append("I refer to the infinite.")
257    return test_case
258
259  @staticmethod
260  def nested_test_case():
261    test_case = test_example_pb2.TestCase()
262    value = test_case.values.add()
263    message_value = value.message_value.add()
264    message_value.double_value = 23.5
265    test_case.shapes.append(1)
266    test_case.sizes.append(1)
267    field = test_case.fields.add()
268    field.name = "message_value"
269    field.dtype = types_pb2.DT_STRING
270    message_value = field.value.message_value.add()
271    message_value.double_value = 23.5
272    return test_case
273
274  @staticmethod
275  def optional_test_case():
276    test_case = test_example_pb2.TestCase()
277    value = test_case.values.add()
278    value.bool_value.append(True)
279    test_case.shapes.append(1)
280    test_case.sizes.append(1)
281    field = test_case.fields.add()
282    field.name = "bool_value"
283    field.dtype = types_pb2.DT_BOOL
284    field.value.bool_value.append(True)
285    test_case.sizes.append(0)
286    field = test_case.fields.add()
287    field.name = "double_value"
288    field.dtype = types_pb2.DT_DOUBLE
289    field.value.double_value.append(0.0)
290    return test_case
291
292  @staticmethod
293  def promote_test_case():
294    test_case = test_example_pb2.TestCase()
295    value = test_case.values.add()
296    value.sint32_value.append(2147483647)
297    value.sfixed32_value.append(2147483647)
298    value.int32_value.append(2147483647)
299    value.fixed32_value.append(4294967295)
300    value.uint32_value.append(4294967295)
301    test_case.shapes.append(1)
302    test_case.sizes.append(1)
303    field = test_case.fields.add()
304    field.name = "sint32_value"
305    field.dtype = types_pb2.DT_INT64
306    field.value.int64_value.append(2147483647)
307    test_case.sizes.append(1)
308    field = test_case.fields.add()
309    field.name = "sfixed32_value"
310    field.dtype = types_pb2.DT_INT64
311    field.value.int64_value.append(2147483647)
312    test_case.sizes.append(1)
313    field = test_case.fields.add()
314    field.name = "int32_value"
315    field.dtype = types_pb2.DT_INT64
316    field.value.int64_value.append(2147483647)
317    test_case.sizes.append(1)
318    field = test_case.fields.add()
319    field.name = "fixed32_value"
320    field.dtype = types_pb2.DT_UINT64
321    field.value.uint64_value.append(4294967295)
322    test_case.sizes.append(1)
323    field = test_case.fields.add()
324    field.name = "uint32_value"
325    field.dtype = types_pb2.DT_UINT64
326    field.value.uint64_value.append(4294967295)
327    return test_case
328
329  @staticmethod
330  def ragged_test_case():
331    test_case = test_example_pb2.TestCase()
332    value = test_case.values.add()
333    value.double_value.append(23.5)
334    value.double_value.append(123.0)
335    value.bool_value.append(True)
336    value = test_case.values.add()
337    value.double_value.append(3.1)
338    value.bool_value.append(False)
339    test_case.shapes.append(2)
340    test_case.sizes.append(2)
341    test_case.sizes.append(1)
342    test_case.sizes.append(1)
343    test_case.sizes.append(1)
344    field = test_case.fields.add()
345    field.name = "double_value"
346    field.dtype = types_pb2.DT_DOUBLE
347    field.value.double_value.append(23.5)
348    field.value.double_value.append(123.0)
349    field.value.double_value.append(3.1)
350    field.value.double_value.append(0.0)
351    field = test_case.fields.add()
352    field.name = "bool_value"
353    field.dtype = types_pb2.DT_BOOL
354    field.value.bool_value.append(True)
355    field.value.bool_value.append(False)
356    return test_case
357
358  @staticmethod
359  def shaped_batch_test_case():
360    test_case = test_example_pb2.TestCase()
361    value = test_case.values.add()
362    value.double_value.append(23.5)
363    value.bool_value.append(True)
364    value = test_case.values.add()
365    value.double_value.append(44.0)
366    value.bool_value.append(False)
367    value = test_case.values.add()
368    value.double_value.append(3.14159)
369    value.bool_value.append(True)
370    value = test_case.values.add()
371    value.double_value.append(1.414)
372    value.bool_value.append(True)
373    value = test_case.values.add()
374    value.double_value.append(-32.2)
375    value.bool_value.append(False)
376    value = test_case.values.add()
377    value.double_value.append(0.0001)
378    value.bool_value.append(True)
379    test_case.shapes.append(3)
380    test_case.shapes.append(2)
381    for _ in range(12):
382      test_case.sizes.append(1)
383    field = test_case.fields.add()
384    field.name = "double_value"
385    field.dtype = types_pb2.DT_DOUBLE
386    field.value.double_value.append(23.5)
387    field.value.double_value.append(44.0)
388    field.value.double_value.append(3.14159)
389    field.value.double_value.append(1.414)
390    field.value.double_value.append(-32.2)
391    field.value.double_value.append(0.0001)
392    field = test_case.fields.add()
393    field.name = "bool_value"
394    field.dtype = types_pb2.DT_BOOL
395    field.value.bool_value.append(True)
396    field.value.bool_value.append(False)
397    field.value.bool_value.append(True)
398    field.value.bool_value.append(True)
399    field.value.bool_value.append(False)
400    field.value.bool_value.append(True)
401    return test_case
402
403  @staticmethod
404  def extension_test_case():
405    test_case = test_example_pb2.TestCase()
406    value = test_case.values.add()
407    message_value = value.Extensions[test_example_pb2.ext_value].add()
408    message_value.double_value = 23.5
409    test_case.shapes.append(1)
410    test_case.sizes.append(1)
411    field = test_case.fields.add()
412    field.name = test_example_pb2.ext_value.full_name
413    field.dtype = types_pb2.DT_STRING
414    message_value = field.value.Extensions[test_example_pb2.ext_value].add()
415    message_value.double_value = 23.5
416    return test_case
417
418  @staticmethod
419  def simple_test_case():
420    test_case = test_example_pb2.TestCase()
421    value = test_case.values.add()
422    value.double_value.append(23.5)
423    value.bool_value.append(True)
424    test_case.shapes.append(1)
425    test_case.sizes.append(1)
426    field = test_case.fields.add()
427    field.name = "double_value"
428    field.dtype = types_pb2.DT_DOUBLE
429    field.value.double_value.append(23.5)
430    test_case.sizes.append(1)
431    field = test_case.fields.add()
432    field.name = "bool_value"
433    field.dtype = types_pb2.DT_BOOL
434    field.value.bool_value.append(True)
435    return test_case
436