1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the 'License'); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an 'AS IS' BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Keras generic Python utils.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from functools import partial 22 23import numpy as np 24 25from tensorflow.python import keras 26from tensorflow.python.keras.utils import generic_utils 27from tensorflow.python.platform import test 28 29 30class HasArgTest(test.TestCase): 31 32 def test_has_arg(self): 33 34 def f_x(x): 35 return x 36 37 def f_x_args(x, *args): 38 _ = args 39 return x 40 41 def f_x_kwargs(x, **kwargs): 42 _ = kwargs 43 return x 44 45 def f(a, b, c): 46 return a + b + c 47 48 partial_f = partial(f, b=1) 49 50 self.assertTrue(keras.utils.generic_utils.has_arg( 51 f_x, 'x', accept_all=False)) 52 self.assertFalse(keras.utils.generic_utils.has_arg( 53 f_x, 'y', accept_all=False)) 54 self.assertTrue(keras.utils.generic_utils.has_arg( 55 f_x_args, 'x', accept_all=False)) 56 self.assertFalse(keras.utils.generic_utils.has_arg( 57 f_x_args, 'y', accept_all=False)) 58 self.assertTrue(keras.utils.generic_utils.has_arg( 59 f_x_kwargs, 'x', accept_all=False)) 60 self.assertFalse(keras.utils.generic_utils.has_arg( 61 f_x_kwargs, 'y', accept_all=False)) 62 self.assertTrue(keras.utils.generic_utils.has_arg( 63 f_x_kwargs, 'y', accept_all=True)) 64 self.assertTrue( 65 keras.utils.generic_utils.has_arg(partial_f, 'c', accept_all=True)) 66 67 68class TestCustomObjectScope(test.TestCase): 69 70 def test_custom_object_scope(self): 71 72 def custom_fn(): 73 pass 74 75 class CustomClass(object): 76 pass 77 78 with keras.utils.generic_utils.custom_object_scope( 79 {'CustomClass': CustomClass, 'custom_fn': custom_fn}): 80 act = keras.activations.get('custom_fn') 81 self.assertEqual(act, custom_fn) 82 cl = keras.regularizers.get('CustomClass') 83 self.assertEqual(cl.__class__, CustomClass) 84 85 86class SerializeKerasObjectTest(test.TestCase): 87 88 def test_serialize_none(self): 89 serialized = keras.utils.generic_utils.serialize_keras_object(None) 90 self.assertEqual(serialized, None) 91 deserialized = keras.utils.generic_utils.deserialize_keras_object( 92 serialized) 93 self.assertEqual(deserialized, None) 94 95 def test_serialize_custom_class_with_default_name(self): 96 97 @keras.utils.generic_utils.register_keras_serializable() 98 class TestClass(object): 99 100 def __init__(self, value): 101 self._value = value 102 103 def get_config(self): 104 return {'value': self._value} 105 106 serialized_name = 'Custom>TestClass' 107 inst = TestClass(value=10) 108 class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[TestClass] 109 self.assertEqual(serialized_name, class_name) 110 config = keras.utils.generic_utils.serialize_keras_object(inst) 111 self.assertEqual(class_name, config['class_name']) 112 new_inst = keras.utils.generic_utils.deserialize_keras_object(config) 113 self.assertIsNot(inst, new_inst) 114 self.assertIsInstance(new_inst, TestClass) 115 self.assertEqual(10, new_inst._value) 116 117 # Make sure registering a new class with same name will fail. 118 with self.assertRaisesRegex(ValueError, '.*has already been registered.*'): 119 @keras.utils.generic_utils.register_keras_serializable() # pylint: disable=function-redefined 120 class TestClass(object): 121 122 def __init__(self, value): 123 self._value = value 124 125 def get_config(self): 126 return {'value': self._value} 127 128 def test_serialize_custom_class_with_custom_name(self): 129 130 @keras.utils.generic_utils.register_keras_serializable( 131 'TestPackage', 'CustomName') 132 class OtherTestClass(object): 133 134 def __init__(self, val): 135 self._val = val 136 137 def get_config(self): 138 return {'val': self._val} 139 140 serialized_name = 'TestPackage>CustomName' 141 inst = OtherTestClass(val=5) 142 class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[OtherTestClass] 143 self.assertEqual(serialized_name, class_name) 144 fn_class_name = keras.utils.generic_utils.get_registered_name( 145 OtherTestClass) 146 self.assertEqual(fn_class_name, class_name) 147 148 cls = keras.utils.generic_utils.get_registered_object(fn_class_name) 149 self.assertEqual(OtherTestClass, cls) 150 151 config = keras.utils.generic_utils.serialize_keras_object(inst) 152 self.assertEqual(class_name, config['class_name']) 153 new_inst = keras.utils.generic_utils.deserialize_keras_object(config) 154 self.assertIsNot(inst, new_inst) 155 self.assertIsInstance(new_inst, OtherTestClass) 156 self.assertEqual(5, new_inst._val) 157 158 def test_serialize_custom_function(self): 159 160 @keras.utils.generic_utils.register_keras_serializable() 161 def my_fn(): 162 return 42 163 164 serialized_name = 'Custom>my_fn' 165 class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[my_fn] 166 self.assertEqual(serialized_name, class_name) 167 fn_class_name = keras.utils.generic_utils.get_registered_name(my_fn) 168 self.assertEqual(fn_class_name, class_name) 169 170 config = keras.utils.generic_utils.serialize_keras_object(my_fn) 171 self.assertEqual(class_name, config) 172 fn = keras.utils.generic_utils.deserialize_keras_object(config) 173 self.assertEqual(42, fn()) 174 175 fn_2 = keras.utils.generic_utils.get_registered_object(fn_class_name) 176 self.assertEqual(42, fn_2()) 177 178 def test_serialize_custom_class_without_get_config_fails(self): 179 180 with self.assertRaisesRegex( 181 ValueError, 'Cannot register a class that does ' 182 'not have a get_config.*'): 183 184 @keras.utils.generic_utils.register_keras_serializable( # pylint: disable=unused-variable 185 'TestPackage', 'TestClass') 186 class TestClass(object): 187 188 def __init__(self, value): 189 self._value = value 190 191 def test_serializable_object(self): 192 193 class SerializableInt(int): 194 """A serializable object to pass out of a test layer's config.""" 195 196 def __new__(cls, value): 197 return int.__new__(cls, value) 198 199 def get_config(self): 200 return {'value': int(self)} 201 202 @classmethod 203 def from_config(cls, config): 204 return cls(**config) 205 206 layer = keras.layers.Dense( 207 SerializableInt(3), 208 activation='relu', 209 kernel_initializer='ones', 210 bias_regularizer='l2') 211 config = keras.layers.serialize(layer) 212 new_layer = keras.layers.deserialize( 213 config, custom_objects={'SerializableInt': SerializableInt}) 214 self.assertEqual(new_layer.activation, keras.activations.relu) 215 self.assertEqual(new_layer.bias_regularizer.__class__, 216 keras.regularizers.L2) 217 self.assertEqual(new_layer.units.__class__, SerializableInt) 218 self.assertEqual(new_layer.units, 3) 219 220 def test_nested_serializable_object(self): 221 class SerializableInt(int): 222 """A serializable object to pass out of a test layer's config.""" 223 224 def __new__(cls, value): 225 return int.__new__(cls, value) 226 227 def get_config(self): 228 return {'value': int(self)} 229 230 @classmethod 231 def from_config(cls, config): 232 return cls(**config) 233 234 class SerializableNestedInt(int): 235 """A serializable object containing another serializable object.""" 236 237 def __new__(cls, value, int_obj): 238 obj = int.__new__(cls, value) 239 obj.int_obj = int_obj 240 return obj 241 242 def get_config(self): 243 return {'value': int(self), 'int_obj': self.int_obj} 244 245 @classmethod 246 def from_config(cls, config): 247 return cls(**config) 248 249 nested_int = SerializableInt(4) 250 layer = keras.layers.Dense( 251 SerializableNestedInt(3, nested_int), 252 name='SerializableNestedInt', 253 activation='relu', 254 kernel_initializer='ones', 255 bias_regularizer='l2') 256 config = keras.layers.serialize(layer) 257 new_layer = keras.layers.deserialize( 258 config, 259 custom_objects={ 260 'SerializableInt': SerializableInt, 261 'SerializableNestedInt': SerializableNestedInt 262 }) 263 # Make sure the string field doesn't get convert to custom object, even 264 # they have same value. 265 self.assertEqual(new_layer.name, 'SerializableNestedInt') 266 self.assertEqual(new_layer.activation, keras.activations.relu) 267 self.assertEqual(new_layer.bias_regularizer.__class__, 268 keras.regularizers.L2) 269 self.assertEqual(new_layer.units.__class__, SerializableNestedInt) 270 self.assertEqual(new_layer.units, 3) 271 self.assertEqual(new_layer.units.int_obj.__class__, SerializableInt) 272 self.assertEqual(new_layer.units.int_obj, 4) 273 274 def test_nested_serializable_fn(self): 275 276 def serializable_fn(x): 277 """A serializable function to pass out of a test layer's config.""" 278 return x 279 280 class SerializableNestedInt(int): 281 """A serializable object containing a serializable function.""" 282 283 def __new__(cls, value, fn): 284 obj = int.__new__(cls, value) 285 obj.fn = fn 286 return obj 287 288 def get_config(self): 289 return {'value': int(self), 'fn': self.fn} 290 291 @classmethod 292 def from_config(cls, config): 293 return cls(**config) 294 295 layer = keras.layers.Dense( 296 SerializableNestedInt(3, serializable_fn), 297 activation='relu', 298 kernel_initializer='ones', 299 bias_regularizer='l2') 300 config = keras.layers.serialize(layer) 301 new_layer = keras.layers.deserialize( 302 config, 303 custom_objects={ 304 'serializable_fn': serializable_fn, 305 'SerializableNestedInt': SerializableNestedInt 306 }) 307 self.assertEqual(new_layer.activation, keras.activations.relu) 308 self.assertIsInstance(new_layer.bias_regularizer, keras.regularizers.L2) 309 self.assertIsInstance(new_layer.units, SerializableNestedInt) 310 self.assertEqual(new_layer.units, 3) 311 self.assertIs(new_layer.units.fn, serializable_fn) 312 313 def test_serializable_with_old_config(self): 314 # model config generated by tf-1.2.1 315 old_model_config = { 316 'class_name': 317 'Sequential', 318 'config': [{ 319 'class_name': 'Dense', 320 'config': { 321 'name': 'dense_1', 322 'trainable': True, 323 'batch_input_shape': [None, 784], 324 'dtype': 'float32', 325 'units': 32, 326 'activation': 'linear', 327 'use_bias': True, 328 'kernel_initializer': { 329 'class_name': 'Ones', 330 'config': { 331 'dtype': 'float32' 332 } 333 }, 334 'bias_initializer': { 335 'class_name': 'Zeros', 336 'config': { 337 'dtype': 'float32' 338 } 339 }, 340 'kernel_regularizer': None, 341 'bias_regularizer': None, 342 'activity_regularizer': None, 343 'kernel_constraint': None, 344 'bias_constraint': None 345 } 346 }] 347 } 348 old_model = keras.utils.generic_utils.deserialize_keras_object( 349 old_model_config, module_objects={'Sequential': keras.Sequential}) 350 new_model = keras.Sequential([ 351 keras.layers.Dense(32, input_dim=784, kernel_initializer='Ones'), 352 ]) 353 input_data = np.random.normal(2, 1, (5, 784)) 354 output = old_model.predict(input_data) 355 expected_output = new_model.predict(input_data) 356 self.assertAllEqual(output, expected_output) 357 358 def test_deserialize_unknown_object(self): 359 360 class CustomLayer(keras.layers.Layer): 361 pass 362 363 layer = CustomLayer() 364 config = keras.utils.generic_utils.serialize_keras_object(layer) 365 with self.assertRaisesRegexp(ValueError, 366 'passed to the `custom_objects` arg'): 367 keras.utils.generic_utils.deserialize_keras_object(config) 368 restored = keras.utils.generic_utils.deserialize_keras_object( 369 config, custom_objects={'CustomLayer': CustomLayer}) 370 self.assertIsInstance(restored, CustomLayer) 371 372 373class SliceArraysTest(test.TestCase): 374 375 def test_slice_arrays(self): 376 input_a = list([1, 2, 3]) 377 self.assertEqual( 378 keras.utils.generic_utils.slice_arrays(input_a, start=0), 379 [None, None, None]) 380 self.assertEqual( 381 keras.utils.generic_utils.slice_arrays(input_a, stop=3), 382 [None, None, None]) 383 self.assertEqual( 384 keras.utils.generic_utils.slice_arrays(input_a, start=0, stop=1), 385 [None, None, None]) 386 387 388# object() alone isn't compatible with WeakKeyDictionary, which we use to 389# track shared configs. 390class MaybeSharedObject(object): 391 pass 392 393 394class SharedObjectScopeTest(test.TestCase): 395 396 def test_shared_object_saving_scope_single_object_doesnt_export_id(self): 397 with generic_utils.SharedObjectSavingScope() as scope: 398 single_object = MaybeSharedObject() 399 self.assertIsNone(scope.get_config(single_object)) 400 single_object_config = scope.create_config({}, single_object) 401 self.assertIsNotNone(single_object_config) 402 self.assertNotIn(generic_utils.SHARED_OBJECT_KEY, 403 single_object_config) 404 405 def test_shared_object_saving_scope_shared_object_exports_id(self): 406 with generic_utils.SharedObjectSavingScope() as scope: 407 shared_object = MaybeSharedObject() 408 self.assertIsNone(scope.get_config(shared_object)) 409 scope.create_config({}, shared_object) 410 first_object_config = scope.get_config(shared_object) 411 second_object_config = scope.get_config(shared_object) 412 self.assertIn(generic_utils.SHARED_OBJECT_KEY, 413 first_object_config) 414 self.assertIn(generic_utils.SHARED_OBJECT_KEY, 415 second_object_config) 416 self.assertIs(first_object_config, second_object_config) 417 418 def test_shared_object_loading_scope_noop(self): 419 # Test that, without a context manager scope, adding configs will do 420 # nothing. 421 obj_id = 1 422 obj = MaybeSharedObject() 423 generic_utils._shared_object_loading_scope().set(obj_id, obj) 424 self.assertIsNone(generic_utils._shared_object_loading_scope().get(obj_id)) 425 426 def test_shared_object_loading_scope_returns_shared_obj(self): 427 obj_id = 1 428 obj = MaybeSharedObject() 429 with generic_utils.SharedObjectLoadingScope() as scope: 430 scope.set(obj_id, obj) 431 self.assertIs(scope.get(obj_id), obj) 432 433 def test_nested_shared_object_saving_scopes(self): 434 my_obj = MaybeSharedObject() 435 with generic_utils.SharedObjectSavingScope() as scope_1: 436 scope_1.create_config({}, my_obj) 437 with generic_utils.SharedObjectSavingScope() as scope_2: 438 # Nesting saving scopes should return the original scope and should 439 # not clear any objects we're tracking. 440 self.assertIs(scope_1, scope_2) 441 self.assertIsNotNone(scope_2.get_config(my_obj)) 442 self.assertIsNotNone(scope_1.get_config(my_obj)) 443 self.assertIsNone(generic_utils._shared_object_saving_scope()) 444 445 446if __name__ == '__main__': 447 test.main() 448