1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/example/example_parser_configuration.h"
16
17 #include <vector>
18
19 #include "tensorflow/core/example/feature.pb_text.h"
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/numeric_op.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/protobuf.h"
29
30 namespace tensorflow {
31
FindNodeIndexByName(const tensorflow::GraphDef & graph,const string & node_name,int * node_idx)32 Status FindNodeIndexByName(const tensorflow::GraphDef& graph,
33 const string& node_name, int* node_idx) {
34 for (int i = 0; i < graph.node_size(); ++i) {
35 const auto& node = graph.node(i);
36 if (node.name() == node_name) {
37 *node_idx = i;
38 return Status::OK();
39 }
40 }
41 return errors::InvalidArgument(node_name, " not found in GraphDef");
42 }
43
ExtractExampleParserConfiguration(const tensorflow::GraphDef & graph,const string & node_name,tensorflow::Session * session,std::vector<FixedLenFeature> * fixed_len_features,std::vector<VarLenFeature> * var_len_features)44 Status ExtractExampleParserConfiguration(
45 const tensorflow::GraphDef& graph, const string& node_name,
46 tensorflow::Session* session,
47 std::vector<FixedLenFeature>* fixed_len_features,
48 std::vector<VarLenFeature>* var_len_features) {
49 int node_idx;
50 TF_RETURN_IF_ERROR(FindNodeIndexByName(graph, node_name, &node_idx));
51
52 const auto& node = graph.node(node_idx);
53 if (node.op() != "ParseExample") {
54 return errors::InvalidArgument(node_name, " node is not a ParseExample op");
55 }
56
57 auto& attr_map = node.attr();
58 auto num_sparse = attr_map.at("Nsparse").i();
59 auto num_dense = attr_map.at("Ndense").i();
60 fixed_len_features->resize(num_dense);
61 var_len_features->resize(num_sparse);
62
63 auto tdense = attr_map.at("Tdense");
64 auto dense_shapes = attr_map.at("dense_shapes");
65 auto sparse_types = attr_map.at("sparse_types");
66
67 // Consistency check attributes.
68 if (tdense.list().type_size() != num_dense) {
69 return errors::InvalidArgument("Node attr Tdense has ",
70 tdense.list().type_size(),
71 " elements != Ndense attr: ", num_dense);
72 }
73
74 if (dense_shapes.list().shape_size() != num_dense) {
75 return errors::InvalidArgument("Node attr dense_shapes has ",
76 dense_shapes.list().shape_size(),
77 " elements != Ndense attr: ", num_dense);
78 }
79
80 if (sparse_types.list().type_size() != num_sparse) {
81 return errors::InvalidArgument("Node attr sparse_types has ",
82 sparse_types.list().type_size(),
83 " elements != NSparse attr: ", num_sparse);
84 }
85
86 for (int i = 0; i < tdense.list().type_size(); ++i) {
87 (*fixed_len_features)[i].dtype = tdense.list().type(i);
88 // Convert TensorShapeProto to TensorShape.
89 (*fixed_len_features)[i].shape = TensorShape(dense_shapes.list().shape(i));
90 }
91
92 for (int i = 0; i < sparse_types.list().type_size(); ++i) {
93 (*var_len_features)[i].dtype = sparse_types.list().type(i);
94 }
95
96 // We must fetch the configuration input tensors to the ParseExample op.
97 // Skipping index = 0, which is the serialized proto input.
98 std::vector<string> fetch_names(node.input_size() - 1);
99 for (int i = 1; i < node.input_size(); ++i) {
100 fetch_names[i - 1] = node.input(i);
101 }
102
103 std::vector<Tensor> op_input_tensors;
104
105 TF_RETURN_IF_ERROR(session->Run({}, // no_inputs,
106 fetch_names, {}, // no target_node_names,
107 &op_input_tensors));
108
109 // The input tensors are laid out sequentially in a flat manner.
110 // Here are the various start offsets.
111 int sparse_keys_start = 1;
112 int dense_keys_start = sparse_keys_start + num_sparse;
113 int dense_defaults_start = dense_keys_start + num_dense;
114
115 for (int i = 0; i < num_sparse; ++i) {
116 int input_idx = sparse_keys_start + i;
117 (*var_len_features)[i].key = op_input_tensors[input_idx].scalar<string>()();
118 }
119
120 for (int i = 0; i < num_dense; ++i) {
121 FixedLenFeature& config = (*fixed_len_features)[i];
122 int dense_keys_offset = dense_keys_start + i;
123 config.key = op_input_tensors[dense_keys_offset].scalar<string>()();
124
125 int defaults_offset = dense_defaults_start + i;
126 config.default_value = op_input_tensors[defaults_offset];
127 }
128
129 // The output tensors are laid out sequentially in a flat manner.
130 // Here are the various start offsets.
131 int sparse_indices_output_start = 0;
132 int sparse_values_output_start = sparse_indices_output_start + num_sparse;
133 int sparse_shapes_output_start = sparse_values_output_start + num_sparse;
134 int dense_values_output_start = sparse_shapes_output_start + num_sparse;
135
136 string node_output_prefix = strings::StrCat(node_name, ":");
137
138 for (int i = 0; i < num_sparse; ++i) {
139 VarLenFeature& config = (*var_len_features)[i];
140
141 int indices_offset = sparse_indices_output_start + i;
142 config.indices_output_tensor_name =
143 strings::StrCat(node_output_prefix, indices_offset);
144
145 int values_offset = sparse_values_output_start + i;
146 config.values_output_tensor_name =
147 strings::StrCat(node_output_prefix, values_offset);
148
149 int shapes_offset = sparse_shapes_output_start + i;
150 config.shapes_output_tensor_name =
151 strings::StrCat(node_output_prefix, shapes_offset);
152 }
153
154 for (int i = 0; i < num_dense; ++i) {
155 int output_idx = dense_values_output_start + i;
156 (*fixed_len_features)[i].values_output_tensor_name =
157 strings::StrCat(node_output_prefix, output_idx);
158 }
159 return Status::OK();
160 }
161
ExampleParserConfigurationProtoToFeatureVectors(const ExampleParserConfiguration & config_proto,std::vector<FixedLenFeature> * fixed_len_features,std::vector<VarLenFeature> * var_len_features)162 Status ExampleParserConfigurationProtoToFeatureVectors(
163 const ExampleParserConfiguration& config_proto,
164 std::vector<FixedLenFeature>* fixed_len_features,
165 std::vector<VarLenFeature>* var_len_features) {
166 const auto& feature_map = config_proto.feature_map();
167 for (auto it = feature_map.cbegin(); it != feature_map.cend(); ++it) {
168 string key = it->first;
169 const auto& config = it->second;
170 if (config.has_fixed_len_feature()) {
171 const auto& fixed_config = config.fixed_len_feature();
172 FixedLenFeature f;
173 f.key = key;
174 f.dtype = fixed_config.dtype();
175 f.shape = TensorShape(fixed_config.shape());
176 Tensor default_value(f.dtype, f.shape);
177 if (!default_value.FromProto(fixed_config.default_value())) {
178 return errors::InvalidArgument(
179 "Invalid default_value in config proto ",
180 fixed_config.default_value().DebugString());
181 }
182 f.default_value = default_value;
183 f.values_output_tensor_name = fixed_config.values_output_tensor_name();
184 fixed_len_features->push_back(f);
185 } else {
186 const auto& var_len_config = config.var_len_feature();
187 VarLenFeature v;
188 v.key = key;
189 v.dtype = var_len_config.dtype();
190 v.values_output_tensor_name = var_len_config.values_output_tensor_name();
191 v.indices_output_tensor_name =
192 var_len_config.indices_output_tensor_name();
193 v.shapes_output_tensor_name = var_len_config.shapes_output_tensor_name();
194 var_len_features->push_back(v);
195 }
196 }
197 return Status::OK();
198 }
199
200 } // namespace tensorflow
201