1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Integration test for sequence feature columns with SequenceExamples.""" 16 17import string 18import tempfile 19 20from google.protobuf import text_format 21 22from tensorflow.core.example import example_pb2 23from tensorflow.python.feature_column import feature_column_v2 as fc 24from tensorflow.python.feature_column import sequence_feature_column as sfc 25from tensorflow.python.ops import parsing_ops 26from tensorflow.python.platform import test 27from tensorflow.python.util import compat 28 29 30class SequenceExampleParsingTest(test.TestCase): 31 32 def test_seq_ex_in_sequence_categorical_column_with_identity(self): 33 self._test_parsed_sequence_example( 34 'int_list', sfc.sequence_categorical_column_with_identity, 35 10, [3, 6], [2, 4, 6]) 36 37 def test_seq_ex_in_sequence_categorical_column_with_hash_bucket(self): 38 self._test_parsed_sequence_example( 39 'bytes_list', sfc.sequence_categorical_column_with_hash_bucket, 40 10, [3, 4], [compat.as_bytes(x) for x in 'acg']) 41 42 def test_seq_ex_in_sequence_categorical_column_with_vocabulary_list(self): 43 self._test_parsed_sequence_example( 44 'bytes_list', sfc.sequence_categorical_column_with_vocabulary_list, 45 list(string.ascii_lowercase), [3, 4], 46 [compat.as_bytes(x) for x in 'acg']) 47 48 def test_seq_ex_in_sequence_categorical_column_with_vocabulary_file(self): 49 _, fname = tempfile.mkstemp() 50 with open(fname, 'w') as f: 51 f.write(string.ascii_lowercase) 52 self._test_parsed_sequence_example( 53 'bytes_list', sfc.sequence_categorical_column_with_vocabulary_file, 54 fname, [3, 4], [compat.as_bytes(x) for x in 'acg']) 55 56 def _test_parsed_sequence_example( 57 self, col_name, col_fn, col_arg, shape, values): 58 """Helper function to check that each FeatureColumn parses correctly. 59 60 Args: 61 col_name: string, name to give to the feature column. Should match 62 the name that the column will parse out of the features dict. 63 col_fn: function used to create the feature column. For example, 64 sequence_numeric_column. 65 col_arg: second arg that the target feature column is expecting. 66 shape: the expected dense_shape of the feature after parsing into 67 a SparseTensor. 68 values: the expected values at index [0, 2, 6] of the feature 69 after parsing into a SparseTensor. 70 """ 71 example = _make_sequence_example() 72 columns = [ 73 fc.categorical_column_with_identity('int_ctx', num_buckets=100), 74 fc.numeric_column('float_ctx'), 75 col_fn(col_name, col_arg) 76 ] 77 context, seq_features = parsing_ops.parse_single_sequence_example( 78 example.SerializeToString(), 79 context_features=fc.make_parse_example_spec_v2(columns[:2]), 80 sequence_features=fc.make_parse_example_spec_v2(columns[2:])) 81 82 with self.cached_session() as sess: 83 ctx_result, seq_result = sess.run([context, seq_features]) 84 self.assertEqual(list(seq_result[col_name].dense_shape), shape) 85 self.assertEqual( 86 list(seq_result[col_name].values[[0, 2, 6]]), values) 87 self.assertEqual(list(ctx_result['int_ctx'].dense_shape), [1]) 88 self.assertEqual(ctx_result['int_ctx'].values[0], 5) 89 self.assertEqual(list(ctx_result['float_ctx'].shape), [1]) 90 self.assertAlmostEqual(ctx_result['float_ctx'][0], 123.6, places=1) 91 92 93_SEQ_EX_PROTO = """ 94context { 95 feature { 96 key: "float_ctx" 97 value { 98 float_list { 99 value: 123.6 100 } 101 } 102 } 103 feature { 104 key: "int_ctx" 105 value { 106 int64_list { 107 value: 5 108 } 109 } 110 } 111} 112feature_lists { 113 feature_list { 114 key: "bytes_list" 115 value { 116 feature { 117 bytes_list { 118 value: "a" 119 } 120 } 121 feature { 122 bytes_list { 123 value: "b" 124 value: "c" 125 } 126 } 127 feature { 128 bytes_list { 129 value: "d" 130 value: "e" 131 value: "f" 132 value: "g" 133 } 134 } 135 } 136 } 137 feature_list { 138 key: "float_list" 139 value { 140 feature { 141 float_list { 142 value: 1.0 143 } 144 } 145 feature { 146 float_list { 147 value: 3.0 148 value: 3.0 149 value: 3.0 150 } 151 } 152 feature { 153 float_list { 154 value: 5.0 155 value: 5.0 156 value: 5.0 157 value: 5.0 158 value: 5.0 159 } 160 } 161 } 162 } 163 feature_list { 164 key: "int_list" 165 value { 166 feature { 167 int64_list { 168 value: 2 169 value: 2 170 } 171 } 172 feature { 173 int64_list { 174 value: 4 175 value: 4 176 value: 4 177 value: 4 178 } 179 } 180 feature { 181 int64_list { 182 value: 6 183 value: 6 184 value: 6 185 value: 6 186 value: 6 187 value: 6 188 } 189 } 190 } 191 } 192} 193""" 194 195 196def _make_sequence_example(): 197 example = example_pb2.SequenceExample() 198 return text_format.Parse(_SEQ_EX_PROTO, example) 199 200 201if __name__ == '__main__': 202 test.main() 203