• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/example/example_parser_configuration.h"
16 
17 #include <vector>
18 
19 #include "tensorflow/core/example/feature.pb.h"
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/numeric_op.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 
30 namespace tensorflow {
31 
FindNodeIndexByName(const tensorflow::GraphDef & graph,const string & node_name,int * node_idx)32 Status FindNodeIndexByName(const tensorflow::GraphDef& graph,
33                            const string& node_name, int* node_idx) {
34   for (int i = 0; i < graph.node_size(); ++i) {
35     const auto& node = graph.node(i);
36     if (node.name() == node_name) {
37       *node_idx = i;
38       return Status::OK();
39     }
40   }
41   return errors::InvalidArgument(node_name, " not found in GraphDef");
42 }
43 
ExtractExampleParserConfiguration(const tensorflow::GraphDef & graph,const string & node_name,tensorflow::Session * session,std::vector<FixedLenFeature> * fixed_len_features,std::vector<VarLenFeature> * var_len_features)44 Status ExtractExampleParserConfiguration(
45     const tensorflow::GraphDef& graph, const string& node_name,
46     tensorflow::Session* session,
47     std::vector<FixedLenFeature>* fixed_len_features,
48     std::vector<VarLenFeature>* var_len_features) {
49   int node_idx;
50   TF_RETURN_IF_ERROR(FindNodeIndexByName(graph, node_name, &node_idx));
51 
52   const auto& node = graph.node(node_idx);
53   if (node.op() != "ParseExample") {
54     return errors::InvalidArgument(node_name, " node is not a ParseExample op");
55   }
56 
57   auto& attr_map = node.attr();
58   auto num_sparse = attr_map.at("Nsparse").i();
59   auto num_dense = attr_map.at("Ndense").i();
60   fixed_len_features->resize(num_dense);
61   var_len_features->resize(num_sparse);
62 
63   auto tdense = attr_map.at("Tdense");
64   auto dense_shapes = attr_map.at("dense_shapes");
65   auto sparse_types = attr_map.at("sparse_types");
66 
67   // Consistency check attributes.
68   if (tdense.list().type_size() != num_dense) {
69     return errors::InvalidArgument("Node attr Tdense has ",
70                                    tdense.list().type_size(),
71                                    " elements != Ndense attr: ", num_dense);
72   }
73 
74   if (dense_shapes.list().shape_size() != num_dense) {
75     return errors::InvalidArgument("Node attr dense_shapes has ",
76                                    dense_shapes.list().shape_size(),
77                                    " elements != Ndense attr: ", num_dense);
78   }
79 
80   if (sparse_types.list().type_size() != num_sparse) {
81     return errors::InvalidArgument("Node attr sparse_types has ",
82                                    sparse_types.list().type_size(),
83                                    " elements != NSparse attr: ", num_sparse);
84   }
85 
86   for (int i = 0; i < tdense.list().type_size(); ++i) {
87     (*fixed_len_features)[i].dtype = tdense.list().type(i);
88     // Convert TensorShapeProto to TensorShape.
89     (*fixed_len_features)[i].shape = TensorShape(dense_shapes.list().shape(i));
90   }
91 
92   for (int i = 0; i < sparse_types.list().type_size(); ++i) {
93     (*var_len_features)[i].dtype = sparse_types.list().type(i);
94   }
95 
96   // We must fetch the configuration input tensors to the ParseExample op.
97   // Skipping index = 0, which is the serialized proto input.
98   std::vector<string> fetch_names(node.input_size() - 1);
99   for (int i = 1; i < node.input_size(); ++i) {
100     fetch_names[i - 1] = node.input(i);
101   }
102 
103   std::vector<Tensor> op_input_tensors;
104 
105   TF_RETURN_IF_ERROR(session->Run({},               // no_inputs,
106                                   fetch_names, {},  // no target_node_names,
107                                   &op_input_tensors));
108 
109   // The input tensors are laid out sequentially in a flat manner.
110   // Here are the various start offsets.
111   int sparse_keys_start = 1;
112   int dense_keys_start = sparse_keys_start + num_sparse;
113   int dense_defaults_start = dense_keys_start + num_dense;
114 
115   for (int i = 0; i < num_sparse; ++i) {
116     int input_idx = sparse_keys_start + i;
117     (*var_len_features)[i].key =
118         op_input_tensors[input_idx].scalar<tstring>()();
119   }
120 
121   for (int i = 0; i < num_dense; ++i) {
122     FixedLenFeature& config = (*fixed_len_features)[i];
123     int dense_keys_offset = dense_keys_start + i;
124     config.key = op_input_tensors[dense_keys_offset].scalar<tstring>()();
125 
126     int defaults_offset = dense_defaults_start + i;
127     config.default_value = op_input_tensors[defaults_offset];
128   }
129 
130   // The output tensors are laid out sequentially in a flat manner.
131   // Here are the various start offsets.
132   int sparse_indices_output_start = 0;
133   int sparse_values_output_start = sparse_indices_output_start + num_sparse;
134   int sparse_shapes_output_start = sparse_values_output_start + num_sparse;
135   int dense_values_output_start = sparse_shapes_output_start + num_sparse;
136 
137   string node_output_prefix = strings::StrCat(node_name, ":");
138 
139   for (int i = 0; i < num_sparse; ++i) {
140     VarLenFeature& config = (*var_len_features)[i];
141 
142     int indices_offset = sparse_indices_output_start + i;
143     config.indices_output_tensor_name =
144         strings::StrCat(node_output_prefix, indices_offset);
145 
146     int values_offset = sparse_values_output_start + i;
147     config.values_output_tensor_name =
148         strings::StrCat(node_output_prefix, values_offset);
149 
150     int shapes_offset = sparse_shapes_output_start + i;
151     config.shapes_output_tensor_name =
152         strings::StrCat(node_output_prefix, shapes_offset);
153   }
154 
155   for (int i = 0; i < num_dense; ++i) {
156     int output_idx = dense_values_output_start + i;
157     (*fixed_len_features)[i].values_output_tensor_name =
158         strings::StrCat(node_output_prefix, output_idx);
159   }
160   return Status::OK();
161 }
162 
ExampleParserConfigurationProtoToFeatureVectors(const ExampleParserConfiguration & config_proto,std::vector<FixedLenFeature> * fixed_len_features,std::vector<VarLenFeature> * var_len_features)163 Status ExampleParserConfigurationProtoToFeatureVectors(
164     const ExampleParserConfiguration& config_proto,
165     std::vector<FixedLenFeature>* fixed_len_features,
166     std::vector<VarLenFeature>* var_len_features) {
167   const auto& feature_map = config_proto.feature_map();
168   for (auto it = feature_map.cbegin(); it != feature_map.cend(); ++it) {
169     string key = it->first;
170     const auto& config = it->second;
171     if (config.has_fixed_len_feature()) {
172       const auto& fixed_config = config.fixed_len_feature();
173       FixedLenFeature f;
174       f.key = key;
175       f.dtype = fixed_config.dtype();
176       f.shape = TensorShape(fixed_config.shape());
177       Tensor default_value(f.dtype, f.shape);
178       if (!default_value.FromProto(fixed_config.default_value())) {
179         return errors::InvalidArgument(
180             "Invalid default_value in config proto ",
181             fixed_config.default_value().DebugString());
182       }
183       f.default_value = default_value;
184       f.values_output_tensor_name = fixed_config.values_output_tensor_name();
185       fixed_len_features->push_back(f);
186     } else {
187       const auto& var_len_config = config.var_len_feature();
188       VarLenFeature v;
189       v.key = key;
190       v.dtype = var_len_config.dtype();
191       v.values_output_tensor_name = var_len_config.values_output_tensor_name();
192       v.indices_output_tensor_name =
193           var_len_config.indices_output_tensor_name();
194       v.shapes_output_tensor_name = var_len_config.shapes_output_tensor_name();
195       var_len_features->push_back(v);
196     }
197   }
198   return Status::OK();
199 }
200 
201 }  // namespace tensorflow
202