1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 // This is a surrogate for using a proto, since it doesn't seem to be possible 16 // to use protos in a dynamically-loaded/shared-linkage library, which is 17 // what is used for custom ops in tensorflow/contrib. 18 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_DATA_SPEC_H_ 19 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_DATA_SPEC_H_ 20 #include <unordered_map> 21 22 #include "tensorflow/core/lib/strings/numbers.h" 23 #include "tensorflow/core/lib/strings/str_util.h" 24 #include "tensorflow/core/platform/logging.h" 25 26 namespace tensorflow { 27 namespace tensorforest { 28 29 using tensorflow::strings::safe_strto32; 30 31 // DataColumn holds information about one feature of the original data. 32 // A feature could be dense or sparse, and be of any size. 33 class DataColumn { 34 public: DataColumn()35 DataColumn() {} 36 37 // Parses a serialized DataColumn produced from the SerializeToString() 38 // function of a python data_ops.DataColumn object. 39 // It should look like a proto ASCII format, i.e. 40 // name: <name> original_type: <type> size: <size> ParseFromString(const string & serialized)41 void ParseFromString(const string& serialized) { 42 std::vector<string> tokens = tensorflow::str_util::Split(serialized, ' '); 43 CHECK_EQ(tokens.size(), 6); 44 name_ = tokens[1]; 45 safe_strto32(tokens[3], &original_type_); 46 safe_strto32(tokens[5], &size_); 47 } 48 name()49 const string& name() const { return name_; } 50 original_type()51 int original_type() const { return original_type_; } 52 size()53 int size() const { return size_; } 54 set_name(const string & n)55 void set_name(const string& n) { name_ = n; } 56 set_original_type(int o)57 void set_original_type(int o) { original_type_ = o; } 58 set_size(int s)59 void set_size(int s) { size_ = s; } 60 61 private: 62 string name_; 63 int original_type_; 64 int size_; 65 }; 66 67 // TensorForestDataSpec holds information about the original features of the 68 // data set, which were flattened to a single dense float tensor and/or a 69 // single sparse float tensor. 70 class TensorForestDataSpec { 71 public: TensorForestDataSpec()72 TensorForestDataSpec() {} 73 74 // Parses a serialized DataColumn produced from the SerializeToString() 75 // function of a python data_ops.TensorForestDataSpec object. 76 // It should look something like: 77 // dense_features_size: <size> dense: [{<col1>}{<col2>}] sparse: [{<col3>}] ParseFromString(const string & serialized)78 void ParseFromString(const string& serialized) { 79 std::vector<string> tokens = tensorflow::str_util::Split(serialized, "[]"); 80 std::vector<string> first_part = 81 tensorflow::str_util::Split(tokens[0], ' '); 82 safe_strto32(first_part[1], &dense_features_size_); 83 ParseColumns(tokens[1], &dense_); 84 ParseColumns(tokens[3], &sparse_); 85 86 int total = 0; 87 for (const DataColumn& col : dense_) { 88 for (int i = 0; i < col.size(); ++i) { 89 feature_to_type_.push_back(col.original_type()); 90 ++total; 91 } 92 } 93 } 94 dense(int i)95 const DataColumn& dense(int i) const { return dense_.at(i); } 96 sparse(int i)97 const DataColumn& sparse(int i) const { return sparse_.at(i); } 98 mutable_sparse(int i)99 DataColumn* mutable_sparse(int i) { return &sparse_[i]; } 100 dense_size()101 int dense_size() const { return dense_.size(); } 102 sparse_size()103 int sparse_size() const { return sparse_.size(); } 104 dense_features_size()105 int dense_features_size() const { return dense_features_size_; } 106 set_dense_features_size(int s)107 void set_dense_features_size(int s) { dense_features_size_ = s; } 108 add_dense()109 DataColumn* add_dense() { 110 dense_.push_back(DataColumn()); 111 return &dense_[dense_.size() - 1]; 112 } 113 GetDenseFeatureType(int feature)114 int GetDenseFeatureType(int feature) const { 115 return feature_to_type_[feature]; 116 } 117 118 private: ParseColumns(const string & cols,std::vector<DataColumn> * vec)119 void ParseColumns(const string& cols, std::vector<DataColumn>* vec) { 120 std::vector<string> tokens = tensorflow::str_util::Split(cols, "{}"); 121 for (const string& tok : tokens) { 122 if (!tok.empty()) { 123 DataColumn col; 124 col.ParseFromString(tok); 125 vec->push_back(col); 126 } 127 } 128 } 129 130 std::vector<DataColumn> dense_; 131 std::vector<DataColumn> sparse_; 132 int dense_features_size_; 133 134 // This map tracks features in the total dense feature space to their 135 // original type for fast lookup. 136 std::vector<int> feature_to_type_; 137 }; 138 139 } // namespace tensorforest 140 } // namespace tensorflow 141 142 #endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_DATA_SPEC_H_ 143