1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Extract parse_example op configuration to a proto.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.example import example_parser_configuration_pb2 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.framework import tensor_util 24 25 26def extract_example_parser_configuration(parse_example_op, sess): 27 """Returns an ExampleParserConfig proto. 28 29 Args: 30 parse_example_op: A ParseExample or ParseExampleV2 `Operation` 31 sess: A tf.compat.v1.Session needed to obtain some configuration values. 32 Returns: 33 A ExampleParserConfig proto. 34 35 Raises: 36 ValueError: If attributes are inconsistent. 37 """ 38 if parse_example_op.type == "ParseExample": 39 return _extract_from_parse_example(parse_example_op, sess) 40 elif parse_example_op.type == "ParseExampleV2": 41 return _extract_from_parse_example_v2(parse_example_op, sess) 42 else: 43 raise ValueError("Unexpected op type: %s" % parse_example_op.type) 44 45 46def _extract_from_parse_example(parse_example_op, sess): 47 """Extract ExampleParserConfig from ParseExample op.""" 48 config = example_parser_configuration_pb2.ExampleParserConfiguration() 49 50 num_sparse = parse_example_op.get_attr("Nsparse") 51 num_dense = parse_example_op.get_attr("Ndense") 52 total_features = num_dense + num_sparse 53 54 sparse_types = parse_example_op.get_attr("sparse_types") 55 dense_types = parse_example_op.get_attr("Tdense") 56 dense_shapes = parse_example_op.get_attr("dense_shapes") 57 58 if len(sparse_types) != num_sparse: 59 raise ValueError("len(sparse_types) attribute does not match " 60 "Nsparse attribute (%d vs %d)" % 61 (len(sparse_types), num_sparse)) 62 63 if len(dense_types) != num_dense: 64 raise ValueError("len(dense_types) attribute does not match " 65 "Ndense attribute (%d vs %d)" % 66 (len(dense_types), num_dense)) 67 68 if len(dense_shapes) != num_dense: 69 raise ValueError("len(dense_shapes) attribute does not match " 70 "Ndense attribute (%d vs %d)" % 71 (len(dense_shapes), num_dense)) 72 73 # Skip over the serialized input, and the names input. 74 fetch_list = parse_example_op.inputs[2:] 75 76 # Fetch total_features key names and num_dense default values. 77 if len(fetch_list) != (total_features + num_dense): 78 raise ValueError("len(fetch_list) does not match total features + " 79 "num_dense (%d vs %d)" % 80 (len(fetch_list), (total_features + num_dense))) 81 82 fetched = sess.run(fetch_list) 83 84 if len(fetched) != len(fetch_list): 85 raise ValueError("len(fetched) does not match len(fetch_list) " 86 "(%d vs %d)" % (len(fetched), len(fetch_list))) 87 88 # Fetch indices. 89 sparse_keys_start = 0 90 dense_keys_start = sparse_keys_start + num_sparse 91 dense_def_start = dense_keys_start + num_dense 92 93 # Output tensor indices. 94 sparse_indices_start = 0 95 sparse_values_start = num_sparse 96 sparse_shapes_start = sparse_values_start + num_sparse 97 dense_values_start = sparse_shapes_start + num_sparse 98 99 # Dense features. 100 for i in range(num_dense): 101 key = fetched[dense_keys_start + i] 102 feature_config = config.feature_map[key] 103 # Convert the default value numpy array fetched from the session run 104 # into a TensorProto. 105 fixed_config = feature_config.fixed_len_feature 106 107 fixed_config.default_value.CopyFrom( 108 tensor_util.make_tensor_proto(fetched[dense_def_start + i])) 109 # Convert the shape from the attributes 110 # into a TensorShapeProto. 111 fixed_config.shape.CopyFrom( 112 tensor_shape.TensorShape(dense_shapes[i]).as_proto()) 113 114 fixed_config.dtype = dense_types[i].as_datatype_enum 115 # Get the output tensor name. 116 fixed_config.values_output_tensor_name = parse_example_op.outputs[ 117 dense_values_start + i].name 118 119 # Sparse features. 120 for i in range(num_sparse): 121 key = fetched[sparse_keys_start + i] 122 feature_config = config.feature_map[key] 123 var_len_feature = feature_config.var_len_feature 124 var_len_feature.dtype = sparse_types[i].as_datatype_enum 125 var_len_feature.indices_output_tensor_name = parse_example_op.outputs[ 126 sparse_indices_start + i].name 127 var_len_feature.values_output_tensor_name = parse_example_op.outputs[ 128 sparse_values_start + i].name 129 var_len_feature.shapes_output_tensor_name = parse_example_op.outputs[ 130 sparse_shapes_start + i].name 131 132 return config 133 134 135def _extract_from_parse_example_v2(parse_example_op, sess): 136 """Extract ExampleParserConfig from ParseExampleV2 op.""" 137 config = example_parser_configuration_pb2.ExampleParserConfiguration() 138 139 dense_types = parse_example_op.get_attr("Tdense") 140 num_sparse = parse_example_op.get_attr("num_sparse") 141 sparse_types = parse_example_op.get_attr("sparse_types") 142 ragged_value_types = parse_example_op.get_attr("ragged_value_types") 143 ragged_split_types = parse_example_op.get_attr("ragged_split_types") 144 dense_shapes = parse_example_op.get_attr("dense_shapes") 145 146 num_dense = len(dense_types) 147 num_ragged = len(ragged_value_types) 148 assert len(ragged_value_types) == len(ragged_split_types) 149 assert len(parse_example_op.inputs) == 5 + num_dense 150 151 # Skip over the serialized input, and the names input. 152 fetched = sess.run(parse_example_op.inputs[2:]) 153 sparse_keys = fetched[0].tolist() 154 dense_keys = fetched[1].tolist() 155 ragged_keys = fetched[2].tolist() 156 dense_defaults = fetched[3:] 157 assert len(sparse_keys) == num_sparse 158 assert len(dense_keys) == num_dense 159 assert len(ragged_keys) == num_ragged 160 161 # Output tensor indices. 162 sparse_indices_start = 0 163 sparse_values_start = num_sparse 164 sparse_shapes_start = sparse_values_start + num_sparse 165 dense_values_start = sparse_shapes_start + num_sparse 166 ragged_values_start = dense_values_start + num_dense 167 ragged_row_splits_start = ragged_values_start + num_ragged 168 169 # Dense features. 170 for i in range(num_dense): 171 key = dense_keys[i] 172 feature_config = config.feature_map[key] 173 # Convert the default value numpy array fetched from the session run 174 # into a TensorProto. 175 fixed_config = feature_config.fixed_len_feature 176 177 fixed_config.default_value.CopyFrom( 178 tensor_util.make_tensor_proto(dense_defaults[i])) 179 # Convert the shape from the attributes 180 # into a TensorShapeProto. 181 fixed_config.shape.CopyFrom( 182 tensor_shape.TensorShape(dense_shapes[i]).as_proto()) 183 184 fixed_config.dtype = dense_types[i].as_datatype_enum 185 # Get the output tensor name. 186 fixed_config.values_output_tensor_name = parse_example_op.outputs[ 187 dense_values_start + i].name 188 189 # Sparse features. 190 for i in range(num_sparse): 191 key = sparse_keys[i] 192 feature_config = config.feature_map[key] 193 var_len_feature = feature_config.var_len_feature 194 var_len_feature.dtype = sparse_types[i].as_datatype_enum 195 var_len_feature.indices_output_tensor_name = parse_example_op.outputs[ 196 sparse_indices_start + i].name 197 var_len_feature.values_output_tensor_name = parse_example_op.outputs[ 198 sparse_values_start + i].name 199 var_len_feature.shapes_output_tensor_name = parse_example_op.outputs[ 200 sparse_shapes_start + i].name 201 202 if num_ragged != 0: 203 del ragged_values_start # unused 204 del ragged_row_splits_start # unused 205 raise ValueError("Ragged features are not yet supported by " 206 "example_parser_configuration.proto") 207 208 return config 209