1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Keras weights constraints.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import math 22 23import numpy as np 24 25from tensorflow.python import keras 26from tensorflow.python.framework import test_util 27from tensorflow.python.platform import test 28 29 30def get_test_values(): 31 return [0.1, 0.5, 3, 8, 1e-7] 32 33 34def get_example_array(): 35 np.random.seed(3537) 36 example_array = np.random.random((100, 100)) * 100. - 50. 37 example_array[0, 0] = 0. # 0 could possibly cause trouble 38 return example_array 39 40 41def get_example_kernel(width): 42 np.random.seed(3537) 43 example_array = np.random.rand(width, width, 2, 2) 44 return example_array 45 46 47@test_util.run_all_in_graph_and_eager_modes 48class KerasConstraintsTest(test.TestCase): 49 50 def test_serialization(self): 51 all_activations = ['max_norm', 'non_neg', 52 'unit_norm', 'min_max_norm'] 53 for name in all_activations: 54 fn = keras.constraints.get(name) 55 ref_fn = getattr(keras.constraints, name)() 56 assert fn.__class__ == ref_fn.__class__ 57 config = keras.constraints.serialize(fn) 58 fn = keras.constraints.deserialize(config) 59 assert fn.__class__ == ref_fn.__class__ 60 61 def test_max_norm(self): 62 array = get_example_array() 63 for m in get_test_values(): 64 norm_instance = keras.constraints.max_norm(m) 65 normed = norm_instance(keras.backend.variable(array)) 66 assert np.all(keras.backend.eval(normed) < m) 67 68 # a more explicit example 69 norm_instance = keras.constraints.max_norm(2.0) 70 x = np.array([[0, 0, 0], [1.0, 0, 0], [3, 0, 0], [3, 3, 3]]).T 71 x_normed_target = np.array( 72 [[0, 0, 0], [1.0, 0, 0], [2.0, 0, 0], 73 [2. / np.sqrt(3), 2. / np.sqrt(3), 2. / np.sqrt(3)]]).T 74 x_normed_actual = keras.backend.eval( 75 norm_instance(keras.backend.variable(x))) 76 self.assertAllClose(x_normed_actual, x_normed_target, rtol=1e-05) 77 78 def test_non_neg(self): 79 non_neg_instance = keras.constraints.non_neg() 80 normed = non_neg_instance(keras.backend.variable(get_example_array())) 81 assert np.all(np.min(keras.backend.eval(normed), axis=1) == 0.) 82 83 def test_unit_norm(self): 84 unit_norm_instance = keras.constraints.unit_norm() 85 normalized = unit_norm_instance(keras.backend.variable(get_example_array())) 86 norm_of_normalized = np.sqrt( 87 np.sum(keras.backend.eval(normalized)**2, axis=0)) 88 # In the unit norm constraint, it should be equal to 1. 89 difference = norm_of_normalized - 1. 90 largest_difference = np.max(np.abs(difference)) 91 assert np.abs(largest_difference) < 10e-5 92 93 def test_min_max_norm(self): 94 array = get_example_array() 95 for m in get_test_values(): 96 norm_instance = keras.constraints.min_max_norm( 97 min_value=m, max_value=m * 2) 98 normed = norm_instance(keras.backend.variable(array)) 99 value = keras.backend.eval(normed) 100 l2 = np.sqrt(np.sum(np.square(value), axis=0)) 101 assert not l2[l2 < m] 102 assert not l2[l2 > m * 2 + 1e-5] 103 104 def test_conv2d_radial_constraint(self): 105 for width in (3, 4, 5, 6): 106 array = get_example_kernel(width) 107 norm_instance = keras.constraints.radial_constraint() 108 normed = norm_instance(keras.backend.variable(array)) 109 value = keras.backend.eval(normed) 110 assert np.all(value.shape == array.shape) 111 assert np.all(value[0:, 0, 0, 0] == value[-1:, 0, 0, 0]) 112 assert len(set(value[..., 0, 0].flatten())) == math.ceil(float(width) / 2) 113 114 115if __name__ == '__main__': 116 test.main() 117