• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import numpy as np
17
18from tensorflow.python import keras
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import sparse_tensor
21from tensorflow.python.keras import backend
22from tensorflow.python.keras import combinations
23from tensorflow.python.keras import keras_parameterized
24from tensorflow.python.keras.engine import base_layer_utils
25from tensorflow.python.ops import lookup_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import sparse_ops
28from tensorflow.python.ops.ragged import ragged_factory_ops
29from tensorflow.python.ops.ragged import ragged_tensor
30from tensorflow.python.platform import test
31
32
33@combinations.generate(combinations.combine(mode=['graph', 'eager']))
34class TrackableWeightHandlerTest(keras_parameterized.TestCase):
35
36  def get_table_handler(self):
37    # Note: There is some repetition in these tests' setup. However, Tensorflow
38    # does not play nicely with a separate setUp() call (causing errors related
39    # to graph building), so we have to use a called setup instead of a setUp()
40    # call.
41    table = lookup_ops.MutableHashTable(
42        key_dtype=dtypes.string, value_dtype=dtypes.int32, default_value=0)
43    return base_layer_utils.TrackableWeightHandler(table)
44
45  def test_get_num_tensors(self):
46    table_handler = self.get_table_handler()
47    self.assertEqual(2, table_handler.num_tensors)
48
49  def test_get_and_set_weights(self):
50    table_handler = self.get_table_handler()
51
52    table_data = {b'a': 1, b'b': 2, b'c': 3}
53    table_handler.set_weights(
54        [list(table_data.keys()),
55         list(table_data.values())])
56    weights = backend.batch_get_value(table_handler.get_tensors())
57    weight_data = {key: value for key, value in zip(weights[0], weights[1])}
58    self.assertDictEqual(table_data, weight_data)
59
60  def test_get_and_set_weights_does_not_add_ops(self):
61    table_handler = self.get_table_handler()
62    table_data = {b'a': 1, b'b': 2, b'c': 3}
63    table_handler.set_weights(
64        [list(table_data.keys()),
65         list(table_data.values())])
66    _ = backend.batch_get_value(table_handler.get_tensors())
67    backend.get_session().graph.finalize()
68    table_handler.set_weights(
69        [list(table_data.keys()),
70         list(table_data.values())])
71    _ = backend.batch_get_value(table_handler.get_tensors())
72
73
74@combinations.generate(combinations.combine(mode=['eager']))
75class OpLayerTest(keras_parameterized.TestCase):
76
77  def test_tensor_op_layer(self):
78    int_values = keras.Input(shape=(2,), dtype=dtypes.int32)
79    float_values = math_ops.cast(int_values, dtypes.float32)
80    model = keras.Model(int_values, float_values)
81    model.compile(loss='mse')
82
83    input_data = np.array([[1, 2], [3, 4]], dtype=np.int32)
84    expected = [[1.0, 2.0], [3.0, 4.0]]
85    output = model.predict(input_data)
86    self.assertAllClose(expected, output)
87
88  def test_ragged_op_layer_keras_tensors(self):
89    int_values = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True)
90    float_values = math_ops.cast(int_values, dtypes.float32)
91    model = keras.Model(int_values, float_values)
92    model.compile(loss='mse')
93
94    input_data = ragged_factory_ops.constant(
95        [[1, 2], [3, 4]], dtype=np.int32)
96    expected = [[1.0, 2.0], [3.0, 4.0]]
97    output = model.predict(input_data)
98    self.assertIsInstance(output, ragged_tensor.RaggedTensor)
99    self.assertAllClose(expected, output)
100
101  def test_sparse_op_layer_keras_tensors(self):
102    int_values = keras.Input(shape=(None,), dtype=dtypes.int32, sparse=True)
103    float_values = math_ops.cast(int_values, dtypes.float32)
104    _ = keras.Model(int_values, float_values)
105    model = keras.Model(int_values, float_values)
106    model.compile(loss='mse')
107
108    input_data = sparse_ops.from_dense(
109        np.array([[1, 2], [3, 4]], dtype=np.int32))
110    expected = [[1.0, 2.0], [3.0, 4.0]]
111    output = model.predict(input_data)
112    self.assertIsInstance(output, sparse_tensor.SparseTensor)
113    self.assertAllClose(expected, sparse_ops.sparse_tensor_to_dense(output))
114
115
116if __name__ == '__main__':
117  test.main()
118