1# ============================================================================= 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================= 16"""Test case base for testing proto operations.""" 17 18# Python3 preparedness imports. 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import ctypes as ct 24import os 25 26from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 27from tensorflow.core.framework import types_pb2 28from tensorflow.python.platform import test 29 30 31class ProtoOpTestBase(test.TestCase): 32 """Base class for testing proto decoding and encoding ops.""" 33 34 def __init__(self, methodName="runTest"): # pylint: disable=invalid-name 35 super(ProtoOpTestBase, self).__init__(methodName) 36 lib = os.path.join(os.path.dirname(__file__), "libtestexample.so") 37 if os.path.isfile(lib): 38 ct.cdll.LoadLibrary(lib) 39 40 @staticmethod 41 def named_parameters(extension=True): 42 parameters = [("defaults", ProtoOpTestBase.defaults_test_case()), 43 ("minmax", ProtoOpTestBase.minmax_test_case()), 44 ("nested", ProtoOpTestBase.nested_test_case()), 45 ("optional", ProtoOpTestBase.optional_test_case()), 46 ("promote", ProtoOpTestBase.promote_test_case()), 47 ("ragged", ProtoOpTestBase.ragged_test_case()), 48 ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()), 49 ("simple", ProtoOpTestBase.simple_test_case())] 50 if extension: 51 parameters.append(("extension", ProtoOpTestBase.extension_test_case())) 52 return parameters 53 54 @staticmethod 55 def defaults_test_case(): 56 test_case = test_example_pb2.TestCase() 57 test_case.values.add() # No fields specified, so we get all defaults. 58 test_case.shapes.append(1) 59 test_case.sizes.append(0) 60 field = test_case.fields.add() 61 field.name = "double_value_with_default" 62 field.dtype = types_pb2.DT_DOUBLE 63 field.value.double_value.append(1.0) 64 test_case.sizes.append(0) 65 field = test_case.fields.add() 66 field.name = "float_value_with_default" 67 field.dtype = types_pb2.DT_FLOAT 68 field.value.float_value.append(2.0) 69 test_case.sizes.append(0) 70 field = test_case.fields.add() 71 field.name = "int64_value_with_default" 72 field.dtype = types_pb2.DT_INT64 73 field.value.int64_value.append(3) 74 test_case.sizes.append(0) 75 field = test_case.fields.add() 76 field.name = "sfixed64_value_with_default" 77 field.dtype = types_pb2.DT_INT64 78 field.value.int64_value.append(11) 79 test_case.sizes.append(0) 80 field = test_case.fields.add() 81 field.name = "sint64_value_with_default" 82 field.dtype = types_pb2.DT_INT64 83 field.value.int64_value.append(13) 84 test_case.sizes.append(0) 85 field = test_case.fields.add() 86 field.name = "uint64_value_with_default" 87 field.dtype = types_pb2.DT_UINT64 88 field.value.uint64_value.append(4) 89 test_case.sizes.append(0) 90 field = test_case.fields.add() 91 field.name = "fixed64_value_with_default" 92 field.dtype = types_pb2.DT_UINT64 93 field.value.uint64_value.append(6) 94 test_case.sizes.append(0) 95 field = test_case.fields.add() 96 field.name = "int32_value_with_default" 97 field.dtype = types_pb2.DT_INT32 98 field.value.int32_value.append(5) 99 test_case.sizes.append(0) 100 field = test_case.fields.add() 101 field.name = "sfixed32_value_with_default" 102 field.dtype = types_pb2.DT_INT32 103 field.value.int32_value.append(10) 104 test_case.sizes.append(0) 105 field = test_case.fields.add() 106 field.name = "sint32_value_with_default" 107 field.dtype = types_pb2.DT_INT32 108 field.value.int32_value.append(12) 109 test_case.sizes.append(0) 110 field = test_case.fields.add() 111 field.name = "uint32_value_with_default" 112 field.dtype = types_pb2.DT_UINT32 113 field.value.uint32_value.append(9) 114 test_case.sizes.append(0) 115 field = test_case.fields.add() 116 field.name = "fixed32_value_with_default" 117 field.dtype = types_pb2.DT_UINT32 118 field.value.uint32_value.append(7) 119 test_case.sizes.append(0) 120 field = test_case.fields.add() 121 field.name = "bool_value_with_default" 122 field.dtype = types_pb2.DT_BOOL 123 field.value.bool_value.append(True) 124 test_case.sizes.append(0) 125 field = test_case.fields.add() 126 field.name = "string_value_with_default" 127 field.dtype = types_pb2.DT_STRING 128 field.value.string_value.append("a") 129 test_case.sizes.append(0) 130 field = test_case.fields.add() 131 field.name = "bytes_value_with_default" 132 field.dtype = types_pb2.DT_STRING 133 field.value.string_value.append("a longer default string") 134 return test_case 135 136 @staticmethod 137 def minmax_test_case(): 138 test_case = test_example_pb2.TestCase() 139 value = test_case.values.add() 140 value.double_value.append(-1.7976931348623158e+308) 141 value.double_value.append(2.2250738585072014e-308) 142 value.double_value.append(1.7976931348623158e+308) 143 value.float_value.append(-3.402823466e+38) 144 value.float_value.append(1.175494351e-38) 145 value.float_value.append(3.402823466e+38) 146 value.int64_value.append(-9223372036854775808) 147 value.int64_value.append(9223372036854775807) 148 value.sfixed64_value.append(-9223372036854775808) 149 value.sfixed64_value.append(9223372036854775807) 150 value.sint64_value.append(-9223372036854775808) 151 value.sint64_value.append(9223372036854775807) 152 value.uint64_value.append(0) 153 value.uint64_value.append(18446744073709551615) 154 value.fixed64_value.append(0) 155 value.fixed64_value.append(18446744073709551615) 156 value.int32_value.append(-2147483648) 157 value.int32_value.append(2147483647) 158 value.sfixed32_value.append(-2147483648) 159 value.sfixed32_value.append(2147483647) 160 value.sint32_value.append(-2147483648) 161 value.sint32_value.append(2147483647) 162 value.uint32_value.append(0) 163 value.uint32_value.append(4294967295) 164 value.fixed32_value.append(0) 165 value.fixed32_value.append(4294967295) 166 value.bool_value.append(False) 167 value.bool_value.append(True) 168 value.string_value.append("") 169 value.string_value.append("I refer to the infinite.") 170 test_case.shapes.append(1) 171 test_case.sizes.append(3) 172 field = test_case.fields.add() 173 field.name = "double_value" 174 field.dtype = types_pb2.DT_DOUBLE 175 field.value.double_value.append(-1.7976931348623158e+308) 176 field.value.double_value.append(2.2250738585072014e-308) 177 field.value.double_value.append(1.7976931348623158e+308) 178 test_case.sizes.append(3) 179 field = test_case.fields.add() 180 field.name = "float_value" 181 field.dtype = types_pb2.DT_FLOAT 182 field.value.float_value.append(-3.402823466e+38) 183 field.value.float_value.append(1.175494351e-38) 184 field.value.float_value.append(3.402823466e+38) 185 test_case.sizes.append(2) 186 field = test_case.fields.add() 187 field.name = "int64_value" 188 field.dtype = types_pb2.DT_INT64 189 field.value.int64_value.append(-9223372036854775808) 190 field.value.int64_value.append(9223372036854775807) 191 test_case.sizes.append(2) 192 field = test_case.fields.add() 193 field.name = "sfixed64_value" 194 field.dtype = types_pb2.DT_INT64 195 field.value.int64_value.append(-9223372036854775808) 196 field.value.int64_value.append(9223372036854775807) 197 test_case.sizes.append(2) 198 field = test_case.fields.add() 199 field.name = "sint64_value" 200 field.dtype = types_pb2.DT_INT64 201 field.value.int64_value.append(-9223372036854775808) 202 field.value.int64_value.append(9223372036854775807) 203 test_case.sizes.append(2) 204 field = test_case.fields.add() 205 field.name = "uint64_value" 206 field.dtype = types_pb2.DT_UINT64 207 field.value.uint64_value.append(0) 208 field.value.uint64_value.append(18446744073709551615) 209 test_case.sizes.append(2) 210 field = test_case.fields.add() 211 field.name = "fixed64_value" 212 field.dtype = types_pb2.DT_UINT64 213 field.value.uint64_value.append(0) 214 field.value.uint64_value.append(18446744073709551615) 215 test_case.sizes.append(2) 216 field = test_case.fields.add() 217 field.name = "int32_value" 218 field.dtype = types_pb2.DT_INT32 219 field.value.int32_value.append(-2147483648) 220 field.value.int32_value.append(2147483647) 221 test_case.sizes.append(2) 222 field = test_case.fields.add() 223 field.name = "sfixed32_value" 224 field.dtype = types_pb2.DT_INT32 225 field.value.int32_value.append(-2147483648) 226 field.value.int32_value.append(2147483647) 227 test_case.sizes.append(2) 228 field = test_case.fields.add() 229 field.name = "sint32_value" 230 field.dtype = types_pb2.DT_INT32 231 field.value.int32_value.append(-2147483648) 232 field.value.int32_value.append(2147483647) 233 test_case.sizes.append(2) 234 field = test_case.fields.add() 235 field.name = "uint32_value" 236 field.dtype = types_pb2.DT_UINT32 237 field.value.uint32_value.append(0) 238 field.value.uint32_value.append(4294967295) 239 test_case.sizes.append(2) 240 field = test_case.fields.add() 241 field.name = "fixed32_value" 242 field.dtype = types_pb2.DT_UINT32 243 field.value.uint32_value.append(0) 244 field.value.uint32_value.append(4294967295) 245 test_case.sizes.append(2) 246 field = test_case.fields.add() 247 field.name = "bool_value" 248 field.dtype = types_pb2.DT_BOOL 249 field.value.bool_value.append(False) 250 field.value.bool_value.append(True) 251 test_case.sizes.append(2) 252 field = test_case.fields.add() 253 field.name = "string_value" 254 field.dtype = types_pb2.DT_STRING 255 field.value.string_value.append("") 256 field.value.string_value.append("I refer to the infinite.") 257 return test_case 258 259 @staticmethod 260 def nested_test_case(): 261 test_case = test_example_pb2.TestCase() 262 value = test_case.values.add() 263 message_value = value.message_value.add() 264 message_value.double_value = 23.5 265 test_case.shapes.append(1) 266 test_case.sizes.append(1) 267 field = test_case.fields.add() 268 field.name = "message_value" 269 field.dtype = types_pb2.DT_STRING 270 message_value = field.value.message_value.add() 271 message_value.double_value = 23.5 272 return test_case 273 274 @staticmethod 275 def optional_test_case(): 276 test_case = test_example_pb2.TestCase() 277 value = test_case.values.add() 278 value.bool_value.append(True) 279 test_case.shapes.append(1) 280 test_case.sizes.append(1) 281 field = test_case.fields.add() 282 field.name = "bool_value" 283 field.dtype = types_pb2.DT_BOOL 284 field.value.bool_value.append(True) 285 test_case.sizes.append(0) 286 field = test_case.fields.add() 287 field.name = "double_value" 288 field.dtype = types_pb2.DT_DOUBLE 289 field.value.double_value.append(0.0) 290 return test_case 291 292 @staticmethod 293 def promote_test_case(): 294 test_case = test_example_pb2.TestCase() 295 value = test_case.values.add() 296 value.sint32_value.append(2147483647) 297 value.sfixed32_value.append(2147483647) 298 value.int32_value.append(2147483647) 299 value.fixed32_value.append(4294967295) 300 value.uint32_value.append(4294967295) 301 test_case.shapes.append(1) 302 test_case.sizes.append(1) 303 field = test_case.fields.add() 304 field.name = "sint32_value" 305 field.dtype = types_pb2.DT_INT64 306 field.value.int64_value.append(2147483647) 307 test_case.sizes.append(1) 308 field = test_case.fields.add() 309 field.name = "sfixed32_value" 310 field.dtype = types_pb2.DT_INT64 311 field.value.int64_value.append(2147483647) 312 test_case.sizes.append(1) 313 field = test_case.fields.add() 314 field.name = "int32_value" 315 field.dtype = types_pb2.DT_INT64 316 field.value.int64_value.append(2147483647) 317 test_case.sizes.append(1) 318 field = test_case.fields.add() 319 field.name = "fixed32_value" 320 field.dtype = types_pb2.DT_UINT64 321 field.value.uint64_value.append(4294967295) 322 test_case.sizes.append(1) 323 field = test_case.fields.add() 324 field.name = "uint32_value" 325 field.dtype = types_pb2.DT_UINT64 326 field.value.uint64_value.append(4294967295) 327 return test_case 328 329 @staticmethod 330 def ragged_test_case(): 331 test_case = test_example_pb2.TestCase() 332 value = test_case.values.add() 333 value.double_value.append(23.5) 334 value.double_value.append(123.0) 335 value.bool_value.append(True) 336 value = test_case.values.add() 337 value.double_value.append(3.1) 338 value.bool_value.append(False) 339 test_case.shapes.append(2) 340 test_case.sizes.append(2) 341 test_case.sizes.append(1) 342 test_case.sizes.append(1) 343 test_case.sizes.append(1) 344 field = test_case.fields.add() 345 field.name = "double_value" 346 field.dtype = types_pb2.DT_DOUBLE 347 field.value.double_value.append(23.5) 348 field.value.double_value.append(123.0) 349 field.value.double_value.append(3.1) 350 field.value.double_value.append(0.0) 351 field = test_case.fields.add() 352 field.name = "bool_value" 353 field.dtype = types_pb2.DT_BOOL 354 field.value.bool_value.append(True) 355 field.value.bool_value.append(False) 356 return test_case 357 358 @staticmethod 359 def shaped_batch_test_case(): 360 test_case = test_example_pb2.TestCase() 361 value = test_case.values.add() 362 value.double_value.append(23.5) 363 value.bool_value.append(True) 364 value = test_case.values.add() 365 value.double_value.append(44.0) 366 value.bool_value.append(False) 367 value = test_case.values.add() 368 value.double_value.append(3.14159) 369 value.bool_value.append(True) 370 value = test_case.values.add() 371 value.double_value.append(1.414) 372 value.bool_value.append(True) 373 value = test_case.values.add() 374 value.double_value.append(-32.2) 375 value.bool_value.append(False) 376 value = test_case.values.add() 377 value.double_value.append(0.0001) 378 value.bool_value.append(True) 379 test_case.shapes.append(3) 380 test_case.shapes.append(2) 381 for _ in range(12): 382 test_case.sizes.append(1) 383 field = test_case.fields.add() 384 field.name = "double_value" 385 field.dtype = types_pb2.DT_DOUBLE 386 field.value.double_value.append(23.5) 387 field.value.double_value.append(44.0) 388 field.value.double_value.append(3.14159) 389 field.value.double_value.append(1.414) 390 field.value.double_value.append(-32.2) 391 field.value.double_value.append(0.0001) 392 field = test_case.fields.add() 393 field.name = "bool_value" 394 field.dtype = types_pb2.DT_BOOL 395 field.value.bool_value.append(True) 396 field.value.bool_value.append(False) 397 field.value.bool_value.append(True) 398 field.value.bool_value.append(True) 399 field.value.bool_value.append(False) 400 field.value.bool_value.append(True) 401 return test_case 402 403 @staticmethod 404 def extension_test_case(): 405 test_case = test_example_pb2.TestCase() 406 value = test_case.values.add() 407 message_value = value.Extensions[test_example_pb2.ext_value].add() 408 message_value.double_value = 23.5 409 test_case.shapes.append(1) 410 test_case.sizes.append(1) 411 field = test_case.fields.add() 412 field.name = test_example_pb2.ext_value.full_name 413 field.dtype = types_pb2.DT_STRING 414 message_value = field.value.Extensions[test_example_pb2.ext_value].add() 415 message_value.double_value = 23.5 416 return test_case 417 418 @staticmethod 419 def simple_test_case(): 420 test_case = test_example_pb2.TestCase() 421 value = test_case.values.add() 422 value.double_value.append(23.5) 423 value.bool_value.append(True) 424 test_case.shapes.append(1) 425 test_case.sizes.append(1) 426 field = test_case.fields.add() 427 field.name = "double_value" 428 field.dtype = types_pb2.DT_DOUBLE 429 field.value.double_value.append(23.5) 430 test_case.sizes.append(1) 431 field = test_case.fields.add() 432 field.name = "bool_value" 433 field.dtype = types_pb2.DT_BOOL 434 field.value.bool_value.append(True) 435 return test_case 436