1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Keras subclassed layers utilizing desired user syntax.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python import keras 22from tensorflow.python.eager import context 23from tensorflow.python.framework import ops 24from tensorflow.python.keras import keras_parameterized 25from tensorflow.python.keras import testing_utils 26from tensorflow.python.keras.utils import tf_utils 27from tensorflow.python.ops import variables 28from tensorflow.python.platform import test 29 30 31@keras_parameterized.run_all_keras_modes 32@keras_parameterized.run_with_all_model_types 33class SubclassedLayersTest(keras_parameterized.TestCase): 34 35 def test_simple_build_with_constant(self): 36 37 class BuildConstantLayer(keras.layers.Layer): 38 39 def build(self, input_shape): 40 self.b = ops.convert_to_tensor(2.0) 41 42 def call(self, inputs): 43 return self.b * inputs 44 45 layer = BuildConstantLayer() 46 model = testing_utils.get_model_from_layers( 47 [layer, keras.layers.Dense(1)], input_shape=(1,)) 48 49 x = ops.convert_to_tensor([[3.0]]) 50 self.assertEqual( 51 tf_utils.is_symbolic_tensor(model(x)), not context.executing_eagerly()) 52 self.assertEqual( 53 tf_utils.is_symbolic_tensor(layer(x)), not context.executing_eagerly()) 54 self.assertAllClose(keras.backend.get_value(layer(x)), [[6.0]]) 55 56 def test_build_with_derived_constant(self): 57 58 class BuildDerivedConstantLayer(keras.layers.Layer): 59 60 def build(self, input_shape): 61 a = ops.convert_to_tensor(1.0) 62 b = 2.0 * a 63 self.variable = variables.Variable(b) 64 self.constant = ops.convert_to_tensor(self.variable) 65 66 def call(self, inputs): 67 return self.variable * self.constant * inputs 68 69 layer = BuildDerivedConstantLayer() 70 model = testing_utils.get_model_from_layers( 71 [layer, keras.layers.Dense(1)], input_shape=(1,)) 72 73 x = ops.convert_to_tensor([[3.0]]) 74 self.assertEqual( 75 tf_utils.is_symbolic_tensor(model(x)), not context.executing_eagerly()) 76 self.assertEqual( 77 tf_utils.is_symbolic_tensor(layer(x)), not context.executing_eagerly()) 78 self.assertAllClose(keras.backend.get_value(layer(x)), [[12.0]]) 79 80 81if __name__ == '__main__': 82 test.main() 83