1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for learn.estimators.tensor_signature.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.learn.python.learn.estimators import tensor_signature 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import sparse_tensor 24from tensorflow.python.ops import array_ops 25from tensorflow.python.platform import test 26 27 28class TensorSignatureTest(test.TestCase): 29 30 def testTensorPlaceholderNone(self): 31 self.assertEqual(None, 32 tensor_signature.create_placeholders_from_signatures(None)) 33 34 def testTensorSignatureNone(self): 35 self.assertEqual(None, tensor_signature.create_signatures(None)) 36 37 def testTensorSignatureCompatible(self): 38 placeholder_a = array_ops.placeholder( 39 name='test', shape=[None, 100], dtype=dtypes.int32) 40 placeholder_b = array_ops.placeholder( 41 name='another', shape=[256, 100], dtype=dtypes.int32) 42 placeholder_c = array_ops.placeholder( 43 name='mismatch', shape=[256, 100], dtype=dtypes.float32) 44 placeholder_d = array_ops.placeholder( 45 name='mismatch', shape=[128, 100], dtype=dtypes.int32) 46 signatures = tensor_signature.create_signatures(placeholder_a) 47 self.assertTrue(tensor_signature.tensors_compatible(None, None)) 48 self.assertFalse(tensor_signature.tensors_compatible(None, signatures)) 49 self.assertFalse(tensor_signature.tensors_compatible(placeholder_a, None)) 50 self.assertTrue( 51 tensor_signature.tensors_compatible(placeholder_a, signatures)) 52 self.assertTrue( 53 tensor_signature.tensors_compatible(placeholder_b, signatures)) 54 self.assertFalse( 55 tensor_signature.tensors_compatible(placeholder_c, signatures)) 56 self.assertTrue( 57 tensor_signature.tensors_compatible(placeholder_d, signatures)) 58 59 inputs = {'a': placeholder_a} 60 signatures = tensor_signature.create_signatures(inputs) 61 self.assertTrue(tensor_signature.tensors_compatible(inputs, signatures)) 62 self.assertFalse( 63 tensor_signature.tensors_compatible(placeholder_a, signatures)) 64 self.assertFalse( 65 tensor_signature.tensors_compatible(placeholder_b, signatures)) 66 self.assertFalse( 67 tensor_signature.tensors_compatible({ 68 'b': placeholder_b 69 }, signatures)) 70 self.assertTrue( 71 tensor_signature.tensors_compatible({ 72 'a': placeholder_b, 73 'c': placeholder_c 74 }, signatures)) 75 self.assertFalse( 76 tensor_signature.tensors_compatible({ 77 'a': placeholder_c 78 }, signatures)) 79 80 def testSparseTensorCompatible(self): 81 t = sparse_tensor.SparseTensor( 82 indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) 83 signatures = tensor_signature.create_signatures(t) 84 self.assertTrue(tensor_signature.tensors_compatible(t, signatures)) 85 86 def testTensorSignaturePlaceholders(self): 87 placeholder_a = array_ops.placeholder( 88 name='test', shape=[None, 100], dtype=dtypes.int32) 89 signatures = tensor_signature.create_signatures(placeholder_a) 90 placeholder_out = tensor_signature.create_placeholders_from_signatures( 91 signatures) 92 self.assertEqual(placeholder_out.dtype, placeholder_a.dtype) 93 self.assertTrue(placeholder_out.get_shape().is_compatible_with( 94 placeholder_a.get_shape())) 95 self.assertTrue( 96 tensor_signature.tensors_compatible(placeholder_out, signatures)) 97 98 inputs = {'a': placeholder_a} 99 signatures = tensor_signature.create_signatures(inputs) 100 placeholders_out = tensor_signature.create_placeholders_from_signatures( 101 signatures) 102 self.assertEqual(placeholders_out['a'].dtype, placeholder_a.dtype) 103 self.assertTrue(placeholders_out['a'].get_shape().is_compatible_with( 104 placeholder_a.get_shape())) 105 self.assertTrue( 106 tensor_signature.tensors_compatible(placeholders_out, signatures)) 107 108 def testSparseTensorSignaturePlaceholders(self): 109 tensor = sparse_tensor.SparseTensor( 110 values=[1.0, 2.0], indices=[[0, 2], [0, 3]], dense_shape=[5, 5]) 111 signature = tensor_signature.create_signatures(tensor) 112 placeholder = tensor_signature.create_placeholders_from_signatures( 113 signature) 114 self.assertTrue(isinstance(placeholder, sparse_tensor.SparseTensor)) 115 self.assertEqual(placeholder.values.dtype, tensor.values.dtype) 116 117 def testTensorSignatureExampleParserSingle(self): 118 examples = array_ops.placeholder( 119 name='example', shape=[None], dtype=dtypes.string) 120 placeholder_a = array_ops.placeholder( 121 name='test', shape=[None, 100], dtype=dtypes.int32) 122 signatures = tensor_signature.create_signatures(placeholder_a) 123 result = tensor_signature.create_example_parser_from_signatures(signatures, 124 examples) 125 self.assertTrue(tensor_signature.tensors_compatible(result, signatures)) 126 new_signatures = tensor_signature.create_signatures(result) 127 self.assertTrue(new_signatures.is_compatible_with(signatures)) 128 129 def testTensorSignatureExampleParserDict(self): 130 examples = array_ops.placeholder( 131 name='example', shape=[None], dtype=dtypes.string) 132 placeholder_a = array_ops.placeholder( 133 name='test', shape=[None, 100], dtype=dtypes.int32) 134 placeholder_b = array_ops.placeholder( 135 name='bb', shape=[None, 100], dtype=dtypes.float64) 136 inputs = {'a': placeholder_a, 'b': placeholder_b} 137 signatures = tensor_signature.create_signatures(inputs) 138 result = tensor_signature.create_example_parser_from_signatures(signatures, 139 examples) 140 self.assertTrue(tensor_signature.tensors_compatible(result, signatures)) 141 new_signatures = tensor_signature.create_signatures(result) 142 self.assertTrue(new_signatures['a'].is_compatible_with(signatures['a'])) 143 self.assertTrue(new_signatures['b'].is_compatible_with(signatures['b'])) 144 145 def testUnknownShape(self): 146 placeholder_unk = array_ops.placeholder( 147 name='unk', shape=None, dtype=dtypes.string) 148 placeholder_a = array_ops.placeholder( 149 name='a', shape=[None], dtype=dtypes.string) 150 placeholder_b = array_ops.placeholder( 151 name='b', shape=[128, 2], dtype=dtypes.string) 152 placeholder_c = array_ops.placeholder( 153 name='c', shape=[128, 2], dtype=dtypes.int32) 154 unk_signature = tensor_signature.create_signatures(placeholder_unk) 155 # Tensors of same dtype match unk shape signature. 156 self.assertTrue( 157 tensor_signature.tensors_compatible(placeholder_unk, unk_signature)) 158 self.assertTrue( 159 tensor_signature.tensors_compatible(placeholder_a, unk_signature)) 160 self.assertTrue( 161 tensor_signature.tensors_compatible(placeholder_b, unk_signature)) 162 self.assertFalse( 163 tensor_signature.tensors_compatible(placeholder_c, unk_signature)) 164 165 string_signature = tensor_signature.create_signatures(placeholder_a) 166 int_signature = tensor_signature.create_signatures(placeholder_c) 167 # Unk shape Tensor matche signatures same dtype. 168 self.assertTrue( 169 tensor_signature.tensors_compatible(placeholder_unk, string_signature)) 170 self.assertFalse( 171 tensor_signature.tensors_compatible(placeholder_unk, int_signature)) 172 173 174if __name__ == '__main__': 175 test.main() 176