• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# =============================================================================
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# =============================================================================
16"""Test case base for testing proto operations."""
17
18# Python3 preparedness imports.
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import ctypes as ct
24import os
25
26from tensorflow.core.framework import types_pb2
27from tensorflow.python.kernel_tests.proto import test_example_pb2
28from tensorflow.python.platform import test
29
30
31class ProtoOpTestBase(test.TestCase):
32  """Base class for testing proto decoding and encoding ops."""
33
34  def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
35    super(ProtoOpTestBase, self).__init__(methodName)
36    lib = os.path.join(os.path.dirname(__file__), "libtestexample.so")
37    if os.path.isfile(lib):
38      ct.cdll.LoadLibrary(lib)
39
40  @staticmethod
41  def named_parameters(extension=True):
42    parameters = [("defaults", ProtoOpTestBase.defaults_test_case()),
43                  ("minmax", ProtoOpTestBase.minmax_test_case()),
44                  ("nested", ProtoOpTestBase.nested_test_case()),
45                  ("optional", ProtoOpTestBase.optional_test_case()),
46                  ("promote", ProtoOpTestBase.promote_test_case()),
47                  ("ragged", ProtoOpTestBase.ragged_test_case()),
48                  ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()),
49                  ("simple", ProtoOpTestBase.simple_test_case())]
50    if extension:
51      parameters.append(("extension", ProtoOpTestBase.extension_test_case()))
52    return parameters
53
54  @staticmethod
55  def defaults_test_case():
56    test_case = test_example_pb2.TestCase()
57    test_case.values.add()  # No fields specified, so we get all defaults.
58    test_case.shapes.append(1)
59    test_case.sizes.append(0)
60    field = test_case.fields.add()
61    field.name = "double_value_with_default"
62    field.dtype = types_pb2.DT_DOUBLE
63    field.value.double_value.append(1.0)
64    test_case.sizes.append(0)
65    field = test_case.fields.add()
66    field.name = "float_value_with_default"
67    field.dtype = types_pb2.DT_FLOAT
68    field.value.float_value.append(2.0)
69    test_case.sizes.append(0)
70    field = test_case.fields.add()
71    field.name = "int64_value_with_default"
72    field.dtype = types_pb2.DT_INT64
73    field.value.int64_value.append(3)
74    test_case.sizes.append(0)
75    field = test_case.fields.add()
76    field.name = "sfixed64_value_with_default"
77    field.dtype = types_pb2.DT_INT64
78    field.value.int64_value.append(11)
79    test_case.sizes.append(0)
80    field = test_case.fields.add()
81    field.name = "sint64_value_with_default"
82    field.dtype = types_pb2.DT_INT64
83    field.value.int64_value.append(13)
84    test_case.sizes.append(0)
85    field = test_case.fields.add()
86    field.name = "uint64_value_with_default"
87    field.dtype = types_pb2.DT_UINT64
88    field.value.uint64_value.append(4)
89    test_case.sizes.append(0)
90    field = test_case.fields.add()
91    field.name = "fixed64_value_with_default"
92    field.dtype = types_pb2.DT_UINT64
93    field.value.uint64_value.append(6)
94    test_case.sizes.append(0)
95    field = test_case.fields.add()
96    field.name = "int32_value_with_default"
97    field.dtype = types_pb2.DT_INT32
98    field.value.int32_value.append(5)
99    test_case.sizes.append(0)
100    field = test_case.fields.add()
101    field.name = "sfixed32_value_with_default"
102    field.dtype = types_pb2.DT_INT32
103    field.value.int32_value.append(10)
104    test_case.sizes.append(0)
105    field = test_case.fields.add()
106    field.name = "sint32_value_with_default"
107    field.dtype = types_pb2.DT_INT32
108    field.value.int32_value.append(12)
109    test_case.sizes.append(0)
110    field = test_case.fields.add()
111    field.name = "uint32_value_with_default"
112    field.dtype = types_pb2.DT_UINT32
113    field.value.uint32_value.append(9)
114    test_case.sizes.append(0)
115    field = test_case.fields.add()
116    field.name = "fixed32_value_with_default"
117    field.dtype = types_pb2.DT_UINT32
118    field.value.uint32_value.append(7)
119    test_case.sizes.append(0)
120    field = test_case.fields.add()
121    field.name = "bool_value_with_default"
122    field.dtype = types_pb2.DT_BOOL
123    field.value.bool_value.append(True)
124    test_case.sizes.append(0)
125    field = test_case.fields.add()
126    field.name = "string_value_with_default"
127    field.dtype = types_pb2.DT_STRING
128    field.value.string_value.append("a")
129    test_case.sizes.append(0)
130    field = test_case.fields.add()
131    field.name = "bytes_value_with_default"
132    field.dtype = types_pb2.DT_STRING
133    field.value.string_value.append("a longer default string")
134    test_case.sizes.append(0)
135    field = test_case.fields.add()
136    field.name = "enum_value_with_default"
137    field.dtype = types_pb2.DT_INT32
138    field.value.enum_value.append(test_example_pb2.Color.GREEN)
139    return test_case
140
141  @staticmethod
142  def minmax_test_case():
143    test_case = test_example_pb2.TestCase()
144    value = test_case.values.add()
145    value.double_value.append(-1.7976931348623158e+308)
146    value.double_value.append(2.2250738585072014e-308)
147    value.double_value.append(1.7976931348623158e+308)
148    value.float_value.append(-3.402823466e+38)
149    value.float_value.append(1.175494351e-38)
150    value.float_value.append(3.402823466e+38)
151    value.int64_value.append(-9223372036854775808)
152    value.int64_value.append(9223372036854775807)
153    value.sfixed64_value.append(-9223372036854775808)
154    value.sfixed64_value.append(9223372036854775807)
155    value.sint64_value.append(-9223372036854775808)
156    value.sint64_value.append(9223372036854775807)
157    value.uint64_value.append(0)
158    value.uint64_value.append(18446744073709551615)
159    value.fixed64_value.append(0)
160    value.fixed64_value.append(18446744073709551615)
161    value.int32_value.append(-2147483648)
162    value.int32_value.append(2147483647)
163    value.sfixed32_value.append(-2147483648)
164    value.sfixed32_value.append(2147483647)
165    value.sint32_value.append(-2147483648)
166    value.sint32_value.append(2147483647)
167    value.uint32_value.append(0)
168    value.uint32_value.append(4294967295)
169    value.fixed32_value.append(0)
170    value.fixed32_value.append(4294967295)
171    value.bool_value.append(False)
172    value.bool_value.append(True)
173    value.string_value.append("")
174    value.string_value.append("I refer to the infinite.")
175    test_case.shapes.append(1)
176    test_case.sizes.append(3)
177    field = test_case.fields.add()
178    field.name = "double_value"
179    field.dtype = types_pb2.DT_DOUBLE
180    field.value.double_value.append(-1.7976931348623158e+308)
181    field.value.double_value.append(2.2250738585072014e-308)
182    field.value.double_value.append(1.7976931348623158e+308)
183    test_case.sizes.append(3)
184    field = test_case.fields.add()
185    field.name = "float_value"
186    field.dtype = types_pb2.DT_FLOAT
187    field.value.float_value.append(-3.402823466e+38)
188    field.value.float_value.append(1.175494351e-38)
189    field.value.float_value.append(3.402823466e+38)
190    test_case.sizes.append(2)
191    field = test_case.fields.add()
192    field.name = "int64_value"
193    field.dtype = types_pb2.DT_INT64
194    field.value.int64_value.append(-9223372036854775808)
195    field.value.int64_value.append(9223372036854775807)
196    test_case.sizes.append(2)
197    field = test_case.fields.add()
198    field.name = "sfixed64_value"
199    field.dtype = types_pb2.DT_INT64
200    field.value.int64_value.append(-9223372036854775808)
201    field.value.int64_value.append(9223372036854775807)
202    test_case.sizes.append(2)
203    field = test_case.fields.add()
204    field.name = "sint64_value"
205    field.dtype = types_pb2.DT_INT64
206    field.value.int64_value.append(-9223372036854775808)
207    field.value.int64_value.append(9223372036854775807)
208    test_case.sizes.append(2)
209    field = test_case.fields.add()
210    field.name = "uint64_value"
211    field.dtype = types_pb2.DT_UINT64
212    field.value.uint64_value.append(0)
213    field.value.uint64_value.append(18446744073709551615)
214    test_case.sizes.append(2)
215    field = test_case.fields.add()
216    field.name = "fixed64_value"
217    field.dtype = types_pb2.DT_UINT64
218    field.value.uint64_value.append(0)
219    field.value.uint64_value.append(18446744073709551615)
220    test_case.sizes.append(2)
221    field = test_case.fields.add()
222    field.name = "int32_value"
223    field.dtype = types_pb2.DT_INT32
224    field.value.int32_value.append(-2147483648)
225    field.value.int32_value.append(2147483647)
226    test_case.sizes.append(2)
227    field = test_case.fields.add()
228    field.name = "sfixed32_value"
229    field.dtype = types_pb2.DT_INT32
230    field.value.int32_value.append(-2147483648)
231    field.value.int32_value.append(2147483647)
232    test_case.sizes.append(2)
233    field = test_case.fields.add()
234    field.name = "sint32_value"
235    field.dtype = types_pb2.DT_INT32
236    field.value.int32_value.append(-2147483648)
237    field.value.int32_value.append(2147483647)
238    test_case.sizes.append(2)
239    field = test_case.fields.add()
240    field.name = "uint32_value"
241    field.dtype = types_pb2.DT_UINT32
242    field.value.uint32_value.append(0)
243    field.value.uint32_value.append(4294967295)
244    test_case.sizes.append(2)
245    field = test_case.fields.add()
246    field.name = "fixed32_value"
247    field.dtype = types_pb2.DT_UINT32
248    field.value.uint32_value.append(0)
249    field.value.uint32_value.append(4294967295)
250    test_case.sizes.append(2)
251    field = test_case.fields.add()
252    field.name = "bool_value"
253    field.dtype = types_pb2.DT_BOOL
254    field.value.bool_value.append(False)
255    field.value.bool_value.append(True)
256    test_case.sizes.append(2)
257    field = test_case.fields.add()
258    field.name = "string_value"
259    field.dtype = types_pb2.DT_STRING
260    field.value.string_value.append("")
261    field.value.string_value.append("I refer to the infinite.")
262    return test_case
263
264  @staticmethod
265  def nested_test_case():
266    test_case = test_example_pb2.TestCase()
267    value = test_case.values.add()
268    message_value = value.message_value.add()
269    message_value.double_value = 23.5
270    test_case.shapes.append(1)
271    test_case.sizes.append(1)
272    field = test_case.fields.add()
273    field.name = "message_value"
274    field.dtype = types_pb2.DT_STRING
275    message_value = field.value.message_value.add()
276    message_value.double_value = 23.5
277    return test_case
278
279  @staticmethod
280  def optional_test_case():
281    test_case = test_example_pb2.TestCase()
282    value = test_case.values.add()
283    value.bool_value.append(True)
284    test_case.shapes.append(1)
285    test_case.sizes.append(1)
286    field = test_case.fields.add()
287    field.name = "bool_value"
288    field.dtype = types_pb2.DT_BOOL
289    field.value.bool_value.append(True)
290    test_case.sizes.append(0)
291    field = test_case.fields.add()
292    field.name = "double_value"
293    field.dtype = types_pb2.DT_DOUBLE
294    field.value.double_value.append(0.0)
295    return test_case
296
297  @staticmethod
298  def promote_test_case():
299    test_case = test_example_pb2.TestCase()
300    value = test_case.values.add()
301    value.sint32_value.append(2147483647)
302    value.sfixed32_value.append(2147483647)
303    value.int32_value.append(2147483647)
304    value.fixed32_value.append(4294967295)
305    value.uint32_value.append(4294967295)
306    test_case.shapes.append(1)
307    test_case.sizes.append(1)
308    field = test_case.fields.add()
309    field.name = "sint32_value"
310    field.dtype = types_pb2.DT_INT64
311    field.value.int64_value.append(2147483647)
312    test_case.sizes.append(1)
313    field = test_case.fields.add()
314    field.name = "sfixed32_value"
315    field.dtype = types_pb2.DT_INT64
316    field.value.int64_value.append(2147483647)
317    test_case.sizes.append(1)
318    field = test_case.fields.add()
319    field.name = "int32_value"
320    field.dtype = types_pb2.DT_INT64
321    field.value.int64_value.append(2147483647)
322    test_case.sizes.append(1)
323    field = test_case.fields.add()
324    field.name = "fixed32_value"
325    field.dtype = types_pb2.DT_UINT64
326    field.value.uint64_value.append(4294967295)
327    test_case.sizes.append(1)
328    field = test_case.fields.add()
329    field.name = "uint32_value"
330    field.dtype = types_pb2.DT_UINT64
331    field.value.uint64_value.append(4294967295)
332    return test_case
333
334  @staticmethod
335  def ragged_test_case():
336    test_case = test_example_pb2.TestCase()
337    value = test_case.values.add()
338    value.double_value.append(23.5)
339    value.double_value.append(123.0)
340    value.bool_value.append(True)
341    value = test_case.values.add()
342    value.double_value.append(3.1)
343    value.bool_value.append(False)
344    test_case.shapes.append(2)
345    test_case.sizes.append(2)
346    test_case.sizes.append(1)
347    test_case.sizes.append(1)
348    test_case.sizes.append(1)
349    field = test_case.fields.add()
350    field.name = "double_value"
351    field.dtype = types_pb2.DT_DOUBLE
352    field.value.double_value.append(23.5)
353    field.value.double_value.append(123.0)
354    field.value.double_value.append(3.1)
355    field.value.double_value.append(0.0)
356    field = test_case.fields.add()
357    field.name = "bool_value"
358    field.dtype = types_pb2.DT_BOOL
359    field.value.bool_value.append(True)
360    field.value.bool_value.append(False)
361    return test_case
362
363  @staticmethod
364  def shaped_batch_test_case():
365    test_case = test_example_pb2.TestCase()
366    value = test_case.values.add()
367    value.double_value.append(23.5)
368    value.bool_value.append(True)
369    value = test_case.values.add()
370    value.double_value.append(44.0)
371    value.bool_value.append(False)
372    value = test_case.values.add()
373    value.double_value.append(3.14159)
374    value.bool_value.append(True)
375    value = test_case.values.add()
376    value.double_value.append(1.414)
377    value.bool_value.append(True)
378    value = test_case.values.add()
379    value.double_value.append(-32.2)
380    value.bool_value.append(False)
381    value = test_case.values.add()
382    value.double_value.append(0.0001)
383    value.bool_value.append(True)
384    test_case.shapes.append(3)
385    test_case.shapes.append(2)
386    for _ in range(12):
387      test_case.sizes.append(1)
388    field = test_case.fields.add()
389    field.name = "double_value"
390    field.dtype = types_pb2.DT_DOUBLE
391    field.value.double_value.append(23.5)
392    field.value.double_value.append(44.0)
393    field.value.double_value.append(3.14159)
394    field.value.double_value.append(1.414)
395    field.value.double_value.append(-32.2)
396    field.value.double_value.append(0.0001)
397    field = test_case.fields.add()
398    field.name = "bool_value"
399    field.dtype = types_pb2.DT_BOOL
400    field.value.bool_value.append(True)
401    field.value.bool_value.append(False)
402    field.value.bool_value.append(True)
403    field.value.bool_value.append(True)
404    field.value.bool_value.append(False)
405    field.value.bool_value.append(True)
406    return test_case
407
408  @staticmethod
409  def extension_test_case():
410    test_case = test_example_pb2.TestCase()
411    value = test_case.values.add()
412    message_value = value.Extensions[test_example_pb2.ext_value].add()
413    message_value.double_value = 23.5
414    test_case.shapes.append(1)
415    test_case.sizes.append(1)
416    field = test_case.fields.add()
417    field.name = test_example_pb2.ext_value.full_name
418    field.dtype = types_pb2.DT_STRING
419    message_value = field.value.Extensions[test_example_pb2.ext_value].add()
420    message_value.double_value = 23.5
421    return test_case
422
423  @staticmethod
424  def simple_test_case():
425    test_case = test_example_pb2.TestCase()
426    value = test_case.values.add()
427    value.double_value.append(23.5)
428    value.bool_value.append(True)
429    value.enum_value.append(test_example_pb2.Color.INDIGO)
430    test_case.shapes.append(1)
431    test_case.sizes.append(1)
432    field = test_case.fields.add()
433    field.name = "double_value"
434    field.dtype = types_pb2.DT_DOUBLE
435    field.value.double_value.append(23.5)
436    test_case.sizes.append(1)
437    field = test_case.fields.add()
438    field.name = "bool_value"
439    field.dtype = types_pb2.DT_BOOL
440    field.value.bool_value.append(True)
441    test_case.sizes.append(1)
442    field = test_case.fields.add()
443    field.name = "enum_value"
444    field.dtype = types_pb2.DT_INT32
445    field.value.enum_value.append(test_example_pb2.Color.INDIGO)
446    return test_case
447