• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Extract parse_example op configuration to a proto."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.example import example_parser_configuration_pb2
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.framework import tensor_util
24
25
26def extract_example_parser_configuration(parse_example_op, sess):
27  """Returns an ExampleParserConfig proto.
28
29  Args:
30    parse_example_op: A ParseExample or ParseExampleV2 `Operation`
31    sess: A tf.compat.v1.Session needed to obtain some configuration values.
32  Returns:
33    A ExampleParserConfig proto.
34
35  Raises:
36    ValueError: If attributes are inconsistent.
37  """
38  if parse_example_op.type == "ParseExample":
39    return _extract_from_parse_example(parse_example_op, sess)
40  elif parse_example_op.type == "ParseExampleV2":
41    return _extract_from_parse_example_v2(parse_example_op, sess)
42  else:
43    raise ValueError("Unexpected op type: %s" % parse_example_op.type)
44
45
46def _extract_from_parse_example(parse_example_op, sess):
47  """Extract ExampleParserConfig from ParseExample op."""
48  config = example_parser_configuration_pb2.ExampleParserConfiguration()
49
50  num_sparse = parse_example_op.get_attr("Nsparse")
51  num_dense = parse_example_op.get_attr("Ndense")
52  total_features = num_dense + num_sparse
53
54  sparse_types = parse_example_op.get_attr("sparse_types")
55  dense_types = parse_example_op.get_attr("Tdense")
56  dense_shapes = parse_example_op.get_attr("dense_shapes")
57
58  if len(sparse_types) != num_sparse:
59    raise ValueError("len(sparse_types) attribute does not match "
60                     "Nsparse attribute (%d vs %d)" %
61                     (len(sparse_types), num_sparse))
62
63  if len(dense_types) != num_dense:
64    raise ValueError("len(dense_types) attribute does not match "
65                     "Ndense attribute (%d vs %d)" %
66                     (len(dense_types), num_dense))
67
68  if len(dense_shapes) != num_dense:
69    raise ValueError("len(dense_shapes) attribute does not match "
70                     "Ndense attribute (%d vs %d)" %
71                     (len(dense_shapes), num_dense))
72
73  # Skip over the serialized input, and the names input.
74  fetch_list = parse_example_op.inputs[2:]
75
76  # Fetch total_features key names and num_dense default values.
77  if len(fetch_list) != (total_features + num_dense):
78    raise ValueError("len(fetch_list) does not match total features + "
79                     "num_dense (%d vs %d)" %
80                     (len(fetch_list), (total_features + num_dense)))
81
82  fetched = sess.run(fetch_list)
83
84  if len(fetched) != len(fetch_list):
85    raise ValueError("len(fetched) does not match len(fetch_list) "
86                     "(%d vs %d)" % (len(fetched), len(fetch_list)))
87
88  # Fetch indices.
89  sparse_keys_start = 0
90  dense_keys_start = sparse_keys_start + num_sparse
91  dense_def_start = dense_keys_start + num_dense
92
93  # Output tensor indices.
94  sparse_indices_start = 0
95  sparse_values_start = num_sparse
96  sparse_shapes_start = sparse_values_start + num_sparse
97  dense_values_start = sparse_shapes_start + num_sparse
98
99  # Dense features.
100  for i in range(num_dense):
101    key = fetched[dense_keys_start + i]
102    feature_config = config.feature_map[key]
103    # Convert the default value numpy array fetched from the session run
104    # into a TensorProto.
105    fixed_config = feature_config.fixed_len_feature
106
107    fixed_config.default_value.CopyFrom(
108        tensor_util.make_tensor_proto(fetched[dense_def_start + i]))
109    # Convert the shape from the attributes
110    # into a TensorShapeProto.
111    fixed_config.shape.CopyFrom(
112        tensor_shape.TensorShape(dense_shapes[i]).as_proto())
113
114    fixed_config.dtype = dense_types[i].as_datatype_enum
115    # Get the output tensor name.
116    fixed_config.values_output_tensor_name = parse_example_op.outputs[
117        dense_values_start + i].name
118
119  # Sparse features.
120  for i in range(num_sparse):
121    key = fetched[sparse_keys_start + i]
122    feature_config = config.feature_map[key]
123    var_len_feature = feature_config.var_len_feature
124    var_len_feature.dtype = sparse_types[i].as_datatype_enum
125    var_len_feature.indices_output_tensor_name = parse_example_op.outputs[
126        sparse_indices_start + i].name
127    var_len_feature.values_output_tensor_name = parse_example_op.outputs[
128        sparse_values_start + i].name
129    var_len_feature.shapes_output_tensor_name = parse_example_op.outputs[
130        sparse_shapes_start + i].name
131
132  return config
133
134
135def _extract_from_parse_example_v2(parse_example_op, sess):
136  """Extract ExampleParserConfig from ParseExampleV2 op."""
137  config = example_parser_configuration_pb2.ExampleParserConfiguration()
138
139  dense_types = parse_example_op.get_attr("Tdense")
140  num_sparse = parse_example_op.get_attr("num_sparse")
141  sparse_types = parse_example_op.get_attr("sparse_types")
142  ragged_value_types = parse_example_op.get_attr("ragged_value_types")
143  ragged_split_types = parse_example_op.get_attr("ragged_split_types")
144  dense_shapes = parse_example_op.get_attr("dense_shapes")
145
146  num_dense = len(dense_types)
147  num_ragged = len(ragged_value_types)
148  assert len(ragged_value_types) == len(ragged_split_types)
149  assert len(parse_example_op.inputs) == 5 + num_dense
150
151  # Skip over the serialized input, and the names input.
152  fetched = sess.run(parse_example_op.inputs[2:])
153  sparse_keys = fetched[0].tolist()
154  dense_keys = fetched[1].tolist()
155  ragged_keys = fetched[2].tolist()
156  dense_defaults = fetched[3:]
157  assert len(sparse_keys) == num_sparse
158  assert len(dense_keys) == num_dense
159  assert len(ragged_keys) == num_ragged
160
161  # Output tensor indices.
162  sparse_indices_start = 0
163  sparse_values_start = num_sparse
164  sparse_shapes_start = sparse_values_start + num_sparse
165  dense_values_start = sparse_shapes_start + num_sparse
166  ragged_values_start = dense_values_start + num_dense
167  ragged_row_splits_start = ragged_values_start + num_ragged
168
169  # Dense features.
170  for i in range(num_dense):
171    key = dense_keys[i]
172    feature_config = config.feature_map[key]
173    # Convert the default value numpy array fetched from the session run
174    # into a TensorProto.
175    fixed_config = feature_config.fixed_len_feature
176
177    fixed_config.default_value.CopyFrom(
178        tensor_util.make_tensor_proto(dense_defaults[i]))
179    # Convert the shape from the attributes
180    # into a TensorShapeProto.
181    fixed_config.shape.CopyFrom(
182        tensor_shape.TensorShape(dense_shapes[i]).as_proto())
183
184    fixed_config.dtype = dense_types[i].as_datatype_enum
185    # Get the output tensor name.
186    fixed_config.values_output_tensor_name = parse_example_op.outputs[
187        dense_values_start + i].name
188
189  # Sparse features.
190  for i in range(num_sparse):
191    key = sparse_keys[i]
192    feature_config = config.feature_map[key]
193    var_len_feature = feature_config.var_len_feature
194    var_len_feature.dtype = sparse_types[i].as_datatype_enum
195    var_len_feature.indices_output_tensor_name = parse_example_op.outputs[
196        sparse_indices_start + i].name
197    var_len_feature.values_output_tensor_name = parse_example_op.outputs[
198        sparse_values_start + i].name
199    var_len_feature.shapes_output_tensor_name = parse_example_op.outputs[
200        sparse_shapes_start + i].name
201
202  if num_ragged != 0:
203    del ragged_values_start  # unused
204    del ragged_row_splits_start  # unused
205    raise ValueError("Ragged features are not yet supported by "
206                     "example_parser_configuration.proto")
207
208  return config
209