• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Keras subclassed layers utilizing desired user syntax."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python import keras
22from tensorflow.python.eager import context
23from tensorflow.python.framework import ops
24from tensorflow.python.keras import keras_parameterized
25from tensorflow.python.keras import testing_utils
26from tensorflow.python.keras.utils import tf_utils
27from tensorflow.python.ops import variables
28from tensorflow.python.platform import test
29
30
31@keras_parameterized.run_all_keras_modes
32@keras_parameterized.run_with_all_model_types
33class SubclassedLayersTest(keras_parameterized.TestCase):
34
35  def test_simple_build_with_constant(self):
36
37    class BuildConstantLayer(keras.layers.Layer):
38
39      def build(self, input_shape):
40        self.b = ops.convert_to_tensor(2.0)
41
42      def call(self, inputs):
43        return self.b * inputs
44
45    layer = BuildConstantLayer()
46    model = testing_utils.get_model_from_layers(
47        [layer, keras.layers.Dense(1)], input_shape=(1,))
48
49    x = ops.convert_to_tensor([[3.0]])
50    self.assertEqual(
51        tf_utils.is_symbolic_tensor(model(x)), not context.executing_eagerly())
52    self.assertEqual(
53        tf_utils.is_symbolic_tensor(layer(x)), not context.executing_eagerly())
54    self.assertAllClose(keras.backend.get_value(layer(x)), [[6.0]])
55
56  def test_build_with_derived_constant(self):
57
58    class BuildDerivedConstantLayer(keras.layers.Layer):
59
60      def build(self, input_shape):
61        a = ops.convert_to_tensor(1.0)
62        b = 2.0 * a
63        self.variable = variables.Variable(b)
64        self.constant = ops.convert_to_tensor(self.variable)
65
66      def call(self, inputs):
67        return self.variable * self.constant * inputs
68
69    layer = BuildDerivedConstantLayer()
70    model = testing_utils.get_model_from_layers(
71        [layer, keras.layers.Dense(1)], input_shape=(1,))
72
73    x = ops.convert_to_tensor([[3.0]])
74    self.assertEqual(
75        tf_utils.is_symbolic_tensor(model(x)), not context.executing_eagerly())
76    self.assertEqual(
77        tf_utils.is_symbolic_tensor(layer(x)), not context.executing_eagerly())
78    self.assertAllClose(keras.backend.get_value(layer(x)), [[12.0]])
79
80
81if __name__ == '__main__':
82  test.main()
83