1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14#,============================================================================ 15"""Tests for InputLayer construction.""" 16 17from tensorflow.python.eager import def_function 18from tensorflow.python.framework import composite_tensor 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import ops 21from tensorflow.python.framework import tensor_shape 22from tensorflow.python.framework import tensor_spec 23from tensorflow.python.framework import type_spec 24from tensorflow.python.keras import backend 25from tensorflow.python.keras import combinations 26from tensorflow.python.keras import keras_parameterized 27from tensorflow.python.keras.engine import functional 28from tensorflow.python.keras.engine import input_layer as input_layer_lib 29from tensorflow.python.keras.layers import core 30from tensorflow.python.keras.saving import model_config 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops.ragged import ragged_tensor 34from tensorflow.python.platform import test 35 36 37class TwoTensors(composite_tensor.CompositeTensor): 38 """A simple value type to test TypeSpec. 39 40 Contains two tensors (x, y) and a string (color). The color value is a 41 stand-in for any extra type metadata we might need to store. 42 43 This value type contains no single dtype. 44 """ 45 46 def __init__(self, x, y, color='red', assign_variant_dtype=False): 47 assert isinstance(color, str) 48 self.x = ops.convert_to_tensor_v2_with_dispatch(x) 49 self.y = ops.convert_to_tensor_v2_with_dispatch(y) 50 self.color = color 51 self.shape = tensor_shape.TensorShape(None) 52 self._shape = tensor_shape.TensorShape(None) 53 if assign_variant_dtype: 54 self.dtype = dtypes.variant 55 self._assign_variant_dtype = assign_variant_dtype 56 57 def _type_spec(self): 58 return TwoTensorsSpecNoOneDtype( 59 self.x.shape, self.x.dtype, self.y.shape, 60 self.y.dtype, color=self.color, 61 assign_variant_dtype=self._assign_variant_dtype) 62 63 64def as_shape(shape): 65 """Converts the given object to a TensorShape.""" 66 if isinstance(shape, tensor_shape.TensorShape): 67 return shape 68 else: 69 return tensor_shape.TensorShape(shape) 70 71 72@type_spec.register('tf.TwoTensorsSpec') 73class TwoTensorsSpecNoOneDtype(type_spec.TypeSpec): 74 """A TypeSpec for the TwoTensors value type.""" 75 76 def __init__( 77 self, x_shape, x_dtype, y_shape, y_dtype, color='red', 78 assign_variant_dtype=False): 79 self.x_shape = as_shape(x_shape) 80 self.x_dtype = dtypes.as_dtype(x_dtype) 81 self.y_shape = as_shape(y_shape) 82 self.y_dtype = dtypes.as_dtype(y_dtype) 83 self.color = color 84 self.shape = tensor_shape.TensorShape(None) 85 self._shape = tensor_shape.TensorShape(None) 86 if assign_variant_dtype: 87 self.dtype = dtypes.variant 88 self._assign_variant_dtype = assign_variant_dtype 89 90 value_type = property(lambda self: TwoTensors) 91 92 @property 93 def _component_specs(self): 94 return (tensor_spec.TensorSpec(self.x_shape, self.x_dtype), 95 tensor_spec.TensorSpec(self.y_shape, self.y_dtype)) 96 97 def _to_components(self, value): 98 return (value.x, value.y) 99 100 def _from_components(self, components): 101 x, y = components 102 return TwoTensors(x, y, self.color) 103 104 def _serialize(self): 105 return (self.x_shape, self.x_dtype, self.y_shape, self.y_dtype, self.color) 106 107 @classmethod 108 def from_value(cls, value): 109 return cls(value.x.shape, value.x.dtype, value.y.shape, value.y.dtype, 110 value.color) 111 112 113type_spec.register_type_spec_from_value_converter( 114 TwoTensors, TwoTensorsSpecNoOneDtype.from_value) 115 116 117class InputLayerTest(keras_parameterized.TestCase): 118 119 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 120 def testBasicOutputShapeNoBatchSize(self): 121 # Create a Keras Input 122 x = input_layer_lib.Input(shape=(32,), name='input_a') 123 self.assertAllEqual(x.shape.as_list(), [None, 32]) 124 125 # Verify you can construct and use a model w/ this input 126 model = functional.Functional(x, x * 2.0) 127 self.assertAllEqual(model(array_ops.ones((3, 32))), 128 array_ops.ones((3, 32)) * 2.0) 129 130 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 131 def testBasicOutputShapeWithBatchSize(self): 132 # Create a Keras Input 133 x = input_layer_lib.Input(batch_size=6, shape=(32,), name='input_b') 134 self.assertAllEqual(x.shape.as_list(), [6, 32]) 135 136 # Verify you can construct and use a model w/ this input 137 model = functional.Functional(x, x * 2.0) 138 self.assertAllEqual(model(array_ops.ones(x.shape)), 139 array_ops.ones(x.shape) * 2.0) 140 141 @combinations.generate(combinations.combine(mode=['eager'])) 142 def testBasicOutputShapeNoBatchSizeInTFFunction(self): 143 model = None 144 @def_function.function 145 def run_model(inp): 146 nonlocal model 147 if not model: 148 # Create a Keras Input 149 x = input_layer_lib.Input(shape=(8,), name='input_a') 150 self.assertAllEqual(x.shape.as_list(), [None, 8]) 151 152 # Verify you can construct and use a model w/ this input 153 model = functional.Functional(x, x * 2.0) 154 return model(inp) 155 156 self.assertAllEqual(run_model(array_ops.ones((10, 8))), 157 array_ops.ones((10, 8)) * 2.0) 158 159 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 160 def testInputTensorArg(self): 161 # Create a Keras Input 162 x = input_layer_lib.Input(tensor=array_ops.zeros((7, 32))) 163 self.assertAllEqual(x.shape.as_list(), [7, 32]) 164 165 # Verify you can construct and use a model w/ this input 166 model = functional.Functional(x, x * 2.0) 167 self.assertAllEqual(model(array_ops.ones(x.shape)), 168 array_ops.ones(x.shape) * 2.0) 169 170 @combinations.generate(combinations.combine(mode=['eager'])) 171 def testInputTensorArgInTFFunction(self): 172 # We use a mutable model container instead of a model python variable, 173 # because python 2.7 does not have `nonlocal` 174 model_container = {} 175 176 @def_function.function 177 def run_model(inp): 178 if not model_container: 179 # Create a Keras Input 180 x = input_layer_lib.Input(tensor=array_ops.zeros((10, 16))) 181 self.assertAllEqual(x.shape.as_list(), [10, 16]) 182 183 # Verify you can construct and use a model w/ this input 184 model_container['model'] = functional.Functional(x, x * 3.0) 185 return model_container['model'](inp) 186 187 self.assertAllEqual(run_model(array_ops.ones((10, 16))), 188 array_ops.ones((10, 16)) * 3.0) 189 190 @combinations.generate(combinations.combine(mode=['eager'])) 191 def testCompositeInputTensorArg(self): 192 # Create a Keras Input 193 rt = ragged_tensor.RaggedTensor.from_row_splits( 194 values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) 195 x = input_layer_lib.Input(tensor=rt) 196 197 # Verify you can construct and use a model w/ this input 198 model = functional.Functional(x, x * 2) 199 200 # And that the model works 201 rt = ragged_tensor.RaggedTensor.from_row_splits( 202 values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) 203 self.assertAllEqual(model(rt), rt * 2) 204 205 @combinations.generate(combinations.combine(mode=['eager'])) 206 def testCompositeInputTensorArgInTFFunction(self): 207 # We use a mutable model container instead of a model python variable, 208 # because python 2.7 does not have `nonlocal` 209 model_container = {} 210 211 @def_function.function 212 def run_model(inp): 213 if not model_container: 214 # Create a Keras Input 215 rt = ragged_tensor.RaggedTensor.from_row_splits( 216 values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) 217 x = input_layer_lib.Input(tensor=rt) 218 219 # Verify you can construct and use a model w/ this input 220 model_container['model'] = functional.Functional(x, x * 3) 221 return model_container['model'](inp) 222 223 # And verify the model works 224 rt = ragged_tensor.RaggedTensor.from_row_splits( 225 values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) 226 self.assertAllEqual(run_model(rt), rt * 3) 227 228 @combinations.generate(combinations.combine(mode=['eager'])) 229 def testNoMixingArgsWithTypeSpecArg(self): 230 with self.assertRaisesRegexp( 231 ValueError, 'all other args except `name` must be None'): 232 input_layer_lib.Input( 233 shape=(4, 7), 234 type_spec=tensor_spec.TensorSpec((2, 7, 32), dtypes.float32)) 235 with self.assertRaisesRegexp( 236 ValueError, 'all other args except `name` must be None'): 237 input_layer_lib.Input( 238 batch_size=4, 239 type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32)) 240 with self.assertRaisesRegexp( 241 ValueError, 'all other args except `name` must be None'): 242 input_layer_lib.Input( 243 dtype=dtypes.int64, 244 type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32)) 245 with self.assertRaisesRegexp( 246 ValueError, 'all other args except `name` must be None'): 247 input_layer_lib.Input( 248 sparse=True, 249 type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32)) 250 with self.assertRaisesRegexp( 251 ValueError, 'all other args except `name` must be None'): 252 input_layer_lib.Input( 253 ragged=True, 254 type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32)) 255 256 @combinations.generate(combinations.combine(mode=['eager'])) 257 def testTypeSpecArg(self): 258 # Create a Keras Input 259 x = input_layer_lib.Input( 260 type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32)) 261 self.assertAllEqual(x.shape.as_list(), [7, 32]) 262 263 # Verify you can construct and use a model w/ this input 264 model = functional.Functional(x, x * 2.0) 265 self.assertAllEqual(model(array_ops.ones(x.shape)), 266 array_ops.ones(x.shape) * 2.0) 267 268 # Test serialization / deserialization 269 model = functional.Functional.from_config(model.get_config()) 270 self.assertAllEqual(model(array_ops.ones(x.shape)), 271 array_ops.ones(x.shape) * 2.0) 272 273 model = model_config.model_from_json(model.to_json()) 274 self.assertAllEqual(model(array_ops.ones(x.shape)), 275 array_ops.ones(x.shape) * 2.0) 276 277 @combinations.generate(combinations.combine(mode=['eager'])) 278 def testTypeSpecArgInTFFunction(self): 279 # We use a mutable model container instead of a model python variable, 280 # because python 2.7 does not have `nonlocal` 281 model_container = {} 282 283 @def_function.function 284 def run_model(inp): 285 if not model_container: 286 # Create a Keras Input 287 x = input_layer_lib.Input( 288 type_spec=tensor_spec.TensorSpec((10, 16), dtypes.float32)) 289 self.assertAllEqual(x.shape.as_list(), [10, 16]) 290 291 # Verify you can construct and use a model w/ this input 292 model_container['model'] = functional.Functional(x, x * 3.0) 293 return model_container['model'](inp) 294 295 self.assertAllEqual(run_model(array_ops.ones((10, 16))), 296 array_ops.ones((10, 16)) * 3.0) 297 298 @combinations.generate(combinations.combine(mode=['eager'])) 299 def testCompositeTypeSpecArg(self): 300 # Create a Keras Input 301 rt = ragged_tensor.RaggedTensor.from_row_splits( 302 values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) 303 x = input_layer_lib.Input(type_spec=rt._type_spec) 304 305 # Verify you can construct and use a model w/ this input 306 model = functional.Functional(x, x * 2) 307 308 # And that the model works 309 rt = ragged_tensor.RaggedTensor.from_row_splits( 310 values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) 311 self.assertAllEqual(model(rt), rt * 2) 312 313 # Test serialization / deserialization 314 model = functional.Functional.from_config(model.get_config()) 315 self.assertAllEqual(model(rt), rt * 2) 316 model = model_config.model_from_json(model.to_json()) 317 self.assertAllEqual(model(rt), rt * 2) 318 319 @combinations.generate(combinations.combine(mode=['eager'])) 320 def testCompositeTypeSpecArgInTFFunction(self): 321 # We use a mutable model container instead of a model pysthon variable, 322 # because python 2.7 does not have `nonlocal` 323 model_container = {} 324 325 @def_function.function 326 def run_model(inp): 327 if not model_container: 328 # Create a Keras Input 329 rt = ragged_tensor.RaggedTensor.from_row_splits( 330 values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) 331 x = input_layer_lib.Input(type_spec=rt._type_spec) 332 333 # Verify you can construct and use a model w/ this input 334 model_container['model'] = functional.Functional(x, x * 3) 335 return model_container['model'](inp) 336 337 # And verify the model works 338 rt = ragged_tensor.RaggedTensor.from_row_splits( 339 values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) 340 self.assertAllEqual(run_model(rt), rt * 3) 341 342 @combinations.generate(combinations.combine(mode=['eager'])) 343 def testCompositeTypeSpecArgWithoutDtype(self): 344 for assign_variant_dtype in [False, True]: 345 # Create a Keras Input 346 spec = TwoTensorsSpecNoOneDtype( 347 (1, 2, 3), dtypes.float32, (1, 2, 3), dtypes.int64, 348 assign_variant_dtype=assign_variant_dtype) 349 x = input_layer_lib.Input(type_spec=spec) 350 351 def lambda_fn(tensors): 352 return (math_ops.cast(tensors.x, dtypes.float64) 353 + math_ops.cast(tensors.y, dtypes.float64)) 354 # Verify you can construct and use a model w/ this input 355 model = functional.Functional(x, core.Lambda(lambda_fn)(x)) 356 357 # And that the model works 358 two_tensors = TwoTensors(array_ops.ones((1, 2, 3)) * 2.0, 359 array_ops.ones(1, 2, 3)) 360 self.assertAllEqual(model(two_tensors), lambda_fn(two_tensors)) 361 362 # Test serialization / deserialization 363 model = functional.Functional.from_config(model.get_config()) 364 self.assertAllEqual(model(two_tensors), lambda_fn(two_tensors)) 365 model = model_config.model_from_json(model.to_json()) 366 self.assertAllEqual(model(two_tensors), lambda_fn(two_tensors)) 367 368 def test_serialize_with_unknown_rank(self): 369 inp = backend.placeholder(shape=None, dtype=dtypes.string) 370 x = input_layer_lib.InputLayer(input_tensor=inp, dtype=dtypes.string) 371 loaded = input_layer_lib.InputLayer.from_config(x.get_config()) 372 self.assertIsNone(loaded._batch_input_shape) 373 374 375if __name__ == '__main__': 376 test.main() 377