1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Extract parse_example op configuration to a proto.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.example import example_parser_configuration_pb2 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.framework import tensor_util 24 25 26def extract_example_parser_configuration(parse_example_op, sess): 27 """Returns an ExampleParserConfig proto. 28 29 Args: 30 parse_example_op: A ParseExample or ParseExampleV2 `Operation` 31 sess: A tf.compat.v1.Session needed to obtain some configuration values. 32 Returns: 33 A ExampleParserConfig proto. 34 35 Raises: 36 ValueError: If attributes are inconsistent. 37 """ 38 if parse_example_op.type == "ParseExample": 39 return _extract_from_parse_example(parse_example_op, sess) 40 elif parse_example_op.type == "ParseExampleV2": 41 return _extract_from_parse_example_v2(parse_example_op, sess) 42 else: 43 raise ValueError( 44 "Found unexpected type when parsing example. Expected `ParseExample` " 45 f"object. Received type: {parse_example_op.type}") 46 47 48def _extract_from_parse_example(parse_example_op, sess): 49 """Extract ExampleParserConfig from ParseExample op.""" 50 config = example_parser_configuration_pb2.ExampleParserConfiguration() 51 52 num_sparse = parse_example_op.get_attr("Nsparse") 53 num_dense = parse_example_op.get_attr("Ndense") 54 total_features = num_dense + num_sparse 55 56 sparse_types = parse_example_op.get_attr("sparse_types") 57 dense_types = parse_example_op.get_attr("Tdense") 58 dense_shapes = parse_example_op.get_attr("dense_shapes") 59 60 if len(sparse_types) != num_sparse: 61 raise ValueError("len(sparse_types) attribute does not match " 62 "Nsparse attribute (%d vs %d)" % 63 (len(sparse_types), num_sparse)) 64 65 if len(dense_types) != num_dense: 66 raise ValueError("len(dense_types) attribute does not match " 67 "Ndense attribute (%d vs %d)" % 68 (len(dense_types), num_dense)) 69 70 if len(dense_shapes) != num_dense: 71 raise ValueError("len(dense_shapes) attribute does not match " 72 "Ndense attribute (%d vs %d)" % 73 (len(dense_shapes), num_dense)) 74 75 # Skip over the serialized input, and the names input. 76 fetch_list = parse_example_op.inputs[2:] 77 78 # Fetch total_features key names and num_dense default values. 79 if len(fetch_list) != (total_features + num_dense): 80 raise ValueError("len(fetch_list) does not match total features + " 81 "num_dense (%d vs %d)" % 82 (len(fetch_list), (total_features + num_dense))) 83 84 fetched = sess.run(fetch_list) 85 86 if len(fetched) != len(fetch_list): 87 raise ValueError("len(fetched) does not match len(fetch_list) " 88 "(%d vs %d)" % (len(fetched), len(fetch_list))) 89 90 # Fetch indices. 91 sparse_keys_start = 0 92 dense_keys_start = sparse_keys_start + num_sparse 93 dense_def_start = dense_keys_start + num_dense 94 95 # Output tensor indices. 96 sparse_indices_start = 0 97 sparse_values_start = num_sparse 98 sparse_shapes_start = sparse_values_start + num_sparse 99 dense_values_start = sparse_shapes_start + num_sparse 100 101 # Dense features. 102 for i in range(num_dense): 103 key = fetched[dense_keys_start + i] 104 feature_config = config.feature_map[key] 105 # Convert the default value numpy array fetched from the session run 106 # into a TensorProto. 107 fixed_config = feature_config.fixed_len_feature 108 109 fixed_config.default_value.CopyFrom( 110 tensor_util.make_tensor_proto(fetched[dense_def_start + i])) 111 # Convert the shape from the attributes 112 # into a TensorShapeProto. 113 fixed_config.shape.CopyFrom( 114 tensor_shape.TensorShape(dense_shapes[i]).as_proto()) 115 116 fixed_config.dtype = dense_types[i].as_datatype_enum 117 # Get the output tensor name. 118 fixed_config.values_output_tensor_name = parse_example_op.outputs[ 119 dense_values_start + i].name 120 121 # Sparse features. 122 for i in range(num_sparse): 123 key = fetched[sparse_keys_start + i] 124 feature_config = config.feature_map[key] 125 var_len_feature = feature_config.var_len_feature 126 var_len_feature.dtype = sparse_types[i].as_datatype_enum 127 var_len_feature.indices_output_tensor_name = parse_example_op.outputs[ 128 sparse_indices_start + i].name 129 var_len_feature.values_output_tensor_name = parse_example_op.outputs[ 130 sparse_values_start + i].name 131 var_len_feature.shapes_output_tensor_name = parse_example_op.outputs[ 132 sparse_shapes_start + i].name 133 134 return config 135 136 137def _extract_from_parse_example_v2(parse_example_op, sess): 138 """Extract ExampleParserConfig from ParseExampleV2 op.""" 139 config = example_parser_configuration_pb2.ExampleParserConfiguration() 140 141 dense_types = parse_example_op.get_attr("Tdense") 142 num_sparse = parse_example_op.get_attr("num_sparse") 143 sparse_types = parse_example_op.get_attr("sparse_types") 144 ragged_value_types = parse_example_op.get_attr("ragged_value_types") 145 ragged_split_types = parse_example_op.get_attr("ragged_split_types") 146 dense_shapes = parse_example_op.get_attr("dense_shapes") 147 148 num_dense = len(dense_types) 149 num_ragged = len(ragged_value_types) 150 assert len(ragged_value_types) == len(ragged_split_types) 151 assert len(parse_example_op.inputs) == 5 + num_dense 152 153 # Skip over the serialized input, and the names input. 154 fetched = sess.run(parse_example_op.inputs[2:]) 155 sparse_keys = fetched[0].tolist() 156 dense_keys = fetched[1].tolist() 157 ragged_keys = fetched[2].tolist() 158 dense_defaults = fetched[3:] 159 assert len(sparse_keys) == num_sparse 160 assert len(dense_keys) == num_dense 161 assert len(ragged_keys) == num_ragged 162 163 # Output tensor indices. 164 sparse_indices_start = 0 165 sparse_values_start = num_sparse 166 sparse_shapes_start = sparse_values_start + num_sparse 167 dense_values_start = sparse_shapes_start + num_sparse 168 ragged_values_start = dense_values_start + num_dense 169 ragged_row_splits_start = ragged_values_start + num_ragged 170 171 # Dense features. 172 for i in range(num_dense): 173 key = dense_keys[i] 174 feature_config = config.feature_map[key] 175 # Convert the default value numpy array fetched from the session run 176 # into a TensorProto. 177 fixed_config = feature_config.fixed_len_feature 178 179 fixed_config.default_value.CopyFrom( 180 tensor_util.make_tensor_proto(dense_defaults[i])) 181 # Convert the shape from the attributes 182 # into a TensorShapeProto. 183 fixed_config.shape.CopyFrom( 184 tensor_shape.TensorShape(dense_shapes[i]).as_proto()) 185 186 fixed_config.dtype = dense_types[i].as_datatype_enum 187 # Get the output tensor name. 188 fixed_config.values_output_tensor_name = parse_example_op.outputs[ 189 dense_values_start + i].name 190 191 # Sparse features. 192 for i in range(num_sparse): 193 key = sparse_keys[i] 194 feature_config = config.feature_map[key] 195 var_len_feature = feature_config.var_len_feature 196 var_len_feature.dtype = sparse_types[i].as_datatype_enum 197 var_len_feature.indices_output_tensor_name = parse_example_op.outputs[ 198 sparse_indices_start + i].name 199 var_len_feature.values_output_tensor_name = parse_example_op.outputs[ 200 sparse_values_start + i].name 201 var_len_feature.shapes_output_tensor_name = parse_example_op.outputs[ 202 sparse_shapes_start + i].name 203 204 if num_ragged != 0: 205 del ragged_values_start # unused 206 del ragged_row_splits_start # unused 207 raise ValueError("Ragged features are not yet supported by " 208 "example_parser_configuration.proto") 209 210 return config 211