• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Keras generic Python utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from functools import partial
22
23import numpy as np
24
25from tensorflow.python import keras
26from tensorflow.python.keras.utils import generic_utils
27from tensorflow.python.platform import test
28
29
30class HasArgTest(test.TestCase):
31
32  def test_has_arg(self):
33
34    def f_x(x):
35      return x
36
37    def f_x_args(x, *args):
38      _ = args
39      return x
40
41    def f_x_kwargs(x, **kwargs):
42      _ = kwargs
43      return x
44
45    def f(a, b, c):
46      return a + b + c
47
48    partial_f = partial(f, b=1)
49
50    self.assertTrue(keras.utils.generic_utils.has_arg(
51        f_x, 'x', accept_all=False))
52    self.assertFalse(keras.utils.generic_utils.has_arg(
53        f_x, 'y', accept_all=False))
54    self.assertTrue(keras.utils.generic_utils.has_arg(
55        f_x_args, 'x', accept_all=False))
56    self.assertFalse(keras.utils.generic_utils.has_arg(
57        f_x_args, 'y', accept_all=False))
58    self.assertTrue(keras.utils.generic_utils.has_arg(
59        f_x_kwargs, 'x', accept_all=False))
60    self.assertFalse(keras.utils.generic_utils.has_arg(
61        f_x_kwargs, 'y', accept_all=False))
62    self.assertTrue(keras.utils.generic_utils.has_arg(
63        f_x_kwargs, 'y', accept_all=True))
64    self.assertTrue(
65        keras.utils.generic_utils.has_arg(partial_f, 'c', accept_all=True))
66
67
68class TestCustomObjectScope(test.TestCase):
69
70  def test_custom_object_scope(self):
71
72    def custom_fn():
73      pass
74
75    class CustomClass(object):
76      pass
77
78    with keras.utils.generic_utils.custom_object_scope(
79        {'CustomClass': CustomClass, 'custom_fn': custom_fn}):
80      act = keras.activations.get('custom_fn')
81      self.assertEqual(act, custom_fn)
82      cl = keras.regularizers.get('CustomClass')
83      self.assertEqual(cl.__class__, CustomClass)
84
85
86class SerializeKerasObjectTest(test.TestCase):
87
88  def test_serialize_none(self):
89    serialized = keras.utils.generic_utils.serialize_keras_object(None)
90    self.assertEqual(serialized, None)
91    deserialized = keras.utils.generic_utils.deserialize_keras_object(
92        serialized)
93    self.assertEqual(deserialized, None)
94
95  def test_serialize_custom_class_with_default_name(self):
96
97    @keras.utils.generic_utils.register_keras_serializable()
98    class TestClass(object):
99
100      def __init__(self, value):
101        self._value = value
102
103      def get_config(self):
104        return {'value': self._value}
105
106    serialized_name = 'Custom>TestClass'
107    inst = TestClass(value=10)
108    class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[TestClass]
109    self.assertEqual(serialized_name, class_name)
110    config = keras.utils.generic_utils.serialize_keras_object(inst)
111    self.assertEqual(class_name, config['class_name'])
112    new_inst = keras.utils.generic_utils.deserialize_keras_object(config)
113    self.assertIsNot(inst, new_inst)
114    self.assertIsInstance(new_inst, TestClass)
115    self.assertEqual(10, new_inst._value)
116
117    # Make sure registering a new class with same name will fail.
118    with self.assertRaisesRegex(ValueError, '.*has already been registered.*'):
119      @keras.utils.generic_utils.register_keras_serializable()  # pylint: disable=function-redefined
120      class TestClass(object):
121
122        def __init__(self, value):
123          self._value = value
124
125        def get_config(self):
126          return {'value': self._value}
127
128  def test_serialize_custom_class_with_custom_name(self):
129
130    @keras.utils.generic_utils.register_keras_serializable(
131        'TestPackage', 'CustomName')
132    class OtherTestClass(object):
133
134      def __init__(self, val):
135        self._val = val
136
137      def get_config(self):
138        return {'val': self._val}
139
140    serialized_name = 'TestPackage>CustomName'
141    inst = OtherTestClass(val=5)
142    class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[OtherTestClass]
143    self.assertEqual(serialized_name, class_name)
144    fn_class_name = keras.utils.generic_utils.get_registered_name(
145        OtherTestClass)
146    self.assertEqual(fn_class_name, class_name)
147
148    cls = keras.utils.generic_utils.get_registered_object(fn_class_name)
149    self.assertEqual(OtherTestClass, cls)
150
151    config = keras.utils.generic_utils.serialize_keras_object(inst)
152    self.assertEqual(class_name, config['class_name'])
153    new_inst = keras.utils.generic_utils.deserialize_keras_object(config)
154    self.assertIsNot(inst, new_inst)
155    self.assertIsInstance(new_inst, OtherTestClass)
156    self.assertEqual(5, new_inst._val)
157
158  def test_serialize_custom_function(self):
159
160    @keras.utils.generic_utils.register_keras_serializable()
161    def my_fn():
162      return 42
163
164    serialized_name = 'Custom>my_fn'
165    class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[my_fn]
166    self.assertEqual(serialized_name, class_name)
167    fn_class_name = keras.utils.generic_utils.get_registered_name(my_fn)
168    self.assertEqual(fn_class_name, class_name)
169
170    config = keras.utils.generic_utils.serialize_keras_object(my_fn)
171    self.assertEqual(class_name, config)
172    fn = keras.utils.generic_utils.deserialize_keras_object(config)
173    self.assertEqual(42, fn())
174
175    fn_2 = keras.utils.generic_utils.get_registered_object(fn_class_name)
176    self.assertEqual(42, fn_2())
177
178  def test_serialize_custom_class_without_get_config_fails(self):
179
180    with self.assertRaisesRegex(
181        ValueError, 'Cannot register a class that does '
182        'not have a get_config.*'):
183
184      @keras.utils.generic_utils.register_keras_serializable(  # pylint: disable=unused-variable
185          'TestPackage', 'TestClass')
186      class TestClass(object):
187
188        def __init__(self, value):
189          self._value = value
190
191  def test_serializable_object(self):
192
193    class SerializableInt(int):
194      """A serializable object to pass out of a test layer's config."""
195
196      def __new__(cls, value):
197        return int.__new__(cls, value)
198
199      def get_config(self):
200        return {'value': int(self)}
201
202      @classmethod
203      def from_config(cls, config):
204        return cls(**config)
205
206    layer = keras.layers.Dense(
207        SerializableInt(3),
208        activation='relu',
209        kernel_initializer='ones',
210        bias_regularizer='l2')
211    config = keras.layers.serialize(layer)
212    new_layer = keras.layers.deserialize(
213        config, custom_objects={'SerializableInt': SerializableInt})
214    self.assertEqual(new_layer.activation, keras.activations.relu)
215    self.assertEqual(new_layer.bias_regularizer.__class__,
216                     keras.regularizers.L2)
217    self.assertEqual(new_layer.units.__class__, SerializableInt)
218    self.assertEqual(new_layer.units, 3)
219
220  def test_nested_serializable_object(self):
221    class SerializableInt(int):
222      """A serializable object to pass out of a test layer's config."""
223
224      def __new__(cls, value):
225        return int.__new__(cls, value)
226
227      def get_config(self):
228        return {'value': int(self)}
229
230      @classmethod
231      def from_config(cls, config):
232        return cls(**config)
233
234    class SerializableNestedInt(int):
235      """A serializable object containing another serializable object."""
236
237      def __new__(cls, value, int_obj):
238        obj = int.__new__(cls, value)
239        obj.int_obj = int_obj
240        return obj
241
242      def get_config(self):
243        return {'value': int(self), 'int_obj': self.int_obj}
244
245      @classmethod
246      def from_config(cls, config):
247        return cls(**config)
248
249    nested_int = SerializableInt(4)
250    layer = keras.layers.Dense(
251        SerializableNestedInt(3, nested_int),
252        name='SerializableNestedInt',
253        activation='relu',
254        kernel_initializer='ones',
255        bias_regularizer='l2')
256    config = keras.layers.serialize(layer)
257    new_layer = keras.layers.deserialize(
258        config,
259        custom_objects={
260            'SerializableInt': SerializableInt,
261            'SerializableNestedInt': SerializableNestedInt
262        })
263    # Make sure the string field doesn't get convert to custom object, even
264    # they have same value.
265    self.assertEqual(new_layer.name, 'SerializableNestedInt')
266    self.assertEqual(new_layer.activation, keras.activations.relu)
267    self.assertEqual(new_layer.bias_regularizer.__class__,
268                     keras.regularizers.L2)
269    self.assertEqual(new_layer.units.__class__, SerializableNestedInt)
270    self.assertEqual(new_layer.units, 3)
271    self.assertEqual(new_layer.units.int_obj.__class__, SerializableInt)
272    self.assertEqual(new_layer.units.int_obj, 4)
273
274  def test_nested_serializable_fn(self):
275
276    def serializable_fn(x):
277      """A serializable function to pass out of a test layer's config."""
278      return x
279
280    class SerializableNestedInt(int):
281      """A serializable object containing a serializable function."""
282
283      def __new__(cls, value, fn):
284        obj = int.__new__(cls, value)
285        obj.fn = fn
286        return obj
287
288      def get_config(self):
289        return {'value': int(self), 'fn': self.fn}
290
291      @classmethod
292      def from_config(cls, config):
293        return cls(**config)
294
295    layer = keras.layers.Dense(
296        SerializableNestedInt(3, serializable_fn),
297        activation='relu',
298        kernel_initializer='ones',
299        bias_regularizer='l2')
300    config = keras.layers.serialize(layer)
301    new_layer = keras.layers.deserialize(
302        config,
303        custom_objects={
304            'serializable_fn': serializable_fn,
305            'SerializableNestedInt': SerializableNestedInt
306        })
307    self.assertEqual(new_layer.activation, keras.activations.relu)
308    self.assertIsInstance(new_layer.bias_regularizer, keras.regularizers.L2)
309    self.assertIsInstance(new_layer.units, SerializableNestedInt)
310    self.assertEqual(new_layer.units, 3)
311    self.assertIs(new_layer.units.fn, serializable_fn)
312
313  def test_serializable_with_old_config(self):
314    # model config generated by tf-1.2.1
315    old_model_config = {
316        'class_name':
317            'Sequential',
318        'config': [{
319            'class_name': 'Dense',
320            'config': {
321                'name': 'dense_1',
322                'trainable': True,
323                'batch_input_shape': [None, 784],
324                'dtype': 'float32',
325                'units': 32,
326                'activation': 'linear',
327                'use_bias': True,
328                'kernel_initializer': {
329                    'class_name': 'Ones',
330                    'config': {
331                        'dtype': 'float32'
332                    }
333                },
334                'bias_initializer': {
335                    'class_name': 'Zeros',
336                    'config': {
337                        'dtype': 'float32'
338                    }
339                },
340                'kernel_regularizer': None,
341                'bias_regularizer': None,
342                'activity_regularizer': None,
343                'kernel_constraint': None,
344                'bias_constraint': None
345            }
346        }]
347    }
348    old_model = keras.utils.generic_utils.deserialize_keras_object(
349        old_model_config, module_objects={'Sequential': keras.Sequential})
350    new_model = keras.Sequential([
351        keras.layers.Dense(32, input_dim=784, kernel_initializer='Ones'),
352    ])
353    input_data = np.random.normal(2, 1, (5, 784))
354    output = old_model.predict(input_data)
355    expected_output = new_model.predict(input_data)
356    self.assertAllEqual(output, expected_output)
357
358  def test_deserialize_unknown_object(self):
359
360    class CustomLayer(keras.layers.Layer):
361      pass
362
363    layer = CustomLayer()
364    config = keras.utils.generic_utils.serialize_keras_object(layer)
365    with self.assertRaisesRegexp(ValueError,
366                                 'passed to the `custom_objects` arg'):
367      keras.utils.generic_utils.deserialize_keras_object(config)
368    restored = keras.utils.generic_utils.deserialize_keras_object(
369        config, custom_objects={'CustomLayer': CustomLayer})
370    self.assertIsInstance(restored, CustomLayer)
371
372
373class SliceArraysTest(test.TestCase):
374
375  def test_slice_arrays(self):
376    input_a = list([1, 2, 3])
377    self.assertEqual(
378        keras.utils.generic_utils.slice_arrays(input_a, start=0),
379        [None, None, None])
380    self.assertEqual(
381        keras.utils.generic_utils.slice_arrays(input_a, stop=3),
382        [None, None, None])
383    self.assertEqual(
384        keras.utils.generic_utils.slice_arrays(input_a, start=0, stop=1),
385        [None, None, None])
386
387
388# object() alone isn't compatible with WeakKeyDictionary, which we use to
389# track shared configs.
390class MaybeSharedObject(object):
391  pass
392
393
394class SharedObjectScopeTest(test.TestCase):
395
396  def test_shared_object_saving_scope_single_object_doesnt_export_id(self):
397    with generic_utils.SharedObjectSavingScope() as scope:
398      single_object = MaybeSharedObject()
399      self.assertIsNone(scope.get_config(single_object))
400      single_object_config = scope.create_config({}, single_object)
401      self.assertIsNotNone(single_object_config)
402      self.assertNotIn(generic_utils.SHARED_OBJECT_KEY,
403                       single_object_config)
404
405  def test_shared_object_saving_scope_shared_object_exports_id(self):
406    with generic_utils.SharedObjectSavingScope() as scope:
407      shared_object = MaybeSharedObject()
408      self.assertIsNone(scope.get_config(shared_object))
409      scope.create_config({}, shared_object)
410      first_object_config = scope.get_config(shared_object)
411      second_object_config = scope.get_config(shared_object)
412      self.assertIn(generic_utils.SHARED_OBJECT_KEY,
413                    first_object_config)
414      self.assertIn(generic_utils.SHARED_OBJECT_KEY,
415                    second_object_config)
416      self.assertIs(first_object_config, second_object_config)
417
418  def test_shared_object_loading_scope_noop(self):
419    # Test that, without a context manager scope, adding configs will do
420    # nothing.
421    obj_id = 1
422    obj = MaybeSharedObject()
423    generic_utils._shared_object_loading_scope().set(obj_id, obj)
424    self.assertIsNone(generic_utils._shared_object_loading_scope().get(obj_id))
425
426  def test_shared_object_loading_scope_returns_shared_obj(self):
427    obj_id = 1
428    obj = MaybeSharedObject()
429    with generic_utils.SharedObjectLoadingScope() as scope:
430      scope.set(obj_id, obj)
431      self.assertIs(scope.get(obj_id), obj)
432
433  def test_nested_shared_object_saving_scopes(self):
434    my_obj = MaybeSharedObject()
435    with generic_utils.SharedObjectSavingScope() as scope_1:
436      scope_1.create_config({}, my_obj)
437      with generic_utils.SharedObjectSavingScope() as scope_2:
438        # Nesting saving scopes should return the original scope and should
439        # not clear any objects we're tracking.
440        self.assertIs(scope_1, scope_2)
441        self.assertIsNotNone(scope_2.get_config(my_obj))
442      self.assertIsNotNone(scope_1.get_config(my_obj))
443    self.assertIsNone(generic_utils._shared_object_saving_scope())
444
445
446if __name__ == '__main__':
447  test.main()
448