1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the 'License'); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an 'AS IS' BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import numpy as np 17 18from tensorflow.python import keras 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import sparse_tensor 21from tensorflow.python.keras import backend 22from tensorflow.python.keras import combinations 23from tensorflow.python.keras import keras_parameterized 24from tensorflow.python.keras.engine import base_layer_utils 25from tensorflow.python.ops import lookup_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import sparse_ops 28from tensorflow.python.ops.ragged import ragged_factory_ops 29from tensorflow.python.ops.ragged import ragged_tensor 30from tensorflow.python.platform import test 31 32 33@combinations.generate(combinations.combine(mode=['graph', 'eager'])) 34class TrackableWeightHandlerTest(keras_parameterized.TestCase): 35 36 def get_table_handler(self): 37 # Note: There is some repetition in these tests' setup. However, Tensorflow 38 # does not play nicely with a separate setUp() call (causing errors related 39 # to graph building), so we have to use a called setup instead of a setUp() 40 # call. 41 table = lookup_ops.MutableHashTable( 42 key_dtype=dtypes.string, value_dtype=dtypes.int32, default_value=0) 43 return base_layer_utils.TrackableWeightHandler(table) 44 45 def test_get_num_tensors(self): 46 table_handler = self.get_table_handler() 47 self.assertEqual(2, table_handler.num_tensors) 48 49 def test_get_and_set_weights(self): 50 table_handler = self.get_table_handler() 51 52 table_data = {b'a': 1, b'b': 2, b'c': 3} 53 table_handler.set_weights( 54 [list(table_data.keys()), 55 list(table_data.values())]) 56 weights = backend.batch_get_value(table_handler.get_tensors()) 57 weight_data = {key: value for key, value in zip(weights[0], weights[1])} 58 self.assertDictEqual(table_data, weight_data) 59 60 def test_get_and_set_weights_does_not_add_ops(self): 61 table_handler = self.get_table_handler() 62 table_data = {b'a': 1, b'b': 2, b'c': 3} 63 table_handler.set_weights( 64 [list(table_data.keys()), 65 list(table_data.values())]) 66 _ = backend.batch_get_value(table_handler.get_tensors()) 67 backend.get_session().graph.finalize() 68 table_handler.set_weights( 69 [list(table_data.keys()), 70 list(table_data.values())]) 71 _ = backend.batch_get_value(table_handler.get_tensors()) 72 73 74@combinations.generate(combinations.combine(mode=['eager'])) 75class OpLayerTest(keras_parameterized.TestCase): 76 77 def test_tensor_op_layer(self): 78 int_values = keras.Input(shape=(2,), dtype=dtypes.int32) 79 float_values = math_ops.cast(int_values, dtypes.float32) 80 model = keras.Model(int_values, float_values) 81 model.compile(loss='mse') 82 83 input_data = np.array([[1, 2], [3, 4]], dtype=np.int32) 84 expected = [[1.0, 2.0], [3.0, 4.0]] 85 output = model.predict(input_data) 86 self.assertAllClose(expected, output) 87 88 def test_ragged_op_layer_keras_tensors(self): 89 int_values = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True) 90 float_values = math_ops.cast(int_values, dtypes.float32) 91 model = keras.Model(int_values, float_values) 92 model.compile(loss='mse') 93 94 input_data = ragged_factory_ops.constant( 95 [[1, 2], [3, 4]], dtype=np.int32) 96 expected = [[1.0, 2.0], [3.0, 4.0]] 97 output = model.predict(input_data) 98 self.assertIsInstance(output, ragged_tensor.RaggedTensor) 99 self.assertAllClose(expected, output) 100 101 def test_sparse_op_layer_keras_tensors(self): 102 int_values = keras.Input(shape=(None,), dtype=dtypes.int32, sparse=True) 103 float_values = math_ops.cast(int_values, dtypes.float32) 104 _ = keras.Model(int_values, float_values) 105 model = keras.Model(int_values, float_values) 106 model.compile(loss='mse') 107 108 input_data = sparse_ops.from_dense( 109 np.array([[1, 2], [3, 4]], dtype=np.int32)) 110 expected = [[1.0, 2.0], [3.0, 4.0]] 111 output = model.predict(input_data) 112 self.assertIsInstance(output, sparse_tensor.SparseTensor) 113 self.assertAllClose(expected, sparse_ops.sparse_tensor_to_dense(output)) 114 115 116if __name__ == '__main__': 117 test.main() 118