1# ============================================================================= 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================= 16"""Test case base for testing proto operations.""" 17 18# Python3 preparedness imports. 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import ctypes as ct 24import os 25 26from tensorflow.core.framework import types_pb2 27from tensorflow.python.kernel_tests.proto import test_example_pb2 28from tensorflow.python.platform import test 29 30 31class ProtoOpTestBase(test.TestCase): 32 """Base class for testing proto decoding and encoding ops.""" 33 34 def __init__(self, methodName="runTest"): # pylint: disable=invalid-name 35 super(ProtoOpTestBase, self).__init__(methodName) 36 lib = os.path.join(os.path.dirname(__file__), "libtestexample.so") 37 if os.path.isfile(lib): 38 ct.cdll.LoadLibrary(lib) 39 40 @staticmethod 41 def named_parameters(extension=True): 42 parameters = [("defaults", ProtoOpTestBase.defaults_test_case()), 43 ("minmax", ProtoOpTestBase.minmax_test_case()), 44 ("nested", ProtoOpTestBase.nested_test_case()), 45 ("optional", ProtoOpTestBase.optional_test_case()), 46 ("promote", ProtoOpTestBase.promote_test_case()), 47 ("ragged", ProtoOpTestBase.ragged_test_case()), 48 ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()), 49 ("simple", ProtoOpTestBase.simple_test_case())] 50 if extension: 51 parameters.append(("extension", ProtoOpTestBase.extension_test_case())) 52 return parameters 53 54 @staticmethod 55 def defaults_test_case(): 56 test_case = test_example_pb2.TestCase() 57 test_case.values.add() # No fields specified, so we get all defaults. 58 test_case.shapes.append(1) 59 test_case.sizes.append(0) 60 field = test_case.fields.add() 61 field.name = "double_value_with_default" 62 field.dtype = types_pb2.DT_DOUBLE 63 field.value.double_value.append(1.0) 64 test_case.sizes.append(0) 65 field = test_case.fields.add() 66 field.name = "float_value_with_default" 67 field.dtype = types_pb2.DT_FLOAT 68 field.value.float_value.append(2.0) 69 test_case.sizes.append(0) 70 field = test_case.fields.add() 71 field.name = "int64_value_with_default" 72 field.dtype = types_pb2.DT_INT64 73 field.value.int64_value.append(3) 74 test_case.sizes.append(0) 75 field = test_case.fields.add() 76 field.name = "sfixed64_value_with_default" 77 field.dtype = types_pb2.DT_INT64 78 field.value.int64_value.append(11) 79 test_case.sizes.append(0) 80 field = test_case.fields.add() 81 field.name = "sint64_value_with_default" 82 field.dtype = types_pb2.DT_INT64 83 field.value.int64_value.append(13) 84 test_case.sizes.append(0) 85 field = test_case.fields.add() 86 field.name = "uint64_value_with_default" 87 field.dtype = types_pb2.DT_UINT64 88 field.value.uint64_value.append(4) 89 test_case.sizes.append(0) 90 field = test_case.fields.add() 91 field.name = "fixed64_value_with_default" 92 field.dtype = types_pb2.DT_UINT64 93 field.value.uint64_value.append(6) 94 test_case.sizes.append(0) 95 field = test_case.fields.add() 96 field.name = "int32_value_with_default" 97 field.dtype = types_pb2.DT_INT32 98 field.value.int32_value.append(5) 99 test_case.sizes.append(0) 100 field = test_case.fields.add() 101 field.name = "sfixed32_value_with_default" 102 field.dtype = types_pb2.DT_INT32 103 field.value.int32_value.append(10) 104 test_case.sizes.append(0) 105 field = test_case.fields.add() 106 field.name = "sint32_value_with_default" 107 field.dtype = types_pb2.DT_INT32 108 field.value.int32_value.append(12) 109 test_case.sizes.append(0) 110 field = test_case.fields.add() 111 field.name = "uint32_value_with_default" 112 field.dtype = types_pb2.DT_UINT32 113 field.value.uint32_value.append(9) 114 test_case.sizes.append(0) 115 field = test_case.fields.add() 116 field.name = "fixed32_value_with_default" 117 field.dtype = types_pb2.DT_UINT32 118 field.value.uint32_value.append(7) 119 test_case.sizes.append(0) 120 field = test_case.fields.add() 121 field.name = "bool_value_with_default" 122 field.dtype = types_pb2.DT_BOOL 123 field.value.bool_value.append(True) 124 test_case.sizes.append(0) 125 field = test_case.fields.add() 126 field.name = "string_value_with_default" 127 field.dtype = types_pb2.DT_STRING 128 field.value.string_value.append("a") 129 test_case.sizes.append(0) 130 field = test_case.fields.add() 131 field.name = "bytes_value_with_default" 132 field.dtype = types_pb2.DT_STRING 133 field.value.string_value.append("a longer default string") 134 test_case.sizes.append(0) 135 field = test_case.fields.add() 136 field.name = "enum_value_with_default" 137 field.dtype = types_pb2.DT_INT32 138 field.value.enum_value.append(test_example_pb2.Color.GREEN) 139 return test_case 140 141 @staticmethod 142 def minmax_test_case(): 143 test_case = test_example_pb2.TestCase() 144 value = test_case.values.add() 145 value.double_value.append(-1.7976931348623158e+308) 146 value.double_value.append(2.2250738585072014e-308) 147 value.double_value.append(1.7976931348623158e+308) 148 value.float_value.append(-3.402823466e+38) 149 value.float_value.append(1.175494351e-38) 150 value.float_value.append(3.402823466e+38) 151 value.int64_value.append(-9223372036854775808) 152 value.int64_value.append(9223372036854775807) 153 value.sfixed64_value.append(-9223372036854775808) 154 value.sfixed64_value.append(9223372036854775807) 155 value.sint64_value.append(-9223372036854775808) 156 value.sint64_value.append(9223372036854775807) 157 value.uint64_value.append(0) 158 value.uint64_value.append(18446744073709551615) 159 value.fixed64_value.append(0) 160 value.fixed64_value.append(18446744073709551615) 161 value.int32_value.append(-2147483648) 162 value.int32_value.append(2147483647) 163 value.sfixed32_value.append(-2147483648) 164 value.sfixed32_value.append(2147483647) 165 value.sint32_value.append(-2147483648) 166 value.sint32_value.append(2147483647) 167 value.uint32_value.append(0) 168 value.uint32_value.append(4294967295) 169 value.fixed32_value.append(0) 170 value.fixed32_value.append(4294967295) 171 value.bool_value.append(False) 172 value.bool_value.append(True) 173 value.string_value.append("") 174 value.string_value.append("I refer to the infinite.") 175 test_case.shapes.append(1) 176 test_case.sizes.append(3) 177 field = test_case.fields.add() 178 field.name = "double_value" 179 field.dtype = types_pb2.DT_DOUBLE 180 field.value.double_value.append(-1.7976931348623158e+308) 181 field.value.double_value.append(2.2250738585072014e-308) 182 field.value.double_value.append(1.7976931348623158e+308) 183 test_case.sizes.append(3) 184 field = test_case.fields.add() 185 field.name = "float_value" 186 field.dtype = types_pb2.DT_FLOAT 187 field.value.float_value.append(-3.402823466e+38) 188 field.value.float_value.append(1.175494351e-38) 189 field.value.float_value.append(3.402823466e+38) 190 test_case.sizes.append(2) 191 field = test_case.fields.add() 192 field.name = "int64_value" 193 field.dtype = types_pb2.DT_INT64 194 field.value.int64_value.append(-9223372036854775808) 195 field.value.int64_value.append(9223372036854775807) 196 test_case.sizes.append(2) 197 field = test_case.fields.add() 198 field.name = "sfixed64_value" 199 field.dtype = types_pb2.DT_INT64 200 field.value.int64_value.append(-9223372036854775808) 201 field.value.int64_value.append(9223372036854775807) 202 test_case.sizes.append(2) 203 field = test_case.fields.add() 204 field.name = "sint64_value" 205 field.dtype = types_pb2.DT_INT64 206 field.value.int64_value.append(-9223372036854775808) 207 field.value.int64_value.append(9223372036854775807) 208 test_case.sizes.append(2) 209 field = test_case.fields.add() 210 field.name = "uint64_value" 211 field.dtype = types_pb2.DT_UINT64 212 field.value.uint64_value.append(0) 213 field.value.uint64_value.append(18446744073709551615) 214 test_case.sizes.append(2) 215 field = test_case.fields.add() 216 field.name = "fixed64_value" 217 field.dtype = types_pb2.DT_UINT64 218 field.value.uint64_value.append(0) 219 field.value.uint64_value.append(18446744073709551615) 220 test_case.sizes.append(2) 221 field = test_case.fields.add() 222 field.name = "int32_value" 223 field.dtype = types_pb2.DT_INT32 224 field.value.int32_value.append(-2147483648) 225 field.value.int32_value.append(2147483647) 226 test_case.sizes.append(2) 227 field = test_case.fields.add() 228 field.name = "sfixed32_value" 229 field.dtype = types_pb2.DT_INT32 230 field.value.int32_value.append(-2147483648) 231 field.value.int32_value.append(2147483647) 232 test_case.sizes.append(2) 233 field = test_case.fields.add() 234 field.name = "sint32_value" 235 field.dtype = types_pb2.DT_INT32 236 field.value.int32_value.append(-2147483648) 237 field.value.int32_value.append(2147483647) 238 test_case.sizes.append(2) 239 field = test_case.fields.add() 240 field.name = "uint32_value" 241 field.dtype = types_pb2.DT_UINT32 242 field.value.uint32_value.append(0) 243 field.value.uint32_value.append(4294967295) 244 test_case.sizes.append(2) 245 field = test_case.fields.add() 246 field.name = "fixed32_value" 247 field.dtype = types_pb2.DT_UINT32 248 field.value.uint32_value.append(0) 249 field.value.uint32_value.append(4294967295) 250 test_case.sizes.append(2) 251 field = test_case.fields.add() 252 field.name = "bool_value" 253 field.dtype = types_pb2.DT_BOOL 254 field.value.bool_value.append(False) 255 field.value.bool_value.append(True) 256 test_case.sizes.append(2) 257 field = test_case.fields.add() 258 field.name = "string_value" 259 field.dtype = types_pb2.DT_STRING 260 field.value.string_value.append("") 261 field.value.string_value.append("I refer to the infinite.") 262 return test_case 263 264 @staticmethod 265 def nested_test_case(): 266 test_case = test_example_pb2.TestCase() 267 value = test_case.values.add() 268 message_value = value.message_value.add() 269 message_value.double_value = 23.5 270 test_case.shapes.append(1) 271 test_case.sizes.append(1) 272 field = test_case.fields.add() 273 field.name = "message_value" 274 field.dtype = types_pb2.DT_STRING 275 message_value = field.value.message_value.add() 276 message_value.double_value = 23.5 277 return test_case 278 279 @staticmethod 280 def optional_test_case(): 281 test_case = test_example_pb2.TestCase() 282 value = test_case.values.add() 283 value.bool_value.append(True) 284 test_case.shapes.append(1) 285 test_case.sizes.append(1) 286 field = test_case.fields.add() 287 field.name = "bool_value" 288 field.dtype = types_pb2.DT_BOOL 289 field.value.bool_value.append(True) 290 test_case.sizes.append(0) 291 field = test_case.fields.add() 292 field.name = "double_value" 293 field.dtype = types_pb2.DT_DOUBLE 294 field.value.double_value.append(0.0) 295 return test_case 296 297 @staticmethod 298 def promote_test_case(): 299 test_case = test_example_pb2.TestCase() 300 value = test_case.values.add() 301 value.sint32_value.append(2147483647) 302 value.sfixed32_value.append(2147483647) 303 value.int32_value.append(2147483647) 304 value.fixed32_value.append(4294967295) 305 value.uint32_value.append(4294967295) 306 test_case.shapes.append(1) 307 test_case.sizes.append(1) 308 field = test_case.fields.add() 309 field.name = "sint32_value" 310 field.dtype = types_pb2.DT_INT64 311 field.value.int64_value.append(2147483647) 312 test_case.sizes.append(1) 313 field = test_case.fields.add() 314 field.name = "sfixed32_value" 315 field.dtype = types_pb2.DT_INT64 316 field.value.int64_value.append(2147483647) 317 test_case.sizes.append(1) 318 field = test_case.fields.add() 319 field.name = "int32_value" 320 field.dtype = types_pb2.DT_INT64 321 field.value.int64_value.append(2147483647) 322 test_case.sizes.append(1) 323 field = test_case.fields.add() 324 field.name = "fixed32_value" 325 field.dtype = types_pb2.DT_UINT64 326 field.value.uint64_value.append(4294967295) 327 test_case.sizes.append(1) 328 field = test_case.fields.add() 329 field.name = "uint32_value" 330 field.dtype = types_pb2.DT_UINT64 331 field.value.uint64_value.append(4294967295) 332 return test_case 333 334 @staticmethod 335 def ragged_test_case(): 336 test_case = test_example_pb2.TestCase() 337 value = test_case.values.add() 338 value.double_value.append(23.5) 339 value.double_value.append(123.0) 340 value.bool_value.append(True) 341 value = test_case.values.add() 342 value.double_value.append(3.1) 343 value.bool_value.append(False) 344 test_case.shapes.append(2) 345 test_case.sizes.append(2) 346 test_case.sizes.append(1) 347 test_case.sizes.append(1) 348 test_case.sizes.append(1) 349 field = test_case.fields.add() 350 field.name = "double_value" 351 field.dtype = types_pb2.DT_DOUBLE 352 field.value.double_value.append(23.5) 353 field.value.double_value.append(123.0) 354 field.value.double_value.append(3.1) 355 field.value.double_value.append(0.0) 356 field = test_case.fields.add() 357 field.name = "bool_value" 358 field.dtype = types_pb2.DT_BOOL 359 field.value.bool_value.append(True) 360 field.value.bool_value.append(False) 361 return test_case 362 363 @staticmethod 364 def shaped_batch_test_case(): 365 test_case = test_example_pb2.TestCase() 366 value = test_case.values.add() 367 value.double_value.append(23.5) 368 value.bool_value.append(True) 369 value = test_case.values.add() 370 value.double_value.append(44.0) 371 value.bool_value.append(False) 372 value = test_case.values.add() 373 value.double_value.append(3.14159) 374 value.bool_value.append(True) 375 value = test_case.values.add() 376 value.double_value.append(1.414) 377 value.bool_value.append(True) 378 value = test_case.values.add() 379 value.double_value.append(-32.2) 380 value.bool_value.append(False) 381 value = test_case.values.add() 382 value.double_value.append(0.0001) 383 value.bool_value.append(True) 384 test_case.shapes.append(3) 385 test_case.shapes.append(2) 386 for _ in range(12): 387 test_case.sizes.append(1) 388 field = test_case.fields.add() 389 field.name = "double_value" 390 field.dtype = types_pb2.DT_DOUBLE 391 field.value.double_value.append(23.5) 392 field.value.double_value.append(44.0) 393 field.value.double_value.append(3.14159) 394 field.value.double_value.append(1.414) 395 field.value.double_value.append(-32.2) 396 field.value.double_value.append(0.0001) 397 field = test_case.fields.add() 398 field.name = "bool_value" 399 field.dtype = types_pb2.DT_BOOL 400 field.value.bool_value.append(True) 401 field.value.bool_value.append(False) 402 field.value.bool_value.append(True) 403 field.value.bool_value.append(True) 404 field.value.bool_value.append(False) 405 field.value.bool_value.append(True) 406 return test_case 407 408 @staticmethod 409 def extension_test_case(): 410 test_case = test_example_pb2.TestCase() 411 value = test_case.values.add() 412 message_value = value.Extensions[test_example_pb2.ext_value].add() 413 message_value.double_value = 23.5 414 test_case.shapes.append(1) 415 test_case.sizes.append(1) 416 field = test_case.fields.add() 417 field.name = test_example_pb2.ext_value.full_name 418 field.dtype = types_pb2.DT_STRING 419 message_value = field.value.Extensions[test_example_pb2.ext_value].add() 420 message_value.double_value = 23.5 421 return test_case 422 423 @staticmethod 424 def simple_test_case(): 425 test_case = test_example_pb2.TestCase() 426 value = test_case.values.add() 427 value.double_value.append(23.5) 428 value.bool_value.append(True) 429 value.enum_value.append(test_example_pb2.Color.INDIGO) 430 test_case.shapes.append(1) 431 test_case.sizes.append(1) 432 field = test_case.fields.add() 433 field.name = "double_value" 434 field.dtype = types_pb2.DT_DOUBLE 435 field.value.double_value.append(23.5) 436 test_case.sizes.append(1) 437 field = test_case.fields.add() 438 field.name = "bool_value" 439 field.dtype = types_pb2.DT_BOOL 440 field.value.bool_value.append(True) 441 test_case.sizes.append(1) 442 field = test_case.fields.add() 443 field.name = "enum_value" 444 field.dtype = types_pb2.DT_INT32 445 field.value.enum_value.append(test_example_pb2.Color.INDIGO) 446 return test_case 447