• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 // This is a surrogate for using a proto, since it doesn't seem to be possible
16 // to use protos in a dynamically-loaded/shared-linkage library, which is
17 // what is used for custom ops in tensorflow/contrib.
18 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_DATA_SPEC_H_
19 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_DATA_SPEC_H_
20 #include <unordered_map>
21 
22 #include "tensorflow/core/lib/strings/numbers.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/platform/logging.h"
25 
26 namespace tensorflow {
27 namespace tensorforest {
28 
29 using tensorflow::strings::safe_strto32;
30 
31 // DataColumn holds information about one feature of the original data.
32 // A feature could be dense or sparse, and be of any size.
33 class DataColumn {
34  public:
DataColumn()35   DataColumn() {}
36 
37   // Parses a serialized DataColumn produced from the SerializeToString()
38   // function of a python data_ops.DataColumn object.
39   // It should look like a proto ASCII format, i.e.
40   // name: <name> original_type: <type> size: <size>
ParseFromString(const string & serialized)41   void ParseFromString(const string& serialized) {
42     std::vector<string> tokens = tensorflow::str_util::Split(serialized, ' ');
43     CHECK_EQ(tokens.size(), 6);
44     name_ = tokens[1];
45     safe_strto32(tokens[3], &original_type_);
46     safe_strto32(tokens[5], &size_);
47   }
48 
name()49   const string& name() const { return name_; }
50 
original_type()51   int original_type() const { return original_type_; }
52 
size()53   int size() const { return size_; }
54 
set_name(const string & n)55   void set_name(const string& n) { name_ = n; }
56 
set_original_type(int o)57   void set_original_type(int o) { original_type_ = o; }
58 
set_size(int s)59   void set_size(int s) { size_ = s; }
60 
61  private:
62   string name_;
63   int original_type_;
64   int size_;
65 };
66 
67 // TensorForestDataSpec holds information about the original features of the
68 // data set, which were flattened to a single dense float tensor and/or a
69 // single sparse float tensor.
70 class TensorForestDataSpec {
71  public:
TensorForestDataSpec()72   TensorForestDataSpec() {}
73 
74   // Parses a serialized DataColumn produced from the SerializeToString()
75   // function of a python data_ops.TensorForestDataSpec object.
76   // It should look something like:
77   // dense_features_size: <size> dense: [{<col1>}{<col2>}] sparse: [{<col3>}]
ParseFromString(const string & serialized)78   void ParseFromString(const string& serialized) {
79     std::vector<string> tokens = tensorflow::str_util::Split(serialized, "[]");
80     std::vector<string> first_part =
81         tensorflow::str_util::Split(tokens[0], ' ');
82     safe_strto32(first_part[1], &dense_features_size_);
83     ParseColumns(tokens[1], &dense_);
84     ParseColumns(tokens[3], &sparse_);
85 
86     int total = 0;
87     for (const DataColumn& col : dense_) {
88       for (int i = 0; i < col.size(); ++i) {
89         feature_to_type_.push_back(col.original_type());
90         ++total;
91       }
92     }
93   }
94 
dense(int i)95   const DataColumn& dense(int i) const { return dense_.at(i); }
96 
sparse(int i)97   const DataColumn& sparse(int i) const { return sparse_.at(i); }
98 
mutable_sparse(int i)99   DataColumn* mutable_sparse(int i) { return &sparse_[i]; }
100 
dense_size()101   int dense_size() const { return dense_.size(); }
102 
sparse_size()103   int sparse_size() const { return sparse_.size(); }
104 
dense_features_size()105   int dense_features_size() const { return dense_features_size_; }
106 
set_dense_features_size(int s)107   void set_dense_features_size(int s) { dense_features_size_ = s; }
108 
add_dense()109   DataColumn* add_dense() {
110     dense_.push_back(DataColumn());
111     return &dense_[dense_.size() - 1];
112   }
113 
GetDenseFeatureType(int feature)114   int GetDenseFeatureType(int feature) const {
115     return feature_to_type_[feature];
116   }
117 
118  private:
ParseColumns(const string & cols,std::vector<DataColumn> * vec)119   void ParseColumns(const string& cols, std::vector<DataColumn>* vec) {
120     std::vector<string> tokens = tensorflow::str_util::Split(cols, "{}");
121     for (const string& tok : tokens) {
122       if (!tok.empty()) {
123         DataColumn col;
124         col.ParseFromString(tok);
125         vec->push_back(col);
126       }
127     }
128   }
129 
130   std::vector<DataColumn> dense_;
131   std::vector<DataColumn> sparse_;
132   int dense_features_size_;
133 
134   // This map tracks features in the total dense feature space to their
135   // original type for fast lookup.
136   std::vector<int> feature_to_type_;
137 };
138 
139 }  // namespace tensorforest
140 }  // namespace tensorflow
141 
142 #endif  // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_DATA_SPEC_H_
143