• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for learn.estimators.tensor_signature."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.learn.python.learn.estimators import tensor_signature
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import sparse_tensor
24from tensorflow.python.ops import array_ops
25from tensorflow.python.platform import test
26
27
28class TensorSignatureTest(test.TestCase):
29
30  def testTensorPlaceholderNone(self):
31    self.assertEqual(None,
32                     tensor_signature.create_placeholders_from_signatures(None))
33
34  def testTensorSignatureNone(self):
35    self.assertEqual(None, tensor_signature.create_signatures(None))
36
37  def testTensorSignatureCompatible(self):
38    placeholder_a = array_ops.placeholder(
39        name='test', shape=[None, 100], dtype=dtypes.int32)
40    placeholder_b = array_ops.placeholder(
41        name='another', shape=[256, 100], dtype=dtypes.int32)
42    placeholder_c = array_ops.placeholder(
43        name='mismatch', shape=[256, 100], dtype=dtypes.float32)
44    placeholder_d = array_ops.placeholder(
45        name='mismatch', shape=[128, 100], dtype=dtypes.int32)
46    signatures = tensor_signature.create_signatures(placeholder_a)
47    self.assertTrue(tensor_signature.tensors_compatible(None, None))
48    self.assertFalse(tensor_signature.tensors_compatible(None, signatures))
49    self.assertFalse(tensor_signature.tensors_compatible(placeholder_a, None))
50    self.assertTrue(
51        tensor_signature.tensors_compatible(placeholder_a, signatures))
52    self.assertTrue(
53        tensor_signature.tensors_compatible(placeholder_b, signatures))
54    self.assertFalse(
55        tensor_signature.tensors_compatible(placeholder_c, signatures))
56    self.assertTrue(
57        tensor_signature.tensors_compatible(placeholder_d, signatures))
58
59    inputs = {'a': placeholder_a}
60    signatures = tensor_signature.create_signatures(inputs)
61    self.assertTrue(tensor_signature.tensors_compatible(inputs, signatures))
62    self.assertFalse(
63        tensor_signature.tensors_compatible(placeholder_a, signatures))
64    self.assertFalse(
65        tensor_signature.tensors_compatible(placeholder_b, signatures))
66    self.assertFalse(
67        tensor_signature.tensors_compatible({
68            'b': placeholder_b
69        }, signatures))
70    self.assertTrue(
71        tensor_signature.tensors_compatible({
72            'a': placeholder_b,
73            'c': placeholder_c
74        }, signatures))
75    self.assertFalse(
76        tensor_signature.tensors_compatible({
77            'a': placeholder_c
78        }, signatures))
79
80  def testSparseTensorCompatible(self):
81    t = sparse_tensor.SparseTensor(
82        indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
83    signatures = tensor_signature.create_signatures(t)
84    self.assertTrue(tensor_signature.tensors_compatible(t, signatures))
85
86  def testTensorSignaturePlaceholders(self):
87    placeholder_a = array_ops.placeholder(
88        name='test', shape=[None, 100], dtype=dtypes.int32)
89    signatures = tensor_signature.create_signatures(placeholder_a)
90    placeholder_out = tensor_signature.create_placeholders_from_signatures(
91        signatures)
92    self.assertEqual(placeholder_out.dtype, placeholder_a.dtype)
93    self.assertTrue(placeholder_out.get_shape().is_compatible_with(
94        placeholder_a.get_shape()))
95    self.assertTrue(
96        tensor_signature.tensors_compatible(placeholder_out, signatures))
97
98    inputs = {'a': placeholder_a}
99    signatures = tensor_signature.create_signatures(inputs)
100    placeholders_out = tensor_signature.create_placeholders_from_signatures(
101        signatures)
102    self.assertEqual(placeholders_out['a'].dtype, placeholder_a.dtype)
103    self.assertTrue(placeholders_out['a'].get_shape().is_compatible_with(
104        placeholder_a.get_shape()))
105    self.assertTrue(
106        tensor_signature.tensors_compatible(placeholders_out, signatures))
107
108  def testSparseTensorSignaturePlaceholders(self):
109    tensor = sparse_tensor.SparseTensor(
110        values=[1.0, 2.0], indices=[[0, 2], [0, 3]], dense_shape=[5, 5])
111    signature = tensor_signature.create_signatures(tensor)
112    placeholder = tensor_signature.create_placeholders_from_signatures(
113        signature)
114    self.assertTrue(isinstance(placeholder, sparse_tensor.SparseTensor))
115    self.assertEqual(placeholder.values.dtype, tensor.values.dtype)
116
117  def testTensorSignatureExampleParserSingle(self):
118    examples = array_ops.placeholder(
119        name='example', shape=[None], dtype=dtypes.string)
120    placeholder_a = array_ops.placeholder(
121        name='test', shape=[None, 100], dtype=dtypes.int32)
122    signatures = tensor_signature.create_signatures(placeholder_a)
123    result = tensor_signature.create_example_parser_from_signatures(signatures,
124                                                                    examples)
125    self.assertTrue(tensor_signature.tensors_compatible(result, signatures))
126    new_signatures = tensor_signature.create_signatures(result)
127    self.assertTrue(new_signatures.is_compatible_with(signatures))
128
129  def testTensorSignatureExampleParserDict(self):
130    examples = array_ops.placeholder(
131        name='example', shape=[None], dtype=dtypes.string)
132    placeholder_a = array_ops.placeholder(
133        name='test', shape=[None, 100], dtype=dtypes.int32)
134    placeholder_b = array_ops.placeholder(
135        name='bb', shape=[None, 100], dtype=dtypes.float64)
136    inputs = {'a': placeholder_a, 'b': placeholder_b}
137    signatures = tensor_signature.create_signatures(inputs)
138    result = tensor_signature.create_example_parser_from_signatures(signatures,
139                                                                    examples)
140    self.assertTrue(tensor_signature.tensors_compatible(result, signatures))
141    new_signatures = tensor_signature.create_signatures(result)
142    self.assertTrue(new_signatures['a'].is_compatible_with(signatures['a']))
143    self.assertTrue(new_signatures['b'].is_compatible_with(signatures['b']))
144
145  def testUnknownShape(self):
146    placeholder_unk = array_ops.placeholder(
147        name='unk', shape=None, dtype=dtypes.string)
148    placeholder_a = array_ops.placeholder(
149        name='a', shape=[None], dtype=dtypes.string)
150    placeholder_b = array_ops.placeholder(
151        name='b', shape=[128, 2], dtype=dtypes.string)
152    placeholder_c = array_ops.placeholder(
153        name='c', shape=[128, 2], dtype=dtypes.int32)
154    unk_signature = tensor_signature.create_signatures(placeholder_unk)
155    # Tensors of same dtype match unk shape signature.
156    self.assertTrue(
157        tensor_signature.tensors_compatible(placeholder_unk, unk_signature))
158    self.assertTrue(
159        tensor_signature.tensors_compatible(placeholder_a, unk_signature))
160    self.assertTrue(
161        tensor_signature.tensors_compatible(placeholder_b, unk_signature))
162    self.assertFalse(
163        tensor_signature.tensors_compatible(placeholder_c, unk_signature))
164
165    string_signature = tensor_signature.create_signatures(placeholder_a)
166    int_signature = tensor_signature.create_signatures(placeholder_c)
167    # Unk shape Tensor matche signatures same dtype.
168    self.assertTrue(
169        tensor_signature.tensors_compatible(placeholder_unk, string_signature))
170    self.assertFalse(
171        tensor_signature.tensors_compatible(placeholder_unk, int_signature))
172
173
174if __name__ == '__main__':
175  test.main()
176