• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Extract parse_example op configuration to a proto."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.example import example_parser_configuration_pb2
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.framework import tensor_util
24
25
26def extract_example_parser_configuration(parse_example_op, sess):
27  """Returns an ExampleParserConfig proto.
28
29  Args:
30    parse_example_op: A ParseExample or ParseExampleV2 `Operation`
31    sess: A tf.compat.v1.Session needed to obtain some configuration values.
32  Returns:
33    A ExampleParserConfig proto.
34
35  Raises:
36    ValueError: If attributes are inconsistent.
37  """
38  if parse_example_op.type == "ParseExample":
39    return _extract_from_parse_example(parse_example_op, sess)
40  elif parse_example_op.type == "ParseExampleV2":
41    return _extract_from_parse_example_v2(parse_example_op, sess)
42  else:
43    raise ValueError(
44        "Found unexpected type when parsing example. Expected `ParseExample` "
45        f"object. Received type: {parse_example_op.type}")
46
47
48def _extract_from_parse_example(parse_example_op, sess):
49  """Extract ExampleParserConfig from ParseExample op."""
50  config = example_parser_configuration_pb2.ExampleParserConfiguration()
51
52  num_sparse = parse_example_op.get_attr("Nsparse")
53  num_dense = parse_example_op.get_attr("Ndense")
54  total_features = num_dense + num_sparse
55
56  sparse_types = parse_example_op.get_attr("sparse_types")
57  dense_types = parse_example_op.get_attr("Tdense")
58  dense_shapes = parse_example_op.get_attr("dense_shapes")
59
60  if len(sparse_types) != num_sparse:
61    raise ValueError("len(sparse_types) attribute does not match "
62                     "Nsparse attribute (%d vs %d)" %
63                     (len(sparse_types), num_sparse))
64
65  if len(dense_types) != num_dense:
66    raise ValueError("len(dense_types) attribute does not match "
67                     "Ndense attribute (%d vs %d)" %
68                     (len(dense_types), num_dense))
69
70  if len(dense_shapes) != num_dense:
71    raise ValueError("len(dense_shapes) attribute does not match "
72                     "Ndense attribute (%d vs %d)" %
73                     (len(dense_shapes), num_dense))
74
75  # Skip over the serialized input, and the names input.
76  fetch_list = parse_example_op.inputs[2:]
77
78  # Fetch total_features key names and num_dense default values.
79  if len(fetch_list) != (total_features + num_dense):
80    raise ValueError("len(fetch_list) does not match total features + "
81                     "num_dense (%d vs %d)" %
82                     (len(fetch_list), (total_features + num_dense)))
83
84  fetched = sess.run(fetch_list)
85
86  if len(fetched) != len(fetch_list):
87    raise ValueError("len(fetched) does not match len(fetch_list) "
88                     "(%d vs %d)" % (len(fetched), len(fetch_list)))
89
90  # Fetch indices.
91  sparse_keys_start = 0
92  dense_keys_start = sparse_keys_start + num_sparse
93  dense_def_start = dense_keys_start + num_dense
94
95  # Output tensor indices.
96  sparse_indices_start = 0
97  sparse_values_start = num_sparse
98  sparse_shapes_start = sparse_values_start + num_sparse
99  dense_values_start = sparse_shapes_start + num_sparse
100
101  # Dense features.
102  for i in range(num_dense):
103    key = fetched[dense_keys_start + i]
104    feature_config = config.feature_map[key]
105    # Convert the default value numpy array fetched from the session run
106    # into a TensorProto.
107    fixed_config = feature_config.fixed_len_feature
108
109    fixed_config.default_value.CopyFrom(
110        tensor_util.make_tensor_proto(fetched[dense_def_start + i]))
111    # Convert the shape from the attributes
112    # into a TensorShapeProto.
113    fixed_config.shape.CopyFrom(
114        tensor_shape.TensorShape(dense_shapes[i]).as_proto())
115
116    fixed_config.dtype = dense_types[i].as_datatype_enum
117    # Get the output tensor name.
118    fixed_config.values_output_tensor_name = parse_example_op.outputs[
119        dense_values_start + i].name
120
121  # Sparse features.
122  for i in range(num_sparse):
123    key = fetched[sparse_keys_start + i]
124    feature_config = config.feature_map[key]
125    var_len_feature = feature_config.var_len_feature
126    var_len_feature.dtype = sparse_types[i].as_datatype_enum
127    var_len_feature.indices_output_tensor_name = parse_example_op.outputs[
128        sparse_indices_start + i].name
129    var_len_feature.values_output_tensor_name = parse_example_op.outputs[
130        sparse_values_start + i].name
131    var_len_feature.shapes_output_tensor_name = parse_example_op.outputs[
132        sparse_shapes_start + i].name
133
134  return config
135
136
137def _extract_from_parse_example_v2(parse_example_op, sess):
138  """Extract ExampleParserConfig from ParseExampleV2 op."""
139  config = example_parser_configuration_pb2.ExampleParserConfiguration()
140
141  dense_types = parse_example_op.get_attr("Tdense")
142  num_sparse = parse_example_op.get_attr("num_sparse")
143  sparse_types = parse_example_op.get_attr("sparse_types")
144  ragged_value_types = parse_example_op.get_attr("ragged_value_types")
145  ragged_split_types = parse_example_op.get_attr("ragged_split_types")
146  dense_shapes = parse_example_op.get_attr("dense_shapes")
147
148  num_dense = len(dense_types)
149  num_ragged = len(ragged_value_types)
150  assert len(ragged_value_types) == len(ragged_split_types)
151  assert len(parse_example_op.inputs) == 5 + num_dense
152
153  # Skip over the serialized input, and the names input.
154  fetched = sess.run(parse_example_op.inputs[2:])
155  sparse_keys = fetched[0].tolist()
156  dense_keys = fetched[1].tolist()
157  ragged_keys = fetched[2].tolist()
158  dense_defaults = fetched[3:]
159  assert len(sparse_keys) == num_sparse
160  assert len(dense_keys) == num_dense
161  assert len(ragged_keys) == num_ragged
162
163  # Output tensor indices.
164  sparse_indices_start = 0
165  sparse_values_start = num_sparse
166  sparse_shapes_start = sparse_values_start + num_sparse
167  dense_values_start = sparse_shapes_start + num_sparse
168  ragged_values_start = dense_values_start + num_dense
169  ragged_row_splits_start = ragged_values_start + num_ragged
170
171  # Dense features.
172  for i in range(num_dense):
173    key = dense_keys[i]
174    feature_config = config.feature_map[key]
175    # Convert the default value numpy array fetched from the session run
176    # into a TensorProto.
177    fixed_config = feature_config.fixed_len_feature
178
179    fixed_config.default_value.CopyFrom(
180        tensor_util.make_tensor_proto(dense_defaults[i]))
181    # Convert the shape from the attributes
182    # into a TensorShapeProto.
183    fixed_config.shape.CopyFrom(
184        tensor_shape.TensorShape(dense_shapes[i]).as_proto())
185
186    fixed_config.dtype = dense_types[i].as_datatype_enum
187    # Get the output tensor name.
188    fixed_config.values_output_tensor_name = parse_example_op.outputs[
189        dense_values_start + i].name
190
191  # Sparse features.
192  for i in range(num_sparse):
193    key = sparse_keys[i]
194    feature_config = config.feature_map[key]
195    var_len_feature = feature_config.var_len_feature
196    var_len_feature.dtype = sparse_types[i].as_datatype_enum
197    var_len_feature.indices_output_tensor_name = parse_example_op.outputs[
198        sparse_indices_start + i].name
199    var_len_feature.values_output_tensor_name = parse_example_op.outputs[
200        sparse_values_start + i].name
201    var_len_feature.shapes_output_tensor_name = parse_example_op.outputs[
202        sparse_shapes_start + i].name
203
204  if num_ragged != 0:
205    del ragged_values_start  # unused
206    del ragged_row_splits_start  # unused
207    raise ValueError("Ragged features are not yet supported by "
208                     "example_parser_configuration.proto")
209
210  return config
211