• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Integration test for sequence feature columns with SequenceExamples."""
16
17import string
18import tempfile
19
20from google.protobuf import text_format
21
22from tensorflow.core.example import example_pb2
23from tensorflow.python.feature_column import feature_column_v2 as fc
24from tensorflow.python.feature_column import sequence_feature_column as sfc
25from tensorflow.python.ops import parsing_ops
26from tensorflow.python.platform import test
27from tensorflow.python.util import compat
28
29
30class SequenceExampleParsingTest(test.TestCase):
31
32  def test_seq_ex_in_sequence_categorical_column_with_identity(self):
33    self._test_parsed_sequence_example(
34        'int_list', sfc.sequence_categorical_column_with_identity,
35        10, [3, 6], [2, 4, 6])
36
37  def test_seq_ex_in_sequence_categorical_column_with_hash_bucket(self):
38    self._test_parsed_sequence_example(
39        'bytes_list', sfc.sequence_categorical_column_with_hash_bucket,
40        10, [3, 4], [compat.as_bytes(x) for x in 'acg'])
41
42  def test_seq_ex_in_sequence_categorical_column_with_vocabulary_list(self):
43    self._test_parsed_sequence_example(
44        'bytes_list', sfc.sequence_categorical_column_with_vocabulary_list,
45        list(string.ascii_lowercase), [3, 4],
46        [compat.as_bytes(x) for x in 'acg'])
47
48  def test_seq_ex_in_sequence_categorical_column_with_vocabulary_file(self):
49    _, fname = tempfile.mkstemp()
50    with open(fname, 'w') as f:
51      f.write(string.ascii_lowercase)
52    self._test_parsed_sequence_example(
53        'bytes_list', sfc.sequence_categorical_column_with_vocabulary_file,
54        fname, [3, 4], [compat.as_bytes(x) for x in 'acg'])
55
56  def _test_parsed_sequence_example(
57      self, col_name, col_fn, col_arg, shape, values):
58    """Helper function to check that each FeatureColumn parses correctly.
59
60    Args:
61      col_name: string, name to give to the feature column. Should match
62        the name that the column will parse out of the features dict.
63      col_fn: function used to create the feature column. For example,
64        sequence_numeric_column.
65      col_arg: second arg that the target feature column is expecting.
66      shape: the expected dense_shape of the feature after parsing into
67        a SparseTensor.
68      values: the expected values at index [0, 2, 6] of the feature
69        after parsing into a SparseTensor.
70    """
71    example = _make_sequence_example()
72    columns = [
73        fc.categorical_column_with_identity('int_ctx', num_buckets=100),
74        fc.numeric_column('float_ctx'),
75        col_fn(col_name, col_arg)
76    ]
77    context, seq_features = parsing_ops.parse_single_sequence_example(
78        example.SerializeToString(),
79        context_features=fc.make_parse_example_spec_v2(columns[:2]),
80        sequence_features=fc.make_parse_example_spec_v2(columns[2:]))
81
82    with self.cached_session() as sess:
83      ctx_result, seq_result = sess.run([context, seq_features])
84      self.assertEqual(list(seq_result[col_name].dense_shape), shape)
85      self.assertEqual(
86          list(seq_result[col_name].values[[0, 2, 6]]), values)
87      self.assertEqual(list(ctx_result['int_ctx'].dense_shape), [1])
88      self.assertEqual(ctx_result['int_ctx'].values[0], 5)
89      self.assertEqual(list(ctx_result['float_ctx'].shape), [1])
90      self.assertAlmostEqual(ctx_result['float_ctx'][0], 123.6, places=1)
91
92
93_SEQ_EX_PROTO = """
94context {
95  feature {
96    key: "float_ctx"
97    value {
98      float_list {
99        value: 123.6
100      }
101    }
102  }
103  feature {
104    key: "int_ctx"
105    value {
106      int64_list {
107        value: 5
108      }
109    }
110  }
111}
112feature_lists {
113  feature_list {
114    key: "bytes_list"
115    value {
116      feature {
117        bytes_list {
118          value: "a"
119        }
120      }
121      feature {
122        bytes_list {
123          value: "b"
124          value: "c"
125        }
126      }
127      feature {
128        bytes_list {
129          value: "d"
130          value: "e"
131          value: "f"
132          value: "g"
133        }
134      }
135    }
136  }
137  feature_list {
138    key: "float_list"
139    value {
140      feature {
141        float_list {
142          value: 1.0
143        }
144      }
145      feature {
146        float_list {
147          value: 3.0
148          value: 3.0
149          value: 3.0
150        }
151      }
152      feature {
153        float_list {
154          value: 5.0
155          value: 5.0
156          value: 5.0
157          value: 5.0
158          value: 5.0
159        }
160      }
161    }
162  }
163  feature_list {
164    key: "int_list"
165    value {
166      feature {
167        int64_list {
168          value: 2
169          value: 2
170        }
171      }
172      feature {
173        int64_list {
174          value: 4
175          value: 4
176          value: 4
177          value: 4
178        }
179      }
180      feature {
181        int64_list {
182          value: 6
183          value: 6
184          value: 6
185          value: 6
186          value: 6
187          value: 6
188        }
189      }
190    }
191  }
192}
193"""
194
195
196def _make_sequence_example():
197  example = example_pb2.SequenceExample()
198  return text_format.Parse(_SEQ_EX_PROTO, example)
199
200
201if __name__ == '__main__':
202  test.main()
203